"""Survival analysis workflows for predictive maintenance.
Integrates survival models with anomsmith's health state and decision
policy framework for comprehensive predictive maintenance.
"""
import logging
from typing import TYPE_CHECKING
import numpy as np
import pandas as pd
from anomsmith.constants import (
DEFAULT_RUL_HEALTHY_THRESHOLD,
DEFAULT_RUL_WARNING_THRESHOLD,
DEFAULT_SURVIVAL_PROBABILITY_AT_MEDIAN_TTF,
)
from anomsmith.objects.health_state import HealthStateView
from anomsmith.primitives.survival.cox import CoxSurvivalModel
from anomsmith.workflows.eval.survival_metrics import (
evaluate_survival_model,
)
if TYPE_CHECKING:
try:
from timesmith.typing import SeriesLike
except ImportError:
SeriesLike = None
logger = logging.getLogger(__name__)
[docs]
def predict_rul_from_survival(
model: CoxSurvivalModel,
X: np.ndarray | pd.DataFrame,
threshold: float = DEFAULT_SURVIVAL_PROBABILITY_AT_MEDIAN_TTF,
index: pd.Index | None = None,
) -> pd.Series:
"""Predict Remaining Useful Life (RUL) from survival model.
Uses median survival time (where survival probability = threshold)
as predicted RUL.
Args:
model: Fitted survival model
X: Feature matrix (n_samples, n_features)
threshold: Survival probability threshold for median (default 0.5)
index: Optional row index for the returned Series (defaults to ``X.index``
for DataFrame inputs, else a :class:`pandas.RangeIndex`)
Returns:
Series of predicted RUL values
Examples:
>>> rul_predictions = predict_rul_from_survival(survival_model, X_test)
>>> health_states = predict_health_states_from_survival(
... survival_model, X_test, healthy_threshold=30, warning_threshold=10
... )
"""
logger.info(f"Predicting RUL from survival model (threshold={threshold})")
rul_array = model.predict_time_to_failure(X, threshold=threshold)
if index is None:
if isinstance(X, pd.DataFrame):
index = X.index
else:
index = pd.RangeIndex(start=0, stop=len(rul_array))
return pd.Series(rul_array, index=index, name="rul")
[docs]
def predict_health_states_from_survival(
model: CoxSurvivalModel,
X: np.ndarray | pd.DataFrame,
healthy_threshold: float = DEFAULT_RUL_HEALTHY_THRESHOLD,
warning_threshold: float = DEFAULT_RUL_WARNING_THRESHOLD,
threshold: float = DEFAULT_SURVIVAL_PROBABILITY_AT_MEDIAN_TTF,
) -> HealthStateView:
"""Predict health states from survival model.
Converts survival model predictions to health states by:
1. Predicting RUL from survival model
2. Discretizing RUL into health states
Args:
model: Fitted survival model
X: Feature matrix (n_samples, n_features)
healthy_threshold: RUL threshold for Healthy state (default 30)
warning_threshold: RUL threshold for Warning state (default 10)
threshold: Survival probability threshold for median RUL (default 0.5)
Returns:
HealthStateView with predicted health states
Examples:
>>> health_states = predict_health_states_from_survival(
... model, X_test, healthy_threshold=30, warning_threshold=10
... )
"""
# Predict RUL (index aligned to X when X is a DataFrame)
rul_series = predict_rul_from_survival(model, X, threshold=threshold)
# Discretize to health states using the primitive directly
from anomsmith.primitives.health_state.discretize import (
discretize_rul_to_health_states,
)
health_states = discretize_rul_to_health_states(
rul_series,
healthy_threshold=healthy_threshold,
warning_threshold=warning_threshold,
)
logger.info(
f"Predicted health states from survival model: {len(health_states.states)} samples"
)
return health_states
[docs]
def fit_survival_model_for_maintenance(
X: np.ndarray | pd.DataFrame,
durations: np.ndarray | pd.Series,
events: np.ndarray | pd.Series | None = None,
model_type: str = "logistic_hazard",
**model_kwargs,
) -> CoxSurvivalModel:
"""Fit a survival model for predictive maintenance.
Convenience function that fits a survival model with sensible defaults
for predictive maintenance use cases.
Args:
X: Feature matrix (n_samples, n_features) - sensor readings
durations: Time-to-failure values (n_samples,)
events: Event indicators (1 = failure, 0 = censored), optional
model_type: Model type - 'cox' (lifelines), 'logistic_hazard', or 'deepsurv'
**model_kwargs: Additional model parameters
Returns:
Fitted survival model
Examples:
>>> model = fit_survival_model_for_maintenance(
... X_train, durations_train, events_train,
... model_type="logistic_hazard", n_bins=50
... )
"""
logger.info(f"Fitting {model_type} survival model for predictive maintenance")
if model_type == "cox" or model_type == "lifelines":
try:
from anomsmith.primitives.survival.lifelines_cox import LifelinesCoxModel
model = LifelinesCoxModel(**model_kwargs)
except ImportError:
raise ImportError(
"lifelines is required for Cox model. Install with: pip install lifelines"
)
elif model_type == "logistic_hazard":
try:
from anomsmith.primitives.survival.neural import LogisticHazardModel
model = LogisticHazardModel(**model_kwargs)
except ImportError:
raise ImportError(
"pycox is required for LogisticHazard. Install with: pip install pycox"
)
elif model_type == "deepsurv":
try:
from anomsmith.primitives.survival.neural import DeepSurvModel
model = DeepSurvModel(**model_kwargs)
except ImportError:
raise ImportError(
"pycox is required for DeepSurv. Install with: pip install pycox"
)
else:
raise ValueError(
f"Unknown model_type: {model_type}. "
"Choose from: 'cox', 'logistic_hazard', 'deepsurv'"
)
model.fit(X, durations, events)
logger.info(f"Fitted {model_type} model: {len(X)} samples, {X.shape[1]} features")
return model
[docs]
def compare_survival_models(
models: dict[str, CoxSurvivalModel],
X_test: np.ndarray | pd.DataFrame,
durations_test: np.ndarray | pd.Series,
events_test: np.ndarray | pd.Series | None = None,
) -> pd.DataFrame:
"""Compare multiple survival models.
Evaluates multiple survival models and returns comparison metrics.
Args:
models: Dictionary mapping model names to fitted CoxSurvivalModel instances
X_test: Test feature matrix
durations_test: Test time-to-event values
events_test: Test event indicators, optional
Returns:
DataFrame with comparison metrics (C-index, MAE, etc.) for each model
Examples:
>>> models = {
... "CoxPH": cox_model,
... "LogisticHazard": lhaz_model,
... "DeepSurv": deepsurv_model
... }
>>> comparison = compare_survival_models(models, X_test, durations_test, events_test)
>>> print(comparison)
"""
logger.info(f"Comparing {len(models)} survival models")
results = []
for model_name, model in models.items():
try:
# Predict survival function
surv_df = model.predict_survival_function(X_test)
risk_scores = model.predict_risk_score(X_test)
# Evaluate
metrics = evaluate_survival_model(
surv_df, durations_test, events_test, risk_scores=risk_scores
)
metrics["model"] = model_name
results.append(metrics)
logger.info(f"{model_name}: C-index = {metrics['c_index']:.3f}")
except Exception as e:
logger.error(f"Error evaluating {model_name}: {e}")
results.append({"model": model_name, "c_index": np.nan, "error": str(e)})
return pd.DataFrame(results)