Skip to content

Commit bbf6218

Browse files
arthurdarcetasvetlov
authored andcommitted
Slow request body copy (#2126)
* avoid copying the request payload before writing it to the transport * when the server closes the connection before writing EOF, the client should see a payload error, not expect to see the deconnection before the body is read * do not use private attributes in SendfilePayloadWriter * on windows, some test is now raising a fully fledged ClientOSError instead of an empty ServerDisconnectedError * code review * add CHANGES file * in PayloadWriter, the waiter result is set and the self._drain_waiter is set back to None at the same time. So self._drain_waiter can never be done() * test compressing by chunks in the PayloadWriter (cover the branch where a chunk does not yield anything that should be written to the transport) * Test writing too many chunks to a PayloadWriter that has a Content-Length * test multiple waiting drains in PayloadWriter
1 parent 84825da commit bbf6218

6 files changed

+98
-57
lines changed

CHANGES/2126.feature

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Speed up the `PayloadWriter.write` method for large request bodies.

aiohttp/http_writer.py

+22-36
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ class PayloadWriter(AbstractPayloadWriter):
125125
def __init__(self, stream, loop, acquire=True):
126126
self._stream = stream
127127
self._transport = None
128+
self._buffer = []
128129

129130
self.loop = loop
130131
self.length = None
@@ -133,28 +134,35 @@ def __init__(self, stream, loop, acquire=True):
133134
self.output_size = 0
134135

135136
self._eof = False
136-
self._buffer = []
137137
self._compress = None
138138
self._drain_waiter = None
139139

140140
if self._stream.available:
141141
self._transport = self._stream.transport
142+
self._buffer = None
142143
self._stream.available = False
143144
elif acquire:
144145
self._stream.acquire(self)
145146

146147
def set_transport(self, transport):
147148
self._transport = transport
148149

149-
chunk = b''.join(self._buffer)
150-
if chunk:
151-
transport.write(chunk)
152-
self._buffer.clear()
150+
if self._buffer is not None:
151+
for chunk in self._buffer:
152+
transport.write(chunk)
153+
self._buffer = None
153154

154155
if self._drain_waiter is not None:
155156
waiter, self._drain_waiter = self._drain_waiter, None
156-
if not waiter.done():
157-
waiter.set_result(None)
157+
waiter.set_result(None)
158+
159+
async def get_transport(self):
160+
if self._transport is None:
161+
if self._drain_waiter is None:
162+
self._drain_waiter = self.loop.create_future()
163+
await self._drain_waiter
164+
165+
return self._transport
158166

159167
@property
160168
def tcp_nodelay(self):
@@ -178,25 +186,14 @@ def enable_compression(self, encoding='deflate'):
178186
if encoding == 'gzip' else -zlib.MAX_WBITS)
179187
self._compress = zlib.compressobj(wbits=zlib_mode)
180188

181-
def buffer_data(self, chunk):
182-
if chunk:
183-
size = len(chunk)
184-
self.buffer_size += size
185-
self.output_size += size
186-
self._buffer.append(chunk)
187-
188189
def _write(self, chunk):
189190
size = len(chunk)
190191
self.buffer_size += size
191192
self.output_size += size
192193

194+
# see set_transport: exactly one of _buffer or _transport is None
193195
if self._transport is not None:
194-
if self._buffer:
195-
self._buffer.append(chunk)
196-
self._transport.write(b''.join(self._buffer))
197-
self._buffer.clear()
198-
else:
199-
self._transport.write(chunk)
196+
self._transport.write(chunk)
200197
else:
201198
self._buffer.append(chunk)
202199

@@ -241,11 +238,7 @@ def write_headers(self, status_line, headers, SEP=': ', END='\r\n'):
241238
headers = status_line + ''.join(
242239
[k + SEP + v + END for k, v in headers.items()])
243240
headers = headers.encode('utf-8') + b'\r\n'
244-
245-
size = len(headers)
246-
self.buffer_size += size
247-
self.output_size += size
248-
self._buffer.append(headers)
241+
self._write(headers)
249242

250243
async def write_eof(self, chunk=b''):
251244
if self._eof:
@@ -268,24 +261,17 @@ async def write_eof(self, chunk=b''):
268261
chunk = b'0\r\n\r\n'
269262

270263
if chunk:
271-
self.buffer_data(chunk)
264+
self._write(chunk)
272265

273-
await self.drain(True)
266+
await self.drain()
274267

275268
self._eof = True
276269
self._transport = None
277270
self._stream.release()
278271

279-
async def drain(self, last=False):
272+
async def drain(self):
280273
if self._transport is not None:
281-
if self._buffer:
282-
self._transport.write(b''.join(self._buffer))
283-
if not last:
284-
self._buffer.clear()
285274
await self._stream.drain()
286275
else:
287276
# wait for transport
288-
if self._drain_waiter is None:
289-
self._drain_waiter = self.loop.create_future()
290-
291-
await self._drain_waiter
277+
await self.get_transport()

aiohttp/web_fileresponse.py

+13-18
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,16 @@
1818

1919
class SendfilePayloadWriter(PayloadWriter):
2020

21-
def set_transport(self, transport):
22-
self._transport = transport
23-
24-
if self._drain_waiter is not None:
25-
waiter, self._drain_maiter = self._drain_maiter, None
26-
if not waiter.done():
27-
waiter.set_result(None)
21+
def __init__(self, *args, **kwargs):
22+
self._sendfile_buffer = []
23+
super().__init__(*args, **kwargs)
2824

2925
def _write(self, chunk):
26+
# we overwrite PayloadWriter._write, so nothing can be appended to
27+
# _buffer, and nothing is written to the transport directly by the
28+
# parent class
3029
self.output_size += len(chunk)
31-
self._buffer.append(chunk)
30+
self._sendfile_buffer.append(chunk)
3231

3332
def _sendfile_cb(self, fut, out_fd, in_fd,
3433
offset, count, loop, registered):
@@ -54,33 +53,29 @@ def _sendfile_cb(self, fut, out_fd, in_fd,
5453
fut.set_result(None)
5554

5655
async def sendfile(self, fobj, count):
57-
if self._transport is None:
58-
if self._drain_waiter is None:
59-
self._drain_waiter = self.loop.create_future()
60-
61-
await self._drain_waiter
56+
transport = await self.get_transport()
6257

63-
out_socket = self._transport.get_extra_info("socket").dup()
58+
out_socket = transport.get_extra_info('socket').dup()
6459
out_socket.setblocking(False)
6560
out_fd = out_socket.fileno()
6661
in_fd = fobj.fileno()
6762
offset = fobj.tell()
6863

6964
loop = self.loop
65+
data = b''.join(self._sendfile_buffer)
7066
try:
71-
await loop.sock_sendall(out_socket, b''.join(self._buffer))
67+
await loop.sock_sendall(out_socket, data)
7268
fut = loop.create_future()
7369
self._sendfile_cb(fut, out_fd, in_fd, offset, count, loop, False)
7470
await fut
7571
except Exception:
7672
server_logger.debug('Socket error')
77-
self._transport.close()
73+
transport.close()
7874
finally:
7975
out_socket.close()
8076

8177
self.output_size += count
82-
self._transport = None
83-
self._stream.release()
78+
await super().write_eof()
8479

8580
async def write_eof(self, chunk=b''):
8681
pass

tests/test_client_functional.py

+42-1
Original file line numberDiff line numberDiff line change
@@ -1694,6 +1694,27 @@ async def handler(request):
16941694
resp.close()
16951695

16961696

1697+
async def test_encoding_gzip_write_by_chunks(loop, test_client):
1698+
1699+
async def handler(request):
1700+
resp = web.StreamResponse()
1701+
resp.enable_compression(web.ContentCoding.gzip)
1702+
await resp.prepare(request)
1703+
await resp.write(b'0')
1704+
await resp.write(b'0')
1705+
return resp
1706+
1707+
app = web.Application()
1708+
app.router.add_get('/', handler)
1709+
client = await test_client(app)
1710+
1711+
resp = await client.get('/')
1712+
assert 200 == resp.status
1713+
txt = await resp.text()
1714+
assert txt == '00'
1715+
resp.close()
1716+
1717+
16971718
async def test_encoding_gzip_nochunk(loop, test_client):
16981719

16991720
async def handler(request):
@@ -1778,6 +1799,26 @@ async def handler(request):
17781799
resp.close()
17791800

17801801

1802+
async def test_payload_content_length_by_chunks(loop, test_client):
1803+
1804+
async def handler(request):
1805+
resp = web.StreamResponse(headers={'content-length': '3'})
1806+
await resp.prepare(request)
1807+
await resp.write(b'answer')
1808+
await resp.write(b'two')
1809+
request.transport.close()
1810+
return resp
1811+
1812+
app = web.Application()
1813+
app.router.add_get('/', handler)
1814+
client = await test_client(app)
1815+
1816+
resp = await client.get('/')
1817+
data = await resp.read()
1818+
assert data == b'ans'
1819+
resp.close()
1820+
1821+
17811822
async def test_chunked(loop, test_client):
17821823

17831824
async def handler(request):
@@ -2462,7 +2503,7 @@ def connection_lost(self, exc):
24622503
await r.read()
24632504
assert 1 == len(connector._conns)
24642505

2465-
with pytest.raises(aiohttp.ServerDisconnectedError):
2506+
with pytest.raises(aiohttp.ClientConnectionError):
24662507
await session.request('GET', url)
24672508
assert 0 == len(connector._conns)
24682509

tests/test_http_writer.py

+17
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Tests for aiohttp/http_writer.py"""
22

3+
import asyncio
34
import zlib
45
from unittest import mock
56

@@ -148,3 +149,19 @@ def test_write_drain(stream, loop):
148149
msg.write(b'1', drain=True)
149150
assert msg.drain.called
150151
assert msg.buffer_size == 0
152+
153+
154+
async def test_multiple_drains(stream, loop):
155+
stream.available = False
156+
msg = http.PayloadWriter(stream, loop)
157+
fut1 = loop.create_task(msg.drain())
158+
fut2 = loop.create_task(msg.drain())
159+
160+
assert not fut1.done()
161+
assert not fut2.done()
162+
163+
msg.set_transport(stream.transport)
164+
165+
await asyncio.sleep(0)
166+
assert fut1.done()
167+
assert fut2.done()

tests/test_web_server.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ async def handler(request):
8080
server = await raw_test_server(handler, logger=logger)
8181
cli = await test_client(server)
8282

83-
with pytest.raises(client.ServerDisconnectedError):
84-
await cli.get('/path/to')
83+
resp = await cli.get('/path/to')
84+
with pytest.raises(client.ClientPayloadError):
85+
await resp.read()
8586

8687
logger.debug.assert_called_with('Ignored premature client disconnection ')
8788

0 commit comments

Comments
 (0)