Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace tcp_sockopts with socket_factory (#10520) #10534

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions CHANGES/10474.feature.rst

This file was deleted.

2 changes: 2 additions & 0 deletions CHANGES/10520.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added ``socket_factory`` to ``TCPConnector`` to allow specifying custom socket options
-- by :user:`TimMenninger`.
3 changes: 3 additions & 0 deletions aiohttp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
WSServerHandshakeError,
request,
)
from .connector import AddrInfoType, SocketFactoryType
from .cookiejar import CookieJar, DummyCookieJar
from .formdata import FormData
from .helpers import BasicAuth, ChainMapProxy, ETag
Expand Down Expand Up @@ -112,6 +113,7 @@
__all__: Tuple[str, ...] = (
"hdrs",
# client
"AddrInfoType",
"BaseConnector",
"ClientConnectionError",
"ClientConnectionResetError",
Expand Down Expand Up @@ -146,6 +148,7 @@
"ServerDisconnectedError",
"ServerFingerprintMismatch",
"ServerTimeoutError",
"SocketFactoryType",
"SocketTimeoutError",
"TCPConnector",
"TooManyRedirects",
Expand Down
29 changes: 18 additions & 11 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
DefaultDict,
Deque,
Dict,
Iterable,
Iterator,
List,
Literal,
Expand All @@ -34,6 +33,7 @@
)

import aiohappyeyeballs
from aiohappyeyeballs import AddrInfoType, SocketFactoryType

from . import hdrs, helpers
from .abc import AbstractResolver, ResolveResult
Expand Down Expand Up @@ -96,7 +96,14 @@
# which first appeared in Python 3.12.7 and 3.13.1


__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
__all__ = (
"BaseConnector",
"TCPConnector",
"UnixConnector",
"NamedPipeConnector",
"AddrInfoType",
"SocketFactoryType",
)


if TYPE_CHECKING:
Expand Down Expand Up @@ -826,8 +833,9 @@ class TCPConnector(BaseConnector):
the happy eyeballs algorithm, set to None.
interleave - “First Address Family Count” as defined in RFC 8305
loop - Optional event loop.
tcp_sockopts - List of tuples of sockopts applied to underlying
socket
socket_factory - A SocketFactoryType function that, if supplied,
will be used to create sockets given an
AddrInfoType.
"""

allowed_protocol_schema_set = HIGH_LEVEL_SCHEMA_SET | frozenset({"tcp"})
Expand All @@ -849,7 +857,7 @@ def __init__(
timeout_ceil_threshold: float = 5,
happy_eyeballs_delay: Optional[float] = 0.25,
interleave: Optional[int] = None,
tcp_sockopts: Iterable[Tuple[int, int, Union[int, Buffer]]] = [],
socket_factory: Optional[SocketFactoryType] = None,
):
super().__init__(
keepalive_timeout=keepalive_timeout,
Expand Down Expand Up @@ -880,7 +888,7 @@ def __init__(
self._happy_eyeballs_delay = happy_eyeballs_delay
self._interleave = interleave
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
self._tcp_sockopts = tcp_sockopts
self._socket_factory = socket_factory

def _close_immediately(self) -> List[Awaitable[object]]:
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
Expand Down Expand Up @@ -1105,7 +1113,7 @@ def _get_fingerprint(self, req: ClientRequest) -> Optional["Fingerprint"]:
async def _wrap_create_connection(
self,
*args: Any,
addr_infos: List[aiohappyeyeballs.AddrInfoType],
addr_infos: List[AddrInfoType],
req: ClientRequest,
timeout: "ClientTimeout",
client_error: Type[Exception] = ClientConnectorError,
Expand All @@ -1122,9 +1130,8 @@ async def _wrap_create_connection(
happy_eyeballs_delay=self._happy_eyeballs_delay,
interleave=self._interleave,
loop=self._loop,
socket_factory=self._socket_factory,
)
for sockopt in self._tcp_sockopts:
sock.setsockopt(*sockopt)
connection = await self._loop.create_connection(
*args, **kwargs, sock=sock
)
Expand Down Expand Up @@ -1256,13 +1263,13 @@ async def _start_tls_connection(

def _convert_hosts_to_addr_infos(
self, hosts: List[ResolveResult]
) -> List[aiohappyeyeballs.AddrInfoType]:
) -> List[AddrInfoType]:
"""Converts the list of hosts to a list of addr_infos.

The list of hosts is the result of a DNS lookup. The list of
addr_infos is the result of a call to `socket.getaddrinfo()`.
"""
addr_infos: List[aiohappyeyeballs.AddrInfoType] = []
addr_infos: List[AddrInfoType] = []
for hinfo in hosts:
host = hinfo["host"]
is_ipv6 = ":" in host
Expand Down
23 changes: 15 additions & 8 deletions docs/client_advanced.rst
Original file line number Diff line number Diff line change
Expand Up @@ -468,19 +468,26 @@ If your HTTP server uses UNIX domain sockets you can use
session = aiohttp.ClientSession(connector=conn)


Setting socket options
Custom socket creation
^^^^^^^^^^^^^^^^^^^^^^

Socket options passed to the :class:`~aiohttp.TCPConnector` will be passed
to the underlying socket when creating a connection. For example, we may
want to change the conditions under which we consider a connection dead.
The following would change that to 9*7200 = 18 hours::
If the default socket is insufficient for your use case, pass an optional
`socket_factory` to the :class:`~aiohttp.TCPConnector`, which implements
`SocketFactoryType`. This will be used to create all sockets for the
lifetime of the class object. For example, we may want to change the
conditions under which we consider a connection dead. The following would
make all sockets respect 9*7200 = 18 hours::

import socket

conn = aiohttp.TCPConnector(tcp_sockopts=[(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True),
(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200),
(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9) ])
def socket_factory(addr_info):
family, type_, proto, _, _, _ = addr_info
sock = socket.socket(family=family, type=type_, proto=proto)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, True)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPIDLE, 7200)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 9)
return sock
conn = aiohttp.TCPConnector(socket_factory=socket_factory)


Named pipes in Windows
Expand Down
18 changes: 14 additions & 4 deletions docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1122,14 +1122,24 @@ is controlled by *force_close* constructor's parameter).
overridden in subclasses.


.. autodata:: AddrInfoType

Refer to :py:data:`aiohappyeyeballs.AddrInfoType`


.. autodata:: SocketFactoryType

Refer to :py:data:`aiohappyeyeballs.SocketFactoryType`


.. class:: TCPConnector(*, ssl=True, verify_ssl=True, fingerprint=None, \
use_dns_cache=True, ttl_dns_cache=10, \
family=0, ssl_context=None, local_addr=None, \
resolver=None, keepalive_timeout=sentinel, \
force_close=False, limit=100, limit_per_host=0, \
enable_cleanup_closed=False, timeout_ceil_threshold=5, \
happy_eyeballs_delay=0.25, interleave=None, loop=None, \
tcp_sockopts=[])
socket_factory=None)

Connector for working with *HTTP* and *HTTPS* via *TCP* sockets.

Expand Down Expand Up @@ -1250,9 +1260,9 @@ is controlled by *force_close* constructor's parameter).

.. versionadded:: 3.10

:param list tcp_sockopts: options applied to the socket when a connection is
created. This should be a list of 3-tuples, each a ``(level, optname, value)``.
Each tuple is deconstructed and passed verbatim to ``<socket>.setsockopt``.
:param :py:data:``SocketFactoryType`` socket_factory: This function takes an
:py:data:``AddrInfoType`` and is used in lieu of ``socket.socket()`` when
creating TCP connections.

.. versionadded:: 3.12

Expand Down
3 changes: 3 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
# ones.
extensions = [
# stdlib-party extensions:
"sphinx.ext.autodoc",
"sphinx.ext.extlinks",
"sphinx.ext.graphviz",
"sphinx.ext.intersphinx",
Expand Down Expand Up @@ -82,6 +83,7 @@
"aiohttpsession": ("https://aiohttp-session.readthedocs.io/en/stable/", None),
"aiohttpdemos": ("https://aiohttp-demos.readthedocs.io/en/latest/", None),
"aiojobs": ("https://aiojobs.readthedocs.io/en/stable/", None),
"aiohappyeyeballs": ("https://aiohappyeyeballs.readthedocs.io/en/stable/", None),
}

# Add any paths that contain templates here, relative to this directory.
Expand Down Expand Up @@ -425,6 +427,7 @@
("py:class", "cgi.FieldStorage"), # undocumented
("py:meth", "aiohttp.web.UrlDispatcher.register_resource"), # undocumented
("py:func", "aiohttp_debugtoolbar.setup"), # undocumented
("py:class", "socket.SocketKind"), # undocumented
]

# -- Options for towncrier_draft extension -----------------------------------
Expand Down
2 changes: 1 addition & 1 deletion requirements/runtime-deps.in
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Extracted from `setup.cfg` via `make sync-direct-runtime-deps`

aiodns >= 3.2.0; sys_platform=="linux" or sys_platform=="darwin"
aiohappyeyeballs >= 2.3.0
aiohappyeyeballs >= 2.5.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
Brotli; platform_python_implementation == 'CPython'
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ zip_safe = False
include_package_data = True

install_requires =
aiohappyeyeballs >= 2.3.0
aiohappyeyeballs >= 2.5.0
aiosignal >= 1.1.2
async-timeout >= 4.0, < 6.0 ; python_version < "3.11"
frozenlist >= 1.1.1
Expand Down
55 changes: 38 additions & 17 deletions tests/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from unittest import mock

import pytest
from aiohappyeyeballs import AddrInfoType
from pytest_mock import MockerFixture
from yarl import URL

Expand All @@ -44,6 +43,7 @@
from aiohttp.connector import (
_SSL_CONTEXT_UNVERIFIED,
_SSL_CONTEXT_VERIFIED,
AddrInfoType,
Connection,
TCPConnector,
_DNSCacheTable,
Expand Down Expand Up @@ -3767,27 +3767,48 @@ def test_connect() -> Literal[True]:
assert raw_response_list == [True, True]


async def test_tcp_connector_setsockopts(
async def test_tcp_connector_socket_factory(
loop: asyncio.AbstractEventLoop, start_connection: mock.AsyncMock
) -> None:
"""Check that sockopts get passed to socket"""
conn = aiohttp.TCPConnector(
tcp_sockopts=[(socket.IPPROTO_TCP, socket.TCP_KEEPCNT, 2)]
)

with mock.patch.object(
conn._loop, "create_connection", autospec=True, spec_set=True
) as create_connection:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
start_connection.return_value = s
create_connection.return_value = mock.Mock(), mock.Mock()
"""Check that socket factory is called"""
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
start_connection.return_value = s

req = ClientRequest("GET", URL("https://127.0.0.1:443"), loop=loop)
local_addr = None
socket_factory: Callable[[AddrInfoType], socket.socket] = lambda _: s
happy_eyeballs_delay = 0.123
interleave = 3
conn = aiohttp.TCPConnector(
interleave=interleave,
local_addr=local_addr,
happy_eyeballs_delay=happy_eyeballs_delay,
socket_factory=socket_factory,
)

with mock.patch.object(
conn._loop,
"create_connection",
autospec=True,
spec_set=True,
return_value=(mock.Mock(), mock.Mock()),
):
host = "127.0.0.1"
port = 443
req = ClientRequest("GET", URL(f"https://{host}:{port}"), loop=loop)
with closing(await conn.connect(req, [], ClientTimeout())):
assert s.getsockopt(socket.IPPROTO_TCP, socket.TCP_KEEPCNT) == 2

await conn.close()
pass
await conn.close()

start_connection.assert_called_with(
addr_infos=[
(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP, "", (host, port))
],
local_addr_infos=local_addr,
happy_eyeballs_delay=happy_eyeballs_delay,
interleave=interleave,
loop=loop,
socket_factory=socket_factory,
)


def test_default_ssl_context_creation_without_ssl() -> None:
Expand Down
Loading