How to get the probability density function from CoxPHSurvivalAnalysis in scikit-survival?

211 views Asked by At

I am using sksurv.linear_model.CoxPHSurvivalAnalysis to fit a cox ph regression and I would like to recover the density function f(t). The sksurv class has methods to predict the survival function and cumulative distribution function S(t) = 1-F(t) and the cumulative hazard function $H(t)$ but it doesn't seem to produce the density function.

My use case has no censoring, so ere is an example:

import pandas as pd
import numpy as np
from sksurv.linear_model import CoxPHSurvivalAnalysis

data = np.random.randint(5,30,size=10)
X_train = pd.DataFrame(data, columns=['covariate'])

y_train = np.array(np.random.randint(0,100,size=10)/100,dtype=[('status',bool),('target',float)])

estimator = CoxPHSurvivalAnalysis()
estimator.fit(X_train,y_train)

X_test = pd.DataFrame({'covariate':[12,2]})
chf = estimator.predict_cumulative_hazard_function(X_test)
cdf = estimator.predict_survival_function(X_test)

fig, ax = plt.subplots(1,2)
for fn_h, fn_c in zip(chf, cdf):
    ax[0].step(fn_h.x,fn_h(fn_h.x),where='post')
    ax[1].step(fn_c.x,fn_c(fn_c.x),where='post')

ax[0].set_title('Cumulative Hazard Functions')
ax[1].set_title('Survival Functions')
plt.show()


enter image description here How can I also access and plot the density function?

1

There are 1 answers

4
user3046211 On BEST ANSWER

The probability density function (PDF) can be obtained from the cumulative distribution function (CDF) as :

f(t) = dF(t)/dt

Now, in Survival Analysis (SA) the PDF (f(t)) can be expressed in terms of Survival Function S(t) and the hazard function h(t) which is given by:

f(t) = h(t) x S(t)

where S(t) = 1 - F(t) and h(t) = -dS(t)/dt x S(t) = dH(t)/dt

So, the PDF f(t) can be expressed as : f(t) = dH(t)/dt x S(t)

Now, to compute the hazard function f(t) we need derivative of Cumulative Hazard Function (CHF) H(t). Since the CHF are all discrete data points, we need InterpolatedUnivariateSpline from the scipy library to differentiate it. It creates a smooth spline interpolation of the CHF, which can then be differentiated to obtain h(t). Here's a slight modification of the code that was pasted:

# Import the necessary libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sksurv.linear_model import CoxPHSurvivalAnalysis
from scipy.interpolate import InterpolatedUnivariateSpline

# Define a function to compute the probability density function (pdf) 
# from the cumulative hazard function (chf) and survival function (sf).
def compute_pdf_from_chf_and_sf(chf, sf):
    # The hazard function is the derivative of the cumulative hazard function.
    # We use InterpolatedUnivariateSpline for spline interpolation to create a smooth 
    # function approximation of the CHF. This provides us with a smooth curve that 
    # passes through each data point, allowing us to differentiate the function and obtain 
    # the hazard function.
    chf_spline = InterpolatedUnivariateSpline(chf.x, chf(chf.x))
    hazard_function = chf_spline.derivative()(chf.x)
    
    # The pdf can be computed using the formula: pdf(t) = hazard(t) * survival(t)
    pdf = hazard_function * sf(chf.x)
    return chf.x, pdf

# Generate random data for demonstration purposes
# Here, we create a random dataset with one covariate and survival times.

np.random.seed(42)  # Setting a fixed seed.
data = np.random.randint(5, 30, size=10)
X_train = pd.DataFrame(data, columns=['covariate'])
y_train = np.array(np.random.randint(0, 100, size=10)/100, dtype=[('status', bool), ('target', float)])

# Initialize and fit the Cox Proportional Hazards model
estimator = CoxPHSurvivalAnalysis()
estimator.fit(X_train, y_train)

# Predict for new data points
X_test = pd.DataFrame({'covariate': [12, 2]})
cumulative_hazard_functions = estimator.predict_cumulative_hazard_function(X_test)
survival_functions = estimator.predict_survival_function(X_test)

# Plot the Cumulative Hazard, Survival, and PDF side by side in a single row
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
for chf, sf in zip(cumulative_hazard_functions, survival_functions):
    # Compute the pdf using our defined function
    times, pdf_values = compute_pdf_from_chf_and_sf(chf, sf)
    
    # Plotting the cumulative hazard function
    axes[0].step(chf.x, chf(chf.x), where='post')
    
    # Plotting the survival function
    axes[1].step(sf.x, sf(sf.x), where='post')
    
    # Plotting the probability density function
    axes[2].step(times, pdf_values, where='post')

# Setting titles for each subplot
axes[0].set_title('Cumulative Hazard Functions')
axes[1].set_title('Survival Functions')
axes[2].set_title('Probability Density Functions')

# Display the plots
plt.tight_layout()
plt.show()


which results in

PDF

References : Machine Learning for Survival Analysis: A Survey