|
10 | 10 |
|
11 | 11 | from .modeling.mask_decoder_hq import MaskDecoderHQ
|
12 | 12 | from .modeling.image_encoder import ImageEncoderViTHQ
|
13 |
| -from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer |
| 13 | +from .modeling.tiny_vit import TinyViT |
| 14 | +from segment_anything.modeling import PromptEncoder, Sam, TwoWayTransformer, MaskDecoder |
14 | 15 | from segment_anything import build_sam_vit_h, build_sam_vit_l, build_sam_vit_b
|
15 | 16 |
|
16 | 17 |
|
@@ -44,16 +45,32 @@ def build_sam_hq_vit_b(checkpoint=None):
|
44 | 45 | )
|
45 | 46 |
|
46 | 47 |
|
| 48 | +def build_mobile_sam(checkpoint=None): |
| 49 | + return _build_mobile_sam(checkpoint) |
| 50 | + |
| 51 | + |
47 | 52 | sam_model_registry = {
|
48 | 53 | "sam_vit_h": build_sam_vit_h,
|
49 | 54 | "sam_vit_l": build_sam_vit_l,
|
50 | 55 | "sam_vit_b": build_sam_vit_b,
|
51 | 56 | "sam_hq_vit_h": build_sam_hq_vit_h,
|
52 | 57 | "sam_hq_vit_l": build_sam_hq_vit_l,
|
53 | 58 | "sam_hq_vit_b": build_sam_hq_vit_b,
|
| 59 | + "mobile_sam": build_mobile_sam, |
54 | 60 | }
|
55 | 61 |
|
56 | 62 |
|
| 63 | +def _load_sam_checkpoint(sam: Sam, checkpoint=None): |
| 64 | + sam.eval() |
| 65 | + if checkpoint is not None: |
| 66 | + with open(checkpoint, "rb") as f: |
| 67 | + state_dict = torch.load(f) |
| 68 | + info = sam.load_state_dict(state_dict, strict=False) |
| 69 | + print(info) |
| 70 | + for _, p in sam.named_parameters(): |
| 71 | + p.requires_grad = False |
| 72 | + return sam |
| 73 | + |
57 | 74 | def _build_sam_hq(
|
58 | 75 | encoder_embed_dim,
|
59 | 76 | encoder_depth,
|
@@ -102,14 +119,48 @@ def _build_sam_hq(
|
102 | 119 | pixel_mean=[123.675, 116.28, 103.53],
|
103 | 120 | pixel_std=[58.395, 57.12, 57.375],
|
104 | 121 | )
|
105 |
| - sam.eval() |
106 |
| - if checkpoint is not None: |
107 |
| - with open(checkpoint, "rb") as f: |
108 |
| - state_dict = torch.load(f) |
109 |
| - info = sam.load_state_dict(state_dict, strict=False) |
110 |
| - print(info) |
111 |
| - for n, p in sam.named_parameters(): |
112 |
| - if 'hf_token' not in n and 'hf_mlp' not in n and 'compress_vit_feat' not in n and 'embedding_encoder' not in n and 'embedding_maskfeature' not in n: |
113 |
| - p.requires_grad = False |
| 122 | + return _load_sam_checkpoint(sam, checkpoint) |
114 | 123 |
|
115 |
| - return sam |
| 124 | + |
| 125 | +def _build_mobile_sam(checkpoint=None): |
| 126 | + prompt_embed_dim = 256 |
| 127 | + image_size = 1024 |
| 128 | + vit_patch_size = 16 |
| 129 | + image_embedding_size = image_size // vit_patch_size |
| 130 | + mobile_sam = Sam( |
| 131 | + image_encoder=TinyViT( |
| 132 | + img_size=1024, in_chans=3, num_classes=1000, |
| 133 | + embed_dims=[64, 128, 160, 320], |
| 134 | + depths=[2, 2, 6, 2], |
| 135 | + num_heads=[2, 4, 5, 10], |
| 136 | + window_sizes=[7, 7, 14, 7], |
| 137 | + mlp_ratio=4., |
| 138 | + drop_rate=0., |
| 139 | + drop_path_rate=0.0, |
| 140 | + use_checkpoint=False, |
| 141 | + mbconv_expand_ratio=4.0, |
| 142 | + local_conv_size=3, |
| 143 | + layer_lr_decay=0.8 |
| 144 | + ), |
| 145 | + prompt_encoder=PromptEncoder( |
| 146 | + embed_dim=prompt_embed_dim, |
| 147 | + image_embedding_size=(image_embedding_size, image_embedding_size), |
| 148 | + input_image_size=(image_size, image_size), |
| 149 | + mask_in_chans=16, |
| 150 | + ), |
| 151 | + mask_decoder=MaskDecoder( |
| 152 | + num_multimask_outputs=3, |
| 153 | + transformer=TwoWayTransformer( |
| 154 | + depth=2, |
| 155 | + embedding_dim=prompt_embed_dim, |
| 156 | + mlp_dim=2048, |
| 157 | + num_heads=8, |
| 158 | + ), |
| 159 | + transformer_dim=prompt_embed_dim, |
| 160 | + iou_head_depth=3, |
| 161 | + iou_head_hidden_dim=256, |
| 162 | + ), |
| 163 | + pixel_mean=[123.675, 116.28, 103.53], |
| 164 | + pixel_std=[58.395, 57.12, 57.375], |
| 165 | + ) |
| 166 | + return _load_sam_checkpoint(mobile_sam, checkpoint) |
0 commit comments