|

A Coding Implementation on MONAI for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes

In this tutorial, we construct an end-to-end 3D medical picture segmentation pipeline utilizing MONAI to phase the spleen on the Medical Segmentation Decathlon Task09 dataset. We work with volumetric CT scans, apply medical imaging transformations similar to orientation alignment, voxel-spacing normalization, depth windowing, foreground cropping, and patch-based sampling, after which prepare a 3D UNet mannequin for binary organ segmentation. We additionally use blended precision coaching, DiceCE loss, sliding-window inference, Dice-based validation, and qualitative visualization to grasp how the mannequin learns and the way its predictions examine with the ground-truth masks. Also, we transfer from uncooked medical volumes to an entire prepare–validate–visualize segmentation system.

!pip set up -q "monai[nibabel,tqdm,matplotlib]==1.5.2" 2>/dev/null
import os, time, glob, tempfile, warnings
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.amp import autocast, GradScaler
from monai.apps import DecathlonDataset
from monai.knowledge import DataLoader, decollate_batch
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric
from monai.inferers import sliding_window_inference
from monai.utils import set_determinism
from monai.transforms import (
   Compose, LoadImaged, EnsureChannelFirstd, EnsureTyped, Orientationd,
   Spacingd, ScaleIntensityRanged, CropForegroundd, RandCropByPosNegLabeld,
   RandFlipd, RandRotate90d, RandShiftIntensityd, AsDiscrete,
)
warnings.filterwarnings("ignore")

We begin by putting in MONAI with the required medical imaging and visualization dependencies. We then import PyTorch, NumPy, Matplotlib, and the primary MONAI modules wanted for datasets, transforms, mannequin coaching, metrics, and inference. We additionally suppress warnings to maintain the pocket book output clear whereas we focus on the segmentation workflow.

QUICK_RUN   = True
system      = torch.system("cuda" if torch.cuda.is_available() else "cpu")
root_dir    = tempfile.mkdtemp()
roi_size    = (96, 96, 96)
num_samples = 4
batch_size  = 2
max_epochs  = 15 if QUICK_RUN else 200
val_every   = 3
train_cache = 8 if QUICK_RUN else 24
val_cache   = 2 if QUICK_RUN else 6
set_determinism(seed=0)
print(f"Device: {system} | epochs: {max_epochs} | knowledge dir: {root_dir}")
train_transforms = Compose(widespread + [
       image_key="image", image_threshold=0),
   RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=0),
   RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=1),
   RandFlipd(keys=["image", "label"], prob=0.2, spatial_axis=2),
   RandRotate90d(keys=["image", "label"], prob=0.2, max_k=3),
   RandShiftIntensityd(keys=["image"], offsets=0.10, prob=0.5),
   EnsureTyped(keys=["image", "label"]),
])
val_transforms = Compose(widespread + [EnsureTyped(keys=["image", "label"])])

We outline the primary configuration for the tutorial, together with the system, dataset listing, patch measurement, batch measurement, variety of epochs, and cache settings. We then create the preprocessing pipeline for CT volumes by loading photos, aligning orientation, resampling voxel spacing, scaling intensities, and cropping the foreground. We additionally outline the coaching and validation transforms, with the coaching pipeline together with random crops, flips, rotations, and depth shifts.

train_ds = DecathlonDataset(
   root_dir=root_dir, job="Task09_Spleen", part="coaching",
   rework=train_transforms, obtain=True, val_frac=0.2,
   cache_num=train_cache, num_workers=2, seed=0)
val_ds = DecathlonDataset(
   root_dir=root_dir, job="Task09_Spleen", part="validation",
   rework=val_transforms, obtain=False, val_frac=0.2,
   cache_num=val_cache, num_workers=2, seed=0)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                         num_workers=2, pin_memory=torch.cuda.is_available())
val_loader   = DataLoader(val_ds, batch_size=1, shuffle=False,
                         num_workers=1, pin_memory=torch.cuda.is_available())
print(f"Train volumes: {len(train_ds)} | Val volumes: {len(val_ds)}")
loss_fn   = DiceCELoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.AdamW(mannequin.parameters(), lr=1e-4, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
scaler    = GradScaler("cuda", enabled=torch.cuda.is_available())
dice_metric = DiceMetric(include_background=False, discount="imply")
post_pred   = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label  = Compose([AsDiscrete(to_onehot=2)])

We load the Medical Segmentation Decathlon Task09 Spleen dataset utilizing MONAI’s DecathlonDataset. We cut up the info into coaching and validation sections, apply the suitable transforms, and wrap each datasets with PyTorch-style knowledge loaders. We then create a 3D UNet mannequin, outline the DiceCE loss, arrange the AdamW optimizer, learning-rate scheduler, mixed-precision scaler, Dice metric, and post-processing steps.

best_dice, best_epoch = -1.0, -1
loss_hist, dice_hist, dice_epochs = [], [], []
best_path = os.path.be part of(root_dir, "best_spleen_unet.pth")
for epoch in vary(1, max_epochs + 1):
   mannequin.prepare(); epoch_loss, t0 = 0.0, time.time()
   for batch in train_loader:
       x, y = batch["image"].to(system), batch["label"].to(system)
       optimizer.zero_grad(set_to_none=True)
       with autocast("cuda", enabled=torch.cuda.is_available()):
           logits = mannequin(x)
           loss = loss_fn(logits, y)
       scaler.scale(loss).backward()
       scaler.step(optimizer); scaler.replace()
       epoch_loss += loss.merchandise()
   scheduler.step()
   epoch_loss /= len(train_loader); loss_hist.append(epoch_loss)
   print(f"[{epoch:3d}/{max_epochs}] loss={epoch_loss:.4f} "
         f"lr={scheduler.get_last_lr()[0]:.2e} ({time.time()-t0:.0f}s)")
   if epoch % val_every == 0 or epoch == max_epochs:
       mannequin.eval(); dice_metric.reset()
       with torch.no_grad():
           for vb in val_loader:
               vx, vy = vb["image"].to(system), vb["label"].to(system)
               with autocast("cuda", enabled=torch.cuda.is_available()):
                   vout = sliding_window_inference(vx, roi_size, 4, mannequin,
                                                   overlap=0.5)
               vout = [post_pred(o)  for o in decollate_batch(vout)]
               vlab = [post_label(o) for o in decollate_batch(vy)]
               dice_metric(y_pred=vout, y=vlab)
       d = dice_metric.combination().merchandise()
       dice_hist.append(d); dice_epochs.append(epoch)
       if d > best_dice:
           best_dice, best_epoch = d, epoch
           torch.save(mannequin.state_dict(), best_path)
       print(f"        >> val Dice={d:.4f} (greatest={best_dice:.4f} @ {best_epoch})")
print(f"nDone. Best imply Dice {best_dice:.4f} at epoch {best_epoch}.")

We run the total coaching loop, the place every epoch trains the 3D UNet on cropped volumetric patches from the spleen dataset. We use computerized blended precision to cut back reminiscence utilization and pace up coaching when a GPU is obtainable. We additionally validate the mannequin at common intervals utilizing sliding-window inference, monitor the Dice rating, and save the best-performing checkpoint.

fig, ax = plt.subplots(1, 2, figsize=(12, 4))
ax[0].plot(vary(1, len(loss_hist)+1), loss_hist, "-o", ms=3)
ax[0].set(title="Training loss", xlabel="epoch", ylabel="DiceCE loss")
ax[1].plot(dice_epochs, dice_hist, "-o", coloration="seagreen", ms=4)
ax[1].set(title="Validation imply Dice", xlabel="epoch", ylabel="Dice"); ax[1].set_ylim(0, 1)
plt.tight_layout(); plt.present()
mannequin.load_state_dict(torch.load(best_path, map_location=system)); mannequin.eval()
with torch.no_grad():
   pattern = subsequent(iter(val_loader))
   img = pattern["image"].to(system)
   with autocast("cuda", enabled=torch.cuda.is_available()):
       pred = sliding_window_inference(img, roi_size, 4, mannequin, overlap=0.5)
   pred = torch.argmax(pred, dim=1).cpu().numpy()[0]
   img_np, lab_np = img.cpu().numpy()[0, 0], pattern["label"].numpy()[0, 0]
   z = int(np.argmax(lab_np.sum(axis=(0, 1))))
fig, ax = plt.subplots(1, 3, figsize=(13, 5))
ax[0].imshow(img_np[:, :, z], cmap="grey");                  ax[0].set_title("CT slice")
ax[1].imshow(lab_np[:, :, z], cmap="viridis");               ax[1].set_title("Ground reality")
ax[2].imshow(pred[:, :, z], cmap="viridis");                 ax[2].set_title("Prediction")
for a in ax: a.axis("off")
plt.tight_layout(); plt.present()

We first plot the coaching loss and validation Dice rating to see how the mannequin improves over time. We then reload the best-saved mannequin checkpoint and run inference on a single validation quantity utilizing sliding-window prediction. We visualize the CT slice, ground-truth masks, and predicted segmentation aspect by aspect to examine the mannequin’s qualitative efficiency.

In conclusion, we completed a sensible MONAI-based workflow for 3D spleen segmentation utilizing a 3D UNet mannequin. We ready the Medical Segmentation Decathlon dataset, reworked and augmented the CT volumes, skilled the mannequin with DiceCE loss, validated it utilizing sliding-window inference, and tracked each loss and Dice rating over time. We additionally inspected the ultimate prediction visually by evaluating the CT slice, ground-truth label, and mannequin output aspect by aspect. Now, we’ve a transparent understanding of how MONAI helps medical segmentation duties from knowledge loading and preprocessing to mannequin coaching, analysis, checkpointing, and qualitative evaluation.


Check out the Full Codes with Notebook. Also, be happy to comply with us on Twitter and don’t overlook to hitch our 150k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

Need to companion with us for selling your GitHub Repo OR Hugging Face Page OR Product Release OR Webinar and so on.? Connect with us

The publish A Coding Implementation on MONAI for End-to-End 3D Spleen Segmentation Using UNet on Medical CT Volumes appeared first on MarkTechPost.

Similar Posts