Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow request body copy #2126

Merged
merged 10 commits into from
Nov 24, 2017
1 change: 1 addition & 0 deletions CHANGES/2126.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up the `PayloadWriter.write` method for large request bodies.
58 changes: 22 additions & 36 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ class PayloadWriter(AbstractPayloadWriter):
def __init__(self, stream, loop, acquire=True):
self._stream = stream
self._transport = None
self._buffer = []
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Side note: asyncio streams switched from list to bytearray for sake of speed, maybe we need it too.
Feel free to create a PR for this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i remember correctly, aiohttp switched from bytearrays to list for performance reasons too, so this probably need to be benchmarked…

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We just follow asyncio streams design, it switched from bytearrays to lists and back to bytearrays :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried to benchmark bytearray vs join in #2179, but the results were inconclusive at best (or lean in favor of b''.join). I'll try to dig up why asyncio changed again to bytearray


self.loop = loop
self.length = None
Expand All @@ -133,28 +134,35 @@ def __init__(self, stream, loop, acquire=True):
self.output_size = 0

self._eof = False
self._buffer = []
self._compress = None
self._drain_waiter = None

if self._stream.available:
self._transport = self._stream.transport
self._buffer = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not good idea to jungle variable type. How empty buffer [] is different from None?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm using [] and None a bit differently here: once the buffer has been emptied into the new transport, it should never be used again. So I would prefer setting it to None and crash if something very wrong happens

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm...I see. Well, then such crash shold be controlled, imho. Otherwise it would be hard to distinguish transport re-use case from else error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's purely internal to PayloadWriter, so everything is controlled where it should be.

This None means that at any given time, exactly one of (self._transport, self._buffer) is None. It makes it easier to check the code / prove that it's correct.

self._stream.available = False
elif acquire:
self._stream.acquire(self)

def set_transport(self, transport):
self._transport = transport

chunk = b''.join(self._buffer)
if chunk:
transport.write(chunk)
self._buffer.clear()
if self._buffer is not None:
for chunk in self._buffer:
transport.write(chunk)
self._buffer = None

if self._drain_waiter is not None:
waiter, self._drain_waiter = self._drain_waiter, None
if not waiter.done():
waiter.set_result(None)
waiter.set_result(None)

async def get_transport(self):
if self._transport is None:
if self._drain_waiter is None:
self._drain_waiter = self.loop.create_future()
await self._drain_waiter

return self._transport

@property
def tcp_nodelay(self):
Expand All @@ -178,25 +186,14 @@ def enable_compression(self, encoding='deflate'):
if encoding == 'gzip' else -zlib.MAX_WBITS)
self._compress = zlib.compressobj(wbits=zlib_mode)

def buffer_data(self, chunk):
if chunk:
size = len(chunk)
self.buffer_size += size
self.output_size += size
self._buffer.append(chunk)

def _write(self, chunk):
size = len(chunk)
self.buffer_size += size
self.output_size += size

# see set_transport: exactly one of _buffer or _transport is None
if self._transport is not None:
if self._buffer:
self._buffer.append(chunk)
self._transport.write(b''.join(self._buffer))
self._buffer.clear()
else:
self._transport.write(chunk)
self._transport.write(chunk)
else:
self._buffer.append(chunk)

Expand Down Expand Up @@ -241,11 +238,7 @@ def write_headers(self, status_line, headers, SEP=': ', END='\r\n'):
headers = status_line + ''.join(
[k + SEP + v + END for k, v in headers.items()])
headers = headers.encode('utf-8') + b'\r\n'

size = len(headers)
self.buffer_size += size
self.output_size += size
self._buffer.append(headers)
self._write(headers)

async def write_eof(self, chunk=b''):
if self._eof:
Expand All @@ -268,24 +261,17 @@ async def write_eof(self, chunk=b''):
chunk = b'0\r\n\r\n'

if chunk:
self.buffer_data(chunk)
self._write(chunk)

await self.drain(True)
await self.drain()

self._eof = True
self._transport = None
self._stream.release()

async def drain(self, last=False):
async def drain(self):
if self._transport is not None:
if self._buffer:
self._transport.write(b''.join(self._buffer))
if not last:
self._buffer.clear()
await self._stream.drain()
else:
# wait for transport
if self._drain_waiter is None:
self._drain_waiter = self.loop.create_future()

await self._drain_waiter
await self.get_transport()
31 changes: 13 additions & 18 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,16 @@

class SendfilePayloadWriter(PayloadWriter):

def set_transport(self, transport):
self._transport = transport

if self._drain_waiter is not None:
waiter, self._drain_maiter = self._drain_maiter, None
if not waiter.done():
waiter.set_result(None)
def __init__(self, *args, **kwargs):
self._sendfile_buffer = []
super().__init__(*args, **kwargs)

def _write(self, chunk):
# we overwrite PayloadWriter._write, so nothing can be appended to
# _buffer, and nothing is written to the transport directly by the
# parent class
self.output_size += len(chunk)
self._buffer.append(chunk)
self._sendfile_buffer.append(chunk)

def _sendfile_cb(self, fut, out_fd, in_fd,
offset, count, loop, registered):
Expand All @@ -54,33 +53,29 @@ def _sendfile_cb(self, fut, out_fd, in_fd,
fut.set_result(None)

async def sendfile(self, fobj, count):
if self._transport is None:
if self._drain_waiter is None:
self._drain_waiter = self.loop.create_future()

await self._drain_waiter
transport = await self.get_transport()

out_socket = self._transport.get_extra_info("socket").dup()
out_socket = transport.get_extra_info('socket').dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
in_fd = fobj.fileno()
offset = fobj.tell()

loop = self.loop
data = b''.join(self._sendfile_buffer)
try:
await loop.sock_sendall(out_socket, b''.join(self._buffer))
await loop.sock_sendall(out_socket, data)
fut = loop.create_future()
self._sendfile_cb(fut, out_fd, in_fd, offset, count, loop, False)
await fut
except Exception:
server_logger.debug('Socket error')
self._transport.close()
transport.close()
finally:
out_socket.close()

self.output_size += count
self._transport = None
self._stream.release()
await super().write_eof()

async def write_eof(self, chunk=b''):
pass
Expand Down
43 changes: 42 additions & 1 deletion tests/test_client_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1694,6 +1694,27 @@ async def handler(request):
resp.close()


async def test_encoding_gzip_write_by_chunks(loop, test_client):

async def handler(request):
resp = web.StreamResponse()
resp.enable_compression(web.ContentCoding.gzip)
await resp.prepare(request)
await resp.write(b'0')
await resp.write(b'0')
return resp

app = web.Application()
app.router.add_get('/', handler)
client = await test_client(app)

resp = await client.get('/')
assert 200 == resp.status
txt = await resp.text()
assert txt == '00'
resp.close()


async def test_encoding_gzip_nochunk(loop, test_client):

async def handler(request):
Expand Down Expand Up @@ -1778,6 +1799,26 @@ async def handler(request):
resp.close()


async def test_payload_content_length_by_chunks(loop, test_client):

async def handler(request):
resp = web.StreamResponse(headers={'content-length': '3'})
await resp.prepare(request)
await resp.write(b'answer')
await resp.write(b'two')
request.transport.close()
return resp

app = web.Application()
app.router.add_get('/', handler)
client = await test_client(app)

resp = await client.get('/')
data = await resp.read()
assert data == b'ans'
resp.close()


async def test_chunked(loop, test_client):

async def handler(request):
Expand Down Expand Up @@ -2462,7 +2503,7 @@ def connection_lost(self, exc):
await r.read()
assert 1 == len(connector._conns)

with pytest.raises(aiohttp.ServerDisconnectedError):
with pytest.raises(aiohttp.ClientConnectionError):
await session.request('GET', url)
assert 0 == len(connector._conns)

Expand Down
17 changes: 17 additions & 0 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tests for aiohttp/http_writer.py"""

import asyncio
import zlib
from unittest import mock

Expand Down Expand Up @@ -148,3 +149,19 @@ def test_write_drain(stream, loop):
msg.write(b'1', drain=True)
assert msg.drain.called
assert msg.buffer_size == 0


async def test_multiple_drains(stream, loop):
stream.available = False
msg = http.PayloadWriter(stream, loop)
fut1 = loop.create_task(msg.drain())
fut2 = loop.create_task(msg.drain())

assert not fut1.done()
assert not fut2.done()

msg.set_transport(stream.transport)

await asyncio.sleep(0)
assert fut1.done()
assert fut2.done()
5 changes: 3 additions & 2 deletions tests/test_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ async def handler(request):
server = await raw_test_server(handler, logger=logger)
cli = await test_client(server)

with pytest.raises(client.ServerDisconnectedError):
await cli.get('/path/to')
resp = await cli.get('/path/to')
with pytest.raises(client.ClientPayloadError):
await resp.read()

logger.debug.assert_called_with('Ignored premature client disconnection ')

Expand Down