Skip to content

Commit 6f5aa81

Browse files
authored
✨ Add support for multiple Annotated annotations, e.g. Annotated[str, Field(), Query()] (#10773)
1 parent 73dcc40 commit 6f5aa81

File tree

5 files changed

+57
-36
lines changed

5 files changed

+57
-36
lines changed

.github/workflows/test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
id: cache
3030
with:
3131
path: ${{ env.pythonLocation }}
32-
key: ${{ runner.os }}-python-${{ env.pythonLocation }}-pydantic-v2-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v06
32+
key: ${{ runner.os }}-python-${{ env.pythonLocation }}-pydantic-v2-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v07
3333
- name: Install Dependencies
3434
if: steps.cache.outputs.cache-hit != 'true'
3535
run: pip install -r requirements-tests.txt
@@ -62,7 +62,7 @@ jobs:
6262
id: cache
6363
with:
6464
path: ${{ env.pythonLocation }}
65-
key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ matrix.pydantic-version }}-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v06
65+
key: ${{ runner.os }}-python-${{ env.pythonLocation }}-${{ matrix.pydantic-version }}-${{ hashFiles('pyproject.toml', 'requirements-tests.txt', 'requirements-docs-tests.txt') }}-test-v07
6666
- name: Install Dependencies
6767
if: steps.cache.outputs.cache-hit != 'true'
6868
run: pip install -r requirements-tests.txt

fastapi/_compat.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,12 @@ def is_bytes_sequence_field(field: ModelField) -> bool:
249249
return is_bytes_sequence_annotation(field.type_)
250250

251251
def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:
252-
return type(field_info).from_annotation(annotation)
252+
cls = type(field_info)
253+
merged_field_info = cls.from_annotation(annotation)
254+
new_field_info = copy(field_info)
255+
new_field_info.metadata = merged_field_info.metadata
256+
new_field_info.annotation = merged_field_info.annotation
257+
return new_field_info
253258

254259
def serialize_sequence_value(*, field: ModelField, value: Any) -> Sequence[Any]:
255260
origin_type = (

fastapi/dependencies/utils.py

+30-23
Original file line numberDiff line numberDiff line change
@@ -325,25 +325,33 @@ def analyze_param(
325325
field_info = None
326326
depends = None
327327
type_annotation: Any = Any
328-
if (
329-
annotation is not inspect.Signature.empty
330-
and get_origin(annotation) is Annotated
331-
):
328+
use_annotation: Any = Any
329+
if annotation is not inspect.Signature.empty:
330+
use_annotation = annotation
331+
type_annotation = annotation
332+
if get_origin(use_annotation) is Annotated:
332333
annotated_args = get_args(annotation)
333334
type_annotation = annotated_args[0]
334335
fastapi_annotations = [
335336
arg
336337
for arg in annotated_args[1:]
337338
if isinstance(arg, (FieldInfo, params.Depends))
338339
]
339-
assert (
340-
len(fastapi_annotations) <= 1
341-
), f"Cannot specify multiple `Annotated` FastAPI arguments for {param_name!r}"
342-
fastapi_annotation = next(iter(fastapi_annotations), None)
340+
fastapi_specific_annotations = [
341+
arg
342+
for arg in fastapi_annotations
343+
if isinstance(arg, (params.Param, params.Body, params.Depends))
344+
]
345+
if fastapi_specific_annotations:
346+
fastapi_annotation: Union[
347+
FieldInfo, params.Depends, None
348+
] = fastapi_specific_annotations[-1]
349+
else:
350+
fastapi_annotation = None
343351
if isinstance(fastapi_annotation, FieldInfo):
344352
# Copy `field_info` because we mutate `field_info.default` below.
345353
field_info = copy_field_info(
346-
field_info=fastapi_annotation, annotation=annotation
354+
field_info=fastapi_annotation, annotation=use_annotation
347355
)
348356
assert field_info.default is Undefined or field_info.default is Required, (
349357
f"`{field_info.__class__.__name__}` default value cannot be set in"
@@ -356,8 +364,6 @@ def analyze_param(
356364
field_info.default = Required
357365
elif isinstance(fastapi_annotation, params.Depends):
358366
depends = fastapi_annotation
359-
elif annotation is not inspect.Signature.empty:
360-
type_annotation = annotation
361367

362368
if isinstance(value, params.Depends):
363369
assert depends is None, (
@@ -402,15 +408,15 @@ def analyze_param(
402408
# We might check here that `default_value is Required`, but the fact is that the same
403409
# parameter might sometimes be a path parameter and sometimes not. See
404410
# `tests/test_infer_param_optionality.py` for an example.
405-
field_info = params.Path(annotation=type_annotation)
411+
field_info = params.Path(annotation=use_annotation)
406412
elif is_uploadfile_or_nonable_uploadfile_annotation(
407413
type_annotation
408414
) or is_uploadfile_sequence_annotation(type_annotation):
409-
field_info = params.File(annotation=type_annotation, default=default_value)
415+
field_info = params.File(annotation=use_annotation, default=default_value)
410416
elif not field_annotation_is_scalar(annotation=type_annotation):
411-
field_info = params.Body(annotation=type_annotation, default=default_value)
417+
field_info = params.Body(annotation=use_annotation, default=default_value)
412418
else:
413-
field_info = params.Query(annotation=type_annotation, default=default_value)
419+
field_info = params.Query(annotation=use_annotation, default=default_value)
414420

415421
field = None
416422
if field_info is not None:
@@ -424,8 +430,8 @@ def analyze_param(
424430
and getattr(field_info, "in_", None) is None
425431
):
426432
field_info.in_ = params.ParamTypes.query
427-
use_annotation = get_annotation_from_field_info(
428-
type_annotation,
433+
use_annotation_from_field_info = get_annotation_from_field_info(
434+
use_annotation,
429435
field_info,
430436
param_name,
431437
)
@@ -436,7 +442,7 @@ def analyze_param(
436442
field_info.alias = alias
437443
field = create_response_field(
438444
name=param_name,
439-
type_=use_annotation,
445+
type_=use_annotation_from_field_info,
440446
default=field_info.default,
441447
alias=alias,
442448
required=field_info.default in (Required, Undefined),
@@ -466,16 +472,17 @@ def is_body_param(*, param_field: ModelField, is_path_param: bool) -> bool:
466472

467473

468474
def add_param_to_fields(*, field: ModelField, dependant: Dependant) -> None:
469-
field_info = cast(params.Param, field.field_info)
470-
if field_info.in_ == params.ParamTypes.path:
475+
field_info = field.field_info
476+
field_info_in = getattr(field_info, "in_", None)
477+
if field_info_in == params.ParamTypes.path:
471478
dependant.path_params.append(field)
472-
elif field_info.in_ == params.ParamTypes.query:
479+
elif field_info_in == params.ParamTypes.query:
473480
dependant.query_params.append(field)
474-
elif field_info.in_ == params.ParamTypes.header:
481+
elif field_info_in == params.ParamTypes.header:
475482
dependant.header_params.append(field)
476483
else:
477484
assert (
478-
field_info.in_ == params.ParamTypes.cookie
485+
field_info_in == params.ParamTypes.cookie
479486
), f"non-body parameters must be in path, query, header or cookie: {field.name}"
480487
dependant.cookie_params.append(field)
481488

tests/test_ambiguous_params.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import pytest
22
from fastapi import Depends, FastAPI, Path
33
from fastapi.param_functions import Query
4+
from fastapi.testclient import TestClient
5+
from fastapi.utils import PYDANTIC_V2
46
from typing_extensions import Annotated
57

68
app = FastAPI()
@@ -28,18 +30,13 @@ async def get(item_id: Annotated[int, Query(default=1)]):
2830
pass # pragma: nocover
2931

3032

31-
def test_no_multiple_annotations():
33+
def test_multiple_annotations():
3234
async def dep():
3335
pass # pragma: nocover
3436

35-
with pytest.raises(
36-
AssertionError,
37-
match="Cannot specify multiple `Annotated` FastAPI arguments for 'foo'",
38-
):
39-
40-
@app.get("/")
41-
async def get(foo: Annotated[int, Query(min_length=1), Query()]):
42-
pass # pragma: nocover
37+
@app.get("/multi-query")
38+
async def get(foo: Annotated[int, Query(gt=2), Query(lt=10)]):
39+
return foo
4340

4441
with pytest.raises(
4542
AssertionError,
@@ -64,3 +61,15 @@ async def get2(foo: Annotated[int, Depends(dep)] = Depends(dep)):
6461
@app.get("/")
6562
async def get3(foo: Annotated[int, Query(min_length=1)] = Depends(dep)):
6663
pass # pragma: nocover
64+
65+
client = TestClient(app)
66+
response = client.get("/multi-query", params={"foo": "5"})
67+
assert response.status_code == 200
68+
assert response.json() == 5
69+
70+
response = client.get("/multi-query", params={"foo": "123"})
71+
assert response.status_code == 422
72+
73+
if PYDANTIC_V2:
74+
response = client.get("/multi-query", params={"foo": "1"})
75+
assert response.status_code == 422

tests/test_annotated.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ async def unrelated(foo: Annotated[str, object()]):
5757
{
5858
"ctx": {"min_length": 1},
5959
"loc": ["query", "foo"],
60-
"msg": "String should have at least 1 characters",
60+
"msg": "String should have at least 1 character",
6161
"type": "string_too_short",
6262
"input": "",
6363
"url": match_pydantic_error_url("string_too_short"),

0 commit comments

Comments
 (0)