Spaces:
Sleeping
Sleeping
| """ | |
| MONAI WholeBody CT Segmentation - Hugging Face Space | |
| Segments 104 anatomical structures from CT scans using MONAI's SegResNet model | |
| """ | |
| import os | |
| import tempfile | |
| import numpy as np | |
| import gradio as gr | |
| import torch | |
| import nibabel as nib | |
| import matplotlib.pyplot as plt | |
| from matplotlib.colors import ListedColormap | |
| from huggingface_hub import hf_hub_download | |
| from monai.networks.nets import SegResNet | |
| from monai.transforms import ( | |
| Compose, | |
| LoadImage, | |
| EnsureChannelFirst, | |
| Orientation, | |
| Spacing, | |
| ScaleIntensityRange, | |
| CropForeground, | |
| Activations, | |
| AsDiscrete, | |
| ) | |
| from monai.inferers import sliding_window_inference | |
| from labels import LABEL_NAMES, get_color_map, get_label_name, get_organ_categories | |
| import trimesh | |
| from skimage import measure | |
| # Constants | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_REPO = "MONAI/wholeBody_ct_segmentation" | |
| SPATIAL_SIZE = (96, 96, 96) | |
| PIXDIM = (3.0, 3.0, 3.0) # Low-res model spacing | |
| # Global model variable | |
| model = None | |
| def load_model(): | |
| """Download and load the MONAI SegResNet model""" | |
| global model | |
| if model is not None: | |
| return model | |
| print("Downloading model weights...") | |
| try: | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="models/model_lowres.pt", | |
| ) | |
| except Exception as e: | |
| print(f"Failed to download from HF, trying alternative: {e}") | |
| # Fallback: try to download from MONAI model zoo | |
| model_path = hf_hub_download( | |
| repo_id=MODEL_REPO, | |
| filename="models/model.pt", | |
| ) | |
| print(f"Loading model from {model_path}...") | |
| # Initialize SegResNet with 105 output channels (background + 104 classes) | |
| model = SegResNet( | |
| blocks_down=[1, 2, 2, 4], | |
| blocks_up=[1, 1, 1], | |
| init_filters=32, | |
| in_channels=1, | |
| out_channels=105, | |
| dropout_prob=0.2, | |
| ) | |
| # Load weights | |
| checkpoint = torch.load(model_path, map_location=DEVICE) | |
| if isinstance(checkpoint, dict) and "state_dict" in checkpoint: | |
| model.load_state_dict(checkpoint["state_dict"]) | |
| else: | |
| model.load_state_dict(checkpoint) | |
| model.to(DEVICE) | |
| model.eval() | |
| print(f"Model loaded successfully on {DEVICE}") | |
| return model | |
| def get_preprocessing_transforms(): | |
| """Get MONAI preprocessing transforms""" | |
| return Compose([ | |
| LoadImage(image_only=True), | |
| EnsureChannelFirst(), | |
| Orientation(axcodes="RAS"), | |
| Spacing(pixdim=PIXDIM, mode="bilinear"), | |
| ScaleIntensityRange( | |
| a_min=-1024, a_max=1024, | |
| b_min=0.0, b_max=1.0, | |
| clip=True | |
| ), | |
| ]) | |
| def get_postprocessing_transforms(): | |
| """Get MONAI postprocessing transforms""" | |
| return Compose([ | |
| Activations(softmax=True), | |
| AsDiscrete(argmax=True), | |
| ]) | |
| def run_inference(image_path: str, progress=gr.Progress()): | |
| """Run segmentation inference on a CT image""" | |
| progress(0.1, desc="Loading model...") | |
| model = load_model() | |
| progress(0.2, desc="Preprocessing image...") | |
| preprocess = get_preprocessing_transforms() | |
| postprocess = get_postprocessing_transforms() | |
| # Load and preprocess | |
| image_nib = nib.load(image_path) | |
| original_data = image_nib.get_fdata() # Keep original data for visualization | |
| image = preprocess(image_path) | |
| image = image.unsqueeze(0).to(DEVICE) # Add batch dimension | |
| progress(0.4, desc="Running segmentation (this may take a few minutes)...") | |
| with torch.no_grad(): | |
| # Use sliding window inference for large volumes | |
| outputs = sliding_window_inference( | |
| image, | |
| roi_size=SPATIAL_SIZE, | |
| sw_batch_size=4, | |
| predictor=model, | |
| overlap=0.5, | |
| ) | |
| progress(0.8, desc="Post-processing...") | |
| # Post-processing | |
| seg_data = postprocess(outputs).squeeze().cpu().numpy().astype(np.uint8) | |
| progress(1.0, desc="Complete!") | |
| return original_data, seg_data | |
| def generate_3d_mesh(seg_data, step_size=2): | |
| """Generate a 3D mesh from segmentation data using Marching Cubes""" | |
| if seg_data is None or np.max(seg_data) == 0: | |
| return None | |
| try: | |
| # Create a boolean mask of all structures (excluding background 0) | |
| # Using a step_size > 1 reduces resolution but speeds up generation significantly | |
| # This is crucial for CPU performance on Hugging Face Spaces | |
| mask = seg_data > 0 | |
| # Marching cubes to get vertices and faces | |
| # level=0.5 because boolean mask is 0 or 1 | |
| verts, faces, normals, values = measure.marching_cubes(mask, level=0.5, step_size=step_size) | |
| # Create trimesh object | |
| mesh = trimesh.Trimesh(vertices=verts, faces=faces, vertex_normals=normals) | |
| # Export to a temporary GLB file (efficient binary format) | |
| temp_file = tempfile.NamedTemporaryFile(suffix=".glb", delete=False) | |
| mesh.export(temp_file.name) | |
| temp_file.close() | |
| return temp_file.name | |
| except Exception as e: | |
| print(f"Error generating 3D mesh: {e}") | |
| return None | |
| def create_slice_visualization(ct_data, seg_data, axis, slice_idx, alpha=0.5, show_overlay=True): | |
| """Create a visualization of a CT slice with segmentation overlay""" | |
| # Get the slice based on axis | |
| if axis == "Axial": | |
| slice_idx = max(0, min(slice_idx, ct_data.shape[2] - 1)) | |
| ct_slice = ct_data[:, :, slice_idx] | |
| seg_slice = seg_data[:, :, slice_idx] if seg_data is not None else None | |
| elif axis == "Coronal": | |
| slice_idx = max(0, min(slice_idx, ct_data.shape[1] - 1)) | |
| ct_slice = ct_data[:, slice_idx, :] | |
| seg_slice = seg_data[:, slice_idx, :] if seg_data is not None else None | |
| else: # Sagittal | |
| slice_idx = max(0, min(slice_idx, ct_data.shape[0] - 1)) | |
| ct_slice = ct_data[slice_idx, :, :] | |
| seg_slice = seg_data[slice_idx, :, :] if seg_data is not None else None | |
| # Create figure | |
| fig, ax = plt.subplots(1, 1, figsize=(8, 8)) | |
| # Normalize CT for display | |
| ct_normalized = np.clip(ct_slice, -1024, 1024) | |
| ct_normalized = (ct_normalized - ct_normalized.min()) / (ct_normalized.max() - ct_normalized.min() + 1e-8) | |
| # Display CT | |
| ax.imshow(ct_normalized.T, cmap='gray', origin='lower') | |
| # Overlay segmentation | |
| if show_overlay and seg_slice is not None and np.any(seg_slice > 0): | |
| colors = get_color_map() / 255.0 | |
| colors[0] = [0, 0, 0, 0] # Make background transparent | |
| # Create RGBA overlay | |
| seg_rgba = colors[seg_slice.astype(int)] | |
| seg_rgba = np.concatenate([seg_rgba, np.ones((*seg_slice.shape, 1)) * alpha], axis=-1) | |
| seg_rgba[seg_slice == 0, 3] = 0 # Transparent background | |
| ax.imshow(seg_rgba.transpose(1, 0, 2), origin='lower') | |
| ax.axis('off') | |
| ax.set_title(f"{axis} View - Slice {slice_idx}") | |
| plt.tight_layout() | |
| return fig | |
| def get_detected_structures(seg_data): | |
| """Get list of detected anatomical structures""" | |
| unique_labels = np.unique(seg_data) | |
| unique_labels = unique_labels[unique_labels > 0] # Exclude background | |
| structures = [] | |
| for label in unique_labels: | |
| name = get_label_name(label) | |
| count = np.sum(seg_data == label) | |
| structures.append(f"• {name} (Label {label})") | |
| return "\n".join(structures) if structures else "No structures detected" | |
| # Global state for current visualization | |
| current_ct_data = None | |
| current_seg_data = None | |
| def process_upload(file_path, progress=gr.Progress()): | |
| """Process uploaded CT file and run segmentation""" | |
| global current_ct_data, current_seg_data | |
| if file_path is None: | |
| return None, "Please upload a NIfTI file", gr.update(maximum=1), gr.update(maximum=1), gr.update(maximum=1) | |
| try: | |
| ct_data, seg_data = run_inference(file_path, progress) | |
| current_ct_data = ct_data | |
| current_seg_data = seg_data | |
| # Get initial visualization | |
| mid_axial = ct_data.shape[2] // 2 | |
| mid_coronal = ct_data.shape[1] // 2 | |
| mid_sagittal = ct_data.shape[0] // 2 | |
| fig = create_slice_visualization(ct_data, seg_data, "Axial", mid_axial) | |
| structures = get_detected_structures(seg_data) | |
| # Generate 3D mesh (this might take a few seconds) | |
| mesh_path = generate_3d_mesh(seg_data) | |
| return ( | |
| fig, | |
| structures, | |
| mesh_path, | |
| gr.update(maximum=ct_data.shape[2] - 1, value=mid_axial), | |
| gr.update(maximum=ct_data.shape[1] - 1, value=mid_coronal), | |
| gr.update(maximum=ct_data.shape[0] - 1, value=mid_sagittal), | |
| ) | |
| except Exception as e: | |
| return None, f"Error processing file: {str(e)}", None, gr.update(), gr.update(), gr.update() | |
| def update_visualization(axis, slice_idx, alpha, show_overlay): | |
| """Update the visualization based on slider changes""" | |
| global current_ct_data, current_seg_data | |
| if current_ct_data is None: | |
| return None | |
| fig = create_slice_visualization( | |
| current_ct_data, | |
| current_seg_data, | |
| axis, | |
| int(slice_idx), | |
| alpha, | |
| show_overlay | |
| ) | |
| return fig | |
| def load_example(example_name): | |
| """Load a bundled example CT scan""" | |
| example_dir = os.path.join(os.path.dirname(__file__), "examples") | |
| example_path = os.path.join(example_dir, example_name) | |
| if os.path.exists(example_path): | |
| return example_path | |
| return None | |
| # Create Gradio interface | |
| with gr.Blocks( | |
| title="MONAI WholeBody CT Segmentation", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .gradio-container {max-width: 1200px !important} | |
| .output-image {min-height: 500px} | |
| """ | |
| ) as demo: | |
| gr.Markdown(""" | |
| # 🏥 MONAI WholeBody CT Segmentation | |
| **Automatic segmentation of 104 anatomical structures from CT scans** | |
| This application uses MONAI's pre-trained SegResNet model trained on the TotalSegmentator dataset. | |
| Upload a CT scan in NIfTI format (.nii or .nii.gz) to get started. | |
| > ⚡ **Note**: Processing may take 1-5 minutes depending on the CT volume size. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Input section | |
| gr.Markdown("### 📤 Upload CT Scan") | |
| file_input = gr.File( | |
| label="Upload NIfTI file (.nii, .nii.gz)", | |
| file_types=[".nii", ".nii.gz", ".gz"], | |
| type="filepath" | |
| ) | |
| # Example files | |
| gr.Markdown("### 📁 Example Files") | |
| # Dynamically list all .nii.gz files in examples folder | |
| example_files = [[os.path.join("examples", f)] for f in os.listdir("examples") if f.endswith(".nii.gz")] | |
| example_gallery = gr.Examples( | |
| examples=example_files, | |
| inputs=[file_input], | |
| label="Click to load example" | |
| ) | |
| process_btn = gr.Button("🔬 Run Segmentation", variant="primary", size="lg") | |
| # Visualization controls | |
| gr.Markdown("### 🎛️ Visualization Controls") | |
| view_axis = gr.Radio( | |
| choices=["Axial", "Coronal", "Sagittal"], | |
| value="Axial", | |
| label="View Axis" | |
| ) | |
| with gr.Row(): | |
| axial_slider = gr.Slider(0, 100, value=50, step=1, label="Axial Slice") | |
| coronal_slider = gr.Slider(0, 100, value=50, step=1, label="Coronal Slice") | |
| sagittal_slider = gr.Slider(0, 100, value=50, step=1, label="Sagittal Slice") | |
| alpha_slider = gr.Slider(0, 1, value=0.5, step=0.1, label="Overlay Opacity") | |
| show_overlay = gr.Checkbox(value=True, label="Show Segmentation Overlay") | |
| with gr.Column(scale=2): | |
| # Output section | |
| gr.Markdown("### 🖼️ Segmentation Result") | |
| output_image = gr.Plot(label="CT with Segmentation Overlay") | |
| gr.Markdown("### 📋 Detected Structures") | |
| structures_output = gr.Textbox( | |
| label="Anatomical Structures Found", | |
| lines=10, | |
| max_lines=20 | |
| ) | |
| # 3D Model Output | |
| gr.Markdown("### 🧊 3D View") | |
| model_3d_output = gr.Model3D( | |
| label="3D Segmentation Mesh", | |
| clear_color=[0.0, 0.0, 0.0, 0.0], | |
| camera_position=(90, 90, 3) | |
| ) | |
| # Model info section | |
| with gr.Accordion("ℹ️ Model Information", open=False): | |
| gr.Markdown(""" | |
| ### About the Model | |
| This model is based on **SegResNet** architecture from MONAI, trained on the **TotalSegmentator** dataset. | |
| **Capabilities:** | |
| - Segments 104 distinct anatomical structures | |
| - Works on whole-body CT scans | |
| - Uses 3.0mm isotropic spacing (low-resolution model for faster inference) | |
| **Segmented Structures include:** | |
| - **Major Organs**: Liver, Spleen, Kidneys, Pancreas, Gallbladder, Stomach, Bladder | |
| - **Cardiovascular**: Heart chambers, Aorta, Vena Cava, Portal Vein | |
| - **Respiratory**: Lung lobes, Trachea | |
| - **Skeletal**: Vertebrae (C1-L5), Ribs, Hip bones, Femur, Humerus, Scapula | |
| - **Muscles**: Gluteal muscles, Iliopsoas | |
| - And many more... | |
| **References:** | |
| - [MONAI Model Zoo](https://monai.io/model-zoo.html) | |
| - [TotalSegmentator Paper](https://pubs.rsna.org/doi/10.1148/ryai.230024) | |
| """) | |
| # Event handlers | |
| process_btn.click( | |
| fn=process_upload, | |
| inputs=[file_input], | |
| outputs=[output_image, structures_output, model_3d_output, axial_slider, coronal_slider, sagittal_slider] | |
| ) | |
| # Update visualization when controls change | |
| for control in [view_axis, alpha_slider, show_overlay]: | |
| control.change( | |
| fn=lambda axis, alpha, overlay, ax_s, cor_s, sag_s: update_visualization( | |
| axis, | |
| ax_s if axis == "Axial" else (cor_s if axis == "Coronal" else sag_s), | |
| alpha, | |
| overlay | |
| ), | |
| inputs=[view_axis, alpha_slider, show_overlay, axial_slider, coronal_slider, sagittal_slider], | |
| outputs=[output_image] | |
| ) | |
| # Update when sliders change | |
| axial_slider.change( | |
| fn=lambda s, alpha, overlay: update_visualization("Axial", s, alpha, overlay), | |
| inputs=[axial_slider, alpha_slider, show_overlay], | |
| outputs=[output_image] | |
| ) | |
| coronal_slider.change( | |
| fn=lambda s, alpha, overlay: update_visualization("Coronal", s, alpha, overlay), | |
| inputs=[coronal_slider, alpha_slider, show_overlay], | |
| outputs=[output_image] | |
| ) | |
| sagittal_slider.change( | |
| fn=lambda s, alpha, overlay: update_visualization("Sagittal", s, alpha, overlay), | |
| inputs=[sagittal_slider, alpha_slider, show_overlay], | |
| outputs=[output_image] | |
| ) | |
| if __name__ == "__main__": | |
| # Ensure examples directory exists | |
| os.makedirs("examples", exist_ok=True) | |
| # Launch the app | |
| demo.launch() | |