diff --git a/CHANGES/10520.feature.rst b/CHANGES/10520.feature.rst new file mode 100644 index 00000000000..d5d6e4b40b9 --- /dev/null +++ b/CHANGES/10520.feature.rst @@ -0,0 +1,2 @@ +Added ``tcp_sockopts`` to ``TCPConnector`` to allow specifying custom socket options +-- by :user:`TimMenninger`. diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 8a3f1bcbf2b..f198ac8a383 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -826,8 +826,9 @@ class TCPConnector(BaseConnector): the happy eyeballs algorithm, set to None. interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. - tcp_sockopts - List of tuples of sockopts applied to underlying - socket + socket_factory - An aiohappyeyeballs.SocketFactoryType function + that, if supplied, will be used to create sockets + given an aiohappyeyeballs.AddrInfoType. """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) @@ -849,7 +850,7 @@ def __init__( timeout_ceil_threshold: float = 5, happy_eyeballs_delay: Optional[float] = 0.25, interleave: Optional[int] = None, - tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [], + socket_factory: Optional[aiohappyeyeballs.SocketFactoryType] = None, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -880,7 +881,7 @@ def __init__( self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() - self._tcp_sockopts = tcp_sockopts + self._socket_factory = socket_factory def _close_immediately(self) -> List[Awaitable[object]]: for fut in chain.from_iterable(self._throttle_dns_futures.values()): @@ -1122,9 +1123,8 @@ async def _wrap_create_connection( happy_eyeballs_delay=self._happy_eyeballs_delay, interleave=self._interleave, loop=self._loop, + socket_factory=self._socket_factory, ) - for sockopt in self._tcp_sockopts: - sock.setsockopt(*sockopt) connection = await self._loop.create_connection( *args, **kwargs, sock=sock ) diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 8f34fefaf81..d7e91712035 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use session = aiohttp.ClientSession(connector=conn) -Setting socket options +Custom socket creation ^^^^^^^^^^^^^^^^^^^^^^ -Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed -to the underlying socket when creating a connection. For example, we may -want to change the conditions under which we consider a connection dead. -The following would change that to 9*7200 = 18 hours:: +If the default socket is insufficient for your use case, pass an optional +`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements +`aiohappyeyeballs.SocketFactoryType`. This will be used to create all +sockets for the lifetime of the class object. For example, we may want to +change the conditions under which we consider a connection dead. The +following would make all sockets respect 9*7200 = 18 hours:: import socket - conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True), - (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200), - (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ]) + def socket_factory(addr_info): + family, type_, proto, _, _, _ = addr_info + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) + return sock + conn = aiohttp.TCPConnector(socket_factory=socket_factory) Named pipes in Windows diff --git a/docs/client_reference.rst b/docs/client_reference.rst index e1128934631..539ea702752 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1129,7 +1129,7 @@ is controlled by *force_close* constructor's parameter). force_close=False, limit=100, limit_per_host=0, \ enable_cleanup_closed=False, timeout_ceil_threshold=5, \ happy_eyeballs_delay=0.25, interleave=None, loop=None, \ - tcp_sockopts=[]) + socket_factory=None) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -1250,9 +1250,9 @@ is controlled by *force_close* constructor's parameter). .. versionadded:: 3.10 - :param list tcp_sockopts: options applied to the socket when a connection is - created. This should be a list of 3-tuples, each a ``(level, optname, value)``. - Each tuple is deconstructed and passed verbatim to ``.setsockopt``. + :param ``aiohappyeyeballs.SocketFactoryType`` socket_factory: This function takes + an ``aiohappyeyeballs.AddrInfoType`` and is used in lieu of ``socket.socket()`` + when creating TCP connections. .. versionadded:: 3.12 diff --git a/docs/conf.py b/docs/conf.py index 2deabea1b4f..52f86294f60 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -82,6 +82,7 @@ "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), + "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None), } # Add any paths that contain templates here, relative to this directory. diff --git a/tests/test_connector.py b/tests/test_connector.py index 076ed556971..599009e454d 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -3767,27 +3767,48 @@ def test_connect() -> Literal[True]: assert raw_response_list == [True, True] -async def test_tcp_connector_setsockopts( +async def test_tcp_connector_socket_factory( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock ) -> None: - """Check that sockopts get passed to socket""" - conn = aiohttp.TCPConnector( - tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)] - ) - - with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True - ) as create_connection: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - start_connection.return_value = s - create_connection.return_value = mock.Mock(), mock.Mock() + """Check that socket factory is called""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + start_connection.return_value = s - req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + local_addr = None + socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s + happy_eyeballs_delay = 0.123 + interleave = 3 + conn = aiohttp.TCPConnector( + interleave=interleave, + local_addr=local_addr, + happy_eyeballs_delay=happy_eyeballs_delay, + socket_factory=socket_factory, + ) + with mock.patch.object( + conn._loop, + "create_connection", + autospec=True, + spec_set=True, + return_value=(mock.Mock(), mock.Mock()), + ) as create_connection: + host = "127.0.0.1" + port = 443 + req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2 - - await conn.close() + pass + await conn.close() + + start_connection.assert_called_with( + addr_infos=[ + (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port)) + ], + local_addr_infos=local_addr, + happy_eyeballs_delay=happy_eyeballs_delay, + interleave=interleave, + loop=loop, + socket_factory=socket_factory, + ) def test_default_ssl_context_creation_without_ssl() -> None: