"""Time series drift detection using forecasting models.
Detects drift by comparing actual values to forecasts from statistical models.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Any
import numpy as np
import pandas as pd
try:
import statsmodels.api as sm
from statsmodels.tsa.arima.model import ARIMA
STATSMODELS_AVAILABLE = True
except ImportError:
STATSMODELS_AVAILABLE = False
ARIMA = None # type: ignore
sm = None # type: ignore
from anomsmith.constants import (
DEFAULT_DRIFT_DETECTION_STDDEV_THRESHOLD,
NUMERICAL_EPSILON,
)
from anomsmith.objects.views import LabelView, ScoreView
from anomsmith.primitives.base import BaseDetector
if TYPE_CHECKING:
try:
from timesmith.typing import PanelLike, SeriesLike
except ImportError:
PanelLike = Any # type: ignore[misc,assignment]
SeriesLike = Any # type: ignore[misc,assignment]
logger = logging.getLogger(__name__)
[docs]
class ARIMADriftDetector(BaseDetector):
"""ARIMA-based drift detector for time series.
Uses ARIMA forecasting to detect drift. If actual values diverge significantly
from forecasts, the series is flagged as drifting.
Args:
order: ARIMA order (p, d, q). Default (1, 1, 1)
threshold_std: Number of standard deviations for drift threshold (default 2.0)
random_state: Random state for reproducibility (not used, kept for compatibility)
"""
def __init__(
self,
order: tuple[int, int, int] = (1, 1, 1),
threshold_std: float = DEFAULT_DRIFT_DETECTION_STDDEV_THRESHOLD,
random_state: int | None = None,
) -> None:
"""Initialize ARIMADriftDetector.
Args:
order: ARIMA order (p, d, q)
threshold_std: Number of standard deviations for drift threshold
random_state: Random state (not used, kept for compatibility)
"""
if not STATSMODELS_AVAILABLE:
raise ImportError(
"statsmodels is required for ARIMADriftDetector. "
"Install with: pip install statsmodels"
)
self.order = order
self.threshold_std = threshold_std
self.random_state = random_state
self.model_: ARIMA | None = None # type: ignore
self.fitted_model_: Any | None = None # type: ignore
self.residual_std_: float = 0.0
super().__init__(
order=order, threshold_std=threshold_std, random_state=random_state
)
self._fitted = False
[docs]
def fit(
self,
y: np.ndarray | pd.Series | SeriesLike,
X: np.ndarray | pd.DataFrame | PanelLike | None = None,
) -> ARIMADriftDetector:
"""Fit the ARIMA model on training data.
Args:
y: Training time series (1D)
X: Optional features (not used for ARIMA)
Returns:
Self for method chaining
"""
if isinstance(y, pd.Series):
values = y.values
else:
values = np.asarray(y)
if values.ndim > 1:
if values.shape[1] > 1:
raise ValueError(
"ARIMADriftDetector only supports univariate time series."
)
values = values.flatten()
# Fit ARIMA model
try:
self.model_ = ARIMA(values, order=self.order) # type: ignore
self.fitted_model_ = self.model_.fit() # type: ignore
# Compute residual standard deviation for threshold
residuals = self.fitted_model_.resid # type: ignore
self.residual_std_ = float(np.std(residuals))
self._fitted = True
logger.debug(
f"Fitted ARIMADriftDetector: residual_std={self.residual_std_:.4f}"
)
except Exception as e:
logger.exception("Error fitting ARIMA model: %s", e)
raise
return self
[docs]
def score(self, y: np.ndarray | pd.Series | SeriesLike) -> ScoreView:
"""Score drift using ARIMA residuals.
Args:
y: Time series to score
Returns:
ScoreView with drift scores (residual magnitudes)
"""
self._check_fitted()
if isinstance(y, pd.Series):
index = y.index
values = y.values
else:
index = pd.RangeIndex(start=0, stop=len(y))
values = np.asarray(y)
if values.ndim > 1:
if values.shape[1] > 1:
raise ValueError(
"ARIMADriftDetector only supports univariate time series."
)
values = values.flatten()
# Generate forecasts using fitted model
if self.fitted_model_ is None:
raise ValueError("Model must be fitted before scoring. Call fit() first.")
try:
# Use fitted model's predict method
forecast = self.fitted_model_.predict(
start=1, end=len(values), dynamic=False
) # type: ignore
# Compute residuals (actual - forecast)
residuals = values[1:] - forecast
# Score is absolute residual normalized by residual std
scores = np.abs(residuals) / (self.residual_std_ + NUMERICAL_EPSILON)
# Pad first value (no prediction for first point)
scores = np.concatenate([[0.0], scores])
except Exception as e:
logger.exception("Error generating ARIMA forecast: %s", e)
raise
return ScoreView(index=index, scores=scores)
[docs]
def predict(self, y: np.ndarray | pd.Series | SeriesLike) -> LabelView:
"""Predict drift labels.
Args:
y: Time series to detect drift in
Returns:
LabelView with binary labels (1 = drift, 0 = normal)
"""
score_view = self.score(y)
# Flag as drift if score exceeds threshold
labels = (score_view.scores > self.threshold_std).astype(int)
return LabelView(index=score_view.index, labels=labels)