0-SAM2-UNet: Segment Anything 2 Makes Strong Encoder for Natural and Medical Image Segmentation

SAM2 - UNet:利用 Segment Anything 2 为自然和医学图像分割打造强大编码器总结

SAM2-UNet

arXiv:2408.08870v1 [cs.CV] 16 Aug 2024

https://github.com/WZH0120/SAM2-UNet.

背景

图像分割是计算机视觉重要任务,多种下游任务依赖于它,但设计统一架构处理不同分割任务仍是挑战。视觉基础模型(VFMs)如 SAM2 在图像分割有潜力,但存在生成与任务无关的分割结果等问题,需探索提升其适应性和性能的策略。已有研究尝试将 SAM 适配到下游任务,如使用适配器、集成额外输入或改变架构,但受限于 SAM1 中 ViT 编码器的结构,SAM2 的分层骨干为设计更有效的 U 形网络提供了新机会。

所以作者提出 SAM2 - UNet,证明 Segment Anything Model 2(SAM2)可作为 U 形分割模型的强大编码器,通过简单有效的框架实现通用图像分割。

实验方法

模型架构

模型架构

SAM2-UNet 概述。Hiera 块有一些变体,为了便于理解,只演示了一个简化的结构。

编码器

采用 SAM2 预训练的 Hiera 骨干网络,其分层结构适合 U 形网络,能多尺度捕捉特征,输出不同层次特征。

感受野模块(RFBs)

用于减少编码器特征的通道数至 64 并增强特征。

适配器

因 Hiera 参数量大,冻结其参数并在每个多尺度块前插入适配器,实现参数高效微调,适配器由线性层、GeLU 激活函数等组成。

解码器

采用经典 U - Net 设计,包含三个解码器块,每个块有两个 “Conv - BN - ReLU” 组合,输出特征经 1×1 Conv 分割头生成分割结果并上采样,与真实掩码监督训练,使用加权 IoU 损失和二元交叉熵(BCE)损失作为训练目标并应用深度监督。

实验结果

在五个不同基准上的十八个数据集进行实验,包括伪装对象检测、显著对象检测、海洋动物分割、镜子检测和息肉分割任务,各任务采用相应数据集和评估指标。

基于 PyTorch 在单张 NVIDIA RTX 4090 GPU 上实现,使用 AdamW 优化器,采用随机垂直和水平翻转的数据增强策略,输入图像调整为 352×352,不同任务设置不同训练轮数,息肉分割任务采用多尺度训练策略。

在伪装对象检测、显著对象检测、海洋动物分割、镜子检测和息肉分割任务中,SAM2 - UNet 在多个数据集上的各项评估指标上表现优异,超过其他对比方法,部分任务实现了最先进性能,可视化结果也显示其在不同场景下的准确性优势。

伪装对象检测的可视化结果

息肉分割的可视化结果。

消融实验

对 Hiera 骨干网络大小进行消融实验,结果表明较大骨干通常性能更好,但较小骨干的 SAM2 - UNet 也能取得不错结果,证明 SAM2 预训练的 Hiera 骨干网络提供的高质量表示有效。

消融实验

总结

提出的 SAM2 - UNet 框架简单有效,适用于自然和医学图像分割任务,其采用的 SAM2 预训练 Hiera 编码器与经典 U - Net 解码器结合的方式被证明有效,可作为未来 SAM2 变体开发的新基线。

代码分析

使用了SAM2.0基础模型的编码器与解码器,去除了其提示模块,并构建了无零样本分割能力的网络。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class SAM2UNet(nn.Module):
def __init__(self, checkpoint_path=None) -> None:
super(SAM2UNet, self).__init__()
model_cfg = "sam2_hiera_l.yaml"
if checkpoint_path:
model = build_sam2(model_cfg, checkpoint_path)
else:
model = build_sam2(model_cfg)
del model.sam_mask_decoder
del model.sam_prompt_encoder
del model.memory_encoder
del model.memory_attention
del model.mask_downsample
del model.obj_ptr_tpos_proj
del model.obj_ptr_proj
del model.image_encoder.neck
self.encoder = model.image_encoder.trunk

for param in self.encoder.parameters():
param.requires_grad = False
blocks = []
for block in self.encoder.blocks:
blocks.append(
Adapter(block)
)
self.encoder.blocks = nn.Sequential(
*blocks
)
self.rfb1 = RFB_modified(144, 64)
self.rfb2 = RFB_modified(288, 64)
self.rfb3 = RFB_modified(576, 64)
self.rfb4 = RFB_modified(1152, 64)
self.up1 = (Up(128, 64))
self.up2 = (Up(128, 64))
self.up3 = (Up(128, 64))
self.up4 = (Up(128, 64))
self.side1 = nn.Conv2d(64, 1, kernel_size=1)
self.side2 = nn.Conv2d(64, 1, kernel_size=1)
self.head = nn.Conv2d(64, 1, kernel_size=1)

def forward(self, x):
x1, x2, x3, x4 = self.encoder(x)
x1, x2, x3, x4 = self.rfb1(x1), self.rfb2(x2), self.rfb3(x3), self.rfb4(x4)
x = self.up1(x4, x3)
out1 = F.interpolate(self.side1(x), scale_factor=16, mode='bilinear')
x = self.up2(x, x2)
out2 = F.interpolate(self.side2(x), scale_factor=8, mode='bilinear')
x = self.up3(x, x1)
out = F.interpolate(self.head(x), scale_factor=4, mode='bilinear')
return out, out1, out2

通过设置不同的上、下采样数值,在前项传播中进行多尺度的融合计算。通过插值法bilinear计算3个不同的输出。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
class RFB_modified(nn.Module):
def __init__(self, in_channel, out_channel):
super(RFB_modified, self).__init__()
self.relu = nn.ReLU(True)
self.branch0 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
)
self.branch1 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 3), padding=(0, 1)),
BasicConv2d(out_channel, out_channel, kernel_size=(3, 1), padding=(1, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=3, dilation=3)
)
self.branch2 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 5), padding=(0, 2)),
BasicConv2d(out_channel, out_channel, kernel_size=(5, 1), padding=(2, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=5, dilation=5)
)
self.branch3 = nn.Sequential(
BasicConv2d(in_channel, out_channel, 1),
BasicConv2d(out_channel, out_channel, kernel_size=(1, 7), padding=(0, 3)),
BasicConv2d(out_channel, out_channel, kernel_size=(7, 1), padding=(3, 0)),
BasicConv2d(out_channel, out_channel, 3, padding=7, dilation=7)
)
self.conv_cat = BasicConv2d(4*out_channel, out_channel, 3, padding=1)
self.conv_res = BasicConv2d(in_channel, out_channel, 1)

def forward(self, x):
x0 = self.branch0(x)
x1 = self.branch1(x)
x2 = self.branch2(x)
x3 = self.branch3(x)
x_cat = self.conv_cat(torch.cat((x0, x1, x2, x3), 1))

x = self.relu(x_cat + self.conv_res(x))
return x

感受野模块用来确保尺度对齐。