Skip to content

Commit

Permalink
Replace tcp_sockopts with socket_factory (#10520)
Browse files Browse the repository at this point in the history
Instead of TCPConnector taking a list of sockopts to be applied sockets
created, take a socket_factory callback that allows the caller to
implement socket creation entirely.
  • Loading branch information
TimMenninger committed Mar 11, 2025
1 parent 4399a6c commit 56fc83d
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 44 deletions.
2 changes: 0 additions & 2 deletions CHANGES/10474.feature.rst

This file was deleted.

2 changes: 2 additions & 0 deletions CHANGES/10520.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added ``socket_factory`` to ``TCPConnector`` to allow specifying custom socket options
-- by :user:`TimMenninger`.
3 changes: 3 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
WSServerHandshakeError,
request,
)
from .connector import AddrInfoType, SocketFactoryType
from .cookiejar import CookieJar, DummyCookieJar
from .formdata import FormData
from .helpers import BasicAuth, ChainMapProxy, ETag
Expand Down Expand Up @@ -112,6 +113,7 @@
__all__: Tuple[str, ...] = (
"hdrs",
# client
"AddrInfoType",
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
Expand Down Expand Up @@ -146,6 +148,7 @@
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketFactoryType",
"SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
Expand Down
29 changes: 18 additions & 11 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
DefaultDict,
Deque,
Dict,
Iterable,
Iterator,
List,
Literal,
Expand All @@ -34,6 +33,7 @@
)

import aiohappyeyeballs
from aiohappyeyeballs import AddrInfoType, SocketFactoryType

from . import hdrs, helpers
from .abc import AbstractResolver, ResolveResult
Expand Down Expand Up @@ -96,7 +96,14 @@
# which first appeared in Python 3.12.7 and 3.13.1


__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
__all__ = (
"BaseConnector",
"TCPConnector",
"UnixConnector",
"NamedPipeConnector",
"AddrInfoType",
"SocketFactoryType",
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -826,8 +833,9 @@ class TCPConnector(BaseConnector):
the happy eyeballs algorithm, set to None.
interleave - “First Address Family Count” as defined in RFC 8305
loop - Optional event loop.
tcp_sockopts - List of tuples of sockopts applied to underlying
socket
socket_factory - A SocketFactoryType function that, if supplied,
will be used to create sockets given an
AddrInfoType.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
Expand All @@ -849,7 +857,7 @@ def __init__(
timeout_ceil_threshold: float = 5,
happy_eyeballs_delay: Optional[float] = 0.25,
interleave: Optional[int] = None,
tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [],
socket_factory: Optional[SocketFactoryType] = None,
):
super().__init__(
keepalive_timeout=keepalive_timeout,
Expand Down Expand Up @@ -880,7 +888,7 @@ def __init__(
self._happy_eyeballs_delay = happy_eyeballs_delay
self._interleave = interleave
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
self._tcp_sockopts = tcp_sockopts
self._socket_factory = socket_factory

def _close_immediately(self) -> List[Awaitable[object]]:
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
Expand Down Expand Up @@ -1105,7 +1113,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
async def _wrap_create_connection(
self,
*args: Any,
addr_infos: List[aiohappyeyeballs.AddrInfoType],
addr_infos: List[AddrInfoType],
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
Expand All @@ -1122,9 +1130,8 @@ async def _wrap_create_connection(
happy_eyeballs_delay=self._happy_eyeballs_delay,
interleave=self._interleave,
loop=self._loop,
socket_factory=self._socket_factory,
)
for sockopt in self._tcp_sockopts:
sock.setsockopt(*sockopt)
connection = await self._loop.create_connection(
*args, **kwargs, sock=sock
)
Expand Down Expand Up @@ -1256,13 +1263,13 @@ async def _start_tls_connection(

def _convert_hosts_to_addr_infos(
self, hosts: List[ResolveResult]
) -> List[aiohappyeyeballs.AddrInfoType]:
) -> List[AddrInfoType]:
"""Converts the list of hosts to a list of addr_infos.
The list of hosts is the result of a DNS lookup. The list of
addr_infos is the result of a call to `socket.getaddrinfo()`.
"""
addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
addr_infos: List[AddrInfoType] = []
for hinfo in hosts:
host = hinfo["host"]
is_ipv6 = ":" in host
Expand Down
23 changes: 15 additions & 8 deletions docs/client_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use
session = aiohttp.ClientSession(connector=conn)


Setting socket options
Custom socket creation
^^^^^^^^^^^^^^^^^^^^^^

Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed
to the underlying socket when creating a connection. For example, we may
want to change the conditions under which we consider a connection dead.
The following would change that to 9*7200 = 18 hours::
If the default socket is insufficient for your use case, pass an optional
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
`SocketFactoryType`. This will be used to create all sockets for the
lifetime of the class object. For example, we may want to change the
conditions under which we consider a connection dead. The following would
make all sockets respect 9*7200 = 18 hours::

import socket

conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200),
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ])
def socket_factory(addr_info):
family, type_, proto, _, _, _ = addr_info
sock = socket.socket(family=family, type=type_, proto=proto)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9)
return sock
conn = aiohttp.TCPConnector(socket_factory=socket_factory)


Named pipes in Windows
Expand Down
18 changes: 14 additions & 4 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1122,14 +1122,24 @@ is controlled by *force_close* constructor's parameter).
overridden in subclasses.


.. autodata:: AddrInfoType

Refer to :py:data:`aiohappyeyeballs.AddrInfoType`


.. autodata:: SocketFactoryType

Refer to :py:data:`aiohappyeyeballs.SocketFactoryType`


.. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \
use_dns_cache=True, ttl_dns_cache=10, \
family=0, ssl_context=None, local_addr=None, \
resolver=None, keepalive_timeout=sentinel, \
force_close=False, limit=100, limit_per_host=0, \
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
tcp_sockopts=[])
socket_factory=None)

Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.

Expand Down Expand Up @@ -1250,9 +1260,9 @@ is controlled by *force_close* constructor's parameter).

.. versionadded:: 3.10

:param list tcp_sockopts: options applied to the socket when a connection is
created. This should be a list of 3-tuples, each a ``(level, optname, value)``.
Each tuple is deconstructed and passed verbatim to ``<socket>.setsockopt``.
:param :py:data:``SocketFactoryType`` socket_factory: This function takes an
:py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when
creating TCP connections.

.. versionadded:: 3.12

Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
# ones.
extensions = [
# stdlib-party extensions:
"sphinx.ext.autodoc",
"sphinx.ext.extlinks",
"sphinx.ext.graphviz",
"sphinx.ext.intersphinx",
Expand Down Expand Up @@ -82,6 +83,7 @@
"aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None),
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime-deps.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`

aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
aiohappyeyeballs >= 2.3.0
aiohappyeyeballs >= 2.5.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
Brotli; platform_python_implementation == 'CPython'
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ zip_safe = False
include_package_data = True

install_requires =
aiohappyeyeballs >= 2.3.0
aiohappyeyeballs >= 2.5.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
frozenlist >= 1.1.1
Expand Down
55 changes: 38 additions & 17 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from unittest import mock

import pytest
from aiohappyeyeballs import AddrInfoType
from pytest_mock import MockerFixture
from yarl import URL

Expand All @@ -44,6 +43,7 @@
from aiohttp.connector import (
_SSL_CONTEXT_UNVERIFIED,
_SSL_CONTEXT_VERIFIED,
AddrInfoType,
Connection,
TCPConnector,
_DNSCacheTable,
Expand Down Expand Up @@ -3767,27 +3767,48 @@ def test_connect() -> Literal[True]:
assert raw_response_list == [True, True]


async def test_tcp_connector_setsockopts(
async def test_tcp_connector_socket_factory(
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
) -> None:
"""Check that sockopts get passed to socket"""
conn = aiohttp.TCPConnector(
tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)]
)

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
start_connection.return_value = s
create_connection.return_value = mock.Mock(), mock.Mock()
"""Check that socket factory is called"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
start_connection.return_value = s

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)
local_addr = None
socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s
happy_eyeballs_delay = 0.123
interleave = 3
conn = aiohttp.TCPConnector(
interleave=interleave,
local_addr=local_addr,
happy_eyeballs_delay=happy_eyeballs_delay,
socket_factory=socket_factory,
)

with mock.patch.object(
conn._loop,
"create_connection",
autospec=True,
spec_set=True,
return_value=(mock.Mock(), mock.Mock()),
):
host = "127.0.0.1"
port = 443
req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop)
with closing(await conn.connect(req, [], ClientTimeout())):
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2

await conn.close()
pass
await conn.close()

start_connection.assert_called_with(
addr_infos=[
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port))
],
local_addr_infos=local_addr,
happy_eyeballs_delay=happy_eyeballs_delay,
interleave=interleave,
loop=loop,
socket_factory=socket_factory,
)


def test_default_ssl_context_creation_without_ssl() -> None:
Expand Down

0 comments on commit 56fc83d

Please sign in to comment.