Skip to content

Commit 35a9837

Browse files
wallseatsvlyubovskAlc-AlcprovinzkrautAlc-Alc
authored
fix: Enum OAS generation (#3518) (#3525)
--------- Co-authored-by: svlyubovsk <svlyubovsk@mts.ru> Co-authored-by: Alc-Alc <45509143+Alc-Alc@users.noreply.github.com> Co-authored-by: Janek Nouvertné <provinzkraut@posteo.de> Co-authored-by: Alc-Alc <alc@localhost>
1 parent 3510cab commit 35a9837

File tree

18 files changed

+192
-146
lines changed

18 files changed

+192
-146
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
from litestar.openapi.plugins import SwaggerRenderPlugin
22

3-
swagger_plugin = SwaggerRenderPlugin(version="5.1.3", path="/swagger")
3+
swagger_plugin = SwaggerRenderPlugin(version="5.18.2", path="/swagger")

litestar/_openapi/schema_generation/schema.py

+40-24
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from copy import copy
55
from datetime import date, datetime, time, timedelta
66
from decimal import Decimal
7-
from enum import Enum, EnumMeta
7+
from enum import Enum
88
from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
99
from pathlib import Path
1010
from typing import (
@@ -40,9 +40,7 @@
4040
create_string_constrained_field_schema,
4141
)
4242
from litestar._openapi.schema_generation.utils import (
43-
_should_create_enum_schema,
4443
_should_create_literal_schema,
45-
_type_or_first_not_none_inner_type,
4644
get_json_schema_formatted_examples,
4745
)
4846
from litestar.datastructures import SecretBytes, SecretString, UploadFile
@@ -181,22 +179,6 @@ def _get_type_schema_name(field_definition: FieldDefinition) -> str:
181179
return name
182180

183181

184-
def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Schema:
185-
"""Create a schema instance for an enum.
186-
187-
Args:
188-
annotation: An enum.
189-
include_null: Whether to include null as a possible value.
190-
191-
Returns:
192-
A schema instance.
193-
"""
194-
enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated]
195-
if include_null and None not in enum_values:
196-
enum_values.append(None)
197-
return Schema(type=_types_in_list(enum_values), enum=enum_values)
198-
199-
200182
def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]:
201183
"""Iterate over the flattened arguments of a Literal.
202184
@@ -331,18 +313,20 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re
331313
result = self.for_type_alias_type(field_definition)
332314
elif plugin_for_annotation := self.get_plugin_for(field_definition):
333315
result = self.for_plugin(field_definition, plugin_for_annotation)
334-
elif _should_create_enum_schema(field_definition):
335-
annotation = _type_or_first_not_none_inner_type(field_definition)
336-
result = create_enum_schema(annotation, include_null=field_definition.is_optional)
337316
elif _should_create_literal_schema(field_definition):
338317
annotation = (
339318
make_non_optional_union(field_definition.annotation)
340319
if field_definition.is_optional
341320
else field_definition.annotation
342321
)
343-
result = create_literal_schema(annotation, include_null=field_definition.is_optional)
322+
result = create_literal_schema(
323+
annotation,
324+
include_null=field_definition.is_optional,
325+
)
344326
elif field_definition.is_optional:
345327
result = self.for_optional_field(field_definition)
328+
elif field_definition.is_enum:
329+
result = self.for_enum_field(field_definition)
346330
elif field_definition.is_union:
347331
result = self.for_union_field(field_definition)
348332
elif field_definition.is_type_var:
@@ -445,7 +429,7 @@ def for_optional_field(self, field_definition: FieldDefinition) -> Schema:
445429
else:
446430
result = [schema_or_reference]
447431

448-
return Schema(one_of=[Schema(type=OpenAPIType.NULL), *result])
432+
return Schema(one_of=[*result, Schema(type=OpenAPIType.NULL)])
449433

450434
def for_union_field(self, field_definition: FieldDefinition) -> Schema:
451435
"""Create a Schema for a union FieldDefinition.
@@ -569,6 +553,38 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) ->
569553
# INFO: Removed because it was only for pydantic constrained collections
570554
return schema
571555

556+
def for_enum_field(
557+
self,
558+
field_definition: FieldDefinition,
559+
) -> Schema | Reference:
560+
"""Create a schema instance for an enum.
561+
562+
Args:
563+
field_definition: A signature field instance.
564+
565+
Returns:
566+
A schema or reference instance.
567+
"""
568+
enum_type: None | OpenAPIType | list[OpenAPIType] = None
569+
if issubclass(field_definition.annotation, Enum): # pragma: no branch
570+
# This method is only called for enums, so this branch is always executed
571+
if issubclass(field_definition.annotation, str): # StrEnum
572+
enum_type = OpenAPIType.STRING
573+
elif issubclass(field_definition.annotation, int): # IntEnum
574+
enum_type = OpenAPIType.INTEGER
575+
576+
enum_values: list[Any] = [v.value for v in field_definition.annotation]
577+
if enum_type is None:
578+
enum_type = _types_in_list(enum_values)
579+
580+
schema = self.schema_registry.get_schema_for_field_definition(field_definition)
581+
schema.type = enum_type
582+
schema.enum = enum_values
583+
schema.title = get_name(field_definition.annotation)
584+
schema.description = field_definition.annotation.__doc__
585+
586+
return self.schema_registry.get_reference_for_field_definition(field_definition) or schema
587+
572588
def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference:
573589
if field.kwarg_definition and field.is_const and field.has_default and schema.const is None:
574590
schema.const = field.default

litestar/_openapi/schema_generation/utils.py

+1-48
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from __future__ import annotations
22

3-
from enum import Enum
43
from typing import TYPE_CHECKING, Any, Mapping
54

65
from litestar.utils.helpers import get_name
@@ -11,53 +10,7 @@
1110
from litestar.openapi.spec import Example
1211
from litestar.typing import FieldDefinition
1312

14-
__all__ = (
15-
"_should_create_enum_schema",
16-
"_should_create_literal_schema",
17-
"_type_or_first_not_none_inner_type",
18-
)
19-
20-
21-
def _type_or_first_not_none_inner_type(field_definition: FieldDefinition) -> Any:
22-
"""Get the first inner type that is not None.
23-
24-
This is a narrow focussed utility to be used when we know that a field definition either represents
25-
a single type, or a single type in a union with `None`, and we want the single type.
26-
27-
Args:
28-
field_definition: A field definition instance.
29-
30-
Returns:
31-
A field definition instance.
32-
"""
33-
if not field_definition.is_optional:
34-
return field_definition.annotation
35-
inner = next((t for t in field_definition.inner_types if not t.is_none_type), None)
36-
if inner is None:
37-
raise ValueError("Field definition has no inner type that is not None")
38-
return inner.annotation
39-
40-
41-
def _should_create_enum_schema(field_definition: FieldDefinition) -> bool:
42-
"""Predicate to determine if we should create an enum schema for the field def, or not.
43-
44-
This returns true if the field definition is an enum, or if the field definition is a union
45-
of an enum and ``None``.
46-
47-
When an annotation is ``SomeEnum | None`` we should create a schema for the enum that includes ``null``
48-
in the enum values.
49-
50-
Args:
51-
field_definition: A field definition instance.
52-
53-
Returns:
54-
A boolean
55-
"""
56-
return field_definition.is_subclass_of(Enum) or (
57-
field_definition.is_optional
58-
and len(field_definition.args) == 2
59-
and field_definition.has_inner_subclass_of(Enum)
60-
)
13+
__all__ = ("_should_create_literal_schema",)
6114

6215

6316
def _should_create_literal_schema(field_definition: FieldDefinition) -> bool:

litestar/openapi/controller.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ class OpenAPIController(Controller):
3636
"""Base styling of the html body."""
3737
redoc_version: str = "next"
3838
"""Redoc version to download from the CDN."""
39-
swagger_ui_version: str = "5.1.3"
39+
swagger_ui_version: str = "5.18.2"
4040
"""SwaggerUI version to download from the CDN."""
4141
stoplight_elements_version: str = "7.7.18"
4242
"""StopLight Elements version to download from the CDN."""

litestar/openapi/plugins.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ class SwaggerRenderPlugin(OpenAPIRenderPlugin):
499499

500500
def __init__(
501501
self,
502-
version: str = "5.1.3",
502+
version: str = "5.18.2",
503503
js_url: str | None = None,
504504
css_url: str | None = None,
505505
standalone_preset_js_url: str | None = None,

litestar/typing.py

+5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from collections import abc
66
from copy import deepcopy
77
from dataclasses import dataclass, is_dataclass, replace
8+
from enum import Enum
89
from inspect import Parameter, Signature
910
from typing import Any, AnyStr, Callable, Collection, ForwardRef, Literal, Mapping, TypeVar, cast
1011

@@ -339,6 +340,10 @@ def is_typeddict_type(self) -> bool:
339340

340341
return is_typeddict(self.origin or self.annotation)
341342

343+
@property
344+
def is_enum(self) -> bool:
345+
return self.is_subclass_of(Enum)
346+
342347
@property
343348
def type_(self) -> Any:
344349
"""The type of the annotation with all the wrappers removed, including the generic types."""

tests/unit/test_contrib/test_piccolo_orm/test_piccolo_orm_dto.py

+14-9
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,14 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
139139
assert concert_schema
140140
assert concert_schema.to_schema() == {
141141
"properties": {
142-
"band_1": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
143-
"band_2": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
144-
"venue": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
142+
"band_1": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
143+
"band_2": {
144+
"oneOf": [
145+
{"type": "integer"},
146+
{"type": "null"},
147+
]
148+
},
149+
"venue": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
145150
},
146151
"required": [],
147152
"title": "CreateConcertConcertRequestBody",
@@ -152,10 +157,10 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
152157
assert record_studio_schema
153158
assert record_studio_schema.to_schema() == {
154159
"properties": {
155-
"facilities": {"oneOf": [{"type": "null"}, {"type": "string"}]},
156-
"facilities_b": {"oneOf": [{"type": "null"}, {"type": "string"}]},
157-
"microphones": {"oneOf": [{"type": "null"}, {"items": {"type": "string"}, "type": "array"}]},
158-
"id": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
160+
"facilities": {"oneOf": [{"type": "string"}, {"type": "null"}]},
161+
"facilities_b": {"oneOf": [{"type": "string"}, {"type": "null"}]},
162+
"microphones": {"oneOf": [{"items": {"type": "string"}, "type": "array"}, {"type": "null"}]},
163+
"id": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
159164
},
160165
"required": [],
161166
"title": "RetrieveStudioRecordingStudioResponseBody",
@@ -166,8 +171,8 @@ def test_piccolo_dto_openapi_spec_generation() -> None:
166171
assert venue_schema
167172
assert venue_schema.to_schema() == {
168173
"properties": {
169-
"id": {"oneOf": [{"type": "null"}, {"type": "integer"}]},
170-
"name": {"oneOf": [{"type": "null"}, {"type": "string"}]},
174+
"id": {"oneOf": [{"type": "integer"}, {"type": "null"}]},
175+
"name": {"oneOf": [{"type": "string"}, {"type": "null"}]},
171176
},
172177
"required": [],
173178
"title": "RetrieveVenuesVenueResponseBody",

tests/unit/test_openapi/conftest.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from litestar.openapi.spec.example import Example
1111
from litestar.params import Parameter
1212
from tests.models import DataclassPerson, DataclassPersonFactory, DataclassPet
13-
from tests.unit.test_openapi.utils import Gender, PetException
13+
from tests.unit.test_openapi.utils import Gender, LuckyNumber, PetException
1414

1515

1616
class PartialDataclassPersonDTO(DataclassDTO[DataclassPerson]):
@@ -45,8 +45,9 @@ def get_persons(
4545
from_date: Optional[Union[int, datetime, date]] = None,
4646
to_date: Optional[Union[int, datetime, date]] = None,
4747
gender: Optional[Union[Gender, List[Gender]]] = Parameter(
48-
examples=[Example(value="M"), Example(value=["M", "O"])]
48+
examples=[Example(value=Gender.MALE), Example(value=[Gender.MALE, Gender.OTHER])]
4949
),
50+
lucky_number: Optional[LuckyNumber] = Parameter(examples=[Example(value=LuckyNumber.SEVEN)]),
5051
# header parameter
5152
secret_header: str = Parameter(header="secret"),
5253
# cookie parameter

tests/unit/test_openapi/test_endpoints.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_default_redoc_cdn_urls(
3939
def test_default_swagger_ui_cdn_urls(
4040
person_controller: Type[Controller], pet_controller: Type[Controller], config: OpenAPIConfig
4141
) -> None:
42-
default_swagger_ui_version = "5.1.3"
42+
default_swagger_ui_version = "5.18.2"
4343
default_swagger_bundles = [
4444
f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui.css",
4545
f"https://cdn.jsdelivr.net/npm/swagger-ui-dist@{default_swagger_ui_version}/swagger-ui-bundle.js",

tests/unit/test_openapi/test_parameters.py

+24-17
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
from litestar.exceptions import ImproperlyConfiguredException
1616
from litestar.handlers import HTTPRouteHandler
1717
from litestar.openapi import OpenAPIConfig
18-
from litestar.openapi.spec import Example, OpenAPI, Schema
18+
from litestar.openapi.spec import Example, OpenAPI, Reference, Schema
1919
from litestar.openapi.spec.enums import OpenAPIType
2020
from litestar.params import Dependency, Parameter
2121
from litestar.routes import BaseRoute
2222
from litestar.testing import create_test_client
2323
from litestar.utils import find_index
24+
from tests.unit.test_openapi.utils import Gender, LuckyNumber
2425

2526
if TYPE_CHECKING:
2627
from litestar.openapi.spec.parameter import Parameter as OpenAPIParameter
@@ -49,8 +50,10 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
4950
ExampleFactory.seed_random(10)
5051

5152
parameters = _create_parameters(app=Litestar(route_handlers=[person_controller]), path="/{service_id}/person")
52-
assert len(parameters) == 9
53-
page, name, service_id, page_size, from_date, to_date, gender, secret_header, cookie_value = tuple(parameters)
53+
assert len(parameters) == 10
54+
page, name, service_id, page_size, from_date, to_date, gender, lucky_number, secret_header, cookie_value = tuple(
55+
parameters
56+
)
5457

5558
assert service_id.name == "service_id"
5659
assert service_id.param_in == ParamType.PATH
@@ -104,23 +107,15 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
104107
assert is_schema_value(gender.schema)
105108
assert gender.schema == Schema(
106109
one_of=[
107-
Schema(type=OpenAPIType.NULL),
108-
Schema(
109-
type=OpenAPIType.STRING,
110-
enum=["M", "F", "O", "A"],
111-
examples=["M"],
112-
),
110+
Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
113111
Schema(
114112
type=OpenAPIType.ARRAY,
115-
items=Schema(
116-
type=OpenAPIType.STRING,
117-
enum=["M", "F", "O", "A"],
118-
examples=["F"],
119-
),
120-
examples=[["A"]],
113+
items=Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_Gender"),
114+
examples=[[Gender.MALE]],
121115
),
116+
Schema(type=OpenAPIType.NULL),
122117
],
123-
examples=["M", ["M", "O"]],
118+
examples=[Gender.MALE, [Gender.MALE, Gender.OTHER]],
124119
)
125120
assert not gender.required
126121

@@ -136,6 +131,18 @@ def test_create_parameters(person_controller: Type[Controller]) -> None:
136131
assert cookie_value.required
137132
assert cookie_value.schema.examples
138133

134+
assert lucky_number.param_in == ParamType.QUERY
135+
assert lucky_number.name == "lucky_number"
136+
assert is_schema_value(lucky_number.schema)
137+
assert lucky_number.schema == Schema(
138+
one_of=[
139+
Reference(ref="#/components/schemas/tests_unit_test_openapi_utils_LuckyNumber"),
140+
Schema(type=OpenAPIType.NULL),
141+
],
142+
examples=[LuckyNumber.SEVEN],
143+
)
144+
assert not lucky_number.required
145+
139146

140147
def test_deduplication_for_param_where_key_and_type_are_equal() -> None:
141148
class BaseDep:
@@ -397,8 +404,8 @@ async def handler(
397404
app = Litestar([handler])
398405
assert app.openapi_schema.paths["/{path_param}"].get.parameters[0].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
399406
assert app.openapi_schema.paths["/{path_param}"].get.parameters[1].schema.one_of == [ # type: ignore[index, union-attr]
400-
Schema(type=OpenAPIType.NULL),
401407
Schema(type=OpenAPIType.STRING),
408+
Schema(type=OpenAPIType.NULL),
402409
]
403410
assert app.openapi_schema.paths["/{path_param}"].get.parameters[2].schema.type == OpenAPIType.STRING # type: ignore[index, union-attr]
404411
assert (

0 commit comments

Comments
 (0)