Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Async writer #2774

Merged
merged 4 commits into from
Feb 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ async def write_bytes(self, writer, conn):
self.body = (self.body,)

for chunk in self.body:
writer.write(chunk)
await writer.write(chunk)

await writer.write_eof()
except OSError as exc:
Expand Down
11 changes: 4 additions & 7 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import zlib

from .abc import AbstractStreamWriter
from .helpers import noop


__all__ = ('StreamWriter', 'HttpVersion', 'HttpVersion10', 'HttpVersion11')
Expand Down Expand Up @@ -56,7 +55,7 @@ def _write(self, chunk):
raise asyncio.CancelledError('Cannot write to closing transport')
self._transport.write(chunk)

def write(self, chunk, *, drain=True, LIMIT=64*1024):
async def write(self, chunk, *, drain=True, LIMIT=64*1024):
"""Writes chunk of data to a stream.

write_eof() indicates end of stream.
Expand All @@ -66,7 +65,7 @@ def write(self, chunk, *, drain=True, LIMIT=64*1024):
if self._compress is not None:
chunk = self._compress.compress(chunk)
if not chunk:
return noop()
return

if self.length is not None:
chunk_len = len(chunk)
Expand All @@ -76,7 +75,7 @@ def write(self, chunk, *, drain=True, LIMIT=64*1024):
chunk = chunk[:self.length]
self.length = 0
if not chunk:
return noop()
return

if chunk:
if self.chunked:
Expand All @@ -87,9 +86,7 @@ def write(self, chunk, *, drain=True, LIMIT=64*1024):

if self.buffer_size > LIMIT and drain:
self.buffer_size = 0
return self.drain()

return noop()
await self.drain()

async def write_headers(self, status_line, headers, SEP=': ', END='\r\n'):
"""Write request/response status and headers."""
Expand Down
2 changes: 1 addition & 1 deletion aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ async def write(self, writer):
field = self._value
chunk = await field.read_chunk(size=2**16)
while chunk:
writer.write(field.decode(chunk))
await writer.write(field.decode(chunk))
chunk = await field.read_chunk(size=2**16)


Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_urldispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ async def _default_expect_handler(request):
expect = request.headers.get(hdrs.EXPECT)
if request.version == HttpVersion11:
if expect.lower() == "100-continue":
request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n", drain=False)
await request.writer.write(
b"HTTP/1.1 100 Continue\r\n\r\n", drain=False)
else:
raise HTTPExpectationFailed(text="Unknown Expect: %s" % expect)

Expand Down
14 changes: 7 additions & 7 deletions tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,9 +835,9 @@ async def test_expect_100_continue_header(loop, conn):

async def test_data_stream(loop, buf, conn):
@aiohttp.streamer
def gen(writer):
writer.write(b'binary data')
writer.write(b' result')
async def gen(writer):
await writer.write(b'binary data')
await writer.write(b' result')

req = ClientRequest(
'POST', URL('http://python.org/'), data=gen(), loop=loop)
Expand Down Expand Up @@ -876,7 +876,7 @@ async def test_data_stream_exc(loop, conn):

@aiohttp.streamer
async def gen(writer):
writer.write(b'binary data')
await writer.write(b'binary data')
await fut

req = ClientRequest(
Expand Down Expand Up @@ -929,8 +929,8 @@ async def throw_exc():
async def test_data_stream_continue(loop, buf, conn):
@aiohttp.streamer
async def gen(writer):
writer.write(b'binary data')
writer.write(b' result')
await writer.write(b'binary data')
await writer.write(b' result')
await writer.write_eof()

req = ClientRequest(
Expand Down Expand Up @@ -975,7 +975,7 @@ async def test_close(loop, buf, conn):
@aiohttp.streamer
async def gen(writer):
await asyncio.sleep(0.00001, loop=loop)
writer.write(b'result')
await writer.write(b'result')

req = ClientRequest(
'POST', URL('http://python.org/'), data=gen(), loop=loop)
Expand Down
52 changes: 26 additions & 26 deletions tests/test_http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

from aiohttp import http
from aiohttp.test_utils import make_mocked_coro


@pytest.fixture
Expand All @@ -28,8 +29,7 @@ def write(chunk):
@pytest.fixture
def protocol(loop, transport):
protocol = mock.Mock(transport=transport)
protocol._drain_helper.return_value = loop.create_future()
protocol._drain_helper.return_value.set_result(None)
protocol._drain_helper = make_mocked_coro()
return protocol


Expand All @@ -43,8 +43,8 @@ async def test_write_payload_eof(transport, protocol, loop):
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, transport, loop)

msg.write(b'data1')
msg.write(b'data2')
await msg.write(b'data1')
await msg.write(b'data2')
await msg.write_eof()

content = b''.join([c[1][0] for c in list(write.mock_calls)])
Expand All @@ -54,7 +54,7 @@ async def test_write_payload_eof(transport, protocol, loop):
async def test_write_payload_chunked(buf, protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'data')
await msg.write(b'data')
await msg.write_eof()

assert b'4\r\ndata\r\n0\r\n\r\n' == buf
Expand All @@ -63,8 +63,8 @@ async def test_write_payload_chunked(buf, protocol, transport, loop):
async def test_write_payload_chunked_multiple(buf, protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'data1')
msg.write(b'data2')
await msg.write(b'data1')
await msg.write(b'data2')
await msg.write_eof()

assert b'5\r\ndata1\r\n5\r\ndata2\r\n0\r\n\r\n' == buf
Expand All @@ -75,8 +75,8 @@ async def test_write_payload_length(protocol, transport, loop):

msg = http.StreamWriter(protocol, transport, loop)
msg.length = 2
msg.write(b'd')
msg.write(b'ata')
await msg.write(b'd')
await msg.write(b'ata')
await msg.write_eof()

content = b''.join([c[1][0] for c in list(write.mock_calls)])
Expand All @@ -88,8 +88,8 @@ async def test_write_payload_chunked_filter(protocol, transport, loop):

msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'da')
msg.write(b'ta')
await msg.write(b'da')
await msg.write(b'ta')
await msg.write_eof()

content = b''.join([c[1][0] for c in list(write.mock_calls)])
Expand All @@ -103,11 +103,11 @@ async def test_write_payload_chunked_filter_mutiple_chunks(
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_chunking()
msg.write(b'da')
msg.write(b'ta')
msg.write(b'1d')
msg.write(b'at')
msg.write(b'a2')
await msg.write(b'da')
await msg.write(b'ta')
await msg.write(b'1d')
await msg.write(b'at')
await msg.write(b'a2')
await msg.write_eof()
content = b''.join([c[1][0] for c in list(write.mock_calls)])
assert content.endswith(
Expand All @@ -123,7 +123,7 @@ async def test_write_payload_deflate_compression(protocol, transport, loop):
write = transport.write = mock.Mock()
msg = http.StreamWriter(protocol, transport, loop)
msg.enable_compression('deflate')
msg.write(b'data')
await msg.write(b'data')
await msg.write_eof()

chunks = [c[1][0] for c in list(write.mock_calls)]
Expand All @@ -141,32 +141,32 @@ async def test_write_payload_deflate_and_chunked(
msg.enable_compression('deflate')
msg.enable_chunking()

msg.write(b'da')
msg.write(b'ta')
await msg.write(b'da')
await msg.write(b'ta')
await msg.write_eof()

assert b'6\r\nKI,I\x04\x00\r\n0\r\n\r\n' == buf


def test_write_drain(protocol, transport, loop):
async def test_write_drain(protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)
msg.drain = mock.Mock()
msg.write(b'1' * (64 * 1024 * 2), drain=False)
msg.drain = make_mocked_coro()
await msg.write(b'1' * (64 * 1024 * 2), drain=False)
assert not msg.drain.called

msg.write(b'1', drain=True)
await msg.write(b'1', drain=True)
assert msg.drain.called
assert msg.buffer_size == 0


def test_write_to_closing_transport(protocol, transport, loop):
async def test_write_to_closing_transport(protocol, transport, loop):
msg = http.StreamWriter(protocol, transport, loop)

msg.write(b'Before closing')
await msg.write(b'Before closing')
transport.is_closing.return_value = True

with pytest.raises(asyncio.CancelledError):
msg.write(b'After closing')
await msg.write(b'After closing')


async def test_drain(protocol, transport, loop):
Expand Down
12 changes: 6 additions & 6 deletions tests/test_web_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ async def expect_handler(request):
nonlocal expect_received
expect_received = True
if request.version == HttpVersion11:
request.transport.write(b"HTTP/1.1 100 Continue\r\n\r\n")
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")

form = FormData()
form.add_field('name', b'123',
Expand All @@ -487,7 +487,7 @@ async def expect_handler(request):
if auth_err:
raise web.HTTPForbidden()

request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")
await request.writer.write(b"HTTP/1.1 100 Continue\r\n\r\n")

form = FormData()
form.add_field('name', b'123',
Expand Down Expand Up @@ -737,11 +737,11 @@ async def test_response_with_streamer(aiohttp_client, fname):
data_size = len(data)

@aiohttp.streamer
def stream(writer, f_name):
async def stream(writer, f_name):
with f_name.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
await writer.write(data)
data = f.read(100)

async def handler(request):
Expand All @@ -767,11 +767,11 @@ async def test_response_with_streamer_no_params(aiohttp_client, fname):
data_size = len(data)

@aiohttp.streamer
def stream(writer):
async def stream(writer):
with fname.open('rb') as f:
data = f.read(100)
while data:
yield from writer.write(data)
await writer.write(data)
data = f.read(100)

async def handler(request):
Expand Down