Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes for strict-bytes #10454

Merged
merged 8 commits into from
Feb 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pretty = True
show_column_numbers = True
show_error_codes = True
show_error_code_links = True
strict_bytes = True
strict_equality = True
warn_incomplete_stub = True
warn_redundant_casts = True
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/_websocket/reader_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,11 @@ def _feed_data(self, data: bytes) -> None:
self.queue.feed_data(msg)
elif opcode == OP_CODE_PING:
self.queue.feed_data(
WSMessagePing(data=payload, size=len(payload), extra="")
WSMessagePing(data=bytes(payload), size=len(payload), extra="")
)
elif opcode == OP_CODE_PONG:
self.queue.feed_data(
WSMessagePong(data=payload, size=len(payload), extra="")
WSMessagePong(data=bytes(payload), size=len(payload), extra="")
)
else:
raise WebSocketError(
Expand Down
6 changes: 3 additions & 3 deletions aiohttp/_websocket/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,9 @@ async def send_frame(
# when aiohttp is acting as a client. Servers do not use a mask.
if use_mask:
mask = PACK_RANDBITS(self.get_random_bits())
message = bytearray(message)
websocket_mask(mask, message)
self.transport.write(header + mask + message)
message_arr = bytearray(message)
websocket_mask(mask, message_arr)
self.transport.write(header + mask + message_arr)
self._output_size += MASK_LEN
elif msg_length > MSG_SIZE:
self.transport.write(header)
Expand Down
4 changes: 3 additions & 1 deletion aiohttp/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ class AbstractStreamWriter(ABC):
length: Optional[int] = 0

@abstractmethod
async def write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
async def write(
self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
) -> None:
"""Write chunk into stream."""

@abstractmethod
Expand Down
18 changes: 13 additions & 5 deletions aiohttp/compression_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import asyncio
import sys
import zlib
from concurrent.futures import Executor
from typing import Optional, cast

if sys.version_info >= (3, 12):
from collections.abc import Buffer
else:
from typing import Union

Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]

try:
try:
import brotlicffi as brotli
Expand Down Expand Up @@ -66,10 +74,10 @@ def __init__(
)
self._compress_lock = asyncio.Lock()

def compress_sync(self, data: bytes) -> bytes:
def compress_sync(self, data: Buffer) -> bytes:
return self._compressor.compress(data)

async def compress(self, data: bytes) -> bytes:
async def compress(self, data: Buffer) -> bytes:
"""Compress the data and returned the compressed bytes.

Note that flush() must be called after the last call to compress()
Expand Down Expand Up @@ -111,10 +119,10 @@ def __init__(
)
self._decompressor = zlib.decompressobj(wbits=self._mode)

def decompress_sync(self, data: bytes, max_length: int = 0) -> bytes:
def decompress_sync(self, data: Buffer, max_length: int = 0) -> bytes:
return self._decompressor.decompress(data, max_length)

async def decompress(self, data: bytes, max_length: int = 0) -> bytes:
async def decompress(self, data: Buffer, max_length: int = 0) -> bytes:
"""Decompress the data and return the decompressed bytes.

If the data size is large than the max_sync_chunk_size, the decompression
Expand Down Expand Up @@ -162,7 +170,7 @@ def __init__(self) -> None:
)
self._obj = brotli.Decompressor()

def decompress_sync(self, data: bytes) -> bytes:
def decompress_sync(self, data: Buffer) -> bytes:
if hasattr(self._obj, "decompress"):
return cast(bytes, self._obj.decompress(data))
return cast(bytes, self._obj.process(data))
Expand Down
24 changes: 18 additions & 6 deletions aiohttp/http_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,12 @@ class HttpVersion(NamedTuple):
HttpVersion11 = HttpVersion(1, 1)


_T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
_T_OnChunkSent = Optional[
Callable[
[Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]],
Awaitable[None],
]
]
_T_OnHeadersSent = Optional[Callable[["CIMultiDict[str]"], Awaitable[None]]]


Expand Down Expand Up @@ -84,16 +89,23 @@ def enable_compression(
) -> None:
self._compress = ZLibCompressor(encoding=encoding, strategy=strategy)

def _write(self, chunk: Union[bytes, bytearray, memoryview]) -> None:
def _write(
self, chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
) -> None:
size = len(chunk)
self.buffer_size += size
self.output_size += size
transport = self._protocol.transport
if transport is None or transport.is_closing():
raise ClientConnectionResetError("Cannot write to closing transport")
transport.write(chunk)
transport.write(chunk) # type: ignore[arg-type]

def _writelines(self, chunks: Iterable[bytes]) -> None:
def _writelines(
self,
chunks: Iterable[
Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
],
) -> None:
size = 0
for chunk in chunks:
size += len(chunk)
Expand All @@ -105,11 +117,11 @@ def _writelines(self, chunks: Iterable[bytes]) -> None:
if SKIP_WRITELINES or size < MIN_PAYLOAD_FOR_WRITELINES:
transport.write(b"".join(chunks))
else:
transport.writelines(chunks)
transport.writelines(chunks) # type: ignore[arg-type]

async def write(
self,
chunk: Union[bytes, bytearray, memoryview],
chunk: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"],
*,
drain: bool = True,
LIMIT: int = 0x10000,
Expand Down
8 changes: 5 additions & 3 deletions aiohttp/web_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,9 @@ async def _write_headers(self) -> None:
status_line = f"HTTP/{version[0]}.{version[1]} {self._status} {self._reason}"
await writer.write_headers(status_line, self._headers)

async def write(self, data: Union[bytes, bytearray, memoryview]) -> None:
async def write(
self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
) -> None:
assert isinstance(
data, (bytes, bytearray, memoryview)
), "data argument must be byte-ish (%r)" % type(data)
Expand Down Expand Up @@ -580,7 +582,7 @@ def __init__(
self._zlib_executor = zlib_executor

@property
def body(self) -> Optional[Union[bytes, Payload]]:
def body(self) -> Optional[Union[bytes, bytearray, Payload]]:
return self._body

@body.setter
Expand Down Expand Up @@ -654,7 +656,7 @@ async def write_eof(self, data: bytes = b"") -> None:
if self._eof_sent:
return
if self._compressed_body is None:
body: Optional[Union[bytes, Payload]] = self._body
body = self._body
else:
body = self._compressed_body
assert not data, f"data arg is not supported, got {data!r}"
Expand Down
4 changes: 3 additions & 1 deletion aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,9 @@ async def receive_json(
data = await self.receive_str(timeout=timeout)
return loads(data)

async def write(self, data: bytes) -> None:
async def write(
self, data: Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"]
) -> None:
raise RuntimeError("Cannot call .write() for websocket")

def __aiter__(self) -> "WebSocketResponse":
Expand Down
Loading