Skip to content

Commit

Permalink
Update model sizes
Browse files Browse the repository at this point in the history
  • Loading branch information
HonzaCuhel committed Feb 3, 2025
1 parent fe83647 commit 37e2c46
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ datadreamer --config <path-to-config>
| | [CLIP](https://huggingface.co/openai/clip-vit-base-patch32) | Zero-shot-image-classification |
| | [AIMv2](https://huggingface.co/apple/aimv2-large-patch14-224-lit) | Zero-shot-image-classification |
| | [SlimSAM](https://huggingface.co/Zigeng/SlimSAM-uniform-50) | Zero-shot-instance-segmentation |
| | [SAM2.1](https://huggingface.co/facebook/sam2-hiera-tiny) | Zero-shot-instance-segmentation |
| | [SAM2.1](https://huggingface.co/facebook/sam2.1-hiera-large) | Zero-shot-instance-segmentation |

<a name="example"></a>

Expand Down
6 changes: 3 additions & 3 deletions datadreamer/dataset_annotation/sam2_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,10 @@ def _init_model(self, device: str) -> SAM2ImagePredictor:
logger.info(f"Initializing SAM2.1 {self.size} model...")
if self.size == "large":
return SAM2ImagePredictor.from_pretrained(
"facebook/sam2.1-hiera-base-plus", device=device
"facebook/sam2-hiera-large", device=device
)
return SAM2ImagePredictor.from_pretrained(
"facebook/sam2-hiera-tiny", device=device
"facebook/sam2.1-hiera-base-plus", device=device
)

def annotate_batch(
Expand Down Expand Up @@ -131,7 +131,7 @@ def release(self, empty_cuda_cache: bool = False) -> None:

url = "https://ultralytics.com/images/bus.jpg"
im = Image.open(requests.get(url, stream=True).raw)
annotator = SAM2Annotator(device="cpu", size="base")
annotator = SAM2Annotator(device="cpu", size="large")
final_segments = annotator.annotate_batch([im], [np.array([[3, 229, 559, 650]])])
print(len(final_segments), len(final_segments[0]))
print(final_segments[0][0][:5])

0 comments on commit 37e2c46

Please sign in to comment.