Skip to content

Commit b6be09a

Browse files
committed
Fix get_models() and get_async_models() duplicates bug
Closes #667, refs #640
1 parent e78fea1 commit b6be09a

File tree

3 files changed

+11
-2
lines changed

3 files changed

+11
-2
lines changed

llm/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -169,12 +169,14 @@ class UnknownModelError(KeyError):
169169

170170
def get_models() -> List[Model]:
171171
"Get all registered models"
172-
return [model for model in get_model_aliases().values()]
172+
models_with_aliases = get_models_with_aliases()
173+
return [mwa.model for mwa in models_with_aliases if mwa.model]
173174

174175

175176
def get_async_models() -> List[AsyncModel]:
176177
"Get all registered async models"
177-
return [model for model in get_async_model_aliases().values()]
178+
models_with_aliases = get_models_with_aliases()
179+
return [mwa.async_model for mwa in models_with_aliases if mwa.async_model]
178180

179181

180182
def get_async_model(name: Optional[str] = None) -> AsyncModel:

llm/default_plugins/openai_models.py

+4
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def register_models(register):
6262
aliases=("gpt-4-turbo-preview", "4-turbo", "4t"),
6363
)
6464
# o1
65+
# register(
66+
# Chat("o1", can_stream=False, allows_system_prompt=False, vision=True),
67+
# AsyncChat("o1", can_stream=False, allows_system_prompt=False, vision=True),
68+
# )
6569
register(
6670
Chat("o1-preview", can_stream=False, allows_system_prompt=False),
6771
AsyncChat("o1-preview", can_stream=False, allows_system_prompt=False),

tests/test_llm.py

+3
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,9 @@ def test_get_models():
604604
assert all(isinstance(model, llm.Model) for model in models)
605605
model_ids = [model.model_id for model in models]
606606
assert "gpt-4o-mini" in model_ids
607+
# Ensure no model_ids are duplicated
608+
# https://github.com/simonw/llm/issues/667
609+
assert len(model_ids) == len(set(model_ids))
607610

608611

609612
def test_get_async_models():

0 commit comments

Comments
 (0)