Skip to content

Commit 3510cab

Browse files
fix: resource warnings (#3838)
--------- Co-authored-by: Janek Nouvertné <25355197+provinzkraut@users.noreply.github.com>
1 parent 69181fd commit 3510cab

File tree

12 files changed

+92
-64
lines changed

12 files changed

+92
-64
lines changed

litestar/_kwargs/extractors.py

+13-11
Original file line numberDiff line numberDiff line change
@@ -337,16 +337,16 @@ async def _extract_multipart(
337337
if body_kwarg_multipart_form_part_limit is not None
338338
else connection.app.multipart_form_part_limit
339339
)
340-
connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key]
341-
connection.scope["_form"] # type: ignore[typeddict-item]
342-
if "_form" in connection.scope
343-
else await parse_multipart_form(
340+
scope_state = ScopeState.from_scope(connection.scope)
341+
if scope_state.form is Empty:
342+
scope_state.form = form_values = await parse_multipart_form(
344343
stream=connection.stream(),
345344
boundary=connection.content_type[-1].get("boundary", "").encode(),
346345
multipart_form_part_limit=multipart_form_part_limit,
347346
type_decoders=connection.route_handler.resolve_type_decoders(),
348347
)
349-
)
348+
else:
349+
form_values = scope_state.form
350350

351351
if field_definition.is_non_string_sequence:
352352
values = list(form_values.values())
@@ -377,7 +377,7 @@ async def _extract_multipart(
377377
or (is_optional_union(tp) and is_non_string_sequence(make_non_optional_union(tp)))
378378
)
379379
):
380-
form_values[name] = [value]
380+
form_values[name] = [value] # pyright: ignore
381381

382382
return form_values
383383

@@ -426,11 +426,13 @@ def create_url_encoded_data_extractor(
426426
async def extract_url_encoded_extractor(
427427
connection: Request[Any, Any, Any],
428428
) -> Any:
429-
connection.scope["_form"] = form_values = ( # type: ignore[typeddict-unknown-key]
430-
connection.scope["_form"] # type: ignore[typeddict-item]
431-
if "_form" in connection.scope
432-
else parse_url_encoded_form_data(await connection.body())
433-
)
429+
scope_state = ScopeState.from_scope(connection.scope)
430+
if scope_state.form is Empty:
431+
scope_state.form = form_values = ( # type: ignore[assignment]
432+
parse_url_encoded_form_data(await connection.body())
433+
)
434+
else:
435+
form_values = scope_state.form # type: ignore[assignment]
434436

435437
if not form_values and is_data_optional:
436438
return None

litestar/_multipart.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ def parse_content_header(value: str) -> tuple[str, dict[str, str]]:
4848
return value.strip().lower(), options
4949

5050

51+
async def _close_upload_files(fields: dict[str, list[Any]]) -> None:
52+
for values in fields.values():
53+
for value in values:
54+
if isinstance(value, UploadFile):
55+
await value.close()
56+
57+
5158
async def parse_multipart_form( # noqa: C901
5259
stream: AsyncGenerator[bytes, None],
5360
boundary: bytes,
@@ -72,10 +79,11 @@ async def parse_multipart_form( # noqa: C901
7279
if not chunk:
7380
return fields
7481

82+
data: UploadFile | bytearray = bytearray()
83+
7584
try:
7685
with PushMultipartParser(boundary, max_segment_count=multipart_form_part_limit) as parser:
7786
segment: MultipartSegment | None = None
78-
data: UploadFile | bytearray = bytearray()
7987
while not parser.closed:
8088
for form_part in parser.parse(chunk):
8189
if isinstance(form_part, MultipartSegment):
@@ -113,8 +121,17 @@ async def parse_multipart_form( # noqa: C901
113121
chunk = await async_next(stream, b"")
114122

115123
except ParserError as exc:
124+
# if an exception is raised, make sure that all 'UploadFile's are closed
125+
if isinstance(data, UploadFile):
126+
await data.close()
127+
await _close_upload_files(fields)
128+
116129
raise ClientException("Invalid multipart/form-data") from exc
117130
except ParserLimitReached:
131+
if isinstance(data, UploadFile):
132+
await data.close()
133+
await _close_upload_files(fields)
134+
118135
# FIXME (3.0): This should raise a '413 - Request Entity Too Large', but for
119136
# backwards compatibility, we keep it as a 400 for now
120137
raise ClientException("Request Entity Too Large") from None

litestar/connection/request.py

+3-13
Original file line numberDiff line numberDiff line change
@@ -270,25 +270,15 @@ async def form(self) -> FormMultiDict:
270270
multipart_form_part_limit=self.app.multipart_form_part_limit,
271271
)
272272
elif content_type == RequestEncodingType.URL_ENCODED:
273-
form_data = parse_url_encoded_form_data(
273+
form_data = parse_url_encoded_form_data( # type: ignore[assignment]
274274
await self.body(),
275275
)
276276
else:
277277
form_data = {}
278278

279-
self._connection_state.form = form_data
279+
self._connection_state.form = form_data # pyright: ignore
280280

281-
# form_data is a dict[str, list[str] | str | UploadFile]. Convert it to a
282-
# list[tuple[str, str | UploadFile]] before passing it to FormMultiDict so
283-
# multi-keys can be accessed properly
284-
items = []
285-
for k, v in form_data.items():
286-
if isinstance(v, list):
287-
for sv in v:
288-
items.append((k, sv))
289-
else:
290-
items.append((k, v))
291-
self._form = FormMultiDict(items)
281+
self._form = FormMultiDict.from_form_data(cast("dict[str, Any]", form_data))
292282

293283
return self._form
294284

litestar/datastructures/multi_dicts.py

+21
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,27 @@ def copy(self) -> Self: # type: ignore[override]
9595
class FormMultiDict(ImmutableMultiDict[Any]):
9696
"""MultiDict for form data."""
9797

98+
@classmethod
99+
def from_form_data(cls, form_data: dict[str, list[str] | str | UploadFile]) -> FormMultiDict:
100+
"""Create a FormMultiDict from form data.
101+
102+
Args:
103+
form_data: Form data to create the FormMultiDict from.
104+
105+
Returns:
106+
A FormMultiDict instance
107+
"""
108+
# Convert form_data to a list[tuple[str, str | UploadFile]] before passing it
109+
# to FormMultiDict so multi-keys can be accessed properly
110+
items = []
111+
for k, v in form_data.items():
112+
if not isinstance(v, list):
113+
items.append((k, v))
114+
else:
115+
for sv in v:
116+
items.append((k, sv))
117+
return cls(items)
118+
98119
async def close(self) -> None:
99120
"""Close all files in the multi-dict.
100121

litestar/datastructures/upload_file.py

+2
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ async def close(self) -> None:
9494
Returns:
9595
None.
9696
"""
97+
if self.file.closed:
98+
return None
9799
if self.rolled_to_disk:
98100
return await sync_to_thread(self.file.close)
99101
return self.file.close()

litestar/routes/http.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
from itertools import chain
4-
from typing import TYPE_CHECKING, Any, cast
4+
from typing import TYPE_CHECKING, Any
55

66
from msgspec.msgpack import decode as _decode_msgpack_plain
77

8-
from litestar.datastructures.upload_file import UploadFile
8+
from litestar.datastructures.multi_dicts import FormMultiDict
99
from litestar.enums import HttpMethod, MediaType, ScopeType
1010
from litestar.exceptions import ClientException, ImproperlyConfiguredException, SerializationException
1111
from litestar.handlers.http_handlers import HTTPRouteHandler
@@ -77,17 +77,18 @@ async def handle(self, scope: HTTPScope, receive: Receive, send: Send) -> None:
7777
if route_handler.resolve_guards():
7878
await route_handler.authorize_connection(connection=request)
7979

80-
response = await self._get_response_for_request(
81-
scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
82-
)
83-
84-
await response(scope, receive, send)
80+
try:
81+
response = await self._get_response_for_request(
82+
scope=scope, request=request, route_handler=route_handler, parameter_model=parameter_model
83+
)
8584

86-
if after_response_handler := route_handler.resolve_after_response():
87-
await after_response_handler(request)
85+
await response(scope, receive, send)
8886

89-
if form_data := scope.get("_form", {}):
90-
await self._cleanup_temporary_files(form_data=cast("dict[str, Any]", form_data))
87+
if after_response_handler := route_handler.resolve_after_response():
88+
await after_response_handler(request)
89+
finally:
90+
if (form_data := ScopeState.from_scope(scope).form) is not Empty:
91+
await FormMultiDict.from_form_data(form_data).close()
9192

9293
def create_handler_map(self) -> None:
9394
"""Parse the ``router_handlers`` of this route and return a mapping of
@@ -258,9 +259,3 @@ def options_handler(scope: Scope) -> Response:
258259
include_in_schema=False,
259260
sync_to_thread=False,
260261
)(options_handler)
261-
262-
@staticmethod
263-
async def _cleanup_temporary_files(form_data: dict[str, Any]) -> None:
264-
for v in form_data.values():
265-
if isinstance(v, UploadFile) and not v.file.closed:
266-
await v.close()

litestar/utils/scope/state.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
if TYPE_CHECKING:
1010
from typing_extensions import Self
1111

12-
from litestar.datastructures import URL, Accept, Headers
12+
from litestar.datastructures import URL, Accept, Headers, UploadFile
1313
from litestar.types.asgi_types import Scope
1414
from litestar.types.composite_types import ExceptionHandlersMap
1515

@@ -83,7 +83,7 @@ def __init__(self) -> None:
8383
dependency_cache: dict[str, Any] | EmptyType
8484
do_cache: bool | EmptyType
8585
exception_handlers: ExceptionHandlersMap | EmptyType
86-
form: dict[str, str | list[str]] | EmptyType
86+
form: dict[str, str | list[str] | UploadFile] | EmptyType
8787
flash_messages: list[dict[str, str]]
8888
headers: Headers | EmptyType
8989
is_cached: bool | EmptyType

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ fail_under = 50
203203
addopts = "--strict-markers --strict-config --dist=loadgroup -m 'not server_integration'"
204204
asyncio_mode = "auto"
205205
filterwarnings = [
206+
"error",
207+
# https://github.com/pytest-dev/pytest-asyncio/issues/724
208+
"default:.*socket.socket:pytest.PytestUnraisableExceptionWarning",
206209
"ignore::trio.TrioDeprecationWarning:anyio._backends._trio*:",
207210
"ignore::DeprecationWarning:pkg_resources.*",
208211
"ignore::DeprecationWarning:google.rpc",

tests/e2e/test_routing/conftest.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import subprocess
21
import time
32
from pathlib import Path
43
from typing import Callable, List
@@ -16,16 +15,13 @@ def runner(app: str, server_command: List[str]) -> None:
1615
tmp_path.joinpath("app.py").write_text(app)
1716
monkeypatch.chdir(tmp_path)
1817

19-
proc = psutil.Popen(
20-
server_command,
21-
stderr=subprocess.PIPE,
22-
stdout=subprocess.PIPE,
23-
)
18+
proc = psutil.Popen(server_command)
2419

2520
def kill() -> None:
2621
for child in proc.children(recursive=True):
2722
child.kill()
2823
proc.kill()
24+
proc.wait()
2925

3026
request.addfinalizer(kill)
3127

tests/unit/test_datastructures/test_multi_dicts.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
from unittest.mock import patch
4+
35
import pytest
4-
from pytest_mock import MockerFixture
56

67
from litestar.datastructures import UploadFile
78
from litestar.datastructures.multi_dicts import FormMultiDict, ImmutableMultiDict, MultiDict
@@ -34,20 +35,19 @@ def test_immutable_multi_dict_as_mutable() -> None:
3435
assert multi.mutable_copy().dict() == MultiDict(data).dict()
3536

3637

37-
async def test_form_multi_dict_close(mocker: MockerFixture) -> None:
38-
close = mocker.patch("litestar.datastructures.multi_dicts.UploadFile.close")
39-
38+
async def test_form_multi_dict_close() -> None:
4039
multi = FormMultiDict(
4140
[
4241
("foo", UploadFile(filename="foo", content_type="text/plain")),
4342
("bar", UploadFile(filename="foo", content_type="text/plain")),
4443
]
4544
)
46-
45+
with patch("litestar.datastructures.multi_dicts.UploadFile.close") as mock_close:
46+
await multi.close()
47+
assert mock_close.call_count == 2
48+
# calls the real UploadFile.close method to clean up
4749
await multi.close()
4850

49-
assert close.call_count == 2
50-
5151

5252
@pytest.mark.parametrize("type_", [MultiDict, ImmutableMultiDict])
5353
def test_copy(type_: type[MultiDict | ImmutableMultiDict]) -> None:

tests/unit/test_dto/test_factory/test_integration.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -56,11 +56,13 @@ def handler(data: User = Body(media_type=RequestEncodingType.URL_ENCODED)) -> Us
5656

5757

5858
async def test_multipart_encoded_form_data(use_experimental_dto_backend: bool) -> None:
59+
default_file = UploadFile(content_type="text/plain", filename="forbidden", file_data=b"forbidden")
60+
5961
@dataclass
6062
class Payload:
6163
file: UploadFile
6264
forbidden: UploadFile = field(
63-
default=UploadFile(content_type="text/plain", filename="forbidden", file_data=b"forbidden"),
65+
default=default_file,
6466
metadata=dto_field("read-only"),
6567
)
6668

@@ -78,6 +80,8 @@ async def handler(data: Payload = Body(media_type=RequestEncodingType.MULTI_PART
7880
)
7981
assert response.content == b"forbidden"
8082

83+
await default_file.close()
84+
8185

8286
def test_renamed_field(use_experimental_dto_backend: bool) -> None:
8387
@dataclass

tests/unit/test_plugins/test_pydantic/test_beanie_integration.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
1-
from typing import TYPE_CHECKING, Optional, Type
1+
from typing import Optional
22

33
import beanie
4-
5-
if TYPE_CHECKING:
6-
from pydantic import BaseModel
4+
from pydantic import BaseModel
75

86
from litestar.plugins.pydantic import PydanticDTO
97

108

11-
def test_generate_field_definitions_from_beanie_models(base_model: "Type[BaseModel]") -> None:
12-
class Category(base_model): # type: ignore[valid-type, misc]
9+
def test_generate_field_definitions_from_beanie_models() -> None:
10+
class Category(BaseModel):
1311
name: str
1412
description: str
1513

0 commit comments

Comments
 (0)