Skip to content

Commit

Permalink
Expose setsockopt in TCPConnector API
Browse files Browse the repository at this point in the history
Optionally give tcp_sockopts to the constructor of TCPConnector, which
will be a list of tuples of (level, optname, value). Each tuple is
deconstructed and passed as arguments to <socket>.setsockopt.
  • Loading branch information
Tim Menninger authored and TimMenninger committed Feb 20, 2025
1 parent aa3296d commit 3ccf8d2
Show file tree
Hide file tree
Showing 8 changed files with 81 additions and 5 deletions.
2 changes: 2 additions & 0 deletions CHANGES/10474.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Allow user to specify sockopts in TCPConnector
-- by :user:`TimMenninger`.
1 change: 1 addition & 0 deletions CONTRIBUTORS.txt
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ Thanos Lefteris
Thijs Vermeir
Thomas Forbes
Thomas Grainger
Tim Menninger
Tolga Tezel
Tomasz Trebski
Toshiaki Tanaka
Expand Down
10 changes: 10 additions & 0 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@
set_result,
)
from .resolver import DefaultResolver
from .tcp_helpers import (
_SOCKOPT,
tcp_setsockopt,
)

if TYPE_CHECKING:
import ssl
Expand Down Expand Up @@ -820,6 +824,8 @@ class TCPConnector(BaseConnector):
the happy eyeballs algorithm, set to None.
interleave - “First Address Family Count” as defined in RFC 8305
loop - Optional event loop.
tcp_sockopts - List of tuples of sockopts applied to underlying
socket
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
Expand All @@ -841,6 +847,7 @@ def __init__(
timeout_ceil_threshold: float = 5,
happy_eyeballs_delay: Optional[float] = 0.25,
interleave: Optional[int] = None,
tcp_sockopts: List[_SOCKOPT] = [],
):
super().__init__(
keepalive_timeout=keepalive_timeout,
Expand Down Expand Up @@ -871,6 +878,7 @@ def __init__(
self._happy_eyeballs_delay = happy_eyeballs_delay
self._interleave = interleave
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
self._tcp_sockopts = tcp_sockopts

def _close_immediately(self) -> List[Awaitable[object]]:
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
Expand Down Expand Up @@ -1112,6 +1120,8 @@ async def _wrap_create_connection(
interleave=self._interleave,
loop=self._loop,
)
if self._tcp_sockopts:
tcp_setsockopt(sock, self._tcp_sockopts)
return await self._loop.create_connection(*args, **kwargs, sock=sock)
except cert_errors as exc:
raise ClientConnectorCertificateError(req.connection_key, exc) from exc
Expand Down
18 changes: 15 additions & 3 deletions aiohttp/tcp_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,23 @@
import asyncio
import socket
from contextlib import suppress
from typing import (
Any,
List,
Optional,
Tuple,
)

__all__ = ("tcp_keepalive", "tcp_nodelay")
__all__ = ("tcp_keepalive", "tcp_nodelay", "tcp_setsockopt")

_SOCKOPT = Tuple[int, int, Any]

if hasattr(socket, "SO_KEEPALIVE"):

def tcp_keepalive(transport: asyncio.Transport) -> None:
sock = transport.get_extra_info("socket")
if sock is not None:
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
tcp_setsockopt(sock, [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)])

else:

Expand All @@ -33,4 +40,9 @@ def tcp_nodelay(transport: asyncio.Transport, value: bool) -> None:

# socket may be closed already, on windows OSError get raised
with suppress(OSError):
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)
tcp_setsockopt(sock, [(socket.IPPROTO_TCP, socket.TCP_NODELAY, value)])

def tcp_setsockopt(sock: Optional[socket.socket], sockopts: List[_SOCKOPT]) -> None:
if sock is not None:
for sockopt in sockopts:
sock.setsockopt(*sockopt)
11 changes: 11 additions & 0 deletions docs/client_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,17 @@ If your HTTP server uses UNIX domain sockets you can use
session = aiohttp.ClientSession(connector=conn)


Setting socket options
^^^^^^^^^^^^^^^^^^^^^^

Socket options passed to the :class:`~aiohttp.TPCConnector` will be passed
to the underlying socket when creating a connection::

import socket

conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)])


Named pipes in Windows
^^^^^^^^^^^^^^^^^^^^^^

Expand Down
9 changes: 8 additions & 1 deletion docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1128,7 +1128,8 @@ is controlled by *force_close* constructor's parameter).
resolver=None, keepalive_timeout=sentinel, \
force_close=False, limit=100, limit_per_host=0, \
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
happy_eyeballs_delay=0.25, interleave=None, loop=None)
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
tcp_sockopts=[])

Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.

Expand Down Expand Up @@ -1249,6 +1250,12 @@ is controlled by *force_close* constructor's parameter).

.. versionadded:: 3.10

:param list tcp_sockopts: options applied to the socket when a connection is
created. This should be a list of 3-tuples, each a (level, optname, value).
Each tuple is deconstructed and passed verbatim to `<socket>.setsockopt`.

.. versionadded:: 3.11

.. attribute:: family

*TCP* socket family e.g. :data:`socket.AF_INET` or
Expand Down
21 changes: 21 additions & 0 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3744,6 +3744,27 @@ def test_connect() -> Literal[True]:
assert raw_response_list == [True, True]


async def test_tcp_connector_setsockopts(
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
) -> None:
"""Check that sockopts get passed to socket"""
conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)])

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
start_connection.return_value = s
create_connection.return_value = mock.Mock(), mock.Mock()

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)

with closing(await conn.connect(req, [], ClientTimeout())):
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2

await conn.close()


def test_default_ssl_context_creation_without_ssl() -> None:
"""Verify _make_ssl_context does not raise when ssl is not available."""
with mock.patch.object(connector_module, "ssl", None):
Expand Down
14 changes: 13 additions & 1 deletion tests/test_tcp_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from aiohttp.tcp_helpers import tcp_nodelay
from aiohttp.tcp_helpers import tcp_nodelay, tcp_setsockopt

has_ipv6: bool = socket.has_ipv6
if has_ipv6:
Expand Down Expand Up @@ -72,3 +72,15 @@ def test_tcp_nodelay_enable_no_socket() -> None:
transport = mock.Mock()
transport.get_extra_info.return_value = None
tcp_nodelay(transport, True)


def test_tcp_setsockopts() -> None:
# Should do nothing
tcp_setsockopt(None, [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 20)])

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
tcp_setsockopt(s, [(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL, 10)])
assert s.getsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE)
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPINTVL) == 10

0 comments on commit 3ccf8d2

Please sign in to comment.