diff --git a/CHANGES/10474.feature.rst b/CHANGES/10474.feature.rst deleted file mode 100644 index d5d6e4b40b9..00000000000 --- a/CHANGES/10474.feature.rst +++ /dev/null @@ -1,2 +0,0 @@ -Added ``tcp_sockopts`` to ``TCPConnector`` to allow specifying custom socket options --- by :user:`TimMenninger`. diff --git a/CHANGES/10520.feature.rst b/CHANGES/10520.feature.rst new file mode 100644 index 00000000000..f03cc1f26bd --- /dev/null +++ b/CHANGES/10520.feature.rst @@ -0,0 +1,2 @@ +Added ``socket_factory`` to ``TCPConnector`` to allow specifying custom socket options +-- by :user:`TimMenninger`. diff --git a/aiohttp/__init__.py b/aiohttp/__init__.py index f7864247791..7759a997cb9 100644 --- a/aiohttp/__init__.py +++ b/aiohttp/__init__.py @@ -47,6 +47,7 @@ WSServerHandshakeError, request, ) +from .connector import AddrInfoType, SocketFactoryType from .cookiejar import CookieJar, DummyCookieJar from .formdata import FormData from .helpers import BasicAuth, ChainMapProxy, ETag @@ -112,6 +113,7 @@ __all__: Tuple[str, ...] = ( "hdrs", # client + "AddrInfoType", "BaseConnector", "ClientConnectionError", "ClientConnectionResetError", @@ -146,6 +148,7 @@ "ServerDisconnectedError", "ServerFingerprintMismatch", "ServerTimeoutError", + "SocketFactoryType", "SocketTimeoutError", "TCPConnector", "TooManyRedirects", diff --git a/aiohttp/connector.py b/aiohttp/connector.py index 8a3f1bcbf2b..37e1d568ba5 100644 --- a/aiohttp/connector.py +++ b/aiohttp/connector.py @@ -20,7 +20,6 @@ DefaultDict, Deque, Dict, - Iterable, Iterator, List, Literal, @@ -34,6 +33,7 @@ ) import aiohappyeyeballs +from aiohappyeyeballs import AddrInfoType, SocketFactoryType from . import hdrs, helpers from .abc import AbstractResolver, ResolveResult @@ -96,7 +96,14 @@ # which first appeared in Python 3.12.7 and 3.13.1 -__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector") +__all__ = ( + "BaseConnector", + "TCPConnector", + "UnixConnector", + "NamedPipeConnector", + "AddrInfoType", + "SocketFactoryType", +) if TYPE_CHECKING: @@ -826,8 +833,9 @@ class TCPConnector(BaseConnector): the happy eyeballs algorithm, set to None. interleave - “First Address Family Count” as defined in RFC 8305 loop - Optional event loop. - tcp_sockopts - List of tuples of sockopts applied to underlying - socket + socket_factory - A SocketFactoryType function that, if supplied, + will be used to create sockets given an + AddrInfoType. """ allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"}) @@ -849,7 +857,7 @@ def __init__( timeout_ceil_threshold: float = 5, happy_eyeballs_delay: Optional[float] = 0.25, interleave: Optional[int] = None, - tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [], + socket_factory: Optional[SocketFactoryType] = None, ): super().__init__( keepalive_timeout=keepalive_timeout, @@ -880,7 +888,7 @@ def __init__( self._happy_eyeballs_delay = happy_eyeballs_delay self._interleave = interleave self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set() - self._tcp_sockopts = tcp_sockopts + self._socket_factory = socket_factory def _close_immediately(self) -> List[Awaitable[object]]: for fut in chain.from_iterable(self._throttle_dns_futures.values()): @@ -1105,7 +1113,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]: async def _wrap_create_connection( self, *args: Any, - addr_infos: List[aiohappyeyeballs.AddrInfoType], + addr_infos: List[AddrInfoType], req: ClientRequest, timeout: "ClientTimeout", client_error: Type[Exception] = ClientConnectorError, @@ -1122,9 +1130,8 @@ async def _wrap_create_connection( happy_eyeballs_delay=self._happy_eyeballs_delay, interleave=self._interleave, loop=self._loop, + socket_factory=self._socket_factory, ) - for sockopt in self._tcp_sockopts: - sock.setsockopt(*sockopt) connection = await self._loop.create_connection( *args, **kwargs, sock=sock ) @@ -1256,13 +1263,13 @@ async def _start_tls_connection( def _convert_hosts_to_addr_infos( self, hosts: List[ResolveResult] - ) -> List[aiohappyeyeballs.AddrInfoType]: + ) -> List[AddrInfoType]: """Converts the list of hosts to a list of addr_infos. The list of hosts is the result of a DNS lookup. The list of addr_infos is the result of a call to `socket.getaddrinfo()`. """ - addr_infos: List[aiohappyeyeballs.AddrInfoType] = [] + addr_infos: List[AddrInfoType] = [] for hinfo in hosts: host = hinfo["host"] is_ipv6 = ":" in host diff --git a/docs/client_advanced.rst b/docs/client_advanced.rst index 8f34fefaf81..4b0a878d715 100644 --- a/docs/client_advanced.rst +++ b/docs/client_advanced.rst @@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use session = aiohttp.ClientSession(connector=conn) -Setting socket options +Custom socket creation ^^^^^^^^^^^^^^^^^^^^^^ -Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed -to the underlying socket when creating a connection. For example, we may -want to change the conditions under which we consider a connection dead. -The following would change that to 9*7200 = 18 hours:: +If the default socket is insufficient for your use case, pass an optional +`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements +`SocketFactoryType`. This will be used to create all sockets for the +lifetime of the class object. For example, we may want to change the +conditions under which we consider a connection dead. The following would +make all sockets respect 9*7200 = 18 hours:: import socket - conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True), - (socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200), - (socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ]) + def socket_factory(addr_info): + family, type_, proto, _, _, _ = addr_info + sock = socket.socket(family=family, type=type_, proto=proto) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200) + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) + return sock + conn = aiohttp.TCPConnector(socket_factory=socket_factory) Named pipes in Windows diff --git a/docs/client_reference.rst b/docs/client_reference.rst index e1128934631..dad3761e868 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1122,6 +1122,16 @@ is controlled by *force_close* constructor's parameter). overridden in subclasses. +.. autodata:: AddrInfoType + + Refer to :py:data:`aiohappyeyeballs.AddrInfoType` + + +.. autodata:: SocketFactoryType + + Refer to :py:data:`aiohappyeyeballs.SocketFactoryType` + + .. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \ use_dns_cache=True, ttl_dns_cache=10, \ family=0, ssl_context=None, local_addr=None, \ @@ -1129,7 +1139,7 @@ is controlled by *force_close* constructor's parameter). force_close=False, limit=100, limit_per_host=0, \ enable_cleanup_closed=False, timeout_ceil_threshold=5, \ happy_eyeballs_delay=0.25, interleave=None, loop=None, \ - tcp_sockopts=[]) + socket_factory=None) Connector for working with *HTTP* and *HTTPS* via *TCP* sockets. @@ -1250,9 +1260,9 @@ is controlled by *force_close* constructor's parameter). .. versionadded:: 3.10 - :param list tcp_sockopts: options applied to the socket when a connection is - created. This should be a list of 3-tuples, each a ``(level, optname, value)``. - Each tuple is deconstructed and passed verbatim to ``.setsockopt``. + :param :py:data:``SocketFactoryType`` socket_factory: This function takes an + :py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when + creating TCP connections. .. versionadded:: 3.12 diff --git a/docs/conf.py b/docs/conf.py index 2deabea1b4f..eba93188b44 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -53,6 +53,7 @@ # ones. extensions = [ # stdlib-party extensions: + "sphinx.ext.autodoc", "sphinx.ext.extlinks", "sphinx.ext.graphviz", "sphinx.ext.intersphinx", @@ -82,6 +83,7 @@ "aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None), "aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None), "aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None), + "aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None), } # Add any paths that contain templates here, relative to this directory. @@ -425,6 +427,7 @@ ("py:class", "cgi.FieldStorage"), # undocumented ("py:meth", "aiohttp.web.UrlDispatcher.register_resource"), # undocumented ("py:func", "aiohttp_debugtoolbar.setup"), # undocumented + ("py:class", "socket.SocketKind"), # undocumented ] # -- Options for towncrier_draft extension ----------------------------------- diff --git a/requirements/runtime-deps.in b/requirements/runtime-deps.in index 48ee7016d13..4dcf7a1dea3 100644 --- a/requirements/runtime-deps.in +++ b/requirements/runtime-deps.in @@ -1,7 +1,7 @@ # Extracted from `setup.cfg` via `make sync-direct-runtime-deps` aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin" -aiohappyeyeballs >= 2.3.0 +aiohappyeyeballs >= 2.5.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 6.0 ; python_version < "3.11" Brotli; platform_python_implementation == 'CPython' diff --git a/setup.cfg b/setup.cfg index 66b779b8db9..674d9ed7c44 100644 --- a/setup.cfg +++ b/setup.cfg @@ -51,7 +51,7 @@ zip_safe = False include_package_data = True install_requires = - aiohappyeyeballs >= 2.3.0 + aiohappyeyeballs >= 2.5.0 aiosignal >= 1.1.2 async-timeout >= 4.0, < 6.0 ; python_version < "3.11" frozenlist >= 1.1.1 diff --git a/tests/test_connector.py b/tests/test_connector.py index 076ed556971..6a04ec472f9 100644 --- a/tests/test_connector.py +++ b/tests/test_connector.py @@ -26,7 +26,6 @@ from unittest import mock import pytest -from aiohappyeyeballs import AddrInfoType from pytest_mock import MockerFixture from yarl import URL @@ -44,6 +43,7 @@ from aiohttp.connector import ( _SSL_CONTEXT_UNVERIFIED, _SSL_CONTEXT_VERIFIED, + AddrInfoType, Connection, TCPConnector, _DNSCacheTable, @@ -3767,27 +3767,48 @@ def test_connect() -> Literal[True]: assert raw_response_list == [True, True] -async def test_tcp_connector_setsockopts( +async def test_tcp_connector_socket_factory( loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock ) -> None: - """Check that sockopts get passed to socket""" - conn = aiohttp.TCPConnector( - tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)] - ) - - with mock.patch.object( - conn._loop, "create_connection", autospec=True, spec_set=True - ) as create_connection: - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - start_connection.return_value = s - create_connection.return_value = mock.Mock(), mock.Mock() + """Check that socket factory is called""" + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + start_connection.return_value = s - req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop) + local_addr = None + socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s + happy_eyeballs_delay = 0.123 + interleave = 3 + conn = aiohttp.TCPConnector( + interleave=interleave, + local_addr=local_addr, + happy_eyeballs_delay=happy_eyeballs_delay, + socket_factory=socket_factory, + ) + with mock.patch.object( + conn._loop, + "create_connection", + autospec=True, + spec_set=True, + return_value=(mock.Mock(), mock.Mock()), + ): + host = "127.0.0.1" + port = 443 + req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop) with closing(await conn.connect(req, [], ClientTimeout())): - assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2 - - await conn.close() + pass + await conn.close() + + start_connection.assert_called_with( + addr_infos=[ + (socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port)) + ], + local_addr_infos=local_addr, + happy_eyeballs_delay=happy_eyeballs_delay, + interleave=interleave, + loop=loop, + socket_factory=socket_factory, + ) def test_default_ssl_context_creation_without_ssl() -> None: