diff --git a/CHANGES/10474.feature.rst b/CHANGES/10474.feature.rst new file mode 100644 index 00000000000..d5d6e4b40b9 --- /dev/null +++ b/CHANGES/10474.feature.rst @@ -0,0 +1,2 @@ +Added ``tcp_sockopts`` to ``TCPConnector`` to allow specifying custom socket options +-- by :user:`TimMenninger`. diff --git a/CONTRIBUTORS.txt b/CONTRIBUTORS.txt index 9dd9d873003..fb5217e3e6b 100644 --- a/CONTRIBUTORS.txt +++ b/CONTRIBUTORS.txt @@ -332,6 +332,7 @@ Thanos Lefteris Thijs Vermeir Thomas Forbes Thomas Grainger +Tim Menninger Tolga Tezel Tomasz Trebski Toshiaki Tanaka diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 14433ba37e1..75d5796f7d2 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -19,6 +19,7 @@ DefaultDict, Deque, Dict, + Iterable, Iterator, List, Literal, @@ -60,6 +61,11 @@ ) from .resolver import DefaultResolver +if sys.version_info >= (3, 12): + from collections.abc import Buffer +else: + Buffer = Union[bytes, bytearray, "memoryview[int]", "memoryview[bytes]"] + if TYPE_CHECKING: import ssl @@ -828,6 +834,8 @@ class TCPConnector(BaseConnector): the happy eyeballs algorithm, set to None. interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. + tcp_sockopts - List of tuples of sockopts applied to underlying + socket """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) @@ -853,6 +861,7 @@ def __init__( timeout_ceil_threshold: float = 5, happy_eyeballs_delay: Optional[float] = 0.25, interleave: Optional[int] = None, + tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [], ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -879,6 +888,7 @@ def __init__( self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() + self._tcp_sockopts = tcp_sockopts def close(self) -> Awaitable[None]: """Close all ongoing DNS calls.""" @@ -1120,6 +1130,8 @@ async def _wrap_create_connection( interleave=self._interleave, loop=self._loop, ) + for sockopt in self._tcp_sockopts: + sock.setsockopt(*sockopt) connection = await self._loop.create_connection( *args, **kwargs, sock=sock ) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 2d00418ffac..eeb0ee98574 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -461,6 +461,21 @@ If your HTTP server uses UNIX domain sockets you can use session = aiohttp.ClientSession(connector=conn) +Setting socket options +^^^^^^^^^^^^^^^^^^^^^^ + +Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed +to the underlying socket when creating a connection. For example, we may +want to change the conditions under which we consider a connection dead. +The following would change that to 9*7200 = 18 hours:: + + import socket + + conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True), + (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200), + (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ]) + + Named pipes in Windows ^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/client_reference.rst b/docs/client_reference.rst index 013c43a13e4..1e49b014007 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1144,7 +1144,8 @@ is controlled by *force_close* constructor's parameter). resolver=None, keepalive_timeout=sentinel, \ force_close=False, limit=100, limit_per_host=0, \ enable_cleanup_closed=False, timeout_ceil_threshold=5, \ - happy_eyeballs_delay=0.25, interleave=None, loop=None) + happy_eyeballs_delay=0.25, interleave=None, loop=None, \ + tcp_sockopts=[]) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -1265,6 +1266,12 @@ is controlled by *force_close* constructor's parameter). .. versionadded:: 3.10 + :param list tcp_sockopts: options applied to the socket when a connection is + created. This should be a list of 3-tuples, each a ``(level, optname, value)``. + Each tuple is deconstructed and passed verbatim to ``.setsockopt``. + + .. versionadded:: 3.12 + .. attribute:: family *TCP* socket family e.g. :data:`socket.AF_INET` or diff --git a/tests/test_connector.py b/tests/test_connector.py index e79b36a673d..b7531361287 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -3581,6 +3581,29 @@ def test_connect() -> Literal[True]: assert raw_response_list == [True, True] +async def test_tcp_connector_setsockopts( + loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock +) -> None: + """Check that sockopts get passed to socket""" + conn = aiohttp.TCPConnector( + tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)] + ) + + with mock.patch.object( + conn._loop, "create_connection", autospec=True, spec_set=True + ) as create_connection: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + start_connection.return_value = s + create_connection.return_value = mock.Mock(), mock.Mock() + + req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + + with closing(await conn.connect(req, [], ClientTimeout())): + assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2 + + await conn.close() + + def test_default_ssl_context_creation_without_ssl() -> None: """Verify _make_ssl_context does not raise when ssl is not available.""" with mock.patch.object(connector_module, "ssl", None):