UNet分割脊柱-600学习网
600学习网终身会员188,所有资源无秘无压缩-购买会员
随着我们每天收集更多的数据,人工智能(AI)将越来越多地应用于医学领域。人工智能在医学领域的一个关键应用是诊断。人工智能在医学诊断中有助于决策.管理.自动化等。
脊柱是肌肉骨骼系统的重要组成部分,支撑着身体及其器官结构,在我们的活动姓和负荷转移中发挥着重要作用。它还可以保护脊髓免受撞击造成的损伤和机械冲击。
在自动脊柱处理流水线中,脊柱标记和分割是两个基本任务。
可靠和准确的脊柱图像处理有望为脊柱和骨骼健康诊断.手术计划和基于人群的分析的临床决策支持系统提供帮助。设计用于脊柱处理的自动化算法具有挑战姓,主要是因为解剖和采集协议之间存在巨大差异,以及公开数据的严重短缺。
在这个博客中,我将只关注给定CT扫描数据集中的脊柱分割。标记每个椎骨并进行进一步诊断的任务不包括在这个博客中,可以作为这个任务的延续。
脊柱或脊柱分割是脊柱形态学和病理学所有自动定量应用中的关键步骤。
随着深度学习的出现,大而多样的数据成为计算机断层扫描(CT)等任务的主要热门资源。然而,没有大规模的公共数据集。
VerSe是一个大型.多探测器.多部位CT脊柱数据集,由355名患者的374次扫描组成。2019年和2020年的数据集可用。在这个博客中,我将两个数据集合并为一个,以从更多数据中受益。
这些数据是根据CC BY-SA 4.0许可证提供的,因此它们是完全开源的。
NIfTI(神经成像信息技术倡议)是神经成像的文件格式。NIfTI文件在神经科学甚至神经放射学研究的成像信息学中非常常用。每个NIfTI文件最多包含7个维度的元数据,并支持多种数据类型。
前三个维度用于定义三个空间维度x.y和z,而第四个维度用于确定时间点t。其余维度(从第五维度到第七维度)用于其他目的。然而,第五维仍然可以有一些预定义的用途,例如存储体素特定的分布参数或保存基于矢量的数据。
ITK-SNAP是用于3D医学图像中的结构分割的软件应用程序。它是可以安装在不同平台上的开源软件。我可以使用它在3D视图中可视化NifTi文件,并在原始图像上加载和覆盖3D遮罩。我强烈建议将其用于此任务。
计算机断层摄影(CT)是一种x射线成像程序,其中x射线以快速旋转速度指向患者。机器收集的信号将存储在计算机中,以生成身体的横截面图像,也称为”切片”。
这些切片称为断层摄影,包含比传统x射线更详细的信息。一系列切片可以通过数字”叠加”形成患者的3D图像,从而更容易识别和定位基本结构以及可能的肿瘤或异常。
步骤如下。首先,下载2019年和2020年的数据集。
然后将这两个数据集合并到它们的培训.验证和测试文件夹中。下一步是读取CT扫描图像并将CT扫描图像的每个切片转换为一系列PNG原始图像和掩模。后来,我在Github仓库中使用了UNet模型,并训练了一个UNet模型。
数据理解:在开始数据处理和培训之前,我想加载几个NIfTI文件,以便更熟悉它们的3D数据结构,可视化它们,并从图像中提取元数据。
下载VerSe数据集后,我打开了一个*。不。gz*文件。通过读取文件并查看CT扫描图像的特定切片,我可以运行Numpy translate函数以三种不同的视图查看切片:轴向.矢状和冠状。
在熟悉原始图像并能够从原始3D图像中提取切片后,是时候查看同一切片的掩码文件了。
如图所示
培训:首先,我定义了UNet类,然后定义了PyTorch数据集类,包括读取和预处理图像。预处理任务包括加载PNG文件,将其大小调整为一个大小(本例中为250×250),将其全部转换为NumPy数组,然后将其转换为PyTorch张量。
通过调用数据集类(VerSeDataset),我们可以在我定义的批中准备数据。为了确保原始图像和掩码图像之间的映射是正确的,我调用next(iter(valid_dataloader))来获取批处理中的下一项并将其可视化。
后来,模型被定义为model=UNet(n_channels=1,n_classes=1)。通道数为1,因为存在灰度图像而不是RGB。如果图像为RGB,则可以将n_channels更改为3。类的数量为1,因为只有一个类可以确定像素是否为脊椎的一部分。如果你的问题是多类分割,你可以将类的数量设置为你有多少类。
后来,模型被训练。对于每个批次,首先计算损失值,然后通过反向传播更新参数。随后,再次检查所有批次,只计算验证数据集的损失,并存储损失值。接下来,我们目视观察列车和验证的损失值,并跟踪模型的性能。
保存模型后,可以捕获其中一个图像并将其传输到训练模型,并且可以接收预测的掩模图像。通过并排绘制原始.真实掩模和预测掩模的三幅图像,可以直观地评估结果。
从上图中可以看出,该模型在矢状和轴向视图中都表现良好,因为预测的掩模面积与实际掩模面积非常相似。
完整代码:
作者:Mazi Boustani
日期:2021 12月24日
目的:使用PyTorch训练UNet模型,以便使用VerSe数据集分割脊柱
将numpy导入为np
将熊猫导入为pd
导入操作系统
从os导入listdir
从os.path导入拆分文本
导入全局变量
进口关闭
随机导入
从pathlib导入路径
从PIL导入图像
从tqdm导入tqdm
将matplotlib.pyplot导入为plt
%matplotlib内联
尝试:
将nibabel导入为nib
除了:
引发ImportError(“安装NIBABEL”)
进口焊炬
将torch.nn导入为nn
从火炬导入张量
导入火炬.nn.功能为F
从火炬导入optim
将torchvision.transforms导入为T
从torch.utils.data导入DataLoader,随机_拆分
从torch.utils.data导入数据集
#为列车和验证数据设置文件夹路径
data_文件夹_路径=”/用户/mazi/项目/其他/CT/数据”
列车_数据=数据_文件夹_路径+”/版本_19_20_培训/”
验证_数据=数据_文件夹_路径+”/版本_19_20_验证/”
数据理解
#获取一个要加载的图像
列车_数据_原始_图像=列车_数据+”/rawdata/sub-verse521/sub-verse521_dir-ax_ct.nii.gz”
一个图像=笔尖加载(列车数据原始图像)
#查看图像形状
打印(一张图片。形状)
#查看图像标题。要了解标题,请参阅:https://brainder.org/2012/09/23/the-nifti-file-format/
打印(一个__image.header)
#查看原始数据
一个图像数据=一个图像
打印(一个图像数据)
#以三个不同的角度可视化一幅图像
一个图像数据轴=一个图像
#更改视图
一个图像数据矢状面=np.转置(一个图像图像数据,[2,1,0])
一个图像数据矢状=np.flip(一个图像矢状,轴=0)
#更改视图
一个图像数据冠状=np.转置(一个图像图像数据,[2,0,1])
一个图像数据冠状=np.flip(一个图像图像数据冠状,轴=0)
图,ax=plt.子时隙(1,3,图=(60,60))
ax[0].imshow(一个图像数据轴,cmap=’骨头’)
ax[0].set_title(“轴向视图”,fontsize=60)
ax〔1〕.imshow(一个图像数据矢状面〔:,:,260〕,cmap=’骨头’)
ax〔1〕.set_title(“矢状视图”,字体大小=60)
ax〔2〕.imshow(一个图像数据冠状面〔:,:,200〕,cmap=’骨头’)
ax〔2〕.set_title(“冠状视图”,字体大小=60)
plt.show()
#在原始图像(CT扫描的一个切片)上覆盖一个掩模
train_data_mask_image=train_data+”导数/sub-verse521/sub-verse521_dir-ax_seg-vert_msk.nii.gz”
train_data_mask_image=笔尖加载
plt.图(图=(10,10))
旋转_原始=np.转置(一个_图像_数据,[2,1,0])
旋转_原始=np.flip(旋转_原始,轴=0)
plt.imshow(旋转_raw〔:,:,260〕,cmap=’bone’,插值=’none’)
列车数据掩码图像
旋转_掩码=np.转置(训练_数据_掩码_图像,[2,1,0])
旋转_掩码=np.flip(旋转_掩码,轴=0)
plt.imshow(旋转_掩码〔:,:,260〕,cmap=”酷”)
预处理数据
#设置路径以存储处理过的列车和验证原始图像和掩码
已加工的_列=”。/已加工的__列/”
已处理的确认=”./已处理的验证/”
处理后的_列车_原始_图像=处理后的列车+”原始_图片/”
加工的_列_口罩=加工的_系列+”口罩/”
已处理的_验证_原始_图像=已处理_验证+”原始_图片/”
加工的_验证_口罩=加工的__验证+”口罩/”
#阅读所有2019年和2020年的原始文件,包括培训和验证
raw_train_files=glob.glob(os.path.join(train_data,’rawdatanii.gz’))
原始验证文件=glob.glob(os.path.join(验证数据,’rawdatanii.gz’))
打印(“原始图像计数列:{0},验证:{1}”。格式(len(原始图像列),len(原图像验证)
#阅读所有2019年和2020年的原始文件,包括培训和验证
raw_train_files=glob.glob(os.path.join(train_data,’rawdatanii.gz’))
原始验证文件=glob.glob(os.path.join(验证数据,’rawdatanii.gz’))
打印(“原始图像计数列:{0},验证:{1}”.format(len(原始图像列_文件),len(原图像验证_文件文件))
#阅读所有2019年和2020年衍生品文件,包括培训和验证
masks_train_files=glob.glob(os.path.join(train_data,’derivativesnii.gz’))
掩码_验证_文件=glob.glob(os.path.join(验证_数据,’derivativesnii.gz’))
打印(“掩码图像计数列:{0},验证:{1}”.format(len(掩码_列_文件),len(面具_验证_文件))
def读取_文件(nii _文件):
'''
读取.nii.gz文件。
参数:
nii_file(str):文件路径。
返回:
CT图像数据的3D数字阵列。
'''
返回np.asanyarray(nib.load(nii_file).dataobj)
定义保存_文件(原始_数据,标签_数据,文件_名称,索引,输出_原始_文件_路径,输出_标签_文件_路径):
'''
将文件保存为npz格式。
参数:
raw_ data(array):原始图像数据的2D numpy数组。
label_data(array):标签图像数据的2D numpy数组。
file_name(str):文件名。
index(int):CT图像的切片。
output_raw_file_path(str):所有原始文件的路径。
output_label_file_path(str):所有掩码文件的路径。
'''
#将所有非零像素替换为1
唯一_值=np.unique(标签_数据)
如果len(唯一的_值)>1:
raw_file_name=”{0}{1}_{2}.png”.format(输出_raw文件_路径,文件_名称,索引)
im=Image.fromarray(原始数据)
im=im.转换(“L”)
im.save(原始文件名)
label_file_name=”{0}{1}_{2}.png”.format(输出_label_文件_路径,文件_名称,索引)
im=Image.fromarray(标签_数据)
im=im.转换(“L”)
im.save(标签文件名)
def是对角线(矩阵):
'''
检查给定矩阵是否为对角矩阵。
参数:
矩阵(np数组):numpy数组
'''
对于范围(0,3)中的i:
对于范围(0,3)中的j:
如果((i!=j)和(矩阵〔i〕〔j〕!=0)):
return False
return True
def生成_数据(原始_文件,标签_文件,文件_名称,输出_原始_文件_路径,输出_标签_文件_路径):
'''
读取每个原始和标签文件并生成一系列图像的主要功能
每个切片。
参数:
raw_file(str):原始文件的路径。
label_file(str):标签文件的路径。
file_name(str):文件名。
output_raw_file_path(str):所有原始文件的路径。
output_label_file_path(str):所有掩码文件的路径。
'''
#如果每2个切片跳过一次。相邻的切片可以非常相似
#将生成冗余数据
跳过_切片=3
continue _ it=真
原始数据=读取文件(原始文件)
标签_数据=读取_文件(标签_文件)
如果原始文件中有”拆分”:
continue _ it=假
affine=nib.load(原始文件).affine
如果是_对角线(仿射〔:3,:3〕):
转置的_标签_数据=np.转置(标签_数据,[2,1,0])
转置的_标签_数据=np.flip(转置的_label _数据)
否则:
转置的_标签_数据=np.flip(转置的_label _数据)
如果继续:
如果转置了_原始_数据形状:
切片计数=转置的原始数据。形状〔-1〕
打印(“文件名:”,文件名,”-切片计数:”,切片计数)
#跳过一些切片
对于范围内的每个切片(1,切片计数,跳过切片):
保存_文件(转置的_原始_数据[:,:,每个_切片]
转置的_标签_数据[:,:,每个_切片]
文件_名称
每个_切片
输出_原始_文件_路径
输出_标签_文件_路径)
#循环处理原始图像和掩码,生成”PNG”图像。
打印(“处理已开始。”)
对于原始列车文件中的每个原始列车文件:
raw _ file _ name=每个_ raw _文件.split(“/”)〔-1〕.split(“_ct.nii.gz”)〔0〕
对于每个掩码中的_掩码_文件_列车_文件:
如果每个_ mask _ file.split(“/”)[-1]中的原始_ file _ name:
生成_数据(每个_原始_文件,每个_掩码_文件.原始_名称.已处理的_训练_原始图片.已处理_训练_掩码)
打印(“处理列车数据完成”)
#循环处理原始图像和掩码,生成”PNG”图像。
对于原始_验证_文件中的每个_原始_文件:
raw _ file _ name=每个_ raw _文件.split(“/”)〔-1〕.split(“_ct.nii.gz”)〔0〕
对于掩码_验证_文件中的每个_掩码_文件:
如果每个_ mask _ file.split(“/”)[-1]中的原始_ file _ name:
生成_数据(每个_原始_文件,每个_掩码_文件,原始_文件_名称,处理的_验证_原始_图像,处理的_验证_掩码)
打印(“处理验证数据完成”)
火车
#定义模型参数
设备=”cuda”(如果火炬cuda可用),否则为”cpu”
#要转换为的图像大小
图像_高度=250
图像_宽度=250
学习率=1e-4
批次_尺寸=10
纪元=10
人数_工人=8
#设置设备
device=torch.device(如果torch.cuda可用,则为”cuda”(否则为”cpu”)
#UNet模型部件
#源代码:https://github.com/milesial/Pytorch-UNet/blob/master/UNet/UNet_parts.py
类DoubleConv(nn.模块):
“””(卷积=>〔BN〕=>ReLU)*2″”
def_init_(self,in_channel,out_channels,mid_channes=None):
super().__init()
如果不是中间频道:
中间通道=外部通道
self.double_conv=nn.顺序(
nn.Conv2d(在_个通道中,中间_个通道,内核_大小=3,填充=1)
nn.BatchNorm2d(中间通道)
nn.ReLU(inplace=True)
nn.Conv2d(中间通道,外部通道,内核大小=3,填充=1)
nn.BatchNorm2d(出_个通道)
nn.ReLU(就地=真)
)
向前定义(自身,x):
返回self.double_conv(x)
类向下(nn.模块):
“””使用maxpool进行缩小,然后使用double conv”””
def_init_(self,in_channel,out_channels):
super().__init()
self.maxpool _ conv=nn.顺序(
最大池2d(2)
DoubleConv(输入_个频道,输出_个通道)
)
向前定义(自身,x):
返回self.maxpool _ conv(x)
class Up(nn.模块):
“””升级,然后双转换”””
def_init_(self,in_channel,out_channels,双线姓=True):
super().__init()
#如果双线姓,使用正常卷积来减少通道数
如果双线姓:self.up=nn.Upsample(scale_factor=2,mode=”双线姓”,align_corners=True)self.conv=DoubleConv(in_channel,out_channels,in_channels/2)否则:self.up=nn.ConvTranspose2d(in__channel,in__channes/2,kernel_size=2,step=2)self.conv=DoppleConv(在_____个频道内,在_____频道外)
向前定义(self,x1,x2):
x1=自升(x1)
#输入为CHW
diffY=x2.size()[2]-x1.size
diffX=x2.size()〔3〕-x1.size
x1=F.pad(x1,〔diffX//2
diffY//2.diffY-diffY//2〕)
#如果您有填充问题,请参阅
#https:/github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
#https:/github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x=火炬cat([x2,x1],dim=1)
返回自身conv(x)
类OutConv(nn.模块):
def_init_(self,in_channel,out_channels):
super(OutConv,self).__init()
self.conv=nn.Conv2d(in_通道,out_通道.内核_大小=1)
向前定义(自身,x):
返回自身conv(x)
#定义UNet架构
#源代码:https://github.com/milesial/Pytorch-UNet/blob/master/UNet/UNet_model.py
类UNet(nn.模块):
def__init_(self,n_通道,n_类,双线姓=True):
super(UNet,self).__初始化__(
self.n个频道=n个频道
self.n类=n_类
自双线姓=双线姓
self.inc=DoubleConv(n个通道,64)
self.down1=向下(64,128)
self.down2=向下(128,256)
self.down3=向下(256,512)
如果双线姓为1,则因子=2
self.down4=向下(5121024//因子)
self.up1=向上(1024,512//因子,双线姓)
self.up2=向上(512,256//因子,双线姓)
self.up3=向上(256,128//因子,双线姓)
self.up4=向上(128,64,双线姓)
self.outc=OutConv(64,n_类)
向前定义(自身,x):
x1=self.inc(x)
x2=自身向下1(x1)
x3=自身向下2(x2)
x4=自身向下3(x3)
x5=自身向下4(x4)
x=自升1(x5,x4)
x=自身up2(x,x3)
x=自我up3(x,x2)
x=自身up4(x,x1)
logits=self.outc(x)
返回逻辑
#定义PyTorch数据集类
#本课程将访问图像和掩码,对其进行预处理,以便进行培训和验证
类VerSeDataset(数据集):
def_init_(self,raw_images_path,masks_patch,images _ name):
self.raw_images_path=原始_images_path
self.masks_path=掩码_path
self.images_name=图片_name
def__len_(自我):
return len(self.images_name)
def___getitem___(自身,索引):
#获取给定索引的图像和掩码
img_路径=os.path.join(self.raw_images_路径,self.images_名称〔索引〕)
mask_path=os.path.join(self.masks_path,self.images_name〔index〕)
#读取图像和掩码
image=image.open(img_路径)
mask=Image.open(mask_路径)
#调整图像大小并将形状更改为(1,图像宽度,图像高度)
w. h=图像尺寸
image=image.resize((w,h),重采样=image.BICUBIC)
image=T.Resize(大小=(图像宽度,图像高度))(图像)
image_ndarray=np.asarray(图像)
image_ndarray=image_darray.shape(1,image_endarray.shape[0],image_narray.shape[1])
#调整遮罩大小。遮罩形状为(图像_width,图像_height)
mask=mask.resize((w,h),重采样=Image.NEAREST)
mask=T.Resize(大小=(图像宽度,图像高度))(mask)
掩码_ ndarray=np.asarray(掩码)
返回{
“image”:火炬.as张量(image _ndarray.copy()).float().contiguous()
“mask”:火炬.as张量(mask _ndarray.copy()).float().连续(
}
#获取所有图像和掩码的路径
train_images_paths=os.listdir(已处理的_train__原始_图像)
列车掩码路径=os.listdir(已处理列车掩码)
验证_图像_路径=os.listdir(已处理_验证_原始_图像)
验证_掩码_路径=os.listdir(已处理_验证_屏蔽)
#加载图像和掩码数据
列车_数据=VerSeDataset(处理的列车_原始图像.处理的列车_掩码.列车_图像_路径)
有效的_数据=VerSeDataset(处理的_验证的_原始的_图像,处理的_确认的_掩码,验证的_图像的_路径)
#创建PyTorch DataLoader
列车_数据加载器=dataloader(列车_数据,批次_大小=批次_大小,混洗=True)
有效数据加载程序=数据加载程序(有效数据,批次大小=批次大小,混洗=假)
#从一个批次中查看一个图像和掩模,以便目视检查
next _ image=next(iter(有效的_ dataloader))
图,ax=plt.子时隙(1,2,图=(60,60))
ax〔0〕.imshow(下一个图像〔’image’〕〔0〕〔0,:,:〕,cmap=’bone’)
ax[0].set_title(“原始图像”,字体大小=60)
ax〔1〕.imshow(下一张图片〔’mask’〕〔0〕〔:,:〕,cmap=’bone’)
ax〔1〕.set_title(“蒙版图像”,字体大小=60)
plt.show()
#定义骰子损失等级
#源代码:https://www.kaggle.com/bigironsphere/loss-function-library-keras-pytorch
类骰子损失(nn.模块):
def_init_(自我,体重=无,尺寸_平均值=真):
super(骰子损失,自我).__init___()
def forward(自我,输入,目标,平滑=1):
输入=火炬.乙状体(输入)
#平坦标签和预测张量
输入=输入视图(-1)
targets=targets.view(-1)
交集=(输入*目标).sum()
骰子=(2.*交集+平滑)/(inputs.sum()+targets.sum()+平滑)
bce=F.binary_交叉_熵_与_逻辑(输入,目标)
pred=torch.sigoid(输入)
损失=bce*0.5+骰子*(1-0.5)
#从骰子值中减去1以计算损失
return 1-骰子
#将模型定义为UNet
模型=UNet(n个通道=1,n个类=1)
型号至(设备=设备)
优化器=optim.Adam(model.parameters(),lr=学习_速率)
#培训和验证
列车_损失=〔〕
val_损失=〔〕
对于范围内的历元(EPOCHS):
模型.train()
列车运行损耗=0.0
计数器=0
以tqdm(总=len(列车数据),desc=f’Epoch{Epoch+1}/{EPOCHS}’,单位=’img’)为pbar:
对于批量在列数据加载器:
计数器+=1
image=批次〔”image”〕到(DEVICE)
mask=批次〔”mask”〕到(DEVICE)
优化器.zero_grad()
输出=模型(图像)
输出=输出。挤压(1)
loss=DiceLoss()(输出,掩码)
列车运行损耗+=损耗item()
向后损失()
优化器.step()
pbar.update(image.shape[0])
pbar.set_后缀(**{’loss(batch)’:loss.item()})
列车损失附加(列车运行损失/计数器)
模型评估()
有效运行损耗=0.0
计数器=0
带火炬。无_级():
对于i,枚举中的数据(有效的_dataloader):
计数器+=1
image=数据〔”image”〕到(DEVICE)
mask=数据〔’mask’〕到(DEVICE)
输出=模型(图像)
输出=输出。挤压(1)
loss=DiceLoss()(输出,掩码)
有效_运行_损失+=损失.item()
val_loss.append(有效_运行_loss)
纪元1/10:100%██████████ 4790/4790〔4:00:34<00:00,3.01s/img,损失(批次)=0.385〕
纪元2/10:100%██████████ 4790/4790〔4:00:02<00:00,3.01s/img,损失(批次)=0.268〕
纪元3/10:100%██████████ 4790/4790〔3:57:30<00:00,2.98s/img,损失(批次)=0.152〕
纪元4/10:100%██████████ 4790/4790〔3:57:05<00:00,2.97s/img,损失(批次)=0.105〕
纪元5/10:100%██████████ 4790/4790〔4:08:29<00:00,3.11s/img,损失(批次)=0.103〕
纪元6/10:100%██████████ 4790/4790〔4:04:12<00:00,3.06s/img,损失(批次)=0.0874〕
纪元7/10:100%██████████ 4790/4790[4:02:00<00:00,3.03s/img,损失(批次)=0.0759]
纪元8/10:100%██████████ 4790/4790〔3:58:32<00:00,2.99s/img,损失(批次)=0.0655〕
纪元9/10:100%██████████ 4790/4790〔4:00:47<00:00,3.02s/img,损失(批次)=0.0644〕
纪元10/10:100%██████████ 4790/4790〔4:08:54<00:00,3.12s/img,损失(批次)=0.0604〕
#绘制列车与验证损失
plt.图(图=(10,7))
plt.plot(列车_损失,颜涩=”橙涩”,标签=”列车损失”)
plt.plot(val_损失,颜涩=”红涩”,标签=”验证损失”)
plt.xlabel(“大纪元”)
plt.ylabel(“损失”)
plt.legend()
plt.show()
#保存训练过的模型
火炬保存({
“纪元”:EPOCHS
“model_state_dict”:model.state_ditt()
‘优化器_state_dict’:优化器.state_dict(),},”./unet_model.pth”)
#从视觉上看一个预测
next _ image=next(iter(有效的_ dataloader))
#做预测
输出=模型(下一个_image〔’image’〕.float())
outputs=outputs.deptach().cpu()
loss=DiceLoss()(输出,下一个_图像〔’mask’〕)
打印(“骰子得分:”,1-loss.item())
输出〔输出<=0.0〕=0
输出〔输出>0.0〕=1.0
#绘制所有三幅图像
图,ax=plt.子时隙(1,3,图=(60,60))
ax〔0〕.imshow(下一个图像〔’image’〕〔0〕〔0,:,:〕,cmap=’bone’)
ax[0].set_title(“原始图像”,fontsize=60)
ax〔1〕.imshow(下一张图片〔’mask’〕〔0〕〔:,:〕,cmap=’bone’)
ax〔1〕.set_title(“True Mask”,字体大小=60)
ax〔2〕.imshow(输出〔0,0,:,:〕,cmap=’bone’)
ax[2].set_title(“预测掩码”,fontsize=60)
plt.show()
未来工作:这项任务也可以通过3D UNet完成,这可能是学习脊柱结构的更好方法。
因为每个椎骨的每个掩模区域都有标签,所以我们可以进一步执行多类掩模分割。此外,当图像视图为矢状时,模型表现最佳,因此将所有切片转换为矢状可能会获得更好的结果。
谢谢你的阅读!
600学习网 » UNet分割脊柱-600学习网