|
4 | 4 | from copy import copy
|
5 | 5 | from datetime import date, datetime, time, timedelta
|
6 | 6 | from decimal import Decimal
|
7 |
| -from enum import Enum, EnumMeta |
| 7 | +from enum import Enum |
8 | 8 | from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network
|
9 | 9 | from pathlib import Path
|
10 | 10 | from typing import (
|
|
40 | 40 | create_string_constrained_field_schema,
|
41 | 41 | )
|
42 | 42 | from litestar._openapi.schema_generation.utils import (
|
43 |
| - _should_create_enum_schema, |
44 | 43 | _should_create_literal_schema,
|
45 |
| - _type_or_first_not_none_inner_type, |
46 | 44 | get_json_schema_formatted_examples,
|
47 | 45 | )
|
48 | 46 | from litestar.datastructures import SecretBytes, SecretString, UploadFile
|
@@ -181,22 +179,6 @@ def _get_type_schema_name(field_definition: FieldDefinition) -> str:
|
181 | 179 | return name
|
182 | 180 |
|
183 | 181 |
|
184 |
| -def create_enum_schema(annotation: EnumMeta, include_null: bool = False) -> Schema: |
185 |
| - """Create a schema instance for an enum. |
186 |
| -
|
187 |
| - Args: |
188 |
| - annotation: An enum. |
189 |
| - include_null: Whether to include null as a possible value. |
190 |
| -
|
191 |
| - Returns: |
192 |
| - A schema instance. |
193 |
| - """ |
194 |
| - enum_values: list[str | int | None] = [v.value for v in annotation] # type: ignore[var-annotated] |
195 |
| - if include_null and None not in enum_values: |
196 |
| - enum_values.append(None) |
197 |
| - return Schema(type=_types_in_list(enum_values), enum=enum_values) |
198 |
| - |
199 |
| - |
200 | 182 | def _iter_flat_literal_args(annotation: Any) -> Iterable[Any]:
|
201 | 183 | """Iterate over the flattened arguments of a Literal.
|
202 | 184 |
|
@@ -331,18 +313,20 @@ def for_field_definition(self, field_definition: FieldDefinition) -> Schema | Re
|
331 | 313 | result = self.for_type_alias_type(field_definition)
|
332 | 314 | elif plugin_for_annotation := self.get_plugin_for(field_definition):
|
333 | 315 | result = self.for_plugin(field_definition, plugin_for_annotation)
|
334 |
| - elif _should_create_enum_schema(field_definition): |
335 |
| - annotation = _type_or_first_not_none_inner_type(field_definition) |
336 |
| - result = create_enum_schema(annotation, include_null=field_definition.is_optional) |
337 | 316 | elif _should_create_literal_schema(field_definition):
|
338 | 317 | annotation = (
|
339 | 318 | make_non_optional_union(field_definition.annotation)
|
340 | 319 | if field_definition.is_optional
|
341 | 320 | else field_definition.annotation
|
342 | 321 | )
|
343 |
| - result = create_literal_schema(annotation, include_null=field_definition.is_optional) |
| 322 | + result = create_literal_schema( |
| 323 | + annotation, |
| 324 | + include_null=field_definition.is_optional, |
| 325 | + ) |
344 | 326 | elif field_definition.is_optional:
|
345 | 327 | result = self.for_optional_field(field_definition)
|
| 328 | + elif field_definition.is_enum: |
| 329 | + result = self.for_enum_field(field_definition) |
346 | 330 | elif field_definition.is_union:
|
347 | 331 | result = self.for_union_field(field_definition)
|
348 | 332 | elif field_definition.is_type_var:
|
@@ -445,7 +429,7 @@ def for_optional_field(self, field_definition: FieldDefinition) -> Schema:
|
445 | 429 | else:
|
446 | 430 | result = [schema_or_reference]
|
447 | 431 |
|
448 |
| - return Schema(one_of=[Schema(type=OpenAPIType.NULL), *result]) |
| 432 | + return Schema(one_of=[*result, Schema(type=OpenAPIType.NULL)]) |
449 | 433 |
|
450 | 434 | def for_union_field(self, field_definition: FieldDefinition) -> Schema:
|
451 | 435 | """Create a Schema for a union FieldDefinition.
|
@@ -569,6 +553,38 @@ def for_collection_constrained_field(self, field_definition: FieldDefinition) ->
|
569 | 553 | # INFO: Removed because it was only for pydantic constrained collections
|
570 | 554 | return schema
|
571 | 555 |
|
| 556 | + def for_enum_field( |
| 557 | + self, |
| 558 | + field_definition: FieldDefinition, |
| 559 | + ) -> Schema | Reference: |
| 560 | + """Create a schema instance for an enum. |
| 561 | +
|
| 562 | + Args: |
| 563 | + field_definition: A signature field instance. |
| 564 | +
|
| 565 | + Returns: |
| 566 | + A schema or reference instance. |
| 567 | + """ |
| 568 | + enum_type: None | OpenAPIType | list[OpenAPIType] = None |
| 569 | + if issubclass(field_definition.annotation, Enum): # pragma: no branch |
| 570 | + # This method is only called for enums, so this branch is always executed |
| 571 | + if issubclass(field_definition.annotation, str): # StrEnum |
| 572 | + enum_type = OpenAPIType.STRING |
| 573 | + elif issubclass(field_definition.annotation, int): # IntEnum |
| 574 | + enum_type = OpenAPIType.INTEGER |
| 575 | + |
| 576 | + enum_values: list[Any] = [v.value for v in field_definition.annotation] |
| 577 | + if enum_type is None: |
| 578 | + enum_type = _types_in_list(enum_values) |
| 579 | + |
| 580 | + schema = self.schema_registry.get_schema_for_field_definition(field_definition) |
| 581 | + schema.type = enum_type |
| 582 | + schema.enum = enum_values |
| 583 | + schema.title = get_name(field_definition.annotation) |
| 584 | + schema.description = field_definition.annotation.__doc__ |
| 585 | + |
| 586 | + return self.schema_registry.get_reference_for_field_definition(field_definition) or schema |
| 587 | + |
572 | 588 | def process_schema_result(self, field: FieldDefinition, schema: Schema) -> Schema | Reference:
|
573 | 589 | if field.kwarg_definition and field.is_const and field.has_default and schema.const is None:
|
574 | 590 | schema.const = field.default
|
|
0 commit comments