| | """ |
| | Self-contained Hugging Face wrapper for Sybil lung cancer risk prediction model. |
| | This version works directly from HF without requiring external Sybil package. |
| | """ |
| |
|
| | import os |
| | import json |
| | import sys |
| | import torch |
| | import numpy as np |
| | from typing import List, Dict, Optional |
| | from dataclasses import dataclass |
| | from transformers.modeling_outputs import BaseModelOutput |
| | from safetensors.torch import load_file |
| |
|
| | |
| | current_dir = os.path.dirname(os.path.abspath(__file__)) |
| | if current_dir not in sys.path: |
| | sys.path.insert(0, current_dir) |
| |
|
| | try: |
| | from .configuration_sybil import SybilConfig |
| | from .modeling_sybil import SybilForRiskPrediction |
| | from .image_processing_sybil import SybilImageProcessor |
| | except ImportError: |
| | from configuration_sybil import SybilConfig |
| | from modeling_sybil import SybilForRiskPrediction |
| | from image_processing_sybil import SybilImageProcessor |
| |
|
| |
|
| | @dataclass |
| | class SybilOutput(BaseModelOutput): |
| | """ |
| | Output class for Sybil model predictions. |
| | |
| | Args: |
| | risk_scores: Risk scores for each year (1-6 years by default) |
| | attentions: Optional attention maps if requested |
| | """ |
| | risk_scores: torch.FloatTensor = None |
| | attentions: Optional[Dict] = None |
| |
|
| |
|
| | class SybilHFWrapper: |
| | """ |
| | Hugging Face wrapper for Sybil ensemble model. |
| | Provides a simple interface for lung cancer risk prediction from CT scans. |
| | """ |
| |
|
| | def __init__(self, config: SybilConfig = None, model_dir: str = None): |
| | """ |
| | Initialize the Sybil model ensemble. |
| | |
| | Args: |
| | config: Model configuration (will use default if not provided) |
| | model_dir: Directory containing model files (defaults to file location) |
| | """ |
| | self.config = config if config is not None else SybilConfig() |
| | self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | if model_dir is not None: |
| | self.model_dir = model_dir |
| | else: |
| | |
| | self.model_dir = os.path.dirname(os.path.abspath(__file__)) |
| |
|
| | |
| | self.image_processor = SybilImageProcessor() |
| |
|
| | |
| | self.calibrator = self._load_calibrator() |
| |
|
| | |
| | self.models = self._load_ensemble_models() |
| |
|
| | def _load_calibrator(self) -> Dict: |
| | """Load ensemble calibrator data""" |
| | calibrator_path = os.path.join(self.model_dir, "checkpoints", "sybil_ensemble_simple_calibrator.json") |
| |
|
| | if os.path.exists(calibrator_path): |
| | with open(calibrator_path, 'r') as f: |
| | return json.load(f) |
| | else: |
| | |
| | calibrator_path = os.path.join(self.model_dir, "calibrator_data.json") |
| | if os.path.exists(calibrator_path): |
| | with open(calibrator_path, 'r') as f: |
| | return json.load(f) |
| | return {} |
| |
|
| | def _load_ensemble_models(self) -> List[torch.nn.Module]: |
| | """ |
| | Load all models in the ensemble from original checkpoints. |
| | |
| | Note: We load from .ckpt files instead of safetensors because the safetensors |
| | were created with the wrong CumulativeProbabilityLayer architecture. |
| | """ |
| | import glob as glob_module |
| | models = [] |
| |
|
| | |
| | checkpoints_dir = os.path.join(self.model_dir, "checkpoints") |
| | checkpoint_files = sorted(glob_module.glob(os.path.join(checkpoints_dir, "*.ckpt"))) |
| |
|
| | print(f"Found {len(checkpoint_files)} checkpoint files") |
| |
|
| | |
| | for checkpoint_path in checkpoint_files: |
| | try: |
| | model = SybilForRiskPrediction(self.config) |
| | checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) |
| |
|
| | |
| | if 'state_dict' in checkpoint: |
| | state_dict = checkpoint['state_dict'] |
| | else: |
| | state_dict = checkpoint |
| |
|
| | |
| | cleaned_state_dict = {} |
| | for k, v in state_dict.items(): |
| | if k.startswith('model.'): |
| | cleaned_state_dict[k[6:]] = v |
| | else: |
| | cleaned_state_dict[k] = v |
| |
|
| | |
| | model.load_state_dict(cleaned_state_dict, strict=False) |
| | model.to(self.device) |
| | model.eval() |
| | models.append(model) |
| | print(f" Loaded model from {os.path.basename(checkpoint_path)}") |
| | except Exception as e: |
| | print(f" Warning: Could not load {os.path.basename(checkpoint_path)}: {e}") |
| | continue |
| |
|
| | if not models: |
| | raise ValueError("No models could be loaded from the ensemble. Please ensure model files are present.") |
| |
|
| | print(f"Loaded {len(models)} models in ensemble") |
| | return models |
| |
|
| | def _apply_calibration(self, scores: np.ndarray) -> np.ndarray: |
| | """ |
| | Apply complete isotonic regression calibration matching the original Sybil implementation. |
| | |
| | This method applies the same calibration as the original SimpleClassifierGroup.predict_proba: |
| | 1. For each year, apply each calibrator in the ensemble |
| | 2. Each calibrator applies: linear transform -> clip -> isotonic regression (np.interp) |
| | 3. Average predictions from all calibrators |
| | |
| | Args: |
| | scores: Raw risk scores from the model (shape: [batch_size, num_years]) |
| | |
| | Returns: |
| | Calibrated risk scores (shape: [batch_size, num_years]) |
| | """ |
| | if not self.calibrator: |
| | return scores |
| |
|
| | calibrated_scores = [] |
| |
|
| | for year in range(scores.shape[1]): |
| | year_key = f"Year{year + 1}" |
| |
|
| | if year_key not in self.calibrator: |
| | |
| | calibrated_scores.append(scores[:, year]) |
| | continue |
| |
|
| | cal_list = self.calibrator[year_key] |
| |
|
| | if not isinstance(cal_list, list) or len(cal_list) == 0: |
| | |
| | calibrated_scores.append(scores[:, year]) |
| | continue |
| |
|
| | |
| | year_predictions = [] |
| |
|
| | for cal_data in cal_list: |
| | if not isinstance(cal_data, dict): |
| | continue |
| |
|
| | |
| | if "coef" not in cal_data or "intercept" not in cal_data: |
| | continue |
| |
|
| | coef = np.array(cal_data["coef"]) |
| | intercept = np.array(cal_data["intercept"]) |
| |
|
| | |
| | if "x0" not in cal_data or "y0" not in cal_data: |
| | continue |
| |
|
| | x0 = np.array(cal_data["x0"]) |
| | y0 = np.array(cal_data["y0"]) |
| |
|
| | |
| | x_min = cal_data.get("x_min", -np.inf) |
| | x_max = cal_data.get("x_max", np.inf) |
| |
|
| | |
| | |
| | probs = scores[:, year].reshape(-1, 1) |
| | T = probs @ coef + intercept |
| | T = T.flatten() |
| |
|
| | |
| | T = np.clip(T, x_min, x_max) |
| |
|
| | |
| | |
| | calibrated = np.interp(T, x0, y0) |
| |
|
| | year_predictions.append(calibrated) |
| |
|
| | if len(year_predictions) == 0: |
| | |
| | calibrated_scores.append(scores[:, year]) |
| | else: |
| | |
| | calibrated_scores.append(np.mean(year_predictions, axis=0)) |
| |
|
| | return np.stack(calibrated_scores, axis=1) |
| |
|
| | def preprocess_dicom(self, dicom_paths: List[str]) -> torch.Tensor: |
| | """ |
| | Preprocess DICOM files for model input. |
| | |
| | Args: |
| | dicom_paths: List of paths to DICOM files |
| | |
| | Returns: |
| | Preprocessed tensor ready for model input |
| | """ |
| | |
| | result = self.image_processor(dicom_paths, file_type="dicom", return_tensors="pt") |
| | pixel_values = result["pixel_values"] |
| |
|
| | |
| | if pixel_values.ndim == 4: |
| | pixel_values = pixel_values.unsqueeze(0) |
| |
|
| | return pixel_values.to(self.device) |
| |
|
| | def predict(self, dicom_paths: List[str], return_attentions: bool = False) -> SybilOutput: |
| | """ |
| | Run prediction on a CT scan series. |
| | |
| | Args: |
| | dicom_paths: List of paths to DICOM files for a single CT series |
| | return_attentions: Whether to return attention maps |
| | |
| | Returns: |
| | SybilOutput with risk scores and optional attention maps |
| | """ |
| | |
| | pixel_values = self.preprocess_dicom(dicom_paths) |
| |
|
| | |
| | all_predictions = [] |
| | all_attentions = [] |
| |
|
| | with torch.no_grad(): |
| | for model in self.models: |
| | output = model( |
| | pixel_values=pixel_values, |
| | return_attentions=return_attentions |
| | ) |
| |
|
| | |
| | if hasattr(output, 'risk_scores'): |
| | predictions = output.risk_scores |
| | else: |
| | predictions = output[0] if isinstance(output, tuple) else output |
| |
|
| | all_predictions.append(predictions.cpu().numpy()) |
| |
|
| | if return_attentions and hasattr(output, 'image_attention'): |
| | all_attentions.append(output.image_attention) |
| |
|
| | |
| | ensemble_pred = np.mean(all_predictions, axis=0) |
| |
|
| | |
| | calibrated_pred = self._apply_calibration(ensemble_pred) |
| |
|
| | |
| | risk_scores = torch.from_numpy(calibrated_pred).float() |
| |
|
| | |
| | attentions = None |
| | if return_attentions and all_attentions: |
| | attentions = {"image_attention": torch.stack(all_attentions).mean(dim=0)} |
| |
|
| | return SybilOutput(risk_scores=risk_scores, attentions=attentions) |
| |
|
| | def __call__(self, dicom_paths: List[str] = None, dicom_series: List[List[str]] = None, **kwargs) -> SybilOutput: |
| | """ |
| | Convenience method for prediction. |
| | |
| | Args: |
| | dicom_paths: List of DICOM file paths for a single series |
| | dicom_series: List of lists of DICOM paths for batch processing |
| | **kwargs: Additional arguments passed to predict() |
| | |
| | Returns: |
| | SybilOutput with predictions |
| | """ |
| | if dicom_series is not None: |
| | |
| | all_outputs = [] |
| | for paths in dicom_series: |
| | output = self.predict(paths, **kwargs) |
| | all_outputs.append(output.risk_scores) |
| |
|
| | risk_scores = torch.stack(all_outputs) |
| | return SybilOutput(risk_scores=risk_scores) |
| | elif dicom_paths is not None: |
| | return self.predict(dicom_paths, **kwargs) |
| | else: |
| | raise ValueError("Either dicom_paths or dicom_series must be provided") |
| |
|
| | @classmethod |
| | def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): |
| | """ |
| | Load model from Hugging Face hub or local path. |
| | |
| | Args: |
| | pretrained_model_name_or_path: HF model ID or local path |
| | **kwargs: Additional configuration arguments |
| | |
| | Returns: |
| | SybilHFWrapper instance |
| | """ |
| | |
| | config = kwargs.pop("config", None) |
| | if config is None: |
| | try: |
| | config = SybilConfig.from_pretrained(pretrained_model_name_or_path) |
| | except: |
| | config = SybilConfig() |
| |
|
| | return cls(config=config) |