Skip to content

Commit 11f2bf5

Browse files
MobileSAM
1 parent 780fc49 commit 11f2bf5

File tree

4 files changed

+685
-13
lines changed

4 files changed

+685
-13
lines changed

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@ This extension aim for connecting [AUTOMATIC1111 Stable Diffusion WebUI](https:/
1919
- `2023/05/29`: [v1.4.2](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.4.2) You may now do SAM inference on CPU by checking "Use CPU for SAM". This is for some MAC users who are not able to do SAM inference on GPU. I discourage other users from using this feature because it is significantly slower than CUDA.
2020
- `2023/06/01`: [v1.5.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.0) You may now choose to use local GroundingDINO to bypass C++ problem. See [FAQ](#faq)-1 for more detail.
2121
- `2023/06/04`: [v1.5.1](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.5.1) `Upload Mask to ControlNet Inpainting` comes back in response to [ControlNet inpaint improvement](https://github.com/Mikubill/sd-webui-controlnet/discussions/1464). You should see a new tab beside `AutoSAM` after updating the extension. This feature will again be removed once ControlNet extension has its own uploading feature.
22-
- `2023/06/13`: [v1.6.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.6.0) [SAM-HQ](https://github.com/SysCV/sam-hq) supported by [@SpenserCai](https://github.com/SpenserCai) and me. This is an "upgraded" SAM from researchers at ETH Zurich & HKUST. However, I cannot guarantee which one is better and you should make your own choice based on your own experiments. Go to [Installation](#installation) to get the link to the models.
22+
- `2023/06/13`: [v1.6.0](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.6.0) [SAM-HQ](https://github.com/SysCV/sam-hq) supported by [@SpenserCai](https://github.com/SpenserCai) and me. This is an "upgraded" SAM, created by researchers at ETH Zurich & HKUST. However, I cannot guarantee which one is better and you should make your own choice based on your own experiments. Go to [Installation](#installation) to get the link to the models.
23+
- `2023/06/29`: [v1.6.1](https://github.com/continue-revolution/sd-webui-segment-anything/releases/tag/v1.6.1) [MobileSAM](https://github.com/ChaoningZhang/MobileSAM) supported. This is a tiny version of SAM, created by researchers at Kyung Hee University. Visit [here](https://github.com/continue-revolution/sd-webui-segment-anything/issues/139) for more information.
24+
25+
Note that support for some other variations of SAM, such as [Matting-Anything](https://github.com/SHI-Labs/Matting-Anything) and [FastSAM](https://github.com/CASIA-IVA-Lab/FastSAM) are still on the way. Support for these models, unlike MobileSAM, are non-trivial, especially FastSAM, which utilize a completely different pipeline, ultralytics/YOLO. Introducing these new works to the current codebase will make the original ugly-enough codebase more ugly. They will be supported once I finish a major refactor of the current codebase.
2326

2427
## FAQ
2528

sam_hq/build_sam_hq.py

+62-11
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010

1111
from .modeling.mask_decoder_hq import MaskDecoderHQ
1212
from .modeling.image_encoder import ImageEncoderViTHQ
13-
from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer
13+
from .modeling.tiny_vit import TinyViT
14+
from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer, MaskDecoder
1415
from segment_anything import build_sam_vit_h, build_sam_vit_l, build_sam_vit_b
1516

1617

@@ -44,16 +45,32 @@ def build_sam_hq_vit_b(checkpoint=None):
4445
)
4546

4647

48+
def build_mobile_sam(checkpoint=None):
49+
return _build_mobile_sam(checkpoint)
50+
51+
4752
sam_model_registry = {
4853
"sam_vit_h": build_sam_vit_h,
4954
"sam_vit_l": build_sam_vit_l,
5055
"sam_vit_b": build_sam_vit_b,
5156
"sam_hq_vit_h": build_sam_hq_vit_h,
5257
"sam_hq_vit_l": build_sam_hq_vit_l,
5358
"sam_hq_vit_b": build_sam_hq_vit_b,
59+
"mobile_sam": build_mobile_sam,
5460
}
5561

5662

63+
def _load_sam_checkpoint(sam: Sam, checkpoint=None):
64+
sam.eval()
65+
if checkpoint is not None:
66+
with open(checkpoint, "rb") as f:
67+
state_dict = torch.load(f)
68+
info = sam.load_state_dict(state_dict, strict=False)
69+
print(info)
70+
for _, p in sam.named_parameters():
71+
p.requires_grad = False
72+
return sam
73+
5774
def _build_sam_hq(
5875
encoder_embed_dim,
5976
encoder_depth,
@@ -102,14 +119,48 @@ def _build_sam_hq(
102119
pixel_mean=[123.675, 116.28, 103.53],
103120
pixel_std=[58.395, 57.12, 57.375],
104121
)
105-
sam.eval()
106-
if checkpoint is not None:
107-
with open(checkpoint, "rb") as f:
108-
state_dict = torch.load(f)
109-
info = sam.load_state_dict(state_dict, strict=False)
110-
print(info)
111-
for n, p in sam.named_parameters():
112-
if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n:
113-
p.requires_grad = False
122+
return _load_sam_checkpoint(sam, checkpoint)
114123

115-
return sam
124+
125+
def _build_mobile_sam(checkpoint=None):
126+
prompt_embed_dim = 256
127+
image_size = 1024
128+
vit_patch_size = 16
129+
image_embedding_size = image_size // vit_patch_size
130+
mobile_sam = Sam(
131+
image_encoder=TinyViT(
132+
img_size=1024, in_chans=3, num_classes=1000,
133+
embed_dims=[64, 128, 160, 320],
134+
depths=[2, 2, 6, 2],
135+
num_heads=[2, 4, 5, 10],
136+
window_sizes=[7, 7, 14, 7],
137+
mlp_ratio=4.,
138+
drop_rate=0.,
139+
drop_path_rate=0.0,
140+
use_checkpoint=False,
141+
mbconv_expand_ratio=4.0,
142+
local_conv_size=3,
143+
layer_lr_decay=0.8
144+
),
145+
prompt_encoder=PromptEncoder(
146+
embed_dim=prompt_embed_dim,
147+
image_embedding_size=(image_embedding_size, image_embedding_size),
148+
input_image_size=(image_size, image_size),
149+
mask_in_chans=16,
150+
),
151+
mask_decoder=MaskDecoder(
152+
num_multimask_outputs=3,
153+
transformer=TwoWayTransformer(
154+
depth=2,
155+
embedding_dim=prompt_embed_dim,
156+
mlp_dim=2048,
157+
num_heads=8,
158+
),
159+
transformer_dim=prompt_embed_dim,
160+
iou_head_depth=3,
161+
iou_head_hidden_dim=256,
162+
),
163+
pixel_mean=[123.675, 116.28, 103.53],
164+
pixel_std=[58.395, 57.12, 57.375],
165+
)
166+
return _load_sam_checkpoint(mobile_sam, checkpoint)

0 commit comments

Comments
 (0)