diff --git a/CHANGES/2126.feature b/CHANGES/2126.feature new file mode 100644 index 00000000000..ee2562d7df6 --- /dev/null +++ b/CHANGES/2126.feature @@ -0,0 +1 @@ +Speed up the `PayloadWriter.write` method for large request bodies. diff --git a/aiohttp/http_writer.py b/aiohttp/http_writer.py index f8da95074c2..03869c44e5e 100644 --- a/aiohttp/http_writer.py +++ b/aiohttp/http_writer.py @@ -125,6 +125,7 @@ class PayloadWriter(AbstractPayloadWriter): def __init__(self, stream, loop, acquire=True): self._stream = stream self._transport = None + self._buffer = [] self.loop = loop self.length = None @@ -133,12 +134,12 @@ def __init__(self, stream, loop, acquire=True): self.output_size = 0 self._eof = False - self._buffer = [] self._compress = None self._drain_waiter = None if self._stream.available: self._transport = self._stream.transport + self._buffer = None self._stream.available = False elif acquire: self._stream.acquire(self) @@ -146,15 +147,22 @@ def __init__(self, stream, loop, acquire=True): def set_transport(self, transport): self._transport = transport - chunk = b''.join(self._buffer) - if chunk: - transport.write(chunk) - self._buffer.clear() + if self._buffer is not None: + for chunk in self._buffer: + transport.write(chunk) + self._buffer = None if self._drain_waiter is not None: waiter, self._drain_waiter = self._drain_waiter, None - if not waiter.done(): - waiter.set_result(None) + waiter.set_result(None) + + async def get_transport(self): + if self._transport is None: + if self._drain_waiter is None: + self._drain_waiter = self.loop.create_future() + await self._drain_waiter + + return self._transport @property def tcp_nodelay(self): @@ -178,25 +186,14 @@ def enable_compression(self, encoding='deflate'): if encoding == 'gzip' else -zlib.MAX_WBITS) self._compress = zlib.compressobj(wbits=zlib_mode) - def buffer_data(self, chunk): - if chunk: - size = len(chunk) - self.buffer_size += size - self.output_size += size - self._buffer.append(chunk) - def _write(self, chunk): size = len(chunk) self.buffer_size += size self.output_size += size + # see set_transport: exactly one of _buffer or _transport is None if self._transport is not None: - if self._buffer: - self._buffer.append(chunk) - self._transport.write(b''.join(self._buffer)) - self._buffer.clear() - else: - self._transport.write(chunk) + self._transport.write(chunk) else: self._buffer.append(chunk) @@ -241,11 +238,7 @@ def write_headers(self, status_line, headers, SEP=': ', END='\r\n'): headers = status_line + ''.join( [k + SEP + v + END for k, v in headers.items()]) headers = headers.encode('utf-8') + b'\r\n' - - size = len(headers) - self.buffer_size += size - self.output_size += size - self._buffer.append(headers) + self._write(headers) async def write_eof(self, chunk=b''): if self._eof: @@ -268,24 +261,17 @@ async def write_eof(self, chunk=b''): chunk = b'0\r\n\r\n' if chunk: - self.buffer_data(chunk) + self._write(chunk) - await self.drain(True) + await self.drain() self._eof = True self._transport = None self._stream.release() - async def drain(self, last=False): + async def drain(self): if self._transport is not None: - if self._buffer: - self._transport.write(b''.join(self._buffer)) - if not last: - self._buffer.clear() await self._stream.drain() else: # wait for transport - if self._drain_waiter is None: - self._drain_waiter = self.loop.create_future() - - await self._drain_waiter + await self.get_transport() diff --git a/aiohttp/web_fileresponse.py b/aiohttp/web_fileresponse.py index 7ef19bb8931..fdb68916ebf 100644 --- a/aiohttp/web_fileresponse.py +++ b/aiohttp/web_fileresponse.py @@ -18,17 +18,16 @@ class SendfilePayloadWriter(PayloadWriter): - def set_transport(self, transport): - self._transport = transport - - if self._drain_waiter is not None: - waiter, self._drain_maiter = self._drain_maiter, None - if not waiter.done(): - waiter.set_result(None) + def __init__(self, *args, **kwargs): + self._sendfile_buffer = [] + super().__init__(*args, **kwargs) def _write(self, chunk): + # we overwrite PayloadWriter._write, so nothing can be appended to + # _buffer, and nothing is written to the transport directly by the + # parent class self.output_size += len(chunk) - self._buffer.append(chunk) + self._sendfile_buffer.append(chunk) def _sendfile_cb(self, fut, out_fd, in_fd, offset, count, loop, registered): @@ -54,33 +53,29 @@ def _sendfile_cb(self, fut, out_fd, in_fd, fut.set_result(None) async def sendfile(self, fobj, count): - if self._transport is None: - if self._drain_waiter is None: - self._drain_waiter = self.loop.create_future() - - await self._drain_waiter + transport = await self.get_transport() - out_socket = self._transport.get_extra_info("socket").dup() + out_socket = transport.get_extra_info('socket').dup() out_socket.setblocking(False) out_fd = out_socket.fileno() in_fd = fobj.fileno() offset = fobj.tell() loop = self.loop + data = b''.join(self._sendfile_buffer) try: - await loop.sock_sendall(out_socket, b''.join(self._buffer)) + await loop.sock_sendall(out_socket, data) fut = loop.create_future() self._sendfile_cb(fut, out_fd, in_fd, offset, count, loop, False) await fut except Exception: server_logger.debug('Socket error') - self._transport.close() + transport.close() finally: out_socket.close() self.output_size += count - self._transport = None - self._stream.release() + await super().write_eof() async def write_eof(self, chunk=b''): pass diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index e0377d2d146..60215de1f56 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -1694,6 +1694,27 @@ async def handler(request): resp.close() +async def test_encoding_gzip_write_by_chunks(loop, test_client): + + async def handler(request): + resp = web.StreamResponse() + resp.enable_compression(web.ContentCoding.gzip) + await resp.prepare(request) + await resp.write(b'0') + await resp.write(b'0') + return resp + + app = web.Application() + app.router.add_get('/', handler) + client = await test_client(app) + + resp = await client.get('/') + assert 200 == resp.status + txt = await resp.text() + assert txt == '00' + resp.close() + + async def test_encoding_gzip_nochunk(loop, test_client): async def handler(request): @@ -1778,6 +1799,26 @@ async def handler(request): resp.close() +async def test_payload_content_length_by_chunks(loop, test_client): + + async def handler(request): + resp = web.StreamResponse(headers={'content-length': '3'}) + await resp.prepare(request) + await resp.write(b'answer') + await resp.write(b'two') + request.transport.close() + return resp + + app = web.Application() + app.router.add_get('/', handler) + client = await test_client(app) + + resp = await client.get('/') + data = await resp.read() + assert data == b'ans' + resp.close() + + async def test_chunked(loop, test_client): async def handler(request): @@ -2462,7 +2503,7 @@ def connection_lost(self, exc): await r.read() assert 1 == len(connector._conns) - with pytest.raises(aiohttp.ServerDisconnectedError): + with pytest.raises(aiohttp.ClientConnectionError): await session.request('GET', url) assert 0 == len(connector._conns) diff --git a/tests/test_http_writer.py b/tests/test_http_writer.py index ac5280fc391..3530abcc7e8 100644 --- a/tests/test_http_writer.py +++ b/tests/test_http_writer.py @@ -1,5 +1,6 @@ """Tests for aiohttp/http_writer.py""" +import asyncio import zlib from unittest import mock @@ -148,3 +149,19 @@ def test_write_drain(stream, loop): msg.write(b'1', drain=True) assert msg.drain.called assert msg.buffer_size == 0 + + +async def test_multiple_drains(stream, loop): + stream.available = False + msg = http.PayloadWriter(stream, loop) + fut1 = loop.create_task(msg.drain()) + fut2 = loop.create_task(msg.drain()) + + assert not fut1.done() + assert not fut2.done() + + msg.set_transport(stream.transport) + + await asyncio.sleep(0) + assert fut1.done() + assert fut2.done() diff --git a/tests/test_web_server.py b/tests/test_web_server.py index b8284502d3c..3561dae504d 100644 --- a/tests/test_web_server.py +++ b/tests/test_web_server.py @@ -80,8 +80,9 @@ async def handler(request): server = await raw_test_server(handler, logger=logger) cli = await test_client(server) - with pytest.raises(client.ServerDisconnectedError): - await cli.get('/path/to') + resp = await cli.get('/path/to') + with pytest.raises(client.ClientPayloadError): + await resp.read() logger.debug.assert_called_with('Ignored premature client disconnection ')