nnUNet(neural network Universal Network)是一个基于深度学习的医学图像分割开源框架,旨在为医学影像分割任务提供一个通用、自动化且高性能的解决方案。该项目由医学影像与深度学习领域的专家团队开发,目的是应对不同医学影像分割任务中的常见挑战,例如模型适配性不佳、参数调优复杂以及工程实施成本高等问题。通过这一框架,即使没有深厚深度学习背景的用户也能轻松地适应各种模态和器官的分割需求。
自从nnUNet开源以来,它迅速成为了医学影像分割领域的标杆工具,广泛应用于学术研究和临床前研究。其“数据驱动的自适应配置”核心理念已成为医学图像分割工具设计的重要参考。
在诸如BraTS、MSD Challenge等国际顶级医学影像分割比赛中,nnUNet持续获得顶尖排名,成为这些比赛中的主要基准框架之一。
该框架已被超过1000篇SCI论文引用,涉及肿瘤分割、器官分割、病灶检测等多个医学影像领域,成为医学深度学习的标准工具库。
nnUNet已经成功应用于CT、MRI、PET等多种医学影像模式,支持超过20种器官或病灶的分割任务,无需大量的定制化开发。
在公开的医学影像数据集中,如LiTS、Pancreas-CT,nnUNet的分割准确率(Dice系数)通常达到0.85以上,某些任务甚至超过了0.95,显著优于传统的分割方法。
| 技术类别 | 具体技术 / 工具 | 核心作用 |
|---|---|---|
| 编程语言 | Python 3.7+ | 作为主要开发语言,确保开发效率和生态系统的完整性 |
| 深度学习框架 | PyTorch 1.6+ | 用于模型构建、训练和推理,支持动态图和分布式训练 |
| 数据处理 | NumPy、SciPy、SimpleITK | 处理医学影像(如DICOM/NIfTI格式)的读取及预处理(重采样、归一化) |
| 工程化工具 | Numba | 加速数值计算,如指标计算和数据增强 |
| 可视化工具 | Matplotlib、Seaborn | 用于分割结果的可视化和训练曲线的监控 |
| 分布式训练 | PyTorch DistributedDataParallel | 实现多GPU并行训练,提高大规模数据训练的效率 |
该工具作为核心分割模块,旨在集成到医疗AI产品中,加速产品落地(例如辅助诊断系统、影像分析平台)。
在教学场景中,该工具适用于医学深度学习、医学影像处理课程的实践教学,帮助学生快速掌握分割模型的工程实现逻辑。
plaintextnnUNet/ ├── nnunet/ │ ├── configuration/ # 配置模块:自适应配置生成、参数管理 │ ├── data_loading/ # 数据加载:影像读取、数据增强、batch生成 │ ├── evaluation/ # 评估模块:Dice系数、Hausdorff距离等指标计算 │ ├── inference/ # 推理模块:模型预测、后处理 │ ├── networks/ # 网络模块:U-Net变体、损失函数定义 │ ├── training/ # 训练模块:训练循环、优化器配置 │ └── utilities/ # 工具函数:影像处理、文件操作、日志管理 ├── examples/ # 示例代码:快速上手教程 ├── tests/ # 单元测试:模块功能验证 └── setup.py # 安装配置
plaintext┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ │ 数据准备 │────?│ 数据预处理 │────?│ 模型配置 │────?│ 模型训练 │────?│ 推理后处理 │ └───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘ │ │ │ │ │ ▼ ▼ ▼ ▼ ▼ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ ┌───────────┐ │格式转换/ │ ┌───────────┐ │自动选择 │ │k-fold交叉 │ │分块预测/ │ │目录结构化 │ │重采样/归一化/│ │网络/参数 │ │验证/模型保存│ │后处理/结果输出│ │ │ │数据增强 │ │ │ │ │ │ │ └───────────┘ └───────────┘ └───────────┘ └───────────┘ └───────────┘
以下示例代码展示了nnUNet的核心流程(简化版),涵盖了数据准备、模型定义、训练与推理的基本功能。
bash 运行# 安装依赖 pip install torch numpy scipy simpleitk numba matplotlib
python 运行import os import SimpleITK as sitk import numpy as np def prepare_nnunet_data(raw_data_dir, output_dir): """ 简化版数据准备:将DICOM格式转换为nnUNet标准NIfTI格式 """ # 创建nnUNet标准目录结构 os.makedirs(os.path.join(output_dir, "imagesTr"), exist_ok=True) os.makedirs(os.path.join(output_dir, "labelsTr"), exist_ok=True) # 遍历原始DICOM数据 for patient_id in os.listdir(raw_data_dir): patient_dir = os.path.join(raw_data_dir, patient_id) if not os.path.isdir(patient_dir): continue # 读取DICOM图像 img_reader = sitk.ImageSeriesReader() img_filenames = img_reader.GetGDCMSeriesFileNames(patient_dir) img_reader.SetFileNames(img_filenames) img = img_reader.Execute() # 读取标签(假设标签为单独的DICOM序列) label_dir = os.path.join(patient_dir, "label") label_filenames = img_reader.GetGDCMSeriesFileNames(label_dir) img_reader.SetFileNames(label_filenames) label = img_reader.Execute() # 保存为NIfTI格式(nnUNet标准命名:patient_id_0000.nii.gz,0000表示模态) sitk.WriteImage(img, os.path.join(output_dir, "imagesTr", f"{patient_id}_0000.nii.gz")) sitk.WriteImage(label, os.path.join(output_dir, "labelsTr", f"{patient_id}.nii.gz")) print("数据准备完成,输出目录:", output_dir) # 调用示例 prepare_nnunet_data(raw_data_dir="./raw_dicom", output_dir="./nnunet_data")
python 运行import torch import torch.nn as nn class SimpleUNet(nn.Module): """简化版3D U-Net,模拟nnUNet核心网络结构""" def __init__(self, in_channels=1, num_classes=2): super(SimpleUNet, self).__init__() # 编码器(下采样) self.enc1 = self.conv_block(in_channels, 64) self.enc2 = self.conv_block(64, 128) self.enc3 = self.conv_block(128, 256) # 解码器(上采样) self.dec1 = self.conv_block(256, 128) self.dec2 = self.conv_block(128, 64) self.dec3 = self.conv_block(64, num_classes) # 池化与上采样 self.pool = nn.MaxPool3d(2, 2) self.upconv = nn.ConvTranspose3d(256, 128, 2, stride=2) self.final_conv = nn.Conv3d(num_classes, num_classes, 1) def conv_block(self, in_channels, out_channels): """卷积块:Conv3d + BatchNorm + ReLU""" return nn.Sequential( nn.Conv3d(in_channels, out_channels, 3, padding=1), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True), nn.Conv3d(out_channels, out_channels, 3, padding=1), nn.BatchNorm3d(out_channels), nn.ReLU(inplace=True) ) def forward(self, x): # 编码器 x1 = self.enc1(x) x2 = self.pool(x1) x2 = self.enc2(x2) x3 = self.pool(x2) x3 = self.enc3(x3) # 解码器 x = self.upconv(x3) x = torch.cat([x, x2], dim=1) # 跳跃连接 x = self.dec1(x) x = self.upconv(x) x = torch.cat([x, x1], dim=1) # 跳跃连接 x = self.dec2(x) x = self.dec3(x) out = self.final_conv(x) return out # 模型实例化 model = SimpleUNet(in_channels=1, num_classes=2) print("模型结构:", model)
python 运行import torch.optim as optim from torch.utils.data import DataLoader, Dataset # 自定义数据集(简化版) class MedicalDataset(Dataset): def __init__(self, data_dir): self.image_dir = os.path.join(data_dir, "imagesTr") self.label_dir = os.path.join(data_dir, "labelsTr") self.patients = [f.split("_0000.nii.gz")[0] for f in os.listdir(self.image_dir)] def __len__(self): return len(self.patients) def __getitem__(self, idx): patient_id = self.patients[idx] # 读取NIfTI图像 img = sitk.ReadImage(os.path.join(self.image_dir, f"{patient_id}_0000.nii.gz")) img = sitk.GetArrayFromImage(img).astype(np.float32)[None, ...] # (C, D, H, W) # 读取标签 label = sitk.ReadImage(os.path.join(self.label_dir, f"{patient_id}.nii.gz")) label = sitk.GetArrayFromImage(label).astype(np.longlong) return torch.from_numpy(img), torch.from_numpy(label) # 训练配置 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") data_dir = "./nnunet_data" dataset = MedicalDataset(data_dir) dataloader = DataLoader(dataset, batch_size=1, shuffle=True) # 模型、损失函数、优化器 model = SimpleUNet().to(device) criterion = nn.CrossEntropyLoss() # 结合Dice损失效果更优,此处简化 optimizer = optim.AdamW(model.parameters(), lr=1e-4) # 训练循环 def train_epoch(model, dataloader, criterion, optimizer, device): model.train() total_loss = 0.0 for imgs, labels in dataloader: imgs, labels = imgs.to(device), labels.to(device) # 前向传播 outputs = model(imgs) loss = criterion(outputs, labels) # 反向传播与优化 optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() return total_loss / len(dataloader) # 执行训练 num_epochs = 10 for epoch in range(num_epochs): train_loss = train_epoch(model, dataloader, criterion, optimizer, device) print(f"Epoch {epoch+1}/{num_epochs},
扫码加好友,拉您进群



收藏
