Source code for anomsmith.workflows.eval.survival_metrics

"""Evaluation metrics for survival analysis models.

Provides metrics for assessing survival model performance, including
concordance index (C-index) for ranking ability.
"""

import logging
from typing import TYPE_CHECKING

import numpy as np
import pandas as pd

from anomsmith.constants import DEFAULT_SURVIVAL_PROBABILITY_AT_MEDIAN_TTF

try:
    from lifelines.utils import concordance_index as lifelines_c_index

    LIFELINES_AVAILABLE = True
except ImportError:
    LIFELINES_AVAILABLE = False
    lifelines_c_index = None  # type: ignore

try:
    from pycox.evaluation import EvalSurv

    PYCOX_AVAILABLE = True
except ImportError:
    PYCOX_AVAILABLE = False
    EvalSurv = None  # type: ignore

if TYPE_CHECKING:
    try:
        from timesmith.typing import SeriesLike
    except ImportError:
        SeriesLike = None

logger = logging.getLogger(__name__)


def _median_survival_time_from_step_function(col: pd.Series):
    """First time t where S(t) <= p (default median survival read-off)."""
    p = DEFAULT_SURVIVAL_PROBABILITY_AT_MEDIAN_TTF
    return col.index[col <= p][0] if (col <= p).any() else col.index[-1]


[docs] def compute_concordance_index( durations: np.ndarray | pd.Series, risk_scores: np.ndarray | pd.Series, events: np.ndarray | pd.Series | None = None, ) -> float: """Compute concordance index (C-index) for survival model evaluation. C-index measures how well a model ranks survival times. A score of 0.5 implies random ordering; 1.0 implies perfect prediction. Uses lifelines if available, otherwise computes manually. Args: durations: Actual time-to-event values (n_samples,) risk_scores: Predicted risk scores (n_samples,) - higher = higher risk events: Event indicators (1 = event occurred, 0 = censored), optional Returns: C-index between 0.0 and 1.0 Examples: >>> c_index = compute_concordance_index(true_durations, risk_scores, events) >>> print(f"C-index: {c_index:.3f}") """ durations_array = np.asarray(durations) risk_array = np.asarray(risk_scores) if events is None: events_array = np.ones(len(durations_array)) # All events observed else: events_array = np.asarray(events) # Use lifelines if available (more robust) if LIFELINES_AVAILABLE: c_index = lifelines_c_index(durations_array, -risk_array, events_array) # type: ignore return float(c_index) # Manual computation (simplified version) # C-index = proportion of concordant pairs n_samples = len(durations_array) if n_samples < 2: return 0.5 concordant = 0 comparable = 0 for i in range(n_samples): if events_array[i] == 0: # Skip censored observations continue for j in range(i + 1, n_samples): # Two observations are comparable if: # - Both have events, or # - One has event and the other's duration > event duration if events_array[j] == 1: # Both have events - always comparable comparable += 1 if ( durations_array[i] < durations_array[j] and risk_array[i] > risk_array[j] ) or ( durations_array[i] > durations_array[j] and risk_array[i] < risk_array[j] ): concordant += 1 elif durations_array[j] > durations_array[i]: # j is censored but after i's event - comparable comparable += 1 if risk_array[i] > risk_array[j]: concordant += 1 if comparable == 0: return 0.5 c_index = concordant / comparable return float(c_index)
[docs] def evaluate_survival_model( surv_df: pd.DataFrame, durations: np.ndarray | pd.Series, events: np.ndarray | pd.Series | None = None, risk_scores: np.ndarray | pd.Series | None = None, ) -> dict[str, float]: """Evaluate survival model performance. Computes comprehensive metrics for survival model evaluation. Args: surv_df: Survival function DataFrame (rows = time points, cols = samples) durations: Actual time-to-event values (n_samples,) events: Event indicators (1 = event occurred, 0 = censored), optional risk_scores: Optional risk scores for C-index (if None, computed from survival) Returns: Dictionary with evaluation metrics: - c_index: Concordance index - mean_absolute_error: Mean absolute error in predicted vs actual durations - median_survival_error: Error in median survival predictions Examples: >>> surv_df = model.predict_survival_function(X_test) >>> metrics = evaluate_survival_model(surv_df, durations_test, events_test) >>> print(f"C-index: {metrics['c_index']:.3f}") """ durations_array = np.asarray(durations) events_array = np.asarray(events) if events is not None else np.ones(len(durations)) # Compute risk scores from survival if not provided if risk_scores is None: # Use negative median survival time as risk score median_survival = surv_df.apply(_median_survival_time_from_step_function) risk_scores_array = -median_survival.values else: risk_scores_array = np.asarray(risk_scores) # Compute C-index c_index = compute_concordance_index( durations_array, risk_scores_array, events_array ) # Compute MAE (using median survival as prediction) median_survival = surv_df.apply(_median_survival_time_from_step_function) mae = np.mean(np.abs(median_survival.values - durations_array)) # Use pycox evaluation if available (more comprehensive) if PYCOX_AVAILABLE: try: eval_surv = EvalSurv( # type: ignore surv_df, durations_array, events_array, censor_surv="km", # Kaplan-Meier for censoring ) c_index_td = eval_surv.concordance_td() # Time-dependent C-index return { "c_index": float(c_index), "c_index_td": float(c_index_td), # Time-dependent version "mean_absolute_error": float(mae), "median_survival_error": float(mae), } except Exception as e: logger.warning(f"Error using pycox evaluation: {e}, using manual metrics") return { "c_index": float(c_index), "mean_absolute_error": float(mae), "median_survival_error": float(mae), }