亚马逊AWS官方博客

基于 Amazon SageMaker 利用 MONAI 处理医疗影像数据最佳实践

介绍

神经网络已被证明可有效解决复杂的计算机视觉任务,例如对象检测、图像相似性和分类。随着低成本 GPU 的发展,构建和部署神经网络的计算成本已大幅降低。然而,大多数技术旨在处理视觉媒体中常见的像素分辨率。例如,典型的分辨率大小对于 YOLOv3 是 544 和 416 像素,对于 SSD 是 300 和 512 像素,对于 VGG 是 224 像素。在由卫星或数字病理图像等千兆像素图像(10^9+ 像素)组成的数据集上训练分类器在计算上具有挑战性。这些图像不能直接输入到神经网络中,因为每个 GPU 都受到可用内存的限制。这需要特定的预处理技术,例如平铺,以便能够以较小的块处理原始图像。此外,由于这些图像的尺寸很大,整体训练时间往往很长,通常需要几天或几周的时间,而没有使用适当的缩放技术,例如分布式训练。本文介绍基于亚马逊云科技部署MONAI进行大规模医疗影像的分析。MONAI 框架是由 Project MONAI 创建的开源基金会。 MONAI 是一个免费的、社区支持的、MONAI 是一个用于医学成像领域的深度学习框架,可在原生 PyTorch 范式中开发医学成像训练工作流。 MONAI 项目还包括 MONAI Label,这是一种智能开源图像标记和学习工具,可帮助研究人员和临床医生协作,创建带注释的数据集,并在标准化的 MONAI 范式中构建 AI 模型。

部署

创建Amazon SageMaker instance:

登录管理控制台并切换至Amazon SageMaker服务,创建Notebook Instance。

打开Jupyter Lab运行以下命令安装Monai和所需组件。

pip install monai -i https://opentuna.cn/pypi/web/simple
pip install -r https://raw.githubusercontent.com/Project-MONAI/MONAI/dev/requirements-dev.txt -i https://opentuna.cn/pypi/web/simple

使用Python运行以下代码,测试是否部署成功:

import logging
import os
import sys
import tempfile
import shutil

import matplotlib.pyplot as plt
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np

import monai
from monai.apps import download_and_extract
from monai.config import print_config
from monai.data import CacheDataset, DataLoader, ImageDataset
from monai.transforms import (
AddChannel,
Compose,
RandRotate90,
Resize,
ScaleIntensity,
EnsureType,
Randomizable,
LoadImaged,
EnsureTyped,
)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
print_config()

输出结果如下:

下载IXI Dataset作为训练集,下载地址:https://brain-development.org/ixi-dataset/

directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)
class IXIDataset(Randomizable, CacheDataset):
resource = "http://biomedic.doc.ic.ac.uk/" \
+ "brain-development/downloads/IXI/IXI-T1.tar"
md5 = "34901a0593b41dd19c1a1f746eac2d58"

def __init__(
self,
root_dir,
section,
transform,
download=False,
seed=0,
val_frac=0.2,
test_frac=0.2,
cache_num=sys.maxsize,
cache_rate=1.0,
num_workers=0,
):
if not os.path.isdir(root_dir):
raise ValueError("Root directory root_dir must be a directory.")
self.section = section
self.val_frac = val_frac
self.test_frac = test_frac
self.set_random_state(seed=seed)
dataset_dir = os.path.join(root_dir, "ixi")
tarfile_name = f"{dataset_dir}.tar"
if download:
download_and_extract(
self.resource, tarfile_name, dataset_dir, self.md5)
# as a quick demo, we just use 10 images to show

self.datalist = [
{"image": os.path.join(
dataset_dir, "IXI314-IOP-0889-T1.nii.gz"), "label": 0},
{"image": os.path.join(
dataset_dir, "IXI249-Guys-1072-T1.nii.gz"), "label": 0},
{"image": os.path.join(
dataset_dir, "IXI609-HH-2600-T1.nii.gz"), "label": 0},
{"image": os.path.join(
dataset_dir, "IXI173-HH-1590-T1.nii.gz"), "label": 1},
{"image": os.path.join(
dataset_dir, "IXI020-Guys-0700-T1.nii.gz"), "label": 0},
{"image": os.path.join(
dataset_dir, "IXI342-Guys-0909-T1.nii.gz"), "label": 0},
{"image": os.path.join(
dataset_dir, "IXI134-Guys-0780-T1.nii.gz"), "label": 0},
{"image": os.path.join(
dataset_dir, "IXI577-HH-2661-T1.nii.gz"), "label": 1},
{"image": os.path.join(
dataset_dir, "IXI066-Guys-0731-T1.nii.gz"), "label": 1},
{"image": os.path.join(
dataset_dir, "IXI130-HH-1528-T1.nii.gz"), "label": 0},
]
data = self._generate_data_list()
super().__init__(
data, transform, cache_num=cache_num,
cache_rate=cache_rate, num_workers=num_workers,
)

def randomize(self, data=None):
self.rann = self.R.random()

def _generate_data_list(self):
data = []
for d in self.datalist:
self.randomize()
if self.section == "training":
if self.rann < self.val_frac + self.test_frac:
continue
elif self.section == "validation":
if self.rann >= self.val_frac:
continue
elif self.section == "test":
if self.rann < self.val_frac or \
self.rann >= self.val_frac + self.test_frac:
continue
else:
raise ValueError(
f"Unsupported section: {self.section}, "
"available options are ['training', 'validation', 'test']."
)
data.append(d)
return data
images = [
os.sep.join([root_dir, "ixi", "IXI314-IOP-0889-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI249-Guys-1072-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI609-HH-2600-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI173-HH-1590-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI020-Guys-0700-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI342-Guys-0909-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI134-Guys-0780-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI577-HH-2661-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI066-Guys-0731-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI130-HH-1528-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI607-Guys-1097-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI175-HH-1570-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI385-HH-2078-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI344-Guys-0905-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI409-Guys-0960-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI584-Guys-1129-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI253-HH-1694-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI092-HH-1436-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI574-IOP-1156-T1.nii.gz"]),
os.sep.join([root_dir, "ixi", "IXI585-Guys-1130-T1.nii.gz"]),
]

# 2 binary labels for gender classification: man and woman
labels = np.array([0, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1,
0, 1, 0, 1, 0, 1, 0], dtype=np.int64)

# Define transforms
train_transforms = Compose([ScaleIntensity(), AddChannel(), Resize(
(96, 96, 96)), RandRotate90(), EnsureType()])
val_transforms = Compose(
[ScaleIntensity(), AddChannel(), Resize((96, 96, 96)), EnsureType()])

# Define nifti dataset, data loader
check_ds = ImageDataset(image_files=images, labels=labels,
transform=train_transforms)
check_loader = DataLoader(check_ds, batch_size=2,
num_workers=2, pin_memory=torch.cuda.is_available())
im, label = monai.utils.misc.first(check_loader)
print(type(im), im.shape, label)

# create a training data loader
train_ds = ImageDataset(
image_files=images[:10], labels=labels[:10], transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True,
num_workers=2, pin_memory=torch.cuda.is_available())

# create a validation data loader
val_ds = ImageDataset(
image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=2,
pin_memory=torch.cuda.is_available())

# Create DenseNet121, CrossEntropyLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.DenseNet121(
spatial_dims=3, in_channels=1, out_channels=2).to(device)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), 1e-5)

# start a typical PyTorch training
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
writer = SummaryWriter()
max_epochs = 5
for epoch in range(max_epochs):
print("-" * 10)
print(f"epoch {epoch + 1}/{max_epochs}")
model.train()
epoch_loss = 0
step = 0
for batch_data in train_loader:
step += 1
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_function(outputs, labels)
loss.backward()
optimizer.step()
epoch_loss += loss.item()
epoch_len = len(train_ds) // train_loader.batch_size
print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
epoch_loss /= step
epoch_loss_values.append(epoch_loss)
print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

if (epoch + 1) % val_interval == 0:
model.eval()
with torch.no_grad():
num_correct = 0.0
metric_count = 0
for val_data in val_loader:
val_images, val_labels = val_data[0].to(
device), val_data[1].to(device)
val_outputs = model(val_images)
value = torch.eq(val_outputs.argmax(dim=1), val_labels)
metric_count += len(value)
num_correct += value.sum().item()
metric = num_correct / metric_count
metric_values.append(metric)
if metric > best_metric:
best_metric = metric
best_metric_epoch = epoch + 1
torch.save(model.state_dict(),
"best_metric_model_classification3d_array.pth")
print("saved new best metric model")
print(
"current epoch: {} current accuracy: {:.4f} "
"best accuracy: {:.4f} at epoch {}".format(
epoch + 1, metric, best_metric, best_metric_epoch
)
)
writer.add_scalar("val_accuracy", metric, epoch + 1)
print(
f"train completed, best_metric: {best_metric:.4f} "
f"at epoch: {best_metric_epoch}")
writer.close()

输出结果如下

Occlusion sensitivity 分析:
尝试可视化网络为何做出给定预测的一种方法是遮挡敏感度。我们遮挡了图像的一部分,看看给定预测的概率如何变化。然后我们迭代图像,随着我们移动被遮挡的部分,在这样做的过程中,我们构建了一个敏感度图,详细说明哪些区域在做出决定时最重要。
边界分析:
如果我们要测试以图像中所有体素为中心的遮挡,我们将不得不进行 torch.prod(im.shape) = 96^3 = ~1e6 预测。我们可以使用边界框仅对感兴趣区域的估计,例如在一个切片上。

为此,我们只需将边界框指定为 (minC,maxC,minD,maxD,minH,maxH,minW,maxW)。我们可以将 -1 用于任何值以使用其完整范围(0 和 im.shape-1 分别用于最小值和最大值)。

结果分析:
这个例子中的输出图像看起来相当糟糕,因为我们的网络没有经过很长时间的训练。更长时间的训练应该会提高遮挡图的质量。

# create a validation data loader
test_ds = ImageDataset(
image_files=images[-10:], labels=labels[-10:], transform=val_transforms)
test_loader = DataLoader(val_ds, batch_size=1, num_workers=2,
pin_memory=torch.cuda.is_available())
itera = iter(test_loader)


def get_next_im():
test_data = next(itera)
return test_data[0].to(device), test_data[1].unsqueeze(0).to(device)


def plot_occlusion_heatmap(im, heatmap):
plt.subplots(1, 2)
plt.subplot(1, 2, 1)
plt.imshow(np.squeeze(im.cpu()))
plt.colorbar()
plt.subplot(1, 2, 2)
plt.imshow(heatmap)
plt.colorbar()
plt.show()

# Get a random image and its corresponding label
img, label = get_next_im()

# Get the occlusion sensitivity map
occ_sens = monai.visualize.OcclusionSensitivity(
nn_module=model, mask_size=12, n_batch=10, stride=12)
# Only get a single slice to save time.
# For the other dimensions (channel, width, height), use
# -1 to use 0 and img.shape[x]-1 for min and max, respectively
depth_slice = img.shape[2] // 2
occ_sens_b_box = [-1, -1, depth_slice, depth_slice, -1, -1, -1, -1]

occ_result, _ = occ_sens(x=img, b_box=occ_sens_b_box)
occ_result = occ_result[..., label.item()]

fig, axes = plt.subplots(1, 2, figsize=(25, 15), facecolor='white')

for i, im in enumerate([img[:, :, depth_slice, ...], occ_result]):
cmap = 'gray' if i == 0 else 'jet'
ax = axes[i]
im_show = ax.imshow(np.squeeze(im[0][0].detach().cpu()), cmap=cmap)
ax.axis('off')
fig.colorbar(im_show, ax=ax)

结果如下:

总结

在这篇博文中,我们为使用 Amazon SageMaker 处理超高分辨率图像引入了一个可扩展的机器学习框架。该框架简化了对由接近千兆像素级的图像组成的数据集进行大规模分类器训练的复杂过程。有关 Amazon SageMaker 的更多信息,请参阅文档学习如何使用 Amazon SageMaker 构建、训练和部署机器学习模型。

本篇作者

余昶

现任大中华地区亚马逊医疗&生命科学行业总监。负责生命科学基因行业上云业务创新及解决方案。在加入亚马逊前,曾先后就职于华大基因,英特尔,英伟达,担任生物信息总监,基因解决方案架构师,医疗行业总监。长期从事生物信息开发及基因组学研究工作,并推动人工智能及云计算技术在医疗生命科学行业中的创新及应用,拥有丰富工业界经验和学术经历。

方康

卡内基梅隆大学计算机硕士,现任大中华地区亚马逊医疗&生命科学行业解决方案架构师。在加入亚马逊之前就职于华大基因,任职首席云架构师,负责开发和维护基因组学领域云计算平台, 为全球合作伙伴提供基因组学数据管理,转化和分析方案。同时致力于基因组学数据管理与分析, HPC(高性能计算集群)与异构计算,工作流语言,数据转化与压缩和生命科学数据的合规与安全等领域的应用。

刘光

亚马逊云科技解决方案架构师,目前负责基于亚马逊云科技云计算方案架构的咨询和设计,同时致力于亚马逊云科技云服务在政企、教育和医疗行业客户的推广。在加入亚马逊云科技之前就职于Citrix,具有多年企业虚拟化、VDI架构设计和支持经验。