diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 376ef684..86b4a03d 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9", "3.10.0-beta.4"] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"] steps: - uses: "actions/checkout@v2" diff --git a/CHANGELOG.md b/CHANGELOG.md index e269872f..501a1b62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,22 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). +## 0.14.0 + +The 0.14 release is a complete reworking of `httpcore`, comprehensively addressing some underlying issues in the connection pooling, as well as substantially redesigning the API to be more user friendly. + +Some of the lower-level API design also makes the components more easily testable in isolation, and the package now has 100% test coverage. + +See [discussion #419](https://github.com/encode/httpcore/discussions/419) for a little more background. + +There's some other neat bits in there too, such as the "trace" extension, which gives a hook into inspecting the internal events that occur during the request/response cycle. This extension is needed for the HTTPX cli, in order to... + +* Log the point at which the connection is established, and the IP/port on which it is made. +* Determine if the outgoing request should log as HTTP/1.1 or HTTP/2, rather than having to assume it's HTTP/2 if the --http2 flag was passed. (Which may not actually be true.) +* Log SSL version info / certificate info. + +Note that `curio` support is not currently available in 0.14.0. If you're using `httpcore` with `curio` please get in touch, so we can assess if we ought to prioritize it as a feature or not. + ## 0.13.7 (September 13th, 2021) - Fix broken error messaging when URL scheme is missing, or a non HTTP(S) scheme is used. (Pull #403) diff --git a/README.md b/README.md index c4cee5e6..003ab4a6 100644 --- a/README.md +++ b/README.md @@ -17,61 +17,59 @@ defaults, or any of that Jazz. Some things HTTP Core does do: * Sending HTTP requests. -* Provides both sync and async interfaces. -* Supports HTTP/1.1 and HTTP/2. -* Async backend support for `asyncio`, `trio` and `curio`. -* Automatic connection pooling. +* Thread-safe / task-safe connection pooling. * HTTP(S) proxy support. +* Supports HTTP/1.1 and HTTP/2. +* Provides both sync and async interfaces. +* Async backend support for `asyncio` and `trio`. ## Installation For HTTP/1.1 only support, install with... ```shell -$ pip install httpcore +$ pip install git+https://github.com/encode/httpcore ``` For HTTP/1.1 and HTTP/2 support, install with... ```shell -$ pip install httpcore[http2] +$ pip install git+https://github.com/encode/httpcore[http2] ``` -## Quickstart +# Sending requests -Here's an example of making an HTTP GET request using `httpcore`... +Send an HTTP request: ```python -with httpcore.SyncConnectionPool() as http: - status_code, headers, stream, extensions = http.handle_request( - method=b'GET', - url=(b'https', b'example.org', 443, b'/'), - headers=[(b'host', b'example.org'), (b'user-agent', b'httpcore')], - stream=httpcore.ByteStream(b''), - extensions={} - ) - body = stream.read() - print(status_code, body) +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'\n\n\nExample Domain\n\n\n ...' ``` -Or, using async... +The top-level `httpcore.request()` function is provided for convenience. In practice whenever you're working with `httpcore` you'll want to use the connection pooling functionality that it provides. ```python -async with httpcore.AsyncConnectionPool() as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b'GET', - url=(b'https', b'example.org', 443, b'/'), - headers=[(b'host', b'example.org'), (b'user-agent', b'httpcore')], - stream=httpcore.ByteStream(b''), - extensions={} - ) - body = await stream.aread() - print(status_code, body) +import httpcore + +pool = httpcore.ConnectionPool() +response = pool.request("GET", "https://www.example.com/") ``` +Once you're ready to get going, [head over to the documentation](https://www.encode.io/httpcore/). + ## Motivation -You probably don't want to be using HTTP Core directly. It might make sense if +You *probably* don't want to be using HTTP Core directly. It might make sense if you're writing something like a proxy service in Python, and you just want something at the lowest possible level, but more typically you'll want to use a higher level client library, such as `httpx`. diff --git a/docs/api.md b/docs/api.md deleted file mode 100644 index c6a2291b..00000000 --- a/docs/api.md +++ /dev/null @@ -1,82 +0,0 @@ -# Developer Interface - -## Async API Overview - -### Base async interfaces - -These classes provide the base interface which transport classes need to implement. - -:::{eval-rst} -.. autoclass:: httpcore.AsyncHTTPTransport - :members: handle_async_request, aclose - -.. autoclass:: httpcore.AsyncByteStream - :members: __aiter__, aclose -::: - -### Async connection pool - -:::{eval-rst} -.. autoclass:: httpcore.AsyncConnectionPool - :show-inheritance: -::: - -### Async proxy - -:::{eval-rst} -.. autoclass:: httpcore.AsyncHTTPProxy - :show-inheritance: -::: - -### Async byte streams - -These classes are concrete implementations of [`AsyncByteStream`](httpcore.AsyncByteStream). - -:::{eval-rst} -.. autoclass:: httpcore.ByteStream - :show-inheritance: - -.. autoclass:: httpcore.AsyncIteratorByteStream - :show-inheritance: -::: - -## Sync API Overview - -### Base sync interfaces - -These classes provide the base interface which transport classes need to implement. - -:::{eval-rst} -.. autoclass:: httpcore.SyncHTTPTransport - :members: request, close - -.. autoclass:: httpcore.SyncByteStream - :members: __iter__, close -::: - -### Sync connection pool - -:::{eval-rst} -.. autoclass:: httpcore.SyncConnectionPool - :show-inheritance: -::: - -### Sync proxy - -:::{eval-rst} -.. autoclass:: httpcore.SyncHTTPProxy - :show-inheritance: -::: - -### Sync byte streams - -These classes are concrete implementations of [`SyncByteStream`](httpcore.SyncByteStream). - -:::{eval-rst} -.. autoclass:: httpcore.ByteStream - :show-inheritance: - :noindex: - -.. autoclass:: httpcore.IteratorByteStream - :show-inheritance: -::: diff --git a/docs/async.md b/docs/async.md new file mode 100644 index 00000000..1023e04a --- /dev/null +++ b/docs/async.md @@ -0,0 +1,214 @@ +# Async Support + +HTTPX offers a standard synchronous API by default, but also gives you the option of an async client if you need it. + +Async is a concurrency model that is far more efficient than multi-threading, and can provide significant performance benefits and enable the use of long-lived network connections such as WebSockets. + +If you're working with an async web framework then you'll also want to use an async client for sending outgoing HTTP requests. + +Launching concurrent async tasks is far more resource efficient than spawning multiple threads. The Python interpreter should be able to comfortably handle switching between over 1000 concurrent tasks, while a sensible number of threads in a thread pool might be to enable around 10 or 20 concurrent threads. + +## API differences + +When using async support, you need make sure to use an async connection pool class: + +```python +# The async variation of `httpcore.ConnectionPool` +async with httpcore.AsyncConnectionPool() as http: + ... +``` + +Or if connecting via a proxy: + +```python +# The async variation of `httpcore.HTTPProxy` +async with httpcore.AsyncHTTPProxy() as proxy: + ... +``` + +### Sending requests + +Sending requests with the async version of `httpcore` requires the `await` keyword: + +```python +import asyncio +import httpcore + +async def main(): + async with httpcore.AsyncConnectionPool() as http: + response = await http.request("GET", "https://www.example.com/") + + +asyncio.run(main()) +``` + +When including content in the request, the content must either be bytes or an *async iterable* yielding bytes. + +### Streaming responses + +Streaming responses also require a slightly different interface to the sync version: + +* `with .stream(...) as response` → `async with .stream() as response`. +* `for chunk in response.iter_stream()` → `async for chunk in response.aiter_stream()`. +* `response.read()` → `await response.aread()`. +* `response.close()` → `await response.aclose()` + +For example: + +```python +import asyncio +import httpcore + + +async def main(): + async with httpcore.AsyncConnectionPool() as http: + async with http.stream("GET", "https://www.example.com/") as response: + async for chunk in response.aiter_stream(): + print(f"Downloaded: {chunk}") + + +asyncio.run(main()) +``` + +### Pool lifespans + +When using `httpcore` in an async environment it is strongly recommended that you instantiate and use connection pools using the context managed style: + +```python +async with httpcore.AsyncConnectionPool() as http: + ... +``` + +To benefit from connection pooling it is recommended that you instantiate a single connection pool in this style, and pass it around throughout your application. + +If you do want to use a connection pool without this style then you'll need to ensure that you explicitly close the pool once it is no longer required: + +```python +try: + http = httpcore.AsyncConnectionPool() + ... +finally: + await http.aclose() +``` + +This is a little different to the threaded context, where it's okay to simply instantiate a globally available connection pool, and then allow Python's garbage collection to deal with closing any connections in the pool, once the `__del__` method is called. + +The reason for this difference is that asynchronous code is not able to run within the context of the synchronous `__del__` method, so there is no way for connections to be automatically closed at the point of garbage collection. This can lead to unterminated TCP connections still remaining after the Python interpreter quits. + +## Supported environments + +HTTPX supports either `asyncio` or `trio` as an async environment. + +It will auto-detect which of those two to use as the backend for socket operations and concurrency primitives. + +### AsyncIO + +AsyncIO is Python's [built-in library](https://docs.python.org/3/library/asyncio.html) for writing concurrent code with the async/await syntax. + +Let's take a look at sending several outgoing HTTP requests concurrently, using `asyncio`: + +```python +import asyncio +import httpcore +import time + + +async def download(http, year): + await http.request("GET", f"https://en.wikipedia.org/wiki/{year}") + + +async def main(): + async with httpcore.AsyncConnectionPool() as http: + started = time.time() + # Here we use `asyncio.gather()` in order to run several tasks concurrently... + tasks = [download(http, year) for year in range(2000, 2020)] + await asyncio.gather(*tasks) + complete = time.time() + + for connection in http.connections: + print(connection) + print("Complete in %.3f seconds" % (complete - started)) + + +asyncio.run(main()) +``` + +### Trio + +Trio is [an alternative async library](https://trio.readthedocs.io/en/stable/), designed around the [the principles of structured concurrency](https://en.wikipedia.org/wiki/Structured_concurrency). + +```python +import httpcore +import trio +import time + + +async def download(http, year): + await http.request("GET", f"https://en.wikipedia.org/wiki/{year}") + + +async def main(): + async with httpcore.AsyncConnectionPool() as http: + started = time.time() + async with trio.open_nursery() as nursery: + for year in range(2000, 2020): + nursery.start_soon(download, http, year) + complete = time.time() + + for connection in http.connections: + print(connection) + print("Complete in %.3f seconds" % (complete - started)) + + +trio.run(main) +``` + +### AnyIO + +AnyIO is an [asynchronous networking and concurrency library](https://anyio.readthedocs.io/) that works on top of either asyncio or trio. It blends in with native libraries of your chosen backend (defaults to asyncio). + +The `anyio` library is designed around the [the principles of structured concurrency](https://en.wikipedia.org/wiki/Structured_concurrency), and brings many of the same correctness and usability benefits that Trio provides, while interoperating with existing `asyncio` libraries. + +```python +import httpcore +import anyio +import time + + +async def download(http, year): + await http.request("GET", f"https://en.wikipedia.org/wiki/{year}") + + +async def main(): + async with httpcore.AsyncConnectionPool() as http: + started = time.time() + async with anyio.create_task_group() as task_group: + for year in range(2000, 2020): + task_group.start_soon(download, http, year) + complete = time.time() + + for connection in http.connections: + print(connection) + print("Complete in %.3f seconds" % (complete - started)) + + +anyio.run(main) +``` + +--- + +# Reference + +## `httpcore.AsyncConnectionPool` + +::: httpcore.AsyncConnectionPool + handler: python + rendering: + show_source: False + +## `httpcore.AsyncHTTPProxy` + +::: httpcore.AsyncHTTPProxy + handler: python + rendering: + show_source: False diff --git a/docs/conf.py b/docs/conf.py deleted file mode 100644 index e5f646c6..00000000 --- a/docs/conf.py +++ /dev/null @@ -1,60 +0,0 @@ -# See: https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -- - -import os -import sys - -# Allow sphinx-autodoc to access `httpcore` contents. -sys.path.insert(0, os.path.abspath(".")) - -# -- Project information -- - -project = "HTTPCore" -copyright = "2021, Encode" -author = "Encode" - -# -- General configuration -- - -extensions = [ - "myst_parser", - "sphinx.ext.autodoc", - "sphinx.ext.viewcode", - "sphinx.ext.napoleon", -] - -myst_enable_extensions = [ - "colon_fence", -] - -# Preserve :members: order. -autodoc_member_order = "bysource" - -# Show type hints in descriptions, rather than signatures. -autodoc_typehints = "description" - -# -- HTML configuration -- - -html_theme = "furo" - -# -- App setup -- - - -def _viewcode_follow_imported(app, modname, attribute): - # We set `__module__ = "httpcore"` on all public attributes for prettier - # repr(), so viewcode needs a little help to find the original source modules. - - if modname != "httpcore": - return None - - import httpcore - - try: - # Set in httpcore/__init__.py - return getattr(httpcore, attribute).__source_module__ - except AttributeError: - return None - - -def setup(app): - app.connect("viewcode-follow-imported", _viewcode_follow_imported) diff --git a/docs/connection-pools.md b/docs/connection-pools.md new file mode 100644 index 00000000..3171e921 --- /dev/null +++ b/docs/connection-pools.md @@ -0,0 +1,130 @@ +# Connection Pools + +While the top-level API provides convenience functions for working with `httpcore`, +in practice you'll almost always want to take advantage of the connection pooling +functionality that it provides. + +To do so, instantiate a pool instance, and use it to send requests: + +```python +import httpcore + +http = httpcore.ConnectionPool() +r = http.request("GET", "https://www.example.com/") + +print(r) +# +``` + +Connection pools support the same `.request()` and `.stream()` APIs [as described in the Quickstart](../quickstart). + +We can observe the benefits of connection pooling with a simple script like so: + +```python +import httpcore +import time + + +http = httpcore.ConnectionPool() +for counter in range(5): + started = time.time() + response = http.request("GET", "https://www.example.com/") + complete = time.time() + print(response, "in %.3f seconds" % (complete - started)) +``` + +The output *should* demonstrate the initial request as being substantially slower than the subsequent requests: + +``` + in {0.529} seconds + in {0.096} seconds + in {0.097} seconds + in {0.095} seconds + in {0.098} seconds +``` + +This is to be expected. Once we've established a connection to `"www.example.com"` we're able to reuse it for following requests. + +## Configuration + +The connection pool instance is also the main point of configuration. Let's take a look at the various options that it provides: + +### SSL configuration + +* `ssl_context`: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + +### Pooling configuration + +* `max_connections`: The maximum number of concurrent HTTP connections that the pool + should allow. Any attempt to send a request on a pool that would + exceed this amount will block until a connection is available. +* `max_keepalive_connections`: The maximum number of idle HTTP connections that will + be maintained in the pool. +* `keepalive_expiry`: The duration in seconds that an idle HTTP connection may be + maintained for before being expired from the pool. + +### HTTP version support + +* `http1`: A boolean indicating if HTTP/1.1 requests should be supported by the connection + pool. Defaults to `True`. +* `http2`: A boolean indicating if HTTP/2 requests should be supported by the connection + pool. Defaults to `False`. + +### Other options + +* `retries`: The maximum number of retries when trying to establish a connection. +* `local_address`: Local address to connect from. Can also be used to connect using + a particular address family. Using `local_address="0.0.0.0"` will + connect using an `AF_INET` address (IPv4), while using `local_address="::"` + will connect using an `AF_INET6` address (IPv6). +* `uds`: Path to a Unix Domain Socket to use instead of TCP sockets. +* `network_backend`: A backend instance to use for handling network I/O. + +## Pool lifespans + +Because connection pools hold onto network resources, careful developers may want to ensure that instances are properly closed once they are no longer required. + +Working with a single global instance isn't a bad idea for many use case, since the connection pool will automatically be closed when the `__del__` method is called on it: + +```python +# This is perfectly fine for most purposes. +# The connection pool will automatically be closed when it is garbage collected, +# or when the Python interpreter exits. +http = httpcore.ConnectionPool() +``` + +However, to be more explicit around the resource usage, we can use the connection pool within a context manager: + +```python +with httpcore.ConnectionPool() as http: + ... +``` + +Or else close the pool explicitly: + +```python +http = httpcore.ConnectionPool() +try: + ... +finally: + http.close() +``` + +## Thread and task safety + +Connection pools are designed to be thread-safe. Similarly, when using `httpcore` in an async context connection pools are task-safe. + +This means that you can have a single connection pool instance shared by multiple threads. + +--- + +# Reference + +## `httpcore.ConnectionPool` + +::: httpcore.ConnectionPool + handler: python + rendering: + show_source: False diff --git a/docs/connections.md b/docs/connections.md new file mode 100644 index 00000000..0ca21556 --- /dev/null +++ b/docs/connections.md @@ -0,0 +1,28 @@ +# Connections + +TODO + +--- + +# Reference + +## `httpcore.HTTPConnection` + +::: httpcore.HTTPConnection + handler: python + rendering: + show_source: False + +## `httpcore.HTTP11Connection` + +::: httpcore.HTTP11Connection + handler: python + rendering: + show_source: False + +## `httpcore.HTTP2Connection` + +::: httpcore.HTTP2Connection + handler: python + rendering: + show_source: False diff --git a/docs/contributing.md b/docs/contributing.md deleted file mode 100644 index eca86dd4..00000000 --- a/docs/contributing.md +++ /dev/null @@ -1,208 +0,0 @@ -# Contributing - -Thanks for considering contributing to HTTP Core! - -We welcome contributors to: - -- Try [HTTPX](https://www.python-httpx.org), as it is HTTP Core's main entry point, -and [report bugs/issues you find](https://github.com/encode/httpx/issues/new) -- Help triage [issues](https://github.com/encode/httpcore/issues) and investigate -root causes of bugs -- [Review Pull Requests of others](https://github.com/encode/httpcore/pulls) -- Review, clarify and write documentation -- Participate in discussions - -## Reporting Bugs or Other Issues - -HTTP Core is a fairly specialized library and its main purpose is to provide a -solid base for [HTTPX](https://www.python-httpx.org). HTTPX should be considered -the main entry point to HTTP Core and as such we encourage users to test and raise -issues in [HTTPX's issue tracker](https://github.com/encode/httpx/issues/new) -where maintainers and contributors can triage and move to HTTP Core if appropriate. - -If you are convinced that the cause of the issue is on HTTP Core you're more than -welcome to [open an issue](https://github.com/encode/httpcore/issues/new). - -Please attach as much detail as possible and, in case of a -bug report, provide information like: - -- OS platform or Docker image -- Python version -- Installed dependencies and versions (`python -m pip freeze`) -- Code snippet to reproduce the issue -- Error traceback and output - -It is quite helpful to increase the logging level of HTTP Core and include the -output of your program. To do so set the `HTTPCORE_LOG_LEVEL` or `HTTPX_LOG_LEVEL` -environment variables to `TRACE`, for example: - -```console -$ HTTPCORE_LOG_LEVEL=TRACE python test_script.py -TRACE [2020-06-06 09:55:10] httpcore._async.connection_pool - get_connection_from_pool=(b'https', b'localhost', 5000) -TRACE [2020-06-06 09:55:10] httpcore._async.connection_pool - created connection= -... -``` - -The output will be quite long but it will help dramatically in diagnosing the problem. - -For more examples please refer to the -[environment variables documentation in HTTPX](https://www.python-httpx.org/environment_variables/#httpx_log_level). - -## Development - -To start developing HTTP Core create a **fork** of the -[repository](https://github.com/encode/httpcore) on GitHub. - -Then clone your fork with the following command replacing `YOUR-USERNAME` with -your GitHub username: - -```shell -$ git clone https://github.com/YOUR-USERNAME/httpcore -``` - -You can now install the project and its dependencies using: - -```shell -$ cd httpcore -$ scripts/install -``` - -## Unasync - -HTTP Core provides synchronous and asynchronous interfaces. As you can imagine, -keeping two almost identical versions of code in sync can be quite time consuming. -To work around this problem HTTP Core uses a technique called _unasync_, where -the development is focused on the asynchronous version of the code and a script -generates the synchronous version from it. - -As such developers should: - -- Only make modifications in the asynchronous and shared portions of the code. -In practice this roughly means avoiding the `httpcore/_sync` directory. -- Write tests _only under `async_tests`_, synchronous tests are also generated -as part of the unasync process. -- Run `scripts/unasync` to generate the synchronous versions. Note the script -is ran as part of other scripts as well, so you don't usually need to run this -yourself. -- Run the entire test suite as described below. - -## Testing and Linting - -We use custom shell scripts to automate testing, linting, -and documentation building workflow. - -To run the tests, use: - -```shell -$ scripts/test -``` - -:::{warning} -The test suite spawns testing servers on ports **8000** and **8001**. -Make sure these are not in use, so the tests can run properly. -::: - -You can run a single test script like this: - -```shell -$ scripts/test -- tests/async_tests/test_interfaces.py -``` - -To run the code auto-formatting: - -```shell -$ scripts/lint -``` - -Lastly, to run code checks separately (they are also run as part of `scripts/test`), run: - -```shell -$ scripts/check -``` - -## Documenting - -Documentation pages are located under the `docs/` folder. - -To run the documentation site locally (useful for previewing changes), use: - -```shell -$ scripts/docs -``` - -## Resolving Build / CI Failures - -Once you've submitted your pull request, the test suite will automatically run, and the results will show up in GitHub. -If the test suite fails, you'll want to click through to the "Details" link, and try to identify why the test suite failed. - -

- Failing PR commit status -

- -Here are some common ways the test suite can fail: - -### Check Job Failed - -

- Failing GitHub action lint job -

- -This job failing means there is either a code formatting issue or type-annotation issue. -You can look at the job output to figure out why it's failed or within a shell run: - -```shell -$ scripts/check -``` - -It may be worth it to run `$ scripts/lint` to attempt auto-formatting the code -and if that job succeeds commit the changes. - -### Docs Job Failed - -This job failing means the documentation failed to build. This can happen for -a variety of reasons like invalid markdown or missing configuration within `mkdocs.yml`. - -### Python 3.X Job Failed - -

- Failing GitHub action test job -

- -This job failing means the unit tests failed or not all code paths are covered by unit tests. - -If tests are failing you will see this message under the coverage report: - -`=== 1 failed, 435 passed, 1 skipped, 1 xfailed in 11.09s ===` - -If tests succeed but coverage is lower than our current threshold, you will see this message under the coverage report: - -`FAIL Required test coverage of 100% not reached. Total coverage: 99.00%` - -## Releasing - -*This section is targeted at HTTPX maintainers.* - -Before releasing a new version, create a pull request that includes: - -- **An update to the changelog**: - - We follow the format from [keepachangelog](https://keepachangelog.com/en/1.0.0/). - - [Compare](https://github.com/encode/httpcore/compare/) `master` with the tag of the latest release, and list all entries that are of interest to our users: - - Things that **must** go in the changelog: added, changed, deprecated or removed features, and bug fixes. - - Things that **should not** go in the changelog: changes to documentation, tests or tooling. - - Try sorting entries in descending order of impact / importance. - - Keep it concise and to-the-point. 🎯 -- **A version bump**: see `__version__.py`. - -For an example, see [#99](https://github.com/encode/httpcore/pull/99). - -Once the release PR is merged, create a -[new release](https://github.com/encode/httpcore/releases/new) including: - -- Tag version like `0.9.3`. -- Release title `Version 0.9.3` -- Description copied from the changelog. - -Once created this release will be automatically uploaded to PyPI. - -If something goes wrong with the PyPI job the release can be published using the -`scripts/publish` script. diff --git a/docs/exceptions.md b/docs/exceptions.md new file mode 100644 index 00000000..63ef3f28 --- /dev/null +++ b/docs/exceptions.md @@ -0,0 +1,18 @@ +# Exceptions + +The following exceptions may be raised when sending a request: + +* `httpcore.TimeoutException` + * `httpcore.PoolTimeout` + * `httpcore.ConnectTimeout` + * `httpcore.ReadTimeout` + * `httpcore.WriteTimeout` +* `httpcore.NetworkError` + * `httpcore.ConnectError` + * `httpcore.ReadError` + * `httpcore.WriteError` +* `httpcore.ProtocolError` + * `httpcore.RemoteProtocolError` + * `httpcore.LocalProtocolError` +* `httpcore.ProxyError` +* `httpcore.UnsupportedProtocol` diff --git a/docs/extensions.md b/docs/extensions.md new file mode 100644 index 00000000..51565fae --- /dev/null +++ b/docs/extensions.md @@ -0,0 +1,229 @@ +# Extensions + +The request/response API used by `httpcore` is kept deliberately simple and explicit. + +The `Request` and `Response` models are pretty slim wrappers around this core API: + +``` +# Pseudo-code expressing the essentials of the request/response model. +( + status_code: int, + headers: List[Tuple(bytes, bytes)], + stream: Iterable[bytes] +) = handle_request( + method: bytes, + url: URL, + headers: List[Tuple(bytes, bytes)], + stream: Iterable[bytes] +) +``` + +This is everything that's needed in order to represent an HTTP exchange. + +Well... almost. + +There is a maxim in Computer Science that *"All non-trivial abstractions, to some degree, are leaky"*. When an expression is leaky, it's important that it ought to at least leak only in well-defined places. + +In order to handle cases that don't otherwise fit inside this core abstraction, `httpcore` requests and responses have 'extensions'. These are a dictionary of optional additional information. + +Let's expand on our request/response abstraction... + +``` +# Pseudo-code expressing the essentials of the request/response model, +# plus extensions allowing for additional API that does not fit into +# this abstraction. +( + status_code: int, + headers: List[Tuple(bytes, bytes)], + stream: Iterable[bytes], + extensions: dict +) = handle_request( + method: bytes, + url: URL, + headers: List[Tuple(bytes, bytes)], + stream: Iterable[bytes], + extensions: dict +) +``` + +Several extensions are supported both on the request: + +```python +r = httpcore.request( + "GET", + "https://www.example.com", + extensions={"timeout": {"connect": 5.0}} +) +``` + +And on the response: + +```python +r = httpcore.request("GET", "https://www.example.com") + +print(r.extensions["http_version"]) +# When using HTTP/1.1 on the client side, the server HTTP response +# could feasibly be one of b"HTTP/0.9", b"HTTP/1.0", or b"HTTP/1.1". +``` + +## Request Extensions + +### `"timeout"` + +A dictionary of `str: Optional[float]` timeout values. + +May include values for `'connect'`, `'read'`, `'write'`, or `'pool'`. + +For example: + +```python +# Timeout if a connection takes more than 5 seconds to established, or if +# we are blocked waiting on the connection pool for more than 10 seconds. +r = httpcore.request( + "GET", + "https://www.example.com", + extensions={"timeout": {"connect": 5.0, "pool": 10.0}} +) +``` + +### `"trace"` + +The trace extension allows a callback handler to be installed to monitor the internal +flow of events within `httpcore`. The simplest way to explain this is with an example: + +```python +import httpcore + +def log(event_name, info): + print(event_name, info) + +r = httpcore.request("GET", "https://www.example.com/", extensions={"trace": log}) +# connection.connect_tcp.started {'host': 'www.example.com', 'port': 443, 'local_address': None, 'timeout': None} +# connection.connect_tcp.complete {'return_value': } +# connection.start_tls.started {'ssl_context': , 'server_hostname': b'www.example.com', 'timeout': None} +# connection.start_tls.complete {'return_value': } +# http11.send_request_headers.started {'request': } +# http11.send_request_headers.complete {'return_value': None} +# http11.send_request_body.started {'request': } +# http11.send_request_body.complete {'return_value': None} +# http11.receive_response_headers.started {'request': } +# http11.receive_response_headers.complete {'return_value': (b'HTTP/1.1', 200, b'OK', [(b'Age', b'553715'), (b'Cache-Control', b'max-age=604800'), (b'Content-Type', b'text/html; charset=UTF-8'), (b'Date', b'Thu, 21 Oct 2021 17:08:42 GMT'), (b'Etag', b'"3147526947+ident"'), (b'Expires', b'Thu, 28 Oct 2021 17:08:42 GMT'), (b'Last-Modified', b'Thu, 17 Oct 2019 07:18:26 GMT'), (b'Server', b'ECS (nyb/1DCD)'), (b'Vary', b'Accept-Encoding'), (b'X-Cache', b'HIT'), (b'Content-Length', b'1256')])} +# http11.receive_response_body.started {'request': } +# http11.receive_response_body.complete {'return_value': None} +# http11.response_closed.started {} +# http11.response_closed.complete {'return_value': None} +``` + +The `event_name` and `info` arguments here will be one of the following: + +* `{event_type}.{event_name}.started`, `` +* `{event_type}.{event_name}.complete`, `{"return_value": <...>}` +* `{event_type}.{event_name}.failed`, `{"exception": <...>}` + +Note that when using the async variant of `httpcore` the handler function passed to `"trace"` must be an `async def ...` function. + +The following event types are currently exposed... + +**Establishing the connection** + +* `"connection.connect_tcp"` +* `"connection.connect_unix_socket"` +* `"connection.start_tls"` + +**HTTP/1.1 events** + +* `"http11.send_request_headers"` +* `"http11.send_request_body"` +* `"http11.receive_response"` +* `"http11.receive_response_body"` +* `"http11.response_closed"` + +**HTTP/2 events** + +* `"http2.send_connection_init"` +* `"http2.send_request_headers"` +* `"http2.send_request_body"` +* `"http2.receive_response_headers"` +* `"http2.receive_response_body"` +* `"http2.response_closed"` + +## Response Extensions + +### `"http_version"` + +The HTTP version, as bytes. Eg. `b"HTTP/1.1"`. + +When using HTTP/1.1 the response line includes an explicit version, and the value of this key could feasibly be one of `b"HTTP/0.9"`, `b"HTTP/1.0"`, or `b"HTTP/1.1"`. + +When using HTTP/2 there is no further response versioning included in the protocol, and the value of this key will always be `b"HTTP/2"`. + +### `"reason_phrase"` + +The reason-phrase of the HTTP response, as bytes. For example `b"OK"`. Some servers may include a custom reason phrase, although this is not recommended. + +HTTP/2 onwards does not include a reason phrase on the wire. + +When no key is included, a default based on the status code may be used. + +### `"network_stream"` + +The `"network_stream"` extension allows developers to handle HTTP `CONNECT` and `Upgrade` requests, by providing an API that steps outside the standard request/response model, and can directly read or write to the network. + +The interface provided by the network stream: + +* `read(max_bytes, timeout = None) -> bytes` +* `write(buffer, timeout = None)` +* `close()` +* `start_tls(ssl_context, server_hostname = None, timeout = None) -> NetworkStream` +* `get_extra_info(info) -> Any` + +This API can be used as the foundation for working with HTTP proxies, WebSocket upgrades, and other advanced use-cases. + +An example to demonstrate: + +```python +# Formulate a CONNECT request... +# +# This will establish a connection to 127.0.0.1:8080, and then send the following... +# +# CONNECT http://www.example.com HTTP/1.1 +# Host: 127.0.0.1:8080 +url = httpcore.URL(b"http", b"127.0.0.1", 8080, b"http://www.example.com") +with httpcore.stream("CONNECT", url) as response: + network_stream = response.extensions["network_stream"] + + # Upgrade to an SSL stream... + network_stream = network_stream.start_tls( + ssl_context=httpcore.default_ssl_context(), + hostname=b"www.example.com", + ) + + # Manually send an HTTP request over the network stream, and read the response... + # + # For a more complete example see the httpcore `TunnelHTTPConnection` implementation. + network_stream.write(b"GET / HTTP/1.1\r\nHost: example.com\r\n\r\n") + data = network_stream.read() + print(data) +``` + +The network stream abstraction also allows access to various low-level information that may be exposed by the underlying socket: + +```python +response = httpcore.request("GET", "https://www.example.com") +network_stream = response.extensions["network_stream"] + +client_addr = network_stream.get_extra_info("client_addr") +server_addr = network_stream.get_extra_info("server_addr") +print("Client address", client_addr) +print("Server address", server_addr) +``` + +The socket SSL information is also available through this interface, although you need to ensure that the underlying connection is still open, in order to access it... + +```python +with httpcore.stream("GET", "https://www.example.com") as response: + network_stream = response.extensions["network_stream"] + + ssl_object = network_stream.get_extra_info("ssl_object") + print("TLS version", ssl_object.version()) +``` diff --git a/docs/http2.md b/docs/http2.md new file mode 100644 index 00000000..03a3d7a7 --- /dev/null +++ b/docs/http2.md @@ -0,0 +1,166 @@ +# HTTP/2 + +HTTP/2 is a major new iteration of the HTTP protocol, that provides a more efficient transport, with potential performance benefits. HTTP/2 does not change the core semantics of the request or response, but alters the way that data is sent to and from the server. + +Rather than the text format that HTTP/1.1 uses, HTTP/2 is a binary format. The binary format provides full request and response multiplexing, and efficient compression of HTTP headers. The stream multiplexing means that where HTTP/1.1 requires one TCP stream for each concurrent request, HTTP/2 allows a single TCP stream to handle multiple concurrent requests. + +HTTP/2 also provides support for functionality such as response prioritization, and server push. + +For a comprehensive guide to HTTP/2 you may want to check out "[HTTP2 Explained](https://http2-explained.haxx.se)". + +## Enabling HTTP/2 + +When using the `httpcore` client, HTTP/2 support is not enabled by default, because HTTP/1.1 is a mature, battle-hardened transport layer, and our HTTP/1.1 implementation may be considered the more robust option at this point in time. It is possible that a future version of `httpcore` may enable HTTP/2 support by default. + +If you're issuing highly concurrent requests you might want to consider trying out our HTTP/2 support. You can do so by first making sure to install the optional HTTP/2 dependencies... + +```shell +$ pip install httpcore[http2] +``` + +And then instantiating a connection pool with HTTP/2 support enabled: + +```python +import httpcore + +pool = httpcore.ConnectionPool(http2=True) +``` + +We can take a look at the difference in behaviour by issuing several outgoing requests in parallel. + +Start out by using a standard HTTP/1.1 connection pool: + +```python +import httpcore +import concurrent.futures +import time + + +def download(http, year): + http.request("GET", f"https://en.wikipedia.org/wiki/{year}") + + +def main(): + with httpcore.ConnectionPool() as http: + started = time.time() + with concurrent.futures.ThreadPoolExecutor(max_workers=10) as threads: + for year in range(2000, 2020): + threads.submit(download, http, year) + complete = time.time() + + for connection in http.connections: + print(connection) + print("Complete in %.3f seconds" % (complete - started)) + + +main() +``` + +If you run this with an HTTP/1.1 connection pool, you ought to see output similar to the following: + +```python +, +, +, +, +, +, +, + +Complete in 0.586 seconds +``` + +We can see that the connection pool required a number of connections in order to handle the parallel requests. + +If we now upgrade our connection pool to support HTTP/2: + +```python +with httpcore.ConnectionPool(http2=True) as http: + ... +``` + +And run the same script again, we should end up with something like this: + +```python + +Complete in 0.573 seconds +``` + +All of our requests have been handled over a single connection. + +Switching to HTTP/2 should not *necessarily* be considered an "upgrade". It is more complex, and requires more computational power, and so particularly in an interpreted language like Python it *could* be slower in some instances. Moreover, utilising multiple connections may end up connecting to multiple hosts, and could sometimes appear faster to the client, at the cost of requiring more server resources. Enabling HTTP/2 is most likely to be beneficial if you are sending requests in high concurrency, and may often be more well suited to an async context, rather than multi-threading. + +## Inspecting the HTTP version + +Enabling HTTP/2 support on the client does not *necessarily* mean that your requests and responses will be transported over HTTP/2, since both the client *and* the server need to support HTTP/2. If you connect to a server that only supports HTTP/1.1 the client will use a standard HTTP/1.1 connection instead. + +You can determine which version of the HTTP protocol was used by examining the `"http_version"` response extension. + +```python +import httpcore + +pool = httpcore.ConnectionPool(http2=True) +response = pool.request("GET", "https://www.example.com/") + +# Should be one of b"HTTP/2", b"HTTP/1.1", b"HTTP/1.0", or b"HTTP/0.9". +print(response.extensions["http_version"]) +``` + +See [the extensions documentation](extensions.md) for more details. + +## HTTP/2 negotiation + +Robust servers need to support both HTTP/2 and HTTP/1.1 capable clients, and so need some way to "negotiate" with the client which protocol version will be used. + +### HTTP/2 over HTTPS + +Generally the method used is for the server to advertise if it has HTTP/2 support during the part of the SSL connection handshake. This is known as ALPN - "Application Layer Protocol Negotiation". + +Most browsers only provide HTTP/2 support over HTTPS connections, and this is also the default behaviour that `httpcore` provides. If you enable HTTP/2 support you should still expect to see HTTP/1.1 connections for any `http://` URLs. + +### HTTP/2 over HTTP + +Servers can optionally also support HTTP/2 over HTTP by supporting the `Upgrade: h2c` header. + +This mechanism is not supported by `httpcore`. It requires an additional round-trip between the client and server, and also requires any request body to be sent twice. + +### Prior Knowledge + +If you know in advance that the server you are communicating with will support HTTP/2, then you can enforce that the client uses HTTP/2, without requiring either ALPN support or an HTTP `Upgrade: h2c` header. + +This is managed by disabling HTTP/1.1 support on the connection pool: + +```python +pool = httpcore.ConnectionPool(http1=False, http2=True) +``` + +## Request & response headers + +Because HTTP/2 frames the requests and responses somewhat differently to HTTP/1.1, there is a difference in some of the headers that are used. + +In order for the `httpcore` library to support both HTTP/1.1 and HTTP/2 transparently, the HTTP/1.1 style is always used throughout the API. Any differences in header styles are only mapped onto HTTP/2 at the internal network layer. + +## Request headers + +The following pseudo-headers are used by HTTP/2 in the request: + +* `:method` - The request method. +* `:path` - Taken from the URL of the request. +* `:authority` - Equivalent to the `Host` header in HTTP/1.1. In `httpcore` this is represented using the request `Host` header, which is automatically populated from the request URL if no `Host` header is explicitly included. +* `:scheme` - Taken from the URL of the request. + +These pseudo-headers are included in `httpcore` as part of the `request.method` and `request.url` attributes, and through the `request.headers["Host"]` header. *They are not exposed directly by their psuedo-header names.* + +The one other difference to be aware of is the `Transfer-Encoding: chunked` header. + +In HTTP/2 this header is never used, since streaming data is framed using a different mechanism. + +In `httpcore` the `Transfer-Encoding: chunked` header is always used to represent the presence of a streaming body on the request, and is automatically populated if required. However the header is only sent if the underlying connection ends up being HTTP/1.1, and is omitted if the underlying connection ends up being HTTP/2. + +## Response headers + +The following pseudo-header is used by HTTP/2 in the response: + +* `:status` - The response status code. + +In `httpcore` this *is represented by the `response.status` attribute, rather than being exposed as a psuedo-header*. diff --git a/docs/index.md b/docs/index.md index 29f3c54c..414e1ed6 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,21 +1,61 @@ -:::{include} ../README.md -::: +# HTTPCore - +[![Test Suite](https://github.com/encode/httpcore/workflows/Test%20Suite/badge.svg)](https://github.com/encode/httpcore/actions) +[![Package version](https://badge.fury.io/py/httpcore.svg)](https://pypi.org/project/httpcore/) -:::{toctree} -:hidden: -:caption: Usage +> *Do one thing, and do it well.* -api -::: +The HTTP Core package provides a minimal low-level HTTP client, which does +one thing only. Sending HTTP requests. -:::{toctree} -:hidden: -:caption: Development +It does not provide any high level model abstractions over the API, +does not handle redirects, multipart uploads, building authentication headers, +transparent HTTP caching, URL parsing, session cookie handling, +content or charset decoding, handling JSON, environment based configuration +defaults, or any of that Jazz. -contributing -Changelog -License -Source Code -::: +Some things HTTP Core does do: + +* Sending HTTP requests. +* Thread-safe / task-safe connection pooling. +* HTTP(S) proxy support. +* Supports HTTP/1.1 and HTTP/2. +* Provides both sync and async interfaces. +* Async backend support for `asyncio` and `trio`. + +## Installation + +For HTTP/1.1 only support, install with... + +```shell +$ pip install git+https://github.com/encode/httpcore +``` + +For HTTP/1.1 and HTTP/2 support, install with... + +```shell +$ pip install git+https://github.com/encode/httpcore[http2] +``` + +## Example + +Let's check we're able to send HTTP requests: + +```python +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'\n\n\nExample Domain\n\n\n ...' +``` + +Ready to get going? + +Head over to [the quickstart documentation](quickstart.md). diff --git a/docs/network-backends.md b/docs/network-backends.md new file mode 100644 index 00000000..b3783e78 --- /dev/null +++ b/docs/network-backends.md @@ -0,0 +1,3 @@ +# Network Backends + +TODO diff --git a/docs/proxies.md b/docs/proxies.md new file mode 100644 index 00000000..bb59cf1e --- /dev/null +++ b/docs/proxies.md @@ -0,0 +1,54 @@ +# Proxies + +The `httpcore` package currently provides support for HTTP proxies, using either "HTTP Forwarding" and "HTTP Tunnelling". Forwarding is a proxy mechanism for sending requests to `http` URLs via an intermediate proxy. Tunnelling is a proxy mechanism for sending requests to `https` URLs via an intermediate proxy. + +Sending requests via a proxy is very similar to sending requests using a standard connection pool: + +```python +import httpcore + +proxy = httpcore.HTTPProxy(proxy_url="http://127.0.0.1:8080/") +r = proxy.request("GET", "https://www.example.com/") + +print(r) +# +``` + +You can test the `httpcore` proxy support, using the Python [`proxy.py`](https://pypi.org/project/proxy.py/) tool: + +```shell +$ pip install proxy.py +$ proxy --hostname 127.0.0.1 --port 8080 +``` + +Requests will automatically use either forwarding or tunnelling, depending on if the scheme is `http` or `https`. + +## Authentication + +Proxy headers can be included in the initial configuration: + +```python +import httpcore +import base64 + +auth = base64.b64encode(b"Basic :") +proxy = httpcore.HTTPProxy( + proxy_url="http://127.0.0.1:8080/", + proxy_headers={"Proxy-Authorization": auth} +) +``` + +## HTTP Versions + +Proxy support currently only allows for HTTP/1.1 connections to the proxy. + +--- + +# Reference + +## `httpcore.HTTPProxy` + +::: httpcore.HTTPProxy + handler: python + rendering: + show_source: False diff --git a/docs/quickstart.md b/docs/quickstart.md new file mode 100644 index 00000000..e43b0596 --- /dev/null +++ b/docs/quickstart.md @@ -0,0 +1,163 @@ +# Quickstart + +For convenience, the `httpcore` package provides a couple of top-level functions that you can use for sending HTTP requests. You probably don't want to integrate against functions if you're writing a library that uses `httpcore`, but you might find them useful for testing `httpcore` from the command-line, or if you're writing a simple script that doesn't require any of the connection pooling or advanced configuration that `httpcore` offers. + +## Sending a request + +We'll start off by sending a request... + +```python +import httpcore + +response = httpcore.request("GET", "https://www.example.com/") + +print(response) +# +print(response.status) +# 200 +print(response.headers) +# [(b'Accept-Ranges', b'bytes'), (b'Age', b'557328'), (b'Cache-Control', b'max-age=604800'), ...] +print(response.content) +# b'\n\n\nExample Domain\n\n\n ...' +``` + +## Request headers + +Request headers may be included either in a dictionary style, or as a list of two-tuples. + +```python +import httpcore +import json + +headers = {'User-Agent': 'httpcore'} +r = httpcore.request('GET', 'https://httpbin.org/headers', headers=headers) + +print(json.loads(r.content)) +# { +# 'headers': { +# 'Host': 'httpbin.org', +# 'User-Agent': 'httpcore', +# 'X-Amzn-Trace-Id': 'Root=1-616ff5de-5ea1b7e12766f1cf3b8e3a33' +# } +# } +``` + +The keys and values may either be provided as strings or as bytes. Where strings are provided they may only contain characters within the ASCII range `chr(0)` - `chr(127)`. To include characters outside this range you must deal with any character encoding explicitly, and pass bytes as the header keys/values. + +The `Host` header will always be automatically included in any outgoing request, as it is strictly required to be present by the HTTP protocol. + +*Note that the `X-Amzn-Trace-Id` header shown in the example above is not an outgoing request header, but has been added by a gateway server.* + +## Request body + +A request body can be included either as bytes... + +```python +import httpcore +import json + +r = httpcore.request('POST', 'https://httpbin.org/post', content=b'Hello, world') + +print(json.loads(r.content)) +# { +# 'args': {}, +# 'data': 'Hello, world', +# 'files': {}, +# 'form': {}, +# 'headers': { +# 'Host': 'httpbin.org', +# 'Content-Length': '12', +# 'X-Amzn-Trace-Id': 'Root=1-61700258-00e338a124ca55854bf8435f' +# }, +# 'json': None, +# 'origin': '68.41.35.196', +# 'url': 'https://httpbin.org/post' +# } +``` + +Or as an iterable that returns bytes... + +```python +import httpcore +import json + +with open("hello-world.txt", "rb") as input_file: + r = httpcore.request('POST', 'https://httpbin.org/post', content=input_file) + +print(json.loads(r.content)) +# { +# 'args': {}, +# 'data': 'Hello, world', +# 'files': {}, +# 'form': {}, +# 'headers': { +# 'Host': 'httpbin.org', +# 'Transfer-Encoding': 'chunked', +# 'X-Amzn-Trace-Id': 'Root=1-61700258-00e338a124ca55854bf8435f' +# }, +# 'json': None, +# 'origin': '68.41.35.196', +# 'url': 'https://httpbin.org/post' +# } +``` + +When a request body is included, either a `Content-Length` header or a `Transfer-Encoding: chunked` header will be automatically included. + +The `Content-Length` header is used when passing bytes, and indicates an HTTP request with a body of a pre-determined length. + +The `Transfer-Encoding: chunked` header is the mechanism that HTTP/1.1 uses for sending HTTP request bodies without a pre-determined length. + +## Streaming responses + +When using the `httpcore.request()` function, the response body will automatically be read to completion, and made available in the `response.content` attribute. + +Sometimes you may be dealing with large responses and not want to read the entire response into memory. The `httpcore.stream()` function provides a mechanism for sending a request and dealing with a streaming response: + +```python +import httpcore + +with httpcore.stream('GET', 'https://example.com') as response: + for chunk in response.iter_stream(): + print(f"Downloaded: {chunk}") +``` + +Here's a more complete example that demonstrates downloading a response: + +```python +import httpcore + +with httpcore.stream('GET', 'https://speed.hetzner.de/100MB.bin') as response: + with open("download.bin", "wb") as output_file: + for chunk in response.iter_stream(): + output_file.write(chunk) +``` + +The `httpcore.stream()` API also allows you to *conditionally* read the response... + +```python +import httpcore + +with httpcore.stream('GET', 'https://example.com') as response: + content_length = [int(v) for k, v in response.headers if k.lower() == b'content-length'][0] + if content_length > 100_000_000: + raise Exception("Response too large.") + response.read() # `response.content` is now available. +``` + +--- + +# Reference + +## `httpcore.request()` + +::: httpcore.request + handler: python + rendering: + show_source: False + +## `httpcore.stream()` + +::: httpcore.stream + handler: python + rendering: + show_source: False diff --git a/docs/requests-responses-urls.md b/docs/requests-responses-urls.md new file mode 100644 index 00000000..7919a54c --- /dev/null +++ b/docs/requests-responses-urls.md @@ -0,0 +1,62 @@ +# Requests, Responses, and URLs + +TODO + +## Requests + +Request instances in `httpcore` are deliberately simple, and only include the essential information required to represent an HTTP request. + +Properties on the request are plain byte-wise representations. + +```python +>>> request = httpcore.Request("GET", "https://www.example.com/") +>>> request.method +b"GET" +>>> request.url +httpcore.URL(scheme=b"https", host=b"www.example.com", port=None, target=b"/") +>>> request.headers +[(b'Host', b'www.example.com')] +>>> request.stream + +``` + +The interface is liberal in the types that it accepts, but specific in the properties that it uses to represent them. For example, headers may be specified as a dictionary of strings, but internally are represented as a list of `(byte, byte)` tuples. + +```python +>>> headers = {"User-Agent": "custom"} +>>> request = httpcore.Request("GET", "https://www.example.com/", headers=headers) +>>> request.headers +[(b'Host', b'www.example.com'), (b"User-Agent", b"custom")] + +## Responses + +... + +## URLs + +... + +--- + +# Reference + +## `httpcore.Request` + +::: httpcore.Request + handler: python + rendering: + show_source: False + +## `httpcore.Response` + +::: httpcore.Response + handler: python + rendering: + show_source: False + +## `httpcore.URL` + +::: httpcore.URL + handler: python + rendering: + show_source: False diff --git a/docs/table-of-contents.md b/docs/table-of-contents.md new file mode 100644 index 00000000..3cf1f725 --- /dev/null +++ b/docs/table-of-contents.md @@ -0,0 +1,50 @@ +# API Reference + +* Quickstart + * `httpcore.request()` + * `httpcore.stream()` +* Requests, Responses, and URLs + * `httpcore.Request` + * `httpcore.Response` + * `httpcore.URL` +* Connection Pools + * `httpcore.ConnectionPool` +* Proxies + * `httpcore.HTTPProxy` +* Connections + * `httpcore.HTTPConnection` + * `httpcore.HTTP11Connection` + * `httpcore.HTTP2Connection` +* Async Support + * `httpcore.AsyncConnectionPool` + * `httpcore.AsyncHTTPProxy` + * `httpcore.AsyncHTTPConnection` + * `httpcore.AsyncHTTP11Connection` + * `httpcore.AsyncHTTP2Connection` +* Network Backends + * Sync + * `httpcore.backends.sync.SyncBackend` + * `httpcore.backends.mock.MockBackend` + * Async + * `httpcore.backends.auto.AutoBackend` + * `httpcore.backends.asyncio.AsyncioBackend` + * `httpcore.backends.trio.TrioBackend` + * `httpcore.backends.mock.AsyncMockBackend` + * Base interfaces + * `httpcore.backends.base.NetworkBackend` + * `httpcore.backends.base.AsyncNetworkBackend` +* Exceptions + * `httpcore.TimeoutException` + * `httpcore.PoolTimeout` + * `httpcore.ConnectTimeout` + * `httpcore.ReadTimeout` + * `httpcore.WriteTimeout` + * `httpcore.NetworkError` + * `httpcore.ConnectError` + * `httpcore.ReadError` + * `httpcore.WriteError` + * `httpcore.ProtocolError` + * `httpcore.RemoteProtocolError` + * `httpcore.LocalProtocolError` + * `httpcore.ProxyError` + * `httpcore.UnsupportedProtocol` diff --git a/httpcore/__init__.py b/httpcore/__init__.py index 3ddc6d61..cd48805e 100644 --- a/httpcore/__init__.py +++ b/httpcore/__init__.py @@ -1,10 +1,15 @@ -from ._async.base import AsyncByteStream, AsyncHTTPTransport -from ._async.connection_pool import AsyncConnectionPool -from ._async.http_proxy import AsyncHTTPProxy -from ._bytestreams import AsyncIteratorByteStream, ByteStream, IteratorByteStream +from ._api import request, stream +from ._async import ( + AsyncConnectionInterface, + AsyncConnectionPool, + AsyncHTTP2Connection, + AsyncHTTP11Connection, + AsyncHTTPConnection, + AsyncHTTPProxy, +) from ._exceptions import ( - CloseError, ConnectError, + ConnectionNotAvailable, ConnectTimeout, LocalProtocolError, NetworkError, @@ -19,45 +24,64 @@ WriteError, WriteTimeout, ) -from ._sync.base import SyncByteStream, SyncHTTPTransport -from ._sync.connection_pool import SyncConnectionPool -from ._sync.http_proxy import SyncHTTPProxy +from ._models import URL, Origin, Request, Response +from ._ssl import default_ssl_context +from ._sync import ( + ConnectionInterface, + ConnectionPool, + HTTP2Connection, + HTTP11Connection, + HTTPConnection, + HTTPProxy, +) __all__ = [ - "AsyncByteStream", + # top-level requests + "request", + "stream", + # models + "Origin", + "URL", + "Request", + "Response", + # async + "AsyncHTTPConnection", "AsyncConnectionPool", "AsyncHTTPProxy", - "AsyncHTTPTransport", - "AsyncIteratorByteStream", - "ByteStream", - "CloseError", - "ConnectError", - "ConnectTimeout", - "IteratorByteStream", - "LocalProtocolError", - "NetworkError", - "PoolTimeout", - "ProtocolError", + "AsyncHTTP11Connection", + "AsyncHTTP2Connection", + "AsyncConnectionInterface", + # sync + "HTTPConnection", + "ConnectionPool", + "HTTPProxy", + "HTTP11Connection", + "HTTP2Connection", + "ConnectionInterface", + # util + "default_ssl_context", + # exceptions + "ConnectionNotAvailable", "ProxyError", - "ReadError", - "ReadTimeout", + "ProtocolError", + "LocalProtocolError", "RemoteProtocolError", - "SyncByteStream", - "SyncConnectionPool", - "SyncHTTPProxy", - "SyncHTTPTransport", - "TimeoutException", "UnsupportedProtocol", - "WriteError", + "TimeoutException", + "PoolTimeout", + "ConnectTimeout", + "ReadTimeout", "WriteTimeout", + "NetworkError", + "ConnectError", + "ReadError", + "WriteError", ] -__version__ = "0.13.7" -__locals = locals() +__version__ = "0.14.0" + -for _name in __all__: - if not _name.startswith("__"): - # Save original source module, used by Sphinx. - __locals[_name].__source_module__ = __locals[_name].__module__ - # Override module for prettier repr(). - setattr(__locals[_name], "__module__", "httpcore") # noqa +__locals = locals() +for __name in __all__: + if not __name.startswith("__"): + setattr(__locals[__name], "__module__", "httpcore") # noqa diff --git a/httpcore/_api.py b/httpcore/_api.py new file mode 100644 index 00000000..e2e5e5a3 --- /dev/null +++ b/httpcore/_api.py @@ -0,0 +1,92 @@ +from contextlib import contextmanager +from typing import Iterator, Union + +from ._models import URL, Response +from ._sync.connection_pool import ConnectionPool + + +def request( + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: Union[dict, list] = None, + content: Union[bytes, Iterator[bytes]] = None, + extensions: dict = None, +) -> Response: + """ + Sends an HTTP request, returning the response. + + ``` + response = httpcore.request("GET", "https://www.example.com/") + ``` + + Arguments: + method: The HTTP method for the request. Typically one of `"GET"`, + `"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`. + url: The URL of the HTTP request. Either as an instance of `httpcore.URL`, + or as str/bytes. + headers: The HTTP request headers. Either as a dictionary of str/bytes, + or as a list of two-tuples of str/bytes. + content: The content of the request body. Either as bytes, + or as a bytes iterator. + extensions: A dictionary of optional extra information included on the request. + Possible keys include `"timeout"`. + + Returns: + An instance of `httpcore.Response`. + """ + with ConnectionPool() as pool: + return pool.request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + + +@contextmanager +def stream( + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: Union[dict, list] = None, + content: Union[bytes, Iterator[bytes]] = None, + extensions: dict = None, +) -> Iterator[Response]: + """ + Sends an HTTP request, returning the response within a content manager. + + ``` + with httpcore.stream("GET", "https://www.example.com/") as response: + ... + ``` + + When using the `stream()` function, the body of the response will not be + automatically read. If you want to access the response body you should + either use `content = response.read()`, or `for chunk in response.iter_content()`. + + Arguments: + method: The HTTP method for the request. Typically one of `"GET"`, + `"OPTIONS"`, `"HEAD"`, `"POST"`, `"PUT"`, `"PATCH"`, or `"DELETE"`. + url: The URL of the HTTP request. Either as an instance of `httpcore.URL`, + or as str/bytes. + headers: The HTTP request headers. Either as a dictionary of str/bytes, + or as a list of two-tuples of str/bytes. + content: The content of the request body. Either as bytes, + or as a bytes iterator. + extensions: A dictionary of optional extra information included on the request. + Possible keys include `"timeout"`. + + Returns: + An instance of `httpcore.Response`. + """ + with ConnectionPool() as pool: + with pool.stream( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) as response: + yield response diff --git a/httpcore/_async/__init__.py b/httpcore/_async/__init__.py index e69de29b..f089fab8 100644 --- a/httpcore/_async/__init__.py +++ b/httpcore/_async/__init__.py @@ -0,0 +1,20 @@ +from .connection import AsyncHTTPConnection +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .http_proxy import AsyncHTTPProxy +from .interfaces import AsyncConnectionInterface + +try: + from .http2 import AsyncHTTP2Connection +except ImportError: # pragma: nocover + pass + + +__all__ = [ + "AsyncHTTPConnection", + "AsyncConnectionPool", + "AsyncHTTPProxy", + "AsyncHTTP11Connection", + "AsyncHTTP2Connection", + "AsyncConnectionInterface", +] diff --git a/httpcore/_async/base.py b/httpcore/_async/base.py deleted file mode 100644 index 2b3961c2..00000000 --- a/httpcore/_async/base.py +++ /dev/null @@ -1,122 +0,0 @@ -import enum -from types import TracebackType -from typing import AsyncIterator, Tuple, Type - -from .._types import URL, Headers, T - - -class NewConnectionRequired(Exception): - pass - - -class ConnectionState(enum.IntEnum): - """ - PENDING READY - | | ^ - v V | - ACTIVE | - | | | - | V | - V IDLE-+ - FULL | - | | - V V - CLOSED - """ - - PENDING = 0 # Connection not yet acquired. - READY = 1 # Re-acquired from pool, about to send a request. - ACTIVE = 2 # Active requests. - FULL = 3 # Active requests, no more stream IDs available. - IDLE = 4 # No active requests. - CLOSED = 5 # Connection closed. - - -class AsyncByteStream: - """ - The base interface for request and response bodies. - - Concrete implementations should subclass this class, and implement - the :meth:`__aiter__` method, and optionally the :meth:`aclose` method. - """ - - async def __aiter__(self) -> AsyncIterator[bytes]: - """ - Yield bytes representing the request or response body. - """ - yield b"" # pragma: nocover - - async def aclose(self) -> None: - """ - Must be called by the client to indicate that the stream has been closed. - """ - pass # pragma: nocover - - async def aread(self) -> bytes: - try: - return b"".join([part async for part in self]) - finally: - await self.aclose() - - -class AsyncHTTPTransport: - """ - The base interface for sending HTTP requests. - - Concrete implementations should subclass this class, and implement - the :meth:`handle_async_request` method, and optionally the :meth:`aclose` method. - """ - - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - """ - The interface for sending a single HTTP request, and returning a response. - - Parameters - ---------- - method: - The HTTP method, such as ``b'GET'``. - url: - The URL as a 4-tuple of (scheme, host, port, path). - headers: - Any HTTP headers to send with the request. - stream: - The body of the HTTP request. - extensions: - A dictionary of optional extensions. - - Returns - ------- - status_code: - The HTTP status code, such as ``200``. - headers: - Any HTTP headers included on the response. - stream: - The body of the HTTP response. - extensions: - A dictionary of optional extensions. - """ - raise NotImplementedError() # pragma: nocover - - async def aclose(self) -> None: - """ - Close the implementation, which should close any outstanding response streams, - and any keep alive connections. - """ - - async def __aenter__(self: T) -> T: - return self - - async def __aexit__( - self, - exc_type: Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.aclose() diff --git a/httpcore/_async/connection.py b/httpcore/_async/connection.py index 2add4d85..c86d4ef4 100644 --- a/httpcore/_async/connection.py +++ b/httpcore/_async/connection.py @@ -1,158 +1,96 @@ -from ssl import SSLContext -from typing import List, Optional, Tuple, cast - -from .._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream, AutoBackend -from .._exceptions import ConnectError, ConnectTimeout -from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import AsyncByteStream, AsyncHTTPTransport, NewConnectionRequired -from .http import AsyncBaseHTTPConnection +import itertools +import ssl +from types import TracebackType +from typing import Iterator, Optional, Type + +from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from .._trace import Trace +from ..backends.auto import AutoBackend +from ..backends.base import AsyncNetworkBackend, AsyncNetworkStream from .http11 import AsyncHTTP11Connection - -logger = get_logger(__name__) +from .interfaces import AsyncConnectionInterface RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. -class AsyncHTTPConnection(AsyncHTTPTransport): +def exponential_backoff(factor: float) -> Iterator[float]: + yield 0 + for n in itertools.count(2): + yield factor * (2 ** (n - 2)) + + +class AsyncHTTPConnection(AsyncConnectionInterface): def __init__( self, origin: Origin, + ssl_context: ssl.SSLContext = None, + keepalive_expiry: float = None, http1: bool = True, http2: bool = False, - keepalive_expiry: float = None, - uds: str = None, - ssl_context: SSLContext = None, - socket: AsyncSocketStream = None, - local_address: str = None, retries: int = 0, - backend: AsyncBackend = None, - ): - self.origin = origin - self._http1_enabled = http1 - self._http2_enabled = http2 + local_address: str = None, + uds: str = None, + network_backend: AsyncNetworkBackend = None, + ) -> None: + ssl_context = default_ssl_context() if ssl_context is None else ssl_context + alpn_protocols = ["http/1.1", "h2"] if http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + self._origin = origin + self._ssl_context = ssl_context self._keepalive_expiry = keepalive_expiry - self._uds = uds - self._ssl_context = SSLContext() if ssl_context is None else ssl_context - self.socket = socket - self._local_address = local_address + self._http1 = http1 + self._http2 = http2 self._retries = retries + self._local_address = local_address + self._uds = uds - alpn_protocols: List[str] = [] - if http1: - alpn_protocols.append("http/1.1") - if http2: - alpn_protocols.append("h2") - - self._ssl_context.set_alpn_protocols(alpn_protocols) - - self.connection: Optional[AsyncBaseHTTPConnection] = None - self._is_http11 = False - self._is_http2 = False - self._connect_failed = False - self._expires_at: Optional[float] = None - self._backend = AutoBackend() if backend is None else backend - - def __repr__(self) -> str: - return f"" - - def info(self) -> str: - if self.connection is None: - return "Connection failed" if self._connect_failed else "Connecting" - return self.connection.info() - - def should_close(self) -> bool: - """ - Return `True` if the connection is in a state where it should be closed. - This occurs when any of the following occur: - - * There are no active requests on an HTTP/1.1 connection, and the underlying - socket is readable. The only valid state the socket can be readable in - if this occurs is when the b"" EOF marker is about to be returned, - indicating a server disconnect. - * There are no active requests being made and the keepalive timeout has passed. - """ - if self.connection is None: - return False - return self.connection.should_close() - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - if self.connection is None: - return False - return self.connection.is_idle() + self._network_backend: AsyncNetworkBackend = ( + AutoBackend() if network_backend is None else network_backend + ) + self._connection: Optional[AsyncConnectionInterface] = None + self._request_lock = AsyncLock() - def is_closed(self) -> bool: - if self.connection is None: - return self._connect_failed - return self.connection.is_closed() + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - This occurs when any of the following occur: - - * The connection has not yet been opened, and HTTP/2 support is enabled. - We don't *know* at this point if we'll end up on an HTTP/2 connection or - not, but we *might* do, so we indicate availability. - * The connection has been opened, and is currently idle. - * The connection is open, and is an HTTP/2 connection. The connection must - also not currently be exceeding the maximum number of allowable concurrent - streams and must not have exhausted the maximum total number of stream IDs. - """ - if self.connection is None: - return self._http2_enabled and not self.is_closed - return self.connection.is_available() - - @property - def request_lock(self) -> AsyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_request_lock"): - self._request_lock = self._backend.create_lock() - return self._request_lock - - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - assert url_to_origin(url) == self.origin - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - async with self.request_lock: - if self.connection is None: - if self._connect_failed: - raise NewConnectionRequired() - if not self.socket: - logger.trace( - "open_socket origin=%r timeout=%r", self.origin, timeout + async with self._request_lock: + if self._connection is None: + stream = await self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import AsyncHTTP2Connection + + self._connection = AsyncHTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, ) - self.socket = await self._open_socket(timeout) - self._create_connection(self.socket) - elif not self.connection.is_available(): - raise NewConnectionRequired() - - assert self.connection is not None - logger.trace( - "connection.handle_async_request method=%r url=%r headers=%r", - method, - url, - headers, - ) - return await self.connection.handle_async_request( - method, url, headers, stream, extensions - ) + else: + self._connection = AsyncHTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + elif not self._connection.is_available(): + raise ConnectionNotAvailable() + + return await self._connection.handle_async_request(request) - async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: - scheme, hostname, port = self.origin - timeout = {} if timeout is None else timeout - ssl_context = self._ssl_context if scheme == b"https" else None + async def _connect(self, request: Request) -> AsyncNetworkStream: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) retries_left = self._retries delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) @@ -160,61 +98,98 @@ async def _open_socket(self, timeout: TimeoutDict = None) -> AsyncSocketStream: while True: try: if self._uds is None: - return await self._backend.open_tcp_stream( - hostname, - port, - ssl_context, - timeout, - local_address=self._local_address, - ) + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + } + async with Trace( + "connection.connect_tcp", request, kwargs + ) as trace: + stream = await self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream else: - return await self._backend.open_uds_stream( - self._uds, hostname, ssl_context, timeout - ) + kwargs = { + "path": self._uds, + "timeout": timeout, + } + async with Trace( + "connection.connect_unix_socket", request, kwargs + ) as trace: + stream = await self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream except (ConnectError, ConnectTimeout): if retries_left <= 0: - self._connect_failed = True raise retries_left -= 1 delay = next(delays) - await self._backend.sleep(delay) - except Exception: # noqa: PIE786 - self._connect_failed = True - raise - - def _create_connection(self, socket: AsyncSocketStream) -> None: - http_version = socket.get_http_version() - logger.trace( - "create_connection socket=%r http_version=%r", socket, http_version - ) - if http_version == "HTTP/2" or ( - self._http2_enabled and not self._http1_enabled - ): - from .http2 import AsyncHTTP2Connection - - self._is_http2 = True - self.connection = AsyncHTTP2Connection( - socket=socket, - keepalive_expiry=self._keepalive_expiry, - backend=self._backend, - ) - else: - self._is_http11 = True - self.connection = AsyncHTTP11Connection( - socket=socket, keepalive_expiry=self._keepalive_expiry - ) - - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> None: - if self.connection is not None: - logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) - self.socket = await self.connection.start_tls( - hostname, ssl_context, timeout - ) - logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) + # TRACE 'retry' + await self._network_backend.sleep(delay) + else: + break + + if self._origin.scheme == b"https": + kwargs = { + "ssl_context": self._ssl_context, + "server_hostname": self._origin.host.decode("ascii"), + "timeout": timeout, + } + async with Trace("connection.start_tls", request, kwargs) as trace: + stream = await stream.start_tls(**kwargs) + trace.return_value = stream + return stream + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin async def aclose(self) -> None: - async with self.request_lock: - if self.connection is not None: - await self.connection.aclose() + if self._connection is not None: + await self._connection.aclose() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return self._http2 and (self._origin.scheme == b"https" or not self._http1) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return False + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return False + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return False + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTPConnection": + return self + + async def __aexit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.aclose() diff --git a/httpcore/_async/connection_pool.py b/httpcore/_async/connection_pool.py index 0902ac2f..b9d79c64 100644 --- a/httpcore/_async/connection_pool.py +++ b/httpcore/_async/connection_pool.py @@ -1,362 +1,335 @@ -import warnings -from ssl import SSLContext -from typing import ( - AsyncIterator, - Callable, - Dict, - List, - Optional, - Set, - Tuple, - Union, - cast, -) - -from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore -from .._backends.base import lookup_async_backend -from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol -from .._threadlock import ThreadLock -from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import get_logger, origin_to_url_string, url_to_origin -from .base import AsyncByteStream, AsyncHTTPTransport, NewConnectionRequired +import ssl +from types import TracebackType +from typing import AsyncIterable, AsyncIterator, List, Optional, Type + +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import AsyncEvent, AsyncLock +from ..backends.auto import AutoBackend +from ..backends.base import AsyncNetworkBackend from .connection import AsyncHTTPConnection +from .interfaces import AsyncConnectionInterface, AsyncRequestInterface -logger = get_logger(__name__) +class RequestStatus: + def __init__(self, request: Request): + self.request = request + self.connection: Optional[AsyncConnectionInterface] = None + self._connection_acquired = AsyncEvent() -class NullSemaphore(AsyncSemaphore): - def __init__(self) -> None: - pass - - async def acquire(self, timeout: float = None) -> None: - return - - async def release(self) -> None: - return - - -class ResponseByteStream(AsyncByteStream): - def __init__( - self, - stream: AsyncByteStream, - connection: AsyncHTTPConnection, - callback: Callable, - ) -> None: - """ - A wrapper around the response stream that we return from - `.handle_async_request()`. - - Ensures that when `stream.aclose()` is called, the connection pool - is notified via a callback. - """ - self.stream = stream + def set_connection(self, connection: AsyncConnectionInterface) -> None: + assert self.connection is None self.connection = connection - self.callback = callback + self._connection_acquired.set() - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self.stream: - yield chunk + def unset_connection(self) -> None: + assert self.connection is not None + self.connection = None + self._connection_acquired = AsyncEvent() - async def aclose(self) -> None: - try: - # Call the underlying stream close callback. - # This will be a call to `AsyncHTTP11Connection._response_closed()` - # or `AsyncHTTP2Stream._response_closed()`. - await self.stream.aclose() - finally: - # Call the connection pool close callback. - # This will be a call to `AsyncConnectionPool._response_closed()`. - await self.callback(self.connection) + async def wait_for_connection( + self, timeout: float = None + ) -> AsyncConnectionInterface: + await self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection -class AsyncConnectionPool(AsyncHTTPTransport): +class AsyncConnectionPool(AsyncRequestInterface): """ A connection pool for making HTTP requests. - - Parameters - ---------- - ssl_context: - An SSL context to use for verifying connections. - max_connections: - The maximum number of concurrent connections to allow. - max_keepalive_connections: - The maximum number of connections to allow before closing keep-alive - connections. - keepalive_expiry: - The maximum time to allow before closing a keep-alive connection. - http1: - Enable/Disable HTTP/1.1 support. Defaults to True. - http2: - Enable/Disable HTTP/2 support. Defaults to False. - uds: - Path to a Unix Domain Socket to use instead of TCP sockets. - local_address: - Local address to connect from. Can also be used to connect using a particular - address family. Using ``local_address="0.0.0.0"`` will connect using an - ``AF_INET`` address (IPv4), while using ``local_address="::"`` will connect - using an ``AF_INET6`` address (IPv6). - retries: - The maximum number of retries when trying to establish a connection. - backend: - A name indicating which concurrency backend to use. """ def __init__( self, - ssl_context: SSLContext = None, - max_connections: int = None, + ssl_context: ssl.SSLContext = None, + max_connections: int = 10, max_keepalive_connections: int = None, keepalive_expiry: float = None, http1: bool = True, http2: bool = False, - uds: str = None, - local_address: str = None, retries: int = 0, - max_keepalive: int = None, - backend: Union[AsyncBackend, str] = "auto", - ): - if max_keepalive is not None: - warnings.warn( - "'max_keepalive' is deprecated. Use 'max_keepalive_connections'.", - DeprecationWarning, - ) - max_keepalive_connections = max_keepalive + local_address: str = None, + uds: str = None, + network_backend: AsyncNetworkBackend = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + if max_keepalive_connections is None: + max_keepalive_connections = max_connections + + if ssl_context is None: + ssl_context = default_ssl_context() - if isinstance(backend, str): - backend = lookup_async_backend(backend) + self._ssl_context = ssl_context - self._ssl_context = SSLContext() if ssl_context is None else ssl_context self._max_connections = max_connections - self._max_keepalive_connections = max_keepalive_connections + self._max_keepalive_connections = min( + max_keepalive_connections, max_connections + ) + self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 - self._uds = uds - self._local_address = local_address self._retries = retries - self._connections: Dict[Origin, Set[AsyncHTTPConnection]] = {} - self._thread_lock = ThreadLock() - self._backend = backend - self._next_keepalive_check = 0.0 - - if not (http1 or http2): - raise ValueError("Either http1 or http2 must be True.") - - if http2: - try: - import h2 # noqa: F401 - except ImportError: - raise ImportError( - "Attempted to use http2=True, but the 'h2' " - "package is not installed. Use 'pip install httpcore[http2]'." - ) - - @property - def _connection_semaphore(self) -> AsyncSemaphore: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_internal_semaphore"): - if self._max_connections is not None: - self._internal_semaphore = self._backend.create_semaphore( - self._max_connections, exc_class=PoolTimeout - ) - else: - self._internal_semaphore = NullSemaphore() - - return self._internal_semaphore + self._local_address = local_address + self._uds = uds - @property - def _connection_acquiry_lock(self) -> AsyncLock: - if not hasattr(self, "_internal_connection_acquiry_lock"): - self._internal_connection_acquiry_lock = self._backend.create_lock() - return self._internal_connection_acquiry_lock + self._pool: List[AsyncConnectionInterface] = [] + self._requests: List[RequestStatus] = [] + self._pool_lock = AsyncLock() + self._network_backend = ( + AutoBackend() if network_backend is None else network_backend + ) - def _create_connection( - self, - origin: Tuple[bytes, bytes, int], - ) -> AsyncHTTPConnection: + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: return AsyncHTTPConnection( origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, http1=self._http1, http2=self._http2, - keepalive_expiry=self._keepalive_expiry, - uds=self._uds, - ssl_context=self._ssl_context, - local_address=self._local_address, retries=self._retries, - backend=self._backend, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, ) - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - if not url[0]: + @property + def connections(self) -> List[AsyncConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + , + , + , + ] + ``` + """ + return list(self._pool) + + async def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: + """ + Attempt to provide a connection that can handle the given origin. + """ + origin = status.request.url.origin + + # If there are queued requests in front of us, then don't acquire a + # connection. We handle requests strictly in order. + waiting = [s for s in self._requests if s.connection is None] + if waiting and waiting[0] is not status: + return False + + # Reuse an existing connection if one is currently available. + for idx, connection in enumerate(self._pool): + if connection.can_handle_request(origin) and connection.is_available(): + self._pool.pop(idx) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + # If the pool is currently full, attempt to close one idle connection. + if len(self._pool) >= self._max_connections: + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle(): + await connection.aclose() + self._pool.pop(idx) + break + + # If the pool is still full, then we cannot acquire a connection. + if len(self._pool) >= self._max_connections: + return False + + # Otherwise create a new connection. + connection = self.create_connection(origin) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + async def _close_expired_connections(self) -> None: + """ + Clean up the connection pool by closing off any connections that have expired. + """ + # Close any connections that have expired their keep-alive time. + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.has_expired(): + await connection.aclose() + self._pool.pop(idx) + + # If the pool size exceeds the maximum number of allowed keep-alive connections, + # then close off idle connections as required. + pool_size = len(self._pool) + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle() and pool_size > self._max_keepalive_connections: + await connection.aclose() + self._pool.pop(idx) + pool_size -= 1 + + async def handle_async_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": raise UnsupportedProtocol( - "Request URL missing either an 'http://' or 'https://' protocol." + "Request URL is missing an 'http://' or 'https://' protocol." ) - - if url[0] not in (b"http", b"https"): - protocol = url[0].decode("ascii") + if scheme not in ("http", "https"): raise UnsupportedProtocol( - f"Request URL has an unsupported protocol '{protocol}://'." + "Request URL has an unsupported protocol '{scheme}://'." ) - if not url[1]: - raise LocalProtocolError("Missing hostname in URL.") + status = RequestStatus(request) - origin = url_to_origin(url) - timeout = cast(TimeoutDict, extensions.get("timeout", {})) + async with self._pool_lock: + self._requests.append(status) + await self._close_expired_connections() + await self._attempt_to_acquire_connection(status) - await self._keepalive_sweep() + while True: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + connection = await status.wait_for_connection(timeout=timeout) + try: + response = await connection.handle_async_request(request) + except ConnectionNotAvailable: + # The ConnectionNotAvailable exception is a special case, that + # indicates we need to retry the request on a new connection. + # + # The most common case where this can occur is when multiple + # requests are queued waiting for a single connection, which + # might end up as an HTTP/2 connection, but which actually ends + # up as HTTP/1.1. + async with self._pool_lock: + # Maintain our position in the request queue, but reset the + # status so that the request becomes queued again. + status.unset_connection() + await self._attempt_to_acquire_connection(status) + except Exception as exc: + await self.response_closed(status) + raise exc + else: + break + + # When we return the response, we wrap the stream in a special class + # that handles notifying the connection pool once the response + # has been released. + assert isinstance(response.stream, AsyncIterable) + return Response( + status=response.status, + headers=response.headers, + content=ConnectionPoolByteStream(response.stream, self, status), + extensions=response.extensions, + ) - connection: Optional[AsyncHTTPConnection] = None - while connection is None: - async with self._connection_acquiry_lock: - # We get-or-create a connection as an atomic operation, to ensure - # that HTTP/2 requests issued in close concurrency will end up - # on the same connection. - logger.trace("get_connection_from_pool=%r", origin) - connection = await self._get_connection_from_pool(origin) + async def response_closed(self, status: RequestStatus) -> None: + """ + This method acts as a callback once the request/response cycle is complete. - if connection is None: - connection = self._create_connection(origin=origin) - logger.trace("created connection=%r", connection) - await self._add_to_pool(connection, timeout=timeout) - else: - logger.trace("reuse connection=%r", connection) + It is called into from the `ConnectionPoolByteStream.aclose()` method. + """ + assert status.connection is not None + connection = status.connection + + async with self._pool_lock: + # Update the state of the connection pool. + self._requests.remove(status) + + if connection.is_closed(): + self._pool.remove(connection) + + # Since we've had a response closed, it's possible we'll now be able + # to service one or more requests that are currently pending. + for status in self._requests: + if status.connection is None: + acquired = await self._attempt_to_acquire_connection(status) + # If we could not acquire a connection for a queued request + # then we don't need to check anymore requests that are + # queued later behind it. + if not acquired: + break + + # Housekeeping. + await self._close_expired_connections() - try: - response = await connection.handle_async_request( - method, url, headers=headers, stream=stream, extensions=extensions - ) - except NewConnectionRequired: - connection = None - except BaseException: # noqa: PIE786 - # See https://github.com/encode/httpcore/pull/305 for motivation - # behind catching 'BaseException' rather than 'Exception' here. - logger.trace("remove from pool connection=%r", connection) - await self._remove_from_pool(connection) - raise - - status_code, headers, stream, extensions = response - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, extensions - - async def _get_connection_from_pool( - self, origin: Origin - ) -> Optional[AsyncHTTPConnection]: - # Determine expired keep alive connections on this origin. - reuse_connection = None - connections_to_close = set() - - for connection in self._connections_for_origin(origin): - if connection.should_close(): - connections_to_close.add(connection) - await self._remove_from_pool(connection) - elif connection.is_available(): - reuse_connection = connection - - # Close any dropped connections. - for connection in connections_to_close: - await connection.aclose() - - return reuse_connection - - async def _response_closed(self, connection: AsyncHTTPConnection) -> None: - remove_from_pool = False - close_connection = False - - if connection.is_closed(): - remove_from_pool = True - elif connection.is_idle(): - num_connections = len(self._get_all_connections()) - if ( - self._max_keepalive_connections is not None - and num_connections > self._max_keepalive_connections - ): - remove_from_pool = True - close_connection = True - - if remove_from_pool: - await self._remove_from_pool(connection) - - if close_connection: - await connection.aclose() - - async def _keepalive_sweep(self) -> None: + async def aclose(self) -> None: """ - Remove any IDLE connections that have expired past their keep-alive time. + Close any connections in the pool. """ - if self._keepalive_expiry is None: - return + async with self._pool_lock: + for connection in self._pool: + await connection.aclose() + self._pool = [] + self._requests = [] - now = await self._backend.time() - if now < self._next_keepalive_check: - return + async def __aenter__(self) -> "AsyncConnectionPool": + return self - self._next_keepalive_check = now + min(1.0, self._keepalive_expiry) - connections_to_close = set() + async def __aexit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.aclose() - for connection in self._get_all_connections(): - if connection.should_close(): - connections_to_close.add(connection) - await self._remove_from_pool(connection) - for connection in connections_to_close: - await connection.aclose() +class ConnectionPoolByteStream: + """ + A wrapper around the response byte stream, that additionally handles + notifying the connection pool when the response has been closed. + """ - async def _add_to_pool( - self, connection: AsyncHTTPConnection, timeout: TimeoutDict + def __init__( + self, + stream: AsyncIterable[bytes], + pool: AsyncConnectionPool, + status: RequestStatus, ) -> None: - logger.trace("adding connection to pool=%r", connection) - await self._connection_semaphore.acquire(timeout=timeout.get("pool", None)) - async with self._thread_lock: - self._connections.setdefault(connection.origin, set()) - self._connections[connection.origin].add(connection) - - async def _remove_from_pool(self, connection: AsyncHTTPConnection) -> None: - logger.trace("removing connection from pool=%r", connection) - async with self._thread_lock: - if connection in self._connections.get(connection.origin, set()): - await self._connection_semaphore.release() - self._connections[connection.origin].remove(connection) - if not self._connections[connection.origin]: - del self._connections[connection.origin] - - def _connections_for_origin(self, origin: Origin) -> Set[AsyncHTTPConnection]: - return set(self._connections.get(origin, set())) - - def _get_all_connections(self) -> Set[AsyncHTTPConnection]: - connections: Set[AsyncHTTPConnection] = set() - for connection_set in self._connections.values(): - connections |= connection_set - return connections + self._stream = stream + self._pool = pool + self._status = status - async def aclose(self) -> None: - connections = self._get_all_connections() - for connection in connections: - await self._remove_from_pool(connection) - - # Close all connections - for connection in connections: - await connection.aclose() - - async def get_connection_info(self) -> Dict[str, List[str]]: - """ - Returns a dict of origin URLs to a list of summary strings for each connection. - """ - await self._keepalive_sweep() + async def __aiter__(self) -> AsyncIterator[bytes]: + async for part in self._stream: + yield part - stats = {} - for origin, connections in self._connections.items(): - stats[origin_to_url_string(origin)] = sorted( - [connection.info() for connection in connections] - ) - return stats + async def aclose(self) -> None: + try: + if hasattr(self._stream, "aclose"): + await self._stream.aclose() # type: ignore + finally: + await self._pool.response_closed(self._status) diff --git a/httpcore/_async/http.py b/httpcore/_async/http.py deleted file mode 100644 index 06270f0f..00000000 --- a/httpcore/_async/http.py +++ /dev/null @@ -1,42 +0,0 @@ -from ssl import SSLContext - -from .._backends.auto import AsyncSocketStream -from .._types import TimeoutDict -from .base import AsyncHTTPTransport - - -class AsyncBaseHTTPConnection(AsyncHTTPTransport): - def info(self) -> str: - raise NotImplementedError() # pragma: nocover - - def should_close(self) -> bool: - """ - Return `True` if the connection is in a state where it should be closed. - """ - raise NotImplementedError() # pragma: nocover - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - raise NotImplementedError() # pragma: nocover - - def is_closed(self) -> bool: - """ - Return `True` if the connection has been closed. - """ - raise NotImplementedError() # pragma: nocover - - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - """ - raise NotImplementedError() # pragma: nocover - - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> AsyncSocketStream: - """ - Upgrade the underlying socket to TLS. - """ - raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_async/http11.py b/httpcore/_async/http11.py index a265657c..fdf4c5e9 100644 --- a/httpcore/_async/http11.py +++ b/httpcore/_async/http11.py @@ -1,17 +1,21 @@ import enum import time -from ssl import SSLContext -from typing import AsyncIterator, List, Optional, Tuple, Union, cast +from types import TracebackType +from typing import AsyncIterable, AsyncIterator, List, Optional, Tuple, Type, Union import h11 -from .._backends.auto import AsyncSocketStream -from .._bytestreams import AsyncIteratorByteStream -from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions -from .._types import URL, Headers, TimeoutDict -from .._utils import get_logger -from .base import AsyncByteStream, NewConnectionRequired -from .http import AsyncBaseHTTPConnection +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock +from .._trace import Trace +from ..backends.base import AsyncNetworkStream +from .interfaces import AsyncConnectionInterface H11Event = Union[ h11.Request, @@ -23,170 +27,120 @@ ] -class ConnectionState(enum.IntEnum): +class HTTPConnectionState(enum.IntEnum): NEW = 0 ACTIVE = 1 IDLE = 2 CLOSED = 3 -logger = get_logger(__name__) - - -class AsyncHTTP11Connection(AsyncBaseHTTPConnection): +class AsyncHTTP11Connection(AsyncConnectionInterface): READ_NUM_BYTES = 64 * 1024 - def __init__(self, socket: AsyncSocketStream, keepalive_expiry: float = None): - self.socket = socket - + def __init__( + self, origin: Origin, stream: AsyncNetworkStream, keepalive_expiry: float = None + ) -> None: + self._origin = origin + self._network_stream = stream self._keepalive_expiry: Optional[float] = keepalive_expiry - self._should_expire_at: Optional[float] = None + self._expire_at: Optional[float] = None + self._state = HTTPConnectionState.NEW + self._state_lock = AsyncLock() + self._request_count = 0 self._h11_state = h11.Connection(our_role=h11.CLIENT) - self._state = ConnectionState.NEW - - def __repr__(self) -> str: - return f"" - - def _now(self) -> float: - return time.monotonic() - - def _server_disconnected(self) -> bool: - """ - Return True if the connection is idle, and the underlying socket is readable. - The only valid state the socket can be readable here is when the b"" - EOF marker is about to be returned, indicating a server disconnect. - """ - return self._state == ConnectionState.IDLE and self.socket.is_readable() - - def _keepalive_expired(self) -> bool: - """ - Return True if the connection is idle, and has passed it's keepalive - expiry time. - """ - return ( - self._state == ConnectionState.IDLE - and self._should_expire_at is not None - and self._now() >= self._should_expire_at - ) - - def info(self) -> str: - return f"HTTP/1.1, {self._state.name}" - - def should_close(self) -> bool: - """ - Return `True` if the connection is in a state where it should be closed. - """ - return self._server_disconnected() or self._keepalive_expired() - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - return self._state == ConnectionState.IDLE - - def is_closed(self) -> bool: - """ - Return `True` if the connection has been closed. - """ - return self._state == ConnectionState.CLOSED - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - """ - return self._state == ConnectionState.IDLE + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + async with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + async with Trace("http11.send_request_headers", request, kwargs) as trace: + await self._send_request_headers(**kwargs) + async with Trace("http11.send_request_body", request, kwargs) as trace: + await self._send_request_body(**kwargs) + async with Trace( + "http11.receive_response_headers", request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + ) = await self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": self._network_stream, + }, + ) + except BaseException as exc: + async with Trace("http11.response_closed", request) as trace: + await self._response_closed() + raise exc + + # Sending the request... + + async def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - """ - Send a single HTTP/1.1 request. - - Note that there is no kind of task/thread locking at this layer of interface. - Dealing with locking for concurrency is handled by the `AsyncHTTPConnection`. - """ - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - if self._state in (ConnectionState.NEW, ConnectionState.IDLE): - self._state = ConnectionState.ACTIVE - self._should_expire_at = None - else: - raise NewConnectionRequired() - - await self._send_request(method, url, headers, timeout) - await self._send_request_body(stream, timeout) - ( - http_version, - status_code, - reason_phrase, - headers, - ) = await self._receive_response(timeout) - response_stream = AsyncIteratorByteStream( - aiterator=self._receive_response_data(timeout), - aclose_func=self._response_closed, - ) - extensions = { - "http_version": http_version, - "reason_phrase": reason_phrase, - } - return (status_code, headers, response_stream, extensions) - - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> AsyncSocketStream: - timeout = {} if timeout is None else timeout - self.socket = await self.socket.start_tls(hostname, ssl_context, timeout) - return self.socket - - async def _send_request( - self, method: bytes, url: URL, headers: Headers, timeout: TimeoutDict - ) -> None: - """ - Send the request line and headers. - """ - logger.trace("send_request method=%r url=%r headers=%s", method, url, headers) - _scheme, _host, _port, target = url with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): - event = h11.Request(method=method, target=target, headers=headers) - await self._send_event(event, timeout) - - async def _send_request_body( - self, stream: AsyncByteStream, timeout: TimeoutDict - ) -> None: - """ - Send the request body. - """ - # Send the request body. - async for chunk in stream: - logger.trace("send_data=Data(<%d bytes>)", len(chunk)) + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + await self._send_event(event, timeout=timeout) + + async def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, AsyncIterable) + async for chunk in request.stream: event = h11.Data(data=chunk) - await self._send_event(event, timeout) + await self._send_event(event, timeout=timeout) - # Finalize sending the request. event = h11.EndOfMessage() - await self._send_event(event, timeout) + await self._send_event(event, timeout=timeout) - async def _send_event(self, event: H11Event, timeout: TimeoutDict) -> None: - """ - Send a single `h11` event to the network, waiting for the data to - drain before returning. - """ + async def _send_event(self, event: H11Event, timeout: float = None) -> None: bytes_to_send = self._h11_state.send(event) - await self.socket.write(bytes_to_send, timeout) + await self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... - async def _receive_response( - self, timeout: TimeoutDict + async def _receive_response_headers( + self, request: Request ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: - """ - Read the response status and headers from the network. - """ + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + while True: - event = await self._receive_event(timeout) + event = await self._receive_event(timeout=timeout) if isinstance(event, h11.Response): break @@ -198,72 +152,127 @@ async def _receive_response( return http_version, event.status_code, event.reason, headers - async def _receive_response_data( - self, timeout: TimeoutDict - ) -> AsyncIterator[bytes]: - """ - Read the response data from the network. - """ + async def _receive_response_body(self, request: Request) -> AsyncIterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + while True: - event = await self._receive_event(timeout) + event = await self._receive_event(timeout=timeout) if isinstance(event, h11.Data): - logger.trace("receive_event=Data(<%d bytes>)", len(event.data)) yield bytes(event.data) elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): - logger.trace("receive_event=%r", event) break - async def _receive_event(self, timeout: TimeoutDict) -> H11Event: - """ - Read a single `h11` event, reading more data from the network if needed. - """ + async def _receive_event(self, timeout: float = None) -> H11Event: while True: with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): event = self._h11_state.next_event() if event is h11.NEED_DATA: - data = await self.socket.read(self.READ_NUM_BYTES, timeout) - - # If we feed this case through h11 we'll raise an exception like: - # - # httpcore.RemoteProtocolError: can't handle event type - # ConnectionClosed when role=SERVER and state=SEND_RESPONSE - # - # Which is accurate, but not very informative from an end-user - # perspective. Instead we handle messaging for this case distinctly. - if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: - msg = "Server disconnected without sending a response." - raise RemoteProtocolError(msg) - + data = await self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) self._h11_state.receive_data(data) else: - assert event is not h11.NEED_DATA - break - return event + return event async def _response_closed(self) -> None: - logger.trace( - "response_closed our_state=%r their_state=%r", - self._h11_state.our_state, - self._h11_state.their_state, - ) - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._h11_state.start_next_cycle() - self._state = ConnectionState.IDLE - if self._keepalive_expiry is not None: - self._should_expire_at = self._now() + self._keepalive_expiry - else: - await self.aclose() + async with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + await self.aclose() + + # Once the connection is no longer required... async def aclose(self) -> None: - if self._state != ConnectionState.CLOSED: - self._state = ConnectionState.CLOSED + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # The AsyncConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. - if self._h11_state.our_state is h11.MUST_CLOSE: - event = h11.ConnectionClosed() - self._h11_state.send(event) + async def __aenter__(self) -> "AsyncHTTP11Connection": + return self - await self.socket.aclose() + async def __aexit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + await self.aclose() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: AsyncHTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + + async def __aiter__(self) -> AsyncIterator[bytes]: + kwargs = {"request": self._request} + async with Trace("http11.receive_response_body", self._request, kwargs): + async for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + + async def aclose(self) -> None: + async with Trace("http11.response_closed", self._request): + await self._connection._response_closed() diff --git a/httpcore/_async/http2.py b/httpcore/_async/http2.py index 35a4e091..23ced4f1 100644 --- a/httpcore/_async/http2.py +++ b/httpcore/_async/http2.py @@ -1,175 +1,119 @@ import enum import time -from ssl import SSLContext -from typing import AsyncIterator, Dict, List, Optional, Tuple, cast +import types +import typing +import h2.config import h2.connection import h2.events -from h2.config import H2Configuration -from h2.exceptions import NoAvailableStreamIDError -from h2.settings import SettingCodes, Settings - -from .._backends.auto import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream -from .._bytestreams import AsyncIteratorByteStream -from .._exceptions import LocalProtocolError, PoolTimeout, RemoteProtocolError -from .._types import URL, Headers, TimeoutDict -from .._utils import get_logger -from .base import AsyncByteStream, NewConnectionRequired -from .http import AsyncBaseHTTPConnection - -logger = get_logger(__name__) +import h2.exceptions +import h2.settings + +from .._exceptions import ConnectionNotAvailable, RemoteProtocolError +from .._models import Origin, Request, Response +from .._synchronization import AsyncLock, AsyncSemaphore +from .._trace import Trace +from ..backends.base import AsyncNetworkStream +from .interfaces import AsyncConnectionInterface + + +def has_body_headers(request: Request) -> bool: + return any( + [ + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ] + ) -class ConnectionState(enum.IntEnum): - IDLE = 0 +class HTTPConnectionState(enum.IntEnum): ACTIVE = 1 - CLOSED = 2 + IDLE = 2 + CLOSED = 3 -class AsyncHTTP2Connection(AsyncBaseHTTPConnection): +class AsyncHTTP2Connection(AsyncConnectionInterface): READ_NUM_BYTES = 64 * 1024 - CONFIG = H2Configuration(validate_inbound_headers=False) + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) def __init__( - self, - socket: AsyncSocketStream, - backend: AsyncBackend, - keepalive_expiry: float = None, + self, origin: Origin, stream: AsyncNetworkStream, keepalive_expiry: float = None ): - self.socket = socket - - self._backend = backend + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: typing.Optional[float] = keepalive_expiry self._h2_state = h2.connection.H2Connection(config=self.CONFIG) - + self._state = HTTPConnectionState.IDLE + self._expire_at: typing.Optional[float] = None + self._request_count = 0 + self._init_lock = AsyncLock() + self._state_lock = AsyncLock() + self._read_lock = AsyncLock() + self._write_lock = AsyncLock() self._sent_connection_init = False - self._streams: Dict[int, AsyncHTTP2Stream] = {} - self._events: Dict[int, List[h2.events.Event]] = {} - - self._keepalive_expiry: Optional[float] = keepalive_expiry - self._should_expire_at: Optional[float] = None - self._state = ConnectionState.ACTIVE - self._exhausted_available_stream_ids = False - - def __repr__(self) -> str: - return f"" - - def info(self) -> str: - return f"HTTP/2, {self._state.name}, {len(self._streams)} streams" - - def _now(self) -> float: - return time.monotonic() - - def should_close(self) -> bool: - """ - Return `True` if the connection is currently idle, and the keepalive - timeout has passed. - """ - return ( - self._state == ConnectionState.IDLE - and self._should_expire_at is not None - and self._now() >= self._should_expire_at - ) - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - return self._state == ConnectionState.IDLE - - def is_closed(self) -> bool: - """ - Return `True` if the connection has been closed. - """ - return self._state == ConnectionState.CLOSED - - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - This occurs when any of the following occur: - - * The connection has not yet been opened, and HTTP/2 support is enabled. - We don't *know* at this point if we'll end up on an HTTP/2 connection or - not, but we *might* do, so we indicate availability. - * The connection has been opened, and is currently idle. - * The connection is open, and is an HTTP/2 connection. The connection must - also not have exhausted the maximum total number of stream IDs. - """ - return ( - self._state != ConnectionState.CLOSED - and not self._exhausted_available_stream_ids - ) - - @property - def init_lock(self) -> AsyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_initialization_lock"): - self._initialization_lock = self._backend.create_lock() - return self._initialization_lock - - @property - def read_lock(self) -> AsyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_read_lock"): - self._read_lock = self._backend.create_lock() - return self._read_lock - - @property - def max_streams_semaphore(self) -> AsyncSemaphore: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_max_streams_semaphore"): - max_streams = self._h2_state.local_settings.max_concurrent_streams - self._max_streams_semaphore = self._backend.create_semaphore( - max_streams, exc_class=PoolTimeout + self._used_all_stream_ids = False + self._events: typing.Dict[int, h2.events.Event] = {} + + async def handle_async_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise ConnectionNotAvailable( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" ) - return self._max_streams_semaphore - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> AsyncSocketStream: - raise NotImplementedError("TLS upgrade not supported on HTTP/2 connections.") + async with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - async with self.init_lock: + async with self._init_lock: if not self._sent_connection_init: - # The very first stream is responsible for initiating the connection. - self._state = ConnectionState.ACTIVE - await self.send_connection_init(timeout) + kwargs = {"request": request} + async with Trace("http2.send_connection_init", request, kwargs): + await self._send_connection_init(**kwargs) self._sent_connection_init = True + max_streams = self._h2_state.local_settings.max_concurrent_streams + self._max_streams_semaphore = AsyncSemaphore(max_streams) - await self.max_streams_semaphore.acquire() - try: - try: - stream_id = self._h2_state.get_next_available_stream_id() - except NoAvailableStreamIDError: - self._exhausted_available_stream_ids = True - raise NewConnectionRequired() - else: - self._state = ConnectionState.ACTIVE - self._should_expire_at = None + await self._max_streams_semaphore.acquire() - h2_stream = AsyncHTTP2Stream(stream_id=stream_id, connection=self) - self._streams[stream_id] = h2_stream + try: + stream_id = self._h2_state.get_next_available_stream_id() self._events[stream_id] = [] - return await h2_stream.handle_async_request( - method, url, headers, stream, extensions + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + async with Trace("http2.send_request_headers", request, kwargs): + await self._send_request_headers(request=request, stream_id=stream_id) + async with Trace("http2.send_request_body", request, kwargs): + await self._send_request_body(request=request, stream_id=stream_id) + async with Trace( + "http2.receive_response_headers", request, kwargs + ) as trace: + status, headers = await self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={"stream_id": stream_id, "http_version": b"HTTP/2"}, ) except Exception: # noqa: PIE786 - await self.max_streams_semaphore.release() + kwargs = {"stream_id": stream_id} + async with Trace("http2.response_closed", request, kwargs): + await self._response_closed(stream_id=stream_id) raise - async def send_connection_init(self, timeout: TimeoutDict) -> None: + async def _send_connection_init(self, request: Request) -> None: """ The HTTP/2 connection requires some initial setup before we can start using individual request/response streams on it. @@ -177,15 +121,15 @@ async def send_connection_init(self, timeout: TimeoutDict) -> None: # Need to set these manually here instead of manipulating via # __setitem__() otherwise the H2Connection will emit SettingsUpdate # frames in addition to sending the undesired defaults. - self._h2_state.local_settings = Settings( + self._h2_state.local_settings = h2.settings.Settings( client=True, initial_values={ # Disable PUSH_PROMISE frames from the server since we don't do anything # with them for now. Maybe when we support caching? - SettingCodes.ENABLE_PUSH: 0, + h2.settings.SettingCodes.ENABLE_PUSH: 0, # These two are taken from h2 for safe defaults - SettingCodes.MAX_CONCURRENT_STREAMS: 100, - SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, }, ) @@ -196,227 +140,63 @@ async def send_connection_init(self, timeout: TimeoutDict) -> None: h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL ] - logger.trace("initiate_connection=%r", self) self._h2_state.initiate_connection() self._h2_state.increment_flow_control_window(2 ** 24) - data_to_send = self._h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - - def is_socket_readable(self) -> bool: - return self.socket.is_readable() - - async def aclose(self) -> None: - logger.trace("close_connection=%r", self) - if self._state != ConnectionState.CLOSED: - self._state = ConnectionState.CLOSED - - await self.socket.aclose() - - async def wait_for_outgoing_flow(self, stream_id: int, timeout: TimeoutDict) -> int: - """ - Returns the maximum allowable outgoing flow for a given stream. - If the allowable flow is zero, then waits on the network until - WindowUpdated frames have increased the flow rate. - https://tools.ietf.org/html/rfc7540#section-6.9 - """ - local_flow = self._h2_state.local_flow_control_window(stream_id) - connection_flow = self._h2_state.max_outbound_frame_size - flow = min(local_flow, connection_flow) - while flow == 0: - await self.receive_events(timeout) - local_flow = self._h2_state.local_flow_control_window(stream_id) - connection_flow = self._h2_state.max_outbound_frame_size - flow = min(local_flow, connection_flow) - return flow - - async def wait_for_event( - self, stream_id: int, timeout: TimeoutDict - ) -> h2.events.Event: - """ - Returns the next event for a given stream. - If no events are available yet, then waits on the network until - an event is available. - """ - async with self.read_lock: - while not self._events[stream_id]: - await self.receive_events(timeout) - return self._events[stream_id].pop(0) - - async def receive_events(self, timeout: TimeoutDict) -> None: - """ - Read some data from the network, and update the H2 state. - """ - data = await self.socket.read(self.READ_NUM_BYTES, timeout) - if data == b"": - raise RemoteProtocolError("Server disconnected") - - events = self._h2_state.receive_data(data) - for event in events: - event_stream_id = getattr(event, "stream_id", 0) - logger.trace("receive_event stream_id=%r event=%s", event_stream_id, event) - - if hasattr(event, "error_code"): - raise RemoteProtocolError(event) - - if event_stream_id in self._events: - self._events[event_stream_id].append(event) - - data_to_send = self._h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - - async def send_headers( - self, stream_id: int, headers: Headers, end_stream: bool, timeout: TimeoutDict - ) -> None: - logger.trace("send_headers stream_id=%r headers=%r", stream_id, headers) - self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) - self._h2_state.increment_flow_control_window(2 ** 24, stream_id=stream_id) - data_to_send = self._h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - - async def send_data( - self, stream_id: int, chunk: bytes, timeout: TimeoutDict - ) -> None: - logger.trace("send_data stream_id=%r chunk=%r", stream_id, chunk) - self._h2_state.send_data(stream_id, chunk) - data_to_send = self._h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - - async def end_stream(self, stream_id: int, timeout: TimeoutDict) -> None: - logger.trace("end_stream stream_id=%r", stream_id) - self._h2_state.end_stream(stream_id) - data_to_send = self._h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) - - async def acknowledge_received_data( - self, stream_id: int, amount: int, timeout: TimeoutDict - ) -> None: - self._h2_state.acknowledge_received_data(amount, stream_id) - data_to_send = self._h2_state.data_to_send() - await self.socket.write(data_to_send, timeout) + await self._write_outgoing_data(request) - async def close_stream(self, stream_id: int) -> None: - try: - logger.trace("close_stream stream_id=%r", stream_id) - del self._streams[stream_id] - del self._events[stream_id] - - if not self._streams: - if self._state == ConnectionState.ACTIVE: - if self._exhausted_available_stream_ids: - await self.aclose() - else: - self._state = ConnectionState.IDLE - if self._keepalive_expiry is not None: - self._should_expire_at = ( - self._now() + self._keepalive_expiry - ) - finally: - await self.max_streams_semaphore.release() - - -class AsyncHTTP2Stream: - def __init__(self, stream_id: int, connection: AsyncHTTP2Connection) -> None: - self.stream_id = stream_id - self.connection = connection - - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - headers = [(k.lower(), v) for (k, v) in headers] - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - # Send the request. - seen_headers = set(key for key, value in headers) - has_body = ( - b"content-length" in seen_headers or b"transfer-encoding" in seen_headers - ) - - await self.send_headers(method, url, headers, has_body, timeout) - if has_body: - await self.send_body(stream, timeout) - - # Receive the response. - status_code, headers = await self.receive_response(timeout) - response_stream = AsyncIteratorByteStream( - aiterator=self.body_iter(timeout), aclose_func=self._response_closed - ) + # Sending the request... - extensions = { - "http_version": b"HTTP/2", - } - return (status_code, headers, response_stream, extensions) - - async def send_headers( - self, - method: bytes, - url: URL, - headers: Headers, - has_body: bool, - timeout: TimeoutDict, - ) -> None: - scheme, hostname, port, path = url + async def _send_request_headers(self, request: Request, stream_id: int) -> None: + end_stream = not has_body_headers(request) # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require # HTTP/1.1 style headers, and map them appropriately if we end up on # an HTTP/2 connection. - authority = None - - for k, v in headers: - if k == b"host": - authority = v - break - - if authority is None: - # Mirror the same error we'd see with `h11`, so that the behaviour - # is consistent. Although we're dealing with an `:authority` - # pseudo-header by this point, from an end-user perspective the issue - # is that the outgoing request needed to include a `host` header. - raise LocalProtocolError("Missing mandatory Host: header") + authority = [v for k, v in request.headers if k.lower() == b"host"][0] headers = [ - (b":method", method), + (b":method", request.method), (b":authority", authority), - (b":scheme", scheme), - (b":path", path), + (b":scheme", request.url.scheme), + (b":path", request.url.target), ] + [ - (k, v) - for k, v in headers - if k + (k.lower(), v) + for k, v in request.headers + if k.lower() not in ( b"host", b"transfer-encoding", ) ] - end_stream = not has_body - await self.connection.send_headers(self.stream_id, headers, end_stream, timeout) + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2 ** 24, stream_id=stream_id) + await self._write_outgoing_data(request) - async def send_body(self, stream: AsyncByteStream, timeout: TimeoutDict) -> None: - async for data in stream: + async def _send_request_body(self, request: Request, stream_id: int) -> None: + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.AsyncIterable) + async for data in request.stream: while data: - max_flow = await self.connection.wait_for_outgoing_flow( - self.stream_id, timeout - ) + max_flow = await self._wait_for_outgoing_flow(request, stream_id) chunk_size = min(len(data), max_flow) chunk, data = data[:chunk_size], data[chunk_size:] - await self.connection.send_data(self.stream_id, chunk, timeout) + self._h2_state.send_data(stream_id, chunk) + await self._write_outgoing_data(request) - await self.connection.end_stream(self.stream_id, timeout) + self._h2_state.end_stream(stream_id) + await self._write_outgoing_data(request) - async def receive_response( - self, timeout: TimeoutDict - ) -> Tuple[int, List[Tuple[bytes, bytes]]]: - """ - Read the response status and headers from the network. - """ + # Receiving the response... + + async def _receive_response( + self, request: Request, stream_id: int + ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: while True: - event = await self.connection.wait_for_event(self.stream_id, timeout) + event = await self._receive_stream_event(request, stream_id) if isinstance(event, h2.events.ResponseReceived): break @@ -430,17 +210,167 @@ async def receive_response( return (status_code, headers) - async def body_iter(self, timeout: TimeoutDict) -> AsyncIterator[bytes]: + async def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.AsyncIterator[bytes]: while True: - event = await self.connection.wait_for_event(self.stream_id, timeout) + event = await self._receive_stream_event(request, stream_id) if isinstance(event, h2.events.DataReceived): amount = event.flow_controlled_length - await self.connection.acknowledge_received_data( - self.stream_id, amount, timeout - ) + self._h2_state.acknowledge_received_data(amount, stream_id) + await self._write_outgoing_data(request) yield event.data elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)): break - async def _response_closed(self) -> None: - await self.connection.close_stream(self.stream_id) + async def _receive_stream_event( + self, request: Request, stream_id: int + ) -> h2.events.Event: + while not self._events.get(stream_id): + await self._receive_events(request) + return self._events[stream_id].pop(0) + + async def _receive_events(self, request: Request) -> None: + events = await self._read_incoming_data(request) + for event in events: + event_stream_id = getattr(event, "stream_id", 0) + + if hasattr(event, "error_code"): + raise RemoteProtocolError(event) + + if event_stream_id in self._events: + self._events[event_stream_id].append(event) + + await self._write_outgoing_data(request) + + async def _response_closed(self, stream_id: int) -> None: + await self._max_streams_semaphore.release() + del self._events[stream_id] + async with self._state_lock: + if self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + await self.aclose() + + async def aclose(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + # For task-safe/thread-safe operations call into 'attempt_close' instead. + self._state = HTTPConnectionState.CLOSED + await self._network_stream.aclose() + + # Wrappers around network read/write operations... + + async def _read_incoming_data( + self, request: Request + ) -> typing.List[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + async with self._read_lock: + data = await self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + return self._h2_state.receive_data(data) + + async def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + async with self._write_lock: + data_to_send = self._h2_state.data_to_send() + await self._network_stream.write(data_to_send, timeout) + + # Flow control... + + async def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + await self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED and not self._used_all_stream_ids + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + async def __aenter__(self) -> "AsyncHTTP2Connection": + return self + + async def __aexit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: types.TracebackType = None, + ) -> None: + await self.aclose() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: AsyncHTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + + async def __aiter__(self) -> typing.AsyncIterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + async with Trace("http2.receive_response_body", self._request, kwargs): + async for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + + async def aclose(self) -> None: + kwargs = {"stream_id": self._stream_id} + async with Trace("http2.response_closed", self._request, kwargs): + await self._connection._response_closed(stream_id=self._stream_id) diff --git a/httpcore/_async/http_proxy.py b/httpcore/_async/http_proxy.py index 275bf214..ad99a6f8 100644 --- a/httpcore/_async/http_proxy.py +++ b/httpcore/_async/http_proxy.py @@ -1,31 +1,27 @@ -from http import HTTPStatus -from ssl import SSLContext -from typing import Tuple, cast +import ssl +from typing import Dict, List, Tuple, Union -from .._bytestreams import ByteStream from .._exceptions import ProxyError -from .._types import URL, Headers, TimeoutDict -from .._utils import get_logger, url_to_origin -from .base import AsyncByteStream +from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import AsyncLock +from ..backends.base import AsyncNetworkBackend from .connection import AsyncHTTPConnection -from .connection_pool import AsyncConnectionPool, ResponseByteStream +from .connection_pool import AsyncConnectionPool +from .http11 import AsyncHTTP11Connection +from .interfaces import AsyncConnectionInterface -logger = get_logger(__name__) - - -def get_reason_phrase(status_code: int) -> str: - try: - return HTTPStatus(status_code).phrase - except ValueError: - return "" +HeadersAsList = List[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsDict = Dict[Union[bytes, str], Union[bytes, str]] def merge_headers( - default_headers: Headers = None, override_headers: Headers = None -) -> Headers: + default_headers: List[Tuple[bytes, bytes]] = None, + override_headers: List[Tuple[bytes, bytes]] = None, +) -> List[Tuple[bytes, bytes]]: """ - Append default_headers and override_headers, de-duplicating if a key existing in - both cases. + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. """ default_headers = [] if default_headers is None else default_headers override_headers = [] if override_headers is None else override_headers @@ -40,251 +36,227 @@ def merge_headers( class AsyncHTTPProxy(AsyncConnectionPool): """ - A connection pool for making HTTP requests via an HTTP proxy. - - Parameters - ---------- - proxy_url: - The URL of the proxy service as a 4-tuple of (scheme, host, port, path). - proxy_headers: - A list of proxy headers to include. - proxy_mode: - A proxy mode to operate in. May be "DEFAULT", "FORWARD_ONLY", or "TUNNEL_ONLY". - ssl_context: - An SSL context to use for verifying connections. - max_connections: - The maximum number of concurrent connections to allow. - max_keepalive_connections: - The maximum number of connections to allow before closing keep-alive - connections. - http2: - Enable HTTP/2 support. + A connection pool that sends requests via an HTTP proxy. """ def __init__( self, - proxy_url: URL, - proxy_headers: Headers = None, - proxy_mode: str = "DEFAULT", - ssl_context: SSLContext = None, - max_connections: int = None, + proxy_url: Union[URL, bytes, str], + proxy_headers: Union[HeadersAsDict, HeadersAsList] = None, + ssl_context: ssl.SSLContext = None, + max_connections: int = 10, max_keepalive_connections: int = None, keepalive_expiry: float = None, - http2: bool = False, - backend: str = "auto", - # Deprecated argument style: - max_keepalive: int = None, - ): - assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY") - - self.proxy_origin = url_to_origin(proxy_url) - self.proxy_headers = [] if proxy_headers is None else proxy_headers - self.proxy_mode = proxy_mode + retries: int = 0, + local_address: str = None, + uds: str = None, + network_backend: AsyncNetworkBackend = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic :"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + if ssl_context is None: + ssl_context = default_ssl_context() + super().__init__( ssl_context=ssl_context, max_connections=max_connections, max_keepalive_connections=max_keepalive_connections, keepalive_expiry=keepalive_expiry, - http2=http2, - backend=backend, - max_keepalive=max_keepalive, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, ) - - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - if self._keepalive_expiry is not None: - await self._keepalive_sweep() - - if ( - self.proxy_mode == "DEFAULT" and url[0] == b"http" - ) or self.proxy_mode == "FORWARD_ONLY": - # By default HTTP requests should be forwarded. - logger.trace( - "forward_request proxy_origin=%r proxy_headers=%r method=%r url=%r", - self.proxy_origin, - self.proxy_headers, - method, - url, - ) - return await self._forward_request( - method, url, headers=headers, stream=stream, extensions=extensions - ) - else: - # By default HTTPS should be tunnelled. - logger.trace( - "tunnel_request proxy_origin=%r proxy_headers=%r method=%r url=%r", - self.proxy_origin, - self.proxy_headers, - method, - url, - ) - return await self._tunnel_request( - method, url, headers=headers, stream=stream, extensions=extensions - ) - - async def _forward_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - """ - Forwarded proxy requests include the entire URL as the HTTP target, - rather than just the path. - """ - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - origin = self.proxy_origin - connection = await self._get_connection_from_pool(origin) - - if connection is None: - connection = AsyncHTTPConnection( - origin=origin, - http2=self._http2, + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + + def create_connection(self, origin: Origin) -> AsyncConnectionInterface: + if origin.scheme == b"http": + return AsyncForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, keepalive_expiry=self._keepalive_expiry, - ssl_context=self._ssl_context, + network_backend=self._network_backend, ) - await self._add_to_pool(connection, timeout) - - # Issue a forwarded proxy request... - - # GET https://www.example.org/path HTTP/1.1 - # [proxy headers] - # [headers] - scheme, host, port, path = url - if port is None: - target = b"%b://%b%b" % (scheme, host, path) - else: - target = b"%b://%b:%d%b" % (scheme, host, port, path) - - url = self.proxy_origin + (target,) - headers = merge_headers(self.proxy_headers, headers) - - ( - status_code, - headers, - stream, - extensions, - ) = await connection.handle_async_request( - method, url, headers=headers, stream=stream, extensions=extensions + return AsyncTunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, ) - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed + +class AsyncForwardHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + proxy_headers: Union[HeadersAsDict, HeadersAsList] = None, + keepalive_expiry: float = None, + network_backend: AsyncNetworkBackend = None, + ) -> None: + self._connection = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + + async def handle_async_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return await self._connection.handle_async_request(proxy_request) - return status_code, headers, wrapped_stream, extensions + def can_handle_request(self, origin: Origin) -> bool: + return origin.scheme == b"http" - async def _tunnel_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: AsyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, AsyncByteStream, dict]: - """ - Tunnelled proxy requests require an initial CONNECT request to - establish the connection, and then send regular requests. - """ - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - origin = url_to_origin(url) - connection = await self._get_connection_from_pool(origin) + async def aclose(self) -> None: + await self._connection.aclose() - if connection is None: - scheme, host, port = origin + def info(self) -> str: + return self._connection.info() - # First, create a connection to the proxy server - proxy_connection = AsyncHTTPConnection( - origin=self.proxy_origin, - http2=self._http2, - keepalive_expiry=self._keepalive_expiry, - ssl_context=self._ssl_context, - ) + def is_available(self) -> bool: + return self._connection.is_available() - # Issue a CONNECT request... - - # CONNECT www.example.org:80 HTTP/1.1 - # [proxy-headers] - target = b"%b:%d" % (host, port) - connect_url = self.proxy_origin + (target,) - connect_headers = [(b"Host", target), (b"Accept", b"*/*")] - connect_headers = merge_headers(connect_headers, self.proxy_headers) - - try: - ( - proxy_status_code, - _, - proxy_stream, - _, - ) = await proxy_connection.handle_async_request( - b"CONNECT", - connect_url, - headers=connect_headers, - stream=ByteStream(b""), - extensions=extensions, - ) + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() - proxy_reason = get_reason_phrase(proxy_status_code) - logger.trace( - "tunnel_response proxy_status_code=%r proxy_reason=%r ", - proxy_status_code, - proxy_reason, + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class AsyncTunnelHTTPConnection(AsyncConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: ssl.SSLContext, + proxy_headers: List[Tuple[bytes, bytes]] = None, + keepalive_expiry: float = None, + network_backend: AsyncNetworkBackend = None, + ) -> None: + self._connection: AsyncConnectionInterface = AsyncHTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_headers = [] if proxy_headers is None else proxy_headers + self._keepalive_expiry = keepalive_expiry + self._connect_lock = AsyncLock() + self._connected = False + + async def handle_async_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + async with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = [(b"Host", target), (b"Accept", b"*/*")] + connect_request = Request( + method=b"CONNECT", url=connect_url, headers=connect_headers + ) + connect_response = await self._connection.handle_async_request( + connect_request ) - # Read the response data without closing the socket - async for _ in proxy_stream: - pass - # See if the tunnel was successfully established. - if proxy_status_code < 200 or proxy_status_code > 299: - msg = "%d %s" % (proxy_status_code, proxy_reason) + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + await self._connection.aclose() raise ProxyError(msg) - # Upgrade to TLS if required - # We assume the target speaks TLS on the specified port - if scheme == b"https": - await proxy_connection.start_tls(host, self._ssl_context, timeout) - except Exception as exc: - await proxy_connection.aclose() - raise ProxyError(exc) - - # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. - # This means the proxy connection is now unusable, and we must create - # a new one for regular requests, making sure to use the same socket to - # retain the tunnel. - connection = AsyncHTTPConnection( - origin=origin, - http2=self._http2, - keepalive_expiry=self._keepalive_expiry, - ssl_context=self._ssl_context, - socket=proxy_connection.socket, - ) - await self._add_to_pool(connection, timeout) - - # Once the connection has been established we can send requests on - # it as normal. - ( - status_code, - headers, - stream, - extensions, - ) = await connection.handle_async_request( - method, - url, - headers=headers, - stream=stream, - extensions=extensions, - ) + stream = connect_response.extensions["network_stream"] + stream = await stream.start_tls( + ssl_context=self._ssl_context, + server_hostname=self._remote_origin.host.decode("ascii"), + timeout=timeout, + ) + self._connection = AsyncHTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + self._connected = True + return await self._connection.handle_async_request(request) - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + async def aclose(self) -> None: + await self._connection.aclose() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() - return status_code, headers, wrapped_stream, extensions + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/httpcore/_async/interfaces.py b/httpcore/_async/interfaces.py new file mode 100644 index 00000000..bf24d67a --- /dev/null +++ b/httpcore/_async/interfaces.py @@ -0,0 +1,133 @@ +from typing import AsyncIterator, Union + +from .._compat import asynccontextmanager +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class AsyncRequestInterface: + async def request( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: Union[dict, list] = None, + content: Union[bytes, AsyncIterator[bytes]] = None, + extensions: dict = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + await response.aread() + finally: + await response.aclose() + return response + + @asynccontextmanager + async def stream( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: Union[dict, list] = None, + content: Union[bytes, AsyncIterator[bytes]] = None, + extensions: dict = None, + ) -> AsyncIterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = await self.handle_async_request(request) + try: + yield response + finally: + await response.aclose() + + async def handle_async_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class AsyncConnectionInterface(AsyncRequestInterface): + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_backends/anyio.py b/httpcore/_backends/anyio.py deleted file mode 100644 index b1332a27..00000000 --- a/httpcore/_backends/anyio.py +++ /dev/null @@ -1,201 +0,0 @@ -from ssl import SSLContext -from typing import Optional - -import anyio.abc -from anyio import BrokenResourceError, EndOfStream -from anyio.abc import ByteStream, SocketAttribute -from anyio.streams.tls import TLSAttribute, TLSStream - -from .._exceptions import ( - ConnectError, - ConnectTimeout, - ReadError, - ReadTimeout, - WriteError, - WriteTimeout, - map_exceptions, -) -from .._types import TimeoutDict -from .._utils import is_socket_readable -from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream - - -class SocketStream(AsyncSocketStream): - def __init__(self, stream: ByteStream) -> None: - self.stream = stream - self.read_lock = anyio.Lock() - self.write_lock = anyio.Lock() - - def get_http_version(self) -> str: - alpn_protocol = self.stream.extra(TLSAttribute.alpn_protocol, None) - return "HTTP/2" if alpn_protocol == "h2" else "HTTP/1.1" - - async def start_tls( - self, - hostname: bytes, - ssl_context: SSLContext, - timeout: TimeoutDict, - ) -> "SocketStream": - connect_timeout = timeout.get("connect") - try: - with anyio.fail_after(connect_timeout): - ssl_stream = await TLSStream.wrap( - self.stream, - ssl_context=ssl_context, - hostname=hostname.decode("ascii"), - standard_compatible=False, - ) - except TimeoutError: - raise ConnectTimeout from None - except BrokenResourceError as exc: - raise ConnectError from exc - - return SocketStream(ssl_stream) - - async def read(self, n: int, timeout: TimeoutDict) -> bytes: - read_timeout = timeout.get("read") - async with self.read_lock: - try: - with anyio.fail_after(read_timeout): - return await self.stream.receive(n) - except TimeoutError: - await self.stream.aclose() - raise ReadTimeout from None - except BrokenResourceError as exc: - raise ReadError from exc - except EndOfStream: - return b"" - - async def write(self, data: bytes, timeout: TimeoutDict) -> None: - if not data: - return - - write_timeout = timeout.get("write") - async with self.write_lock: - try: - with anyio.fail_after(write_timeout): - return await self.stream.send(data) - except TimeoutError: - await self.stream.aclose() - raise WriteTimeout from None - except BrokenResourceError as exc: - raise WriteError from exc - - async def aclose(self) -> None: - async with self.write_lock: - try: - await self.stream.aclose() - except BrokenResourceError: - pass - - def is_readable(self) -> bool: - sock = self.stream.extra(SocketAttribute.raw_socket) - return is_socket_readable(sock) - - -class Lock(AsyncLock): - def __init__(self) -> None: - self._lock = anyio.Lock() - - async def release(self) -> None: - self._lock.release() - - async def acquire(self) -> None: - await self._lock.acquire() - - -class Semaphore(AsyncSemaphore): - def __init__(self, max_value: int, exc_class: type): - self.max_value = max_value - self.exc_class = exc_class - - @property - def semaphore(self) -> anyio.abc.Semaphore: - if not hasattr(self, "_semaphore"): - self._semaphore = anyio.Semaphore(self.max_value) - return self._semaphore - - async def acquire(self, timeout: float = None) -> None: - with anyio.move_on_after(timeout): - await self.semaphore.acquire() - return - - raise self.exc_class() - - async def release(self) -> None: - self.semaphore.release() - - -class AnyIOBackend(AsyncBackend): - async def open_tcp_stream( - self, - hostname: bytes, - port: int, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - *, - local_address: Optional[str], - ) -> AsyncSocketStream: - connect_timeout = timeout.get("connect") - unicode_host = hostname.decode("utf-8") - exc_map = { - TimeoutError: ConnectTimeout, - OSError: ConnectError, - BrokenResourceError: ConnectError, - } - - with map_exceptions(exc_map): - with anyio.fail_after(connect_timeout): - stream: anyio.abc.ByteStream - stream = await anyio.connect_tcp( - unicode_host, port, local_host=local_address - ) - if ssl_context: - stream = await TLSStream.wrap( - stream, - hostname=unicode_host, - ssl_context=ssl_context, - standard_compatible=False, - ) - - return SocketStream(stream=stream) - - async def open_uds_stream( - self, - path: str, - hostname: bytes, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - ) -> AsyncSocketStream: - connect_timeout = timeout.get("connect") - unicode_host = hostname.decode("utf-8") - exc_map = { - TimeoutError: ConnectTimeout, - OSError: ConnectError, - BrokenResourceError: ConnectError, - } - - with map_exceptions(exc_map): - with anyio.fail_after(connect_timeout): - stream: anyio.abc.ByteStream = await anyio.connect_unix(path) - if ssl_context: - stream = await TLSStream.wrap( - stream, - hostname=unicode_host, - ssl_context=ssl_context, - standard_compatible=False, - ) - - return SocketStream(stream=stream) - - def create_lock(self) -> AsyncLock: - return Lock() - - def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: - return Semaphore(max_value, exc_class=exc_class) - - async def time(self) -> float: - return float(anyio.current_time()) - - async def sleep(self, seconds: float) -> None: - await anyio.sleep(seconds) diff --git a/httpcore/_backends/asyncio.py b/httpcore/_backends/asyncio.py deleted file mode 100644 index 5142072e..00000000 --- a/httpcore/_backends/asyncio.py +++ /dev/null @@ -1,303 +0,0 @@ -import asyncio -import socket -from ssl import SSLContext -from typing import Optional - -from .._exceptions import ( - ConnectError, - ConnectTimeout, - ReadError, - ReadTimeout, - WriteError, - WriteTimeout, - map_exceptions, -) -from .._types import TimeoutDict -from .._utils import is_socket_readable -from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream - -SSL_MONKEY_PATCH_APPLIED = False - - -def ssl_monkey_patch() -> None: - """ - Monkey-patch for https://bugs.python.org/issue36709 - - This prevents console errors when outstanding HTTPS connections - still exist at the point of exiting. - - Clients which have been opened using a `with` block, or which have - had `close()` closed, will not exhibit this issue in the first place. - """ - MonkeyPatch = asyncio.selector_events._SelectorSocketTransport # type: ignore - - _write = MonkeyPatch.write - - def _fixed_write(self, data: bytes) -> None: # type: ignore - if self._loop and not self._loop.is_closed(): - _write(self, data) - - MonkeyPatch.write = _fixed_write - - -async def backport_start_tls( - transport: asyncio.BaseTransport, - protocol: asyncio.BaseProtocol, - ssl_context: SSLContext, - *, - server_side: bool = False, - server_hostname: str = None, - ssl_handshake_timeout: float = None, -) -> asyncio.Transport: # pragma: nocover (Since it's not used on all Python versions.) - """ - Python 3.6 asyncio doesn't have a start_tls() method on the loop - so we use this function in place of the loop's start_tls() method. - Adapted from this comment: - https://github.com/urllib3/urllib3/issues/1323#issuecomment-362494839 - """ - import asyncio.sslproto - - loop = asyncio.get_event_loop() - waiter = loop.create_future() - ssl_protocol = asyncio.sslproto.SSLProtocol( - loop, - protocol, - ssl_context, - waiter, - server_side=False, - server_hostname=server_hostname, - call_connection_made=False, - ) - - transport.set_protocol(ssl_protocol) - loop.call_soon(ssl_protocol.connection_made, transport) - loop.call_soon(transport.resume_reading) # type: ignore - - await waiter - return ssl_protocol._app_transport - - -class SocketStream(AsyncSocketStream): - def __init__( - self, stream_reader: asyncio.StreamReader, stream_writer: asyncio.StreamWriter - ): - self.stream_reader = stream_reader - self.stream_writer = stream_writer - self.read_lock = asyncio.Lock() - self.write_lock = asyncio.Lock() - - def get_http_version(self) -> str: - ssl_object = self.stream_writer.get_extra_info("ssl_object") - - if ssl_object is None: - return "HTTP/1.1" - - ident = ssl_object.selected_alpn_protocol() - return "HTTP/2" if ident == "h2" else "HTTP/1.1" - - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict - ) -> "SocketStream": - loop = asyncio.get_event_loop() - - stream_reader = asyncio.StreamReader() - protocol = asyncio.StreamReaderProtocol(stream_reader) - transport = self.stream_writer.transport - - loop_start_tls = getattr(loop, "start_tls", backport_start_tls) - - exc_map = {asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError} - - with map_exceptions(exc_map): - transport = await asyncio.wait_for( - loop_start_tls( - transport, - protocol, - ssl_context, - server_hostname=hostname.decode("ascii"), - ), - timeout=timeout.get("connect"), - ) - - # Initialize the protocol, so it is made aware of being tied to - # a TLS connection. - # See: https://github.com/encode/httpx/issues/859 - protocol.connection_made(transport) - - stream_writer = asyncio.StreamWriter( - transport=transport, protocol=protocol, reader=stream_reader, loop=loop - ) - - ssl_stream = SocketStream(stream_reader, stream_writer) - # When we return a new SocketStream with new StreamReader/StreamWriter instances - # we need to keep references to the old StreamReader/StreamWriter so that they - # are not garbage collected and closed while we're still using them. - ssl_stream._inner = self # type: ignore - return ssl_stream - - async def read(self, n: int, timeout: TimeoutDict) -> bytes: - exc_map = {asyncio.TimeoutError: ReadTimeout, OSError: ReadError} - async with self.read_lock: - with map_exceptions(exc_map): - try: - return await asyncio.wait_for( - self.stream_reader.read(n), timeout.get("read") - ) - except AttributeError as exc: # pragma: nocover - if "resume_reading" in str(exc): - # Python's asyncio has a bug that can occur when a - # connection has been closed, while it is paused. - # See: https://github.com/encode/httpx/issues/1213 - # - # Returning an empty byte-string to indicate connection - # close will eventually raise an httpcore.RemoteProtocolError - # to the user when this goes through our HTTP parsing layer. - return b"" - raise - - async def write(self, data: bytes, timeout: TimeoutDict) -> None: - if not data: - return - - exc_map = {asyncio.TimeoutError: WriteTimeout, OSError: WriteError} - async with self.write_lock: - with map_exceptions(exc_map): - self.stream_writer.write(data) - return await asyncio.wait_for( - self.stream_writer.drain(), timeout.get("write") - ) - - async def aclose(self) -> None: - # SSL connections should issue the close and then abort, rather than - # waiting for the remote end of the connection to signal the EOF. - # - # See: - # - # * https://bugs.python.org/issue39758 - # * https://github.com/python-trio/trio/blob/ - # 31e2ae866ad549f1927d45ce073d4f0ea9f12419/trio/_ssl.py#L779-L829 - # - # And related issues caused if we simply omit the 'wait_closed' call, - # without first using `.abort()` - # - # * https://github.com/encode/httpx/issues/825 - # * https://github.com/encode/httpx/issues/914 - is_ssl = self.stream_writer.get_extra_info("ssl_object") is not None - - async with self.write_lock: - try: - self.stream_writer.close() - if is_ssl: - # Give the connection a chance to write any data in the buffer, - # and then forcibly tear down the SSL connection. - await asyncio.sleep(0) - self.stream_writer.transport.abort() # type: ignore - if hasattr(self.stream_writer, "wait_closed"): - # Python 3.7+ only. - await self.stream_writer.wait_closed() # type: ignore - except OSError: - pass - - def is_readable(self) -> bool: - transport = self.stream_reader._transport # type: ignore - sock: Optional[socket.socket] = transport.get_extra_info("socket") - return is_socket_readable(sock) - - -class Lock(AsyncLock): - def __init__(self) -> None: - self._lock = asyncio.Lock() - - async def release(self) -> None: - self._lock.release() - - async def acquire(self) -> None: - await self._lock.acquire() - - -class Semaphore(AsyncSemaphore): - def __init__(self, max_value: int, exc_class: type) -> None: - self.max_value = max_value - self.exc_class = exc_class - - @property - def semaphore(self) -> asyncio.BoundedSemaphore: - if not hasattr(self, "_semaphore"): - self._semaphore = asyncio.BoundedSemaphore(value=self.max_value) - return self._semaphore - - async def acquire(self, timeout: float = None) -> None: - try: - await asyncio.wait_for(self.semaphore.acquire(), timeout) - except asyncio.TimeoutError: - raise self.exc_class() - - async def release(self) -> None: - self.semaphore.release() - - -class AsyncioBackend(AsyncBackend): - def __init__(self) -> None: - global SSL_MONKEY_PATCH_APPLIED - - if not SSL_MONKEY_PATCH_APPLIED: - ssl_monkey_patch() - SSL_MONKEY_PATCH_APPLIED = True - - async def open_tcp_stream( - self, - hostname: bytes, - port: int, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - *, - local_address: Optional[str], - ) -> SocketStream: - host = hostname.decode("ascii") - connect_timeout = timeout.get("connect") - local_addr = None if local_address is None else (local_address, 0) - - exc_map = {asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError} - with map_exceptions(exc_map): - stream_reader, stream_writer = await asyncio.wait_for( - asyncio.open_connection( - host, port, ssl=ssl_context, local_addr=local_addr - ), - connect_timeout, - ) - return SocketStream( - stream_reader=stream_reader, stream_writer=stream_writer - ) - - async def open_uds_stream( - self, - path: str, - hostname: bytes, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - ) -> AsyncSocketStream: - host = hostname.decode("ascii") - connect_timeout = timeout.get("connect") - kwargs: dict = {"server_hostname": host} if ssl_context is not None else {} - exc_map = {asyncio.TimeoutError: ConnectTimeout, OSError: ConnectError} - with map_exceptions(exc_map): - stream_reader, stream_writer = await asyncio.wait_for( - asyncio.open_unix_connection(path, ssl=ssl_context, **kwargs), - connect_timeout, - ) - return SocketStream( - stream_reader=stream_reader, stream_writer=stream_writer - ) - - def create_lock(self) -> AsyncLock: - return Lock() - - def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: - return Semaphore(max_value, exc_class=exc_class) - - async def time(self) -> float: - loop = asyncio.get_event_loop() - return loop.time() - - async def sleep(self, seconds: float) -> None: - await asyncio.sleep(seconds) diff --git a/httpcore/_backends/auto.py b/httpcore/_backends/auto.py deleted file mode 100644 index 5579ab46..00000000 --- a/httpcore/_backends/auto.py +++ /dev/null @@ -1,67 +0,0 @@ -from ssl import SSLContext -from typing import Optional - -import sniffio - -from .._types import TimeoutDict -from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream - -# The following line is imported from the _sync modules -from .sync import SyncBackend, SyncLock, SyncSemaphore, SyncSocketStream # noqa - - -class AutoBackend(AsyncBackend): - @property - def backend(self) -> AsyncBackend: - if not hasattr(self, "_backend_implementation"): - backend = sniffio.current_async_library() - - if backend == "asyncio": - from .anyio import AnyIOBackend - - self._backend_implementation: AsyncBackend = AnyIOBackend() - elif backend == "trio": - from .trio import TrioBackend - - self._backend_implementation = TrioBackend() - elif backend == "curio": - from .curio import CurioBackend - - self._backend_implementation = CurioBackend() - else: # pragma: nocover - raise RuntimeError(f"Unsupported concurrency backend {backend!r}") - return self._backend_implementation - - async def open_tcp_stream( - self, - hostname: bytes, - port: int, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - *, - local_address: Optional[str], - ) -> AsyncSocketStream: - return await self.backend.open_tcp_stream( - hostname, port, ssl_context, timeout, local_address=local_address - ) - - async def open_uds_stream( - self, - path: str, - hostname: bytes, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - ) -> AsyncSocketStream: - return await self.backend.open_uds_stream(path, hostname, ssl_context, timeout) - - def create_lock(self) -> AsyncLock: - return self.backend.create_lock() - - def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: - return self.backend.create_semaphore(max_value, exc_class=exc_class) - - async def time(self) -> float: - return await self.backend.time() - - async def sleep(self, seconds: float) -> None: - await self.backend.sleep(seconds) diff --git a/httpcore/_backends/base.py b/httpcore/_backends/base.py deleted file mode 100644 index 1ca6e31b..00000000 --- a/httpcore/_backends/base.py +++ /dev/null @@ -1,137 +0,0 @@ -from ssl import SSLContext -from types import TracebackType -from typing import TYPE_CHECKING, Optional, Type - -from .._types import TimeoutDict - -if TYPE_CHECKING: # pragma: no cover - from .sync import SyncBackend - - -def lookup_async_backend(name: str) -> "AsyncBackend": - if name == "auto": - from .auto import AutoBackend - - return AutoBackend() - elif name == "asyncio": - from .asyncio import AsyncioBackend - - return AsyncioBackend() - elif name == "trio": - from .trio import TrioBackend - - return TrioBackend() - elif name == "curio": - from .curio import CurioBackend - - return CurioBackend() - elif name == "anyio": - from .anyio import AnyIOBackend - - return AnyIOBackend() - - raise ValueError("Invalid backend name {name!r}") - - -def lookup_sync_backend(name: str) -> "SyncBackend": - from .sync import SyncBackend - - return SyncBackend() - - -class AsyncSocketStream: - """ - A socket stream with read/write operations. Abstracts away any asyncio-specific - interfaces into a more generic base class, that we can use with alternate - backends, or for stand-alone test cases. - """ - - def get_http_version(self) -> str: - raise NotImplementedError() # pragma: no cover - - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict - ) -> "AsyncSocketStream": - raise NotImplementedError() # pragma: no cover - - async def read(self, n: int, timeout: TimeoutDict) -> bytes: - raise NotImplementedError() # pragma: no cover - - async def write(self, data: bytes, timeout: TimeoutDict) -> None: - raise NotImplementedError() # pragma: no cover - - async def aclose(self) -> None: - raise NotImplementedError() # pragma: no cover - - def is_readable(self) -> bool: - raise NotImplementedError() # pragma: no cover - - -class AsyncLock: - """ - An abstract interface for Lock classes. - """ - - async def __aenter__(self) -> None: - await self.acquire() - - async def __aexit__( - self, - exc_type: Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - await self.release() - - async def release(self) -> None: - raise NotImplementedError() # pragma: no cover - - async def acquire(self) -> None: - raise NotImplementedError() # pragma: no cover - - -class AsyncSemaphore: - """ - An abstract interface for Semaphore classes. - Abstracts away any asyncio-specific interfaces. - """ - - async def acquire(self, timeout: float = None) -> None: - raise NotImplementedError() # pragma: no cover - - async def release(self) -> None: - raise NotImplementedError() # pragma: no cover - - -class AsyncBackend: - async def open_tcp_stream( - self, - hostname: bytes, - port: int, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - *, - local_address: Optional[str], - ) -> AsyncSocketStream: - raise NotImplementedError() # pragma: no cover - - async def open_uds_stream( - self, - path: str, - hostname: bytes, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - ) -> AsyncSocketStream: - raise NotImplementedError() # pragma: no cover - - def create_lock(self) -> AsyncLock: - raise NotImplementedError() # pragma: no cover - - def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: - raise NotImplementedError() # pragma: no cover - - async def time(self) -> float: - raise NotImplementedError() # pragma: no cover - - async def sleep(self, seconds: float) -> None: - raise NotImplementedError() # pragma: no cover diff --git a/httpcore/_backends/curio.py b/httpcore/_backends/curio.py deleted file mode 100644 index 99a7b2cc..00000000 --- a/httpcore/_backends/curio.py +++ /dev/null @@ -1,206 +0,0 @@ -from ssl import SSLContext, SSLSocket -from typing import Optional - -import curio -import curio.io - -from .._exceptions import ( - ConnectError, - ConnectTimeout, - ReadError, - ReadTimeout, - WriteError, - WriteTimeout, - map_exceptions, -) -from .._types import TimeoutDict -from .._utils import get_logger, is_socket_readable -from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream - -logger = get_logger(__name__) - -ONE_DAY_IN_SECONDS = float(60 * 60 * 24) - - -def convert_timeout(value: Optional[float]) -> float: - return value if value is not None else ONE_DAY_IN_SECONDS - - -class Lock(AsyncLock): - def __init__(self) -> None: - self._lock = curio.Lock() - - async def acquire(self) -> None: - await self._lock.acquire() - - async def release(self) -> None: - await self._lock.release() - - -class Semaphore(AsyncSemaphore): - def __init__(self, max_value: int, exc_class: type) -> None: - self.max_value = max_value - self.exc_class = exc_class - - @property - def semaphore(self) -> curio.Semaphore: - if not hasattr(self, "_semaphore"): - self._semaphore = curio.Semaphore(value=self.max_value) - return self._semaphore - - async def acquire(self, timeout: float = None) -> None: - timeout = convert_timeout(timeout) - - try: - return await curio.timeout_after(timeout, self.semaphore.acquire()) - except curio.TaskTimeout: - raise self.exc_class() - - async def release(self) -> None: - await self.semaphore.release() - - -class SocketStream(AsyncSocketStream): - def __init__(self, socket: curio.io.Socket) -> None: - self.read_lock = curio.Lock() - self.write_lock = curio.Lock() - self.socket = socket - self.stream = socket.as_stream() - - def get_http_version(self) -> str: - if hasattr(self.socket, "_socket"): - raw_socket = self.socket._socket - - if isinstance(raw_socket, SSLSocket): - ident = raw_socket.selected_alpn_protocol() - return "HTTP/2" if ident == "h2" else "HTTP/1.1" - - return "HTTP/1.1" - - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict - ) -> "AsyncSocketStream": - connect_timeout = convert_timeout(timeout.get("connect")) - exc_map = { - curio.TaskTimeout: ConnectTimeout, - curio.CurioError: ConnectError, - OSError: ConnectError, - } - - with map_exceptions(exc_map): - wrapped_sock = curio.io.Socket( - ssl_context.wrap_socket( - self.socket._socket, - do_handshake_on_connect=False, - server_hostname=hostname.decode("ascii"), - ) - ) - - await curio.timeout_after( - connect_timeout, - wrapped_sock.do_handshake(), - ) - - return SocketStream(wrapped_sock) - - async def read(self, n: int, timeout: TimeoutDict) -> bytes: - read_timeout = convert_timeout(timeout.get("read")) - exc_map = { - curio.TaskTimeout: ReadTimeout, - curio.CurioError: ReadError, - OSError: ReadError, - } - - with map_exceptions(exc_map): - async with self.read_lock: - return await curio.timeout_after(read_timeout, self.stream.read(n)) - - async def write(self, data: bytes, timeout: TimeoutDict) -> None: - write_timeout = convert_timeout(timeout.get("write")) - exc_map = { - curio.TaskTimeout: WriteTimeout, - curio.CurioError: WriteError, - OSError: WriteError, - } - - with map_exceptions(exc_map): - async with self.write_lock: - await curio.timeout_after(write_timeout, self.stream.write(data)) - - async def aclose(self) -> None: - await self.stream.close() - await self.socket.close() - - def is_readable(self) -> bool: - return is_socket_readable(self.socket) - - -class CurioBackend(AsyncBackend): - async def open_tcp_stream( - self, - hostname: bytes, - port: int, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - *, - local_address: Optional[str], - ) -> AsyncSocketStream: - connect_timeout = convert_timeout(timeout.get("connect")) - exc_map = { - curio.TaskTimeout: ConnectTimeout, - curio.CurioError: ConnectError, - OSError: ConnectError, - } - host = hostname.decode("ascii") - - kwargs: dict = {} - if ssl_context is not None: - kwargs["ssl"] = ssl_context - kwargs["server_hostname"] = host - if local_address is not None: - kwargs["source_addr"] = (local_address, 0) - - with map_exceptions(exc_map): - sock: curio.io.Socket = await curio.timeout_after( - connect_timeout, - curio.open_connection(hostname, port, **kwargs), - ) - - return SocketStream(sock) - - async def open_uds_stream( - self, - path: str, - hostname: bytes, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - ) -> AsyncSocketStream: - connect_timeout = convert_timeout(timeout.get("connect")) - exc_map = { - curio.TaskTimeout: ConnectTimeout, - curio.CurioError: ConnectError, - OSError: ConnectError, - } - host = hostname.decode("ascii") - kwargs = ( - {} if ssl_context is None else {"ssl": ssl_context, "server_hostname": host} - ) - - with map_exceptions(exc_map): - sock: curio.io.Socket = await curio.timeout_after( - connect_timeout, curio.open_unix_connection(path, **kwargs) - ) - - return SocketStream(sock) - - def create_lock(self) -> AsyncLock: - return Lock() - - def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: - return Semaphore(max_value, exc_class) - - async def time(self) -> float: - return await curio.clock() - - async def sleep(self, seconds: float) -> None: - await curio.sleep(seconds) diff --git a/httpcore/_backends/sync.py b/httpcore/_backends/sync.py deleted file mode 100644 index ee8f94b7..00000000 --- a/httpcore/_backends/sync.py +++ /dev/null @@ -1,178 +0,0 @@ -import socket -import threading -import time -from ssl import SSLContext -from types import TracebackType -from typing import Optional, Type - -from .._exceptions import ( - ConnectError, - ConnectTimeout, - ReadError, - ReadTimeout, - WriteError, - WriteTimeout, - map_exceptions, -) -from .._types import TimeoutDict -from .._utils import is_socket_readable - - -class SyncSocketStream: - """ - A socket stream with read/write operations. Abstracts away any asyncio-specific - interfaces into a more generic base class, that we can use with alternate - backends, or for stand-alone test cases. - """ - - def __init__(self, sock: socket.socket) -> None: - self.sock = sock - self.read_lock = threading.Lock() - self.write_lock = threading.Lock() - - def get_http_version(self) -> str: - selected_alpn_protocol = getattr(self.sock, "selected_alpn_protocol", None) - if selected_alpn_protocol is not None: - ident = selected_alpn_protocol() - return "HTTP/2" if ident == "h2" else "HTTP/1.1" - return "HTTP/1.1" - - def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict - ) -> "SyncSocketStream": - connect_timeout = timeout.get("connect") - exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError} - - with map_exceptions(exc_map): - self.sock.settimeout(connect_timeout) - wrapped = ssl_context.wrap_socket( - self.sock, server_hostname=hostname.decode("ascii") - ) - - return SyncSocketStream(wrapped) - - def read(self, n: int, timeout: TimeoutDict) -> bytes: - read_timeout = timeout.get("read") - exc_map = {socket.timeout: ReadTimeout, socket.error: ReadError} - - with self.read_lock: - with map_exceptions(exc_map): - self.sock.settimeout(read_timeout) - return self.sock.recv(n) - - def write(self, data: bytes, timeout: TimeoutDict) -> None: - write_timeout = timeout.get("write") - exc_map = {socket.timeout: WriteTimeout, socket.error: WriteError} - - with self.write_lock: - with map_exceptions(exc_map): - while data: - self.sock.settimeout(write_timeout) - n = self.sock.send(data) - data = data[n:] - - def close(self) -> None: - with self.write_lock: - try: - self.sock.close() - except socket.error: - pass - - def is_readable(self) -> bool: - return is_socket_readable(self.sock) - - -class SyncLock: - def __init__(self) -> None: - self._lock = threading.Lock() - - def __enter__(self) -> None: - self.acquire() - - def __exit__( - self, - exc_type: Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - self.release() - - def release(self) -> None: - self._lock.release() - - def acquire(self) -> None: - self._lock.acquire() - - -class SyncSemaphore: - def __init__(self, max_value: int, exc_class: type) -> None: - self.max_value = max_value - self.exc_class = exc_class - self._semaphore = threading.Semaphore(max_value) - - def acquire(self, timeout: float = None) -> None: - if not self._semaphore.acquire(timeout=timeout): # type: ignore - raise self.exc_class() - - def release(self) -> None: - self._semaphore.release() - - -class SyncBackend: - def open_tcp_stream( - self, - hostname: bytes, - port: int, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - *, - local_address: Optional[str], - ) -> SyncSocketStream: - address = (hostname.decode("ascii"), port) - connect_timeout = timeout.get("connect") - source_address = None if local_address is None else (local_address, 0) - exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError} - - with map_exceptions(exc_map): - sock = socket.create_connection( - address, connect_timeout, source_address=source_address # type: ignore - ) - if ssl_context is not None: - sock = ssl_context.wrap_socket( - sock, server_hostname=hostname.decode("ascii") - ) - return SyncSocketStream(sock=sock) - - def open_uds_stream( - self, - path: str, - hostname: bytes, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - ) -> SyncSocketStream: - connect_timeout = timeout.get("connect") - exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError} - - with map_exceptions(exc_map): - sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.settimeout(connect_timeout) - sock.connect(path) - - if ssl_context is not None: - sock = ssl_context.wrap_socket( - sock, server_hostname=hostname.decode("ascii") - ) - - return SyncSocketStream(sock=sock) - - def create_lock(self) -> SyncLock: - return SyncLock() - - def create_semaphore(self, max_value: int, exc_class: type) -> SyncSemaphore: - return SyncSemaphore(max_value, exc_class=exc_class) - - def time(self) -> float: - return time.monotonic() - - def sleep(self, seconds: float) -> None: - time.sleep(seconds) diff --git a/httpcore/_backends/trio.py b/httpcore/_backends/trio.py deleted file mode 100644 index d6e67c2e..00000000 --- a/httpcore/_backends/trio.py +++ /dev/null @@ -1,212 +0,0 @@ -from ssl import SSLContext -from typing import Optional - -import trio - -from .._exceptions import ( - ConnectError, - ConnectTimeout, - ReadError, - ReadTimeout, - WriteError, - WriteTimeout, - map_exceptions, -) -from .._types import TimeoutDict -from .base import AsyncBackend, AsyncLock, AsyncSemaphore, AsyncSocketStream - - -def none_as_inf(value: Optional[float]) -> float: - return value if value is not None else float("inf") - - -class SocketStream(AsyncSocketStream): - def __init__(self, stream: trio.abc.Stream) -> None: - self.stream = stream - self.read_lock = trio.Lock() - self.write_lock = trio.Lock() - - def get_http_version(self) -> str: - if not isinstance(self.stream, trio.SSLStream): - return "HTTP/1.1" - - ident = self.stream.selected_alpn_protocol() - return "HTTP/2" if ident == "h2" else "HTTP/1.1" - - async def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict - ) -> "SocketStream": - connect_timeout = none_as_inf(timeout.get("connect")) - exc_map = { - trio.TooSlowError: ConnectTimeout, - trio.BrokenResourceError: ConnectError, - } - ssl_stream = trio.SSLStream( - self.stream, - ssl_context=ssl_context, - server_hostname=hostname.decode("ascii"), - ) - - with map_exceptions(exc_map): - with trio.fail_after(connect_timeout): - await ssl_stream.do_handshake() - return SocketStream(ssl_stream) - - async def read(self, n: int, timeout: TimeoutDict) -> bytes: - read_timeout = none_as_inf(timeout.get("read")) - exc_map = {trio.TooSlowError: ReadTimeout, trio.BrokenResourceError: ReadError} - - async with self.read_lock: - with map_exceptions(exc_map): - try: - with trio.fail_after(read_timeout): - return await self.stream.receive_some(max_bytes=n) - except trio.TooSlowError as exc: - await self.stream.aclose() - raise exc - - async def write(self, data: bytes, timeout: TimeoutDict) -> None: - if not data: - return - - write_timeout = none_as_inf(timeout.get("write")) - exc_map = { - trio.TooSlowError: WriteTimeout, - trio.BrokenResourceError: WriteError, - } - - async with self.write_lock: - with map_exceptions(exc_map): - try: - with trio.fail_after(write_timeout): - return await self.stream.send_all(data) - except trio.TooSlowError as exc: - await self.stream.aclose() - raise exc - - async def aclose(self) -> None: - async with self.write_lock: - try: - await self.stream.aclose() - except trio.BrokenResourceError: - pass - - def is_readable(self) -> bool: - # Adapted from: https://github.com/encode/httpx/pull/143#issuecomment-515202982 - stream = self.stream - - # Peek through any SSLStream wrappers to get the underlying SocketStream. - while isinstance(stream, trio.SSLStream): - stream = stream.transport_stream - assert isinstance(stream, trio.SocketStream) - - return stream.socket.is_readable() - - -class Lock(AsyncLock): - def __init__(self) -> None: - self._lock = trio.Lock() - - async def release(self) -> None: - self._lock.release() - - async def acquire(self) -> None: - await self._lock.acquire() - - -class Semaphore(AsyncSemaphore): - def __init__(self, max_value: int, exc_class: type): - self.max_value = max_value - self.exc_class = exc_class - - @property - def semaphore(self) -> trio.Semaphore: - if not hasattr(self, "_semaphore"): - self._semaphore = trio.Semaphore(self.max_value, max_value=self.max_value) - return self._semaphore - - async def acquire(self, timeout: float = None) -> None: - timeout = none_as_inf(timeout) - - with trio.move_on_after(timeout): - await self.semaphore.acquire() - return - - raise self.exc_class() - - async def release(self) -> None: - self.semaphore.release() - - -class TrioBackend(AsyncBackend): - async def open_tcp_stream( - self, - hostname: bytes, - port: int, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - *, - local_address: Optional[str], - ) -> AsyncSocketStream: - connect_timeout = none_as_inf(timeout.get("connect")) - # Trio will support local_address from 0.16.1 onwards. - # We only include the keyword argument if a local_address - #  argument has been passed. - kwargs: dict = {} if local_address is None else {"local_address": local_address} - exc_map = { - OSError: ConnectError, - trio.TooSlowError: ConnectTimeout, - trio.BrokenResourceError: ConnectError, - } - - with map_exceptions(exc_map): - with trio.fail_after(connect_timeout): - stream: trio.abc.Stream = await trio.open_tcp_stream( - hostname, port, **kwargs - ) - - if ssl_context is not None: - stream = trio.SSLStream( - stream, ssl_context, server_hostname=hostname.decode("ascii") - ) - await stream.do_handshake() - - return SocketStream(stream=stream) - - async def open_uds_stream( - self, - path: str, - hostname: bytes, - ssl_context: Optional[SSLContext], - timeout: TimeoutDict, - ) -> AsyncSocketStream: - connect_timeout = none_as_inf(timeout.get("connect")) - exc_map = { - OSError: ConnectError, - trio.TooSlowError: ConnectTimeout, - trio.BrokenResourceError: ConnectError, - } - - with map_exceptions(exc_map): - with trio.fail_after(connect_timeout): - stream: trio.abc.Stream = await trio.open_unix_socket(path) - - if ssl_context is not None: - stream = trio.SSLStream( - stream, ssl_context, server_hostname=hostname.decode("ascii") - ) - await stream.do_handshake() - - return SocketStream(stream=stream) - - def create_lock(self) -> AsyncLock: - return Lock() - - def create_semaphore(self, max_value: int, exc_class: type) -> AsyncSemaphore: - return Semaphore(max_value, exc_class=exc_class) - - async def time(self) -> float: - return trio.current_time() - - async def sleep(self, seconds: float) -> None: - await trio.sleep(seconds) diff --git a/httpcore/_bytestreams.py b/httpcore/_bytestreams.py deleted file mode 100644 index 317f4110..00000000 --- a/httpcore/_bytestreams.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import AsyncIterator, Callable, Iterator - -from ._async.base import AsyncByteStream -from ._sync.base import SyncByteStream - - -class ByteStream(AsyncByteStream, SyncByteStream): - """ - A concrete implementation for either sync or async byte streams. - - Example:: - - stream = httpcore.ByteStream(b"123") - - Parameters - ---------- - content: - A plain byte string used as the content of the stream. - """ - - def __init__(self, content: bytes) -> None: - self._content = content - - def __iter__(self) -> Iterator[bytes]: - yield self._content - - async def __aiter__(self) -> AsyncIterator[bytes]: - yield self._content - - -class IteratorByteStream(SyncByteStream): - """ - A concrete implementation for sync byte streams. - - Example:: - - def generate_content(): - yield b"Hello, world!" - ... - - stream = httpcore.IteratorByteStream(generate_content()) - - Parameters - ---------- - iterator: - A sync byte iterator, used as the content of the stream. - close_func: - An optional function called when closing the stream. - """ - - def __init__(self, iterator: Iterator[bytes], close_func: Callable = None) -> None: - self._iterator = iterator - self._close_func = close_func - - def __iter__(self) -> Iterator[bytes]: - for chunk in self._iterator: - yield chunk - - def close(self) -> None: - if self._close_func is not None: - self._close_func() - - -class AsyncIteratorByteStream(AsyncByteStream): - """ - A concrete implementation for async byte streams. - - Example:: - - async def generate_content(): - yield b"Hello, world!" - ... - - stream = httpcore.AsyncIteratorByteStream(generate_content()) - - Parameters - ---------- - aiterator: - An async byte iterator, used as the content of the stream. - aclose_func: - An optional async function called when closing the stream. - """ - - def __init__( - self, aiterator: AsyncIterator[bytes], aclose_func: Callable = None - ) -> None: - self._aiterator = aiterator - self._aclose_func = aclose_func - - async def __aiter__(self) -> AsyncIterator[bytes]: - async for chunk in self._aiterator: - yield chunk - - async def aclose(self) -> None: - if self._aclose_func is not None: - await self._aclose_func() diff --git a/httpcore/_compat.py b/httpcore/_compat.py new file mode 100644 index 00000000..aa4f5bd3 --- /dev/null +++ b/httpcore/_compat.py @@ -0,0 +1,6 @@ +# `contextlib.asynccontextmanager` exists from Python 3.7 onwards. +# For 3.6 we require the `async_generator` package for a backported version. +try: + from contextlib import asynccontextmanager # type: ignore +except ImportError: + from async_generator import asynccontextmanager # type: ignore # noqa diff --git a/httpcore/_exceptions.py b/httpcore/_exceptions.py index ba568299..dd4637e3 100644 --- a/httpcore/_exceptions.py +++ b/httpcore/_exceptions.py @@ -9,8 +9,16 @@ def map_exceptions(map: Dict[Type[Exception], Type[Exception]]) -> Iterator[None except Exception as exc: # noqa: PIE786 for from_exc, to_exc in map.items(): if isinstance(exc, from_exc): - raise to_exc(exc) from None - raise + raise to_exc(exc) + raise # pragma: nocover + + +class ConnectionNotAvailable(Exception): + pass + + +class ProxyError(Exception): + pass class UnsupportedProtocol(Exception): @@ -29,10 +37,6 @@ class LocalProtocolError(ProtocolError): pass -class ProxyError(Exception): - pass - - # Timeout errors @@ -73,7 +77,3 @@ class ReadError(NetworkError): class WriteError(NetworkError): pass - - -class CloseError(NetworkError): - pass diff --git a/httpcore/_models.py b/httpcore/_models.py new file mode 100644 index 00000000..34569e2d --- /dev/null +++ b/httpcore/_models.py @@ -0,0 +1,466 @@ +from typing import ( + Any, + AsyncIterable, + AsyncIterator, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, +) +from urllib.parse import urlparse + +# Functions for typechecking... + + +def enforce_bytes(value: Union[bytes, str], *, name: str) -> bytes: + """ + Any arguments that are ultimately represented as bytes can be specified + either as bytes or as strings. + + However we enforce that any string arguments must only contain characters in + the plain ASCII range. chr(0)...chr(127). If you need to use characters + outside that range then be precise, and use a byte-wise argument. + """ + if isinstance(value, str): + try: + return value.encode("ascii") + except UnicodeEncodeError: + raise TypeError(f"{name} strings may not include unicode characters.") + elif isinstance(value, bytes): + return value + + seen_type = type(value).__name__ + raise TypeError(f"{name} must be bytes or str, but got {seen_type}.") + + +def enforce_url(value: Union["URL", bytes, str], *, name: str) -> "URL": + """ + Type check for URL parameters. + """ + if isinstance(value, (bytes, str)): + return URL(value) + elif isinstance(value, URL): + return value + + seen_type = type(value).__name__ + raise TypeError(f"{name} must be a URL, bytes, or str, but got {seen_type}.") + + +def enforce_headers( + value: Union[dict, list] = None, *, name: str +) -> List[Tuple[bytes, bytes]]: + """ + Convienence function that ensure all items in request or response headers + are either bytes or strings in the plain ASCII range. + """ + if value is None: + return [] + elif isinstance(value, (list, tuple)): + return [ + ( + enforce_bytes(k, name="header name"), + enforce_bytes(v, name="header value"), + ) + for k, v in value + ] + elif isinstance(value, dict): + return [ + ( + enforce_bytes(k, name="header name"), + enforce_bytes(v, name="header value"), + ) + for k, v in value.items() + ] + + seen_type = type(value).__name__ + raise TypeError(f"{name} must be a list, but got {seen_type}.") + + +def enforce_stream( + value: Union[bytes, Iterable[bytes], AsyncIterable[bytes], None], *, name: str +) -> Union[Iterable[bytes], AsyncIterable[bytes]]: + if value is None: + return ByteStream(b"") + elif isinstance(value, bytes): + return ByteStream(value) + return value + + +# * https://tools.ietf.org/html/rfc3986#section-3.2.3 +# * https://url.spec.whatwg.org/#url-miscellaneous +# * https://url.spec.whatwg.org/#scheme-state +DEFAULT_PORTS = { + b"ftp": 21, + b"http": 80, + b"https": 443, + b"ws": 80, + b"wss": 443, +} + + +def include_request_headers( + headers: List[Tuple[bytes, bytes]], + *, + url: "URL", + content: Union[None, bytes, Iterable[bytes], AsyncIterable[bytes]], +) -> List[Tuple[bytes, bytes]]: + headers_set = set([k.lower() for k, v in headers]) + + if b"host" not in headers_set: + default_port = DEFAULT_PORTS.get(url.scheme) + if url.port is None or url.port == default_port: + header_value = url.host + else: + header_value = b"%b:%d" % (url.host, url.port) + headers = [(b"Host", header_value)] + headers + + if ( + content is not None + and b"content-length" not in headers_set + and b"transfer-encoding" not in headers_set + ): + if isinstance(content, bytes): + content_length = str(len(content)).encode("ascii") + headers += [(b"Content-Length", content_length)] + else: + headers += [(b"Transfer-Encoding", b"chunked")] # pragma: nocover + + return headers + + +# Interfaces for byte streams... + + +class ByteStream: + """ + A container for non-streaming content, and that supports both sync and async + stream iteration. + """ + + def __init__(self, content: bytes) -> None: + self._content = content + + def __iter__(self) -> Iterator[bytes]: + yield self._content + + async def __aiter__(self) -> AsyncIterator[bytes]: + yield self._content + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{len(self._content)} bytes]>" + + +class Origin: + def __init__(self, scheme: bytes, host: bytes, port: int) -> None: + self.scheme = scheme + self.host = host + self.port = port + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, Origin) + and self.scheme == other.scheme + and self.host == other.host + and self.port == other.port + ) + + def __str__(self) -> str: + scheme = self.scheme.decode("ascii") + host = self.host.decode("ascii") + port = str(self.port) + return f"{scheme}://{host}:{port}" + + +class URL: + """ + Represents the URL against which an HTTP request may be made. + + The URL may either be specified as a plain string, for convienence: + + ```python + url = httpcore.URL("https://www.example.com/") + ``` + + Or be constructed with explicitily pre-parsed components: + + ```python + url = httpcore.URL(scheme=b'https', host=b'www.example.com', port=None, target=b'/') + ``` + + Using this second more explicit style allows integrations that are using + `httpcore` to pass through URLs that have already been parsed in order to use + libraries such as `rfc-3986` rather than relying on the stdlib. It also ensures + that URL parsing is treated identically at both the networking level and at any + higher layers of abstraction. + + The four components are important here, as they allow the URL to be precisely + specified in a pre-parsed format. They also allow certain types of request to + be created that could not otherwise be expressed. + + For example, an HTTP request to `http://www.example.com/` forwarded via a proxy + at `http://localhost:8080`... + + ```python + # Constructs an HTTP request with a complete URL as the target: + # GET https://www.example.com/ HTTP/1.1 + url = httpcore.URL( + scheme=b'http', + host=b'localhost', + port=8080, + target=b'https://www.example.com/' + ) + request = httpcore.Request( + method="GET", + url=url + ) + ``` + + Another example is constructing an `OPTIONS *` request... + + ```python + # Constructs an 'OPTIONS *' HTTP request: + # OPTIONS * HTTP/1.1 + url = httpcore.URL(scheme=b'https', host=b'www.example.com', target=b'*') + request = httpcore.Request(method="OPTIONS", url=url) + ``` + + This kind of request is not possible to formulate with a URL string, + because the `/` delimiter is always used to demark the target from the + host/port portion of the URL. + + For convenience, string-like arguments may be specified either as strings or + as bytes. However, once a request is being issue over-the-wire, the URL + components are always ultimately required to be a bytewise representation. + + In order to avoid any ambiguity over character encodings, when strings are used + as arguments, they must be strictly limited to the ASCII range `chr(0)`-`chr(127)`. + If you require a bytewise representation that is outside this range you must + handle the character encoding directly, and pass a bytes instance. + """ + + def __init__( + self, + url: Union[bytes, str] = "", + *, + scheme: Union[bytes, str] = b"", + host: Union[bytes, str] = b"", + port: Optional[int] = None, + target: Union[bytes, str] = b"", + ) -> None: + """ + Parameters: + url: The complete URL as a string or bytes. + scheme: The URL scheme as a string or bytes. + Typically either `"http"` or `"https"`. + host: The URL host as a string or bytes. Such as `"www.example.com"`. + port: The port to connect to. Either an integer or `None`. + target: The target of the HTTP request. Such as `"/items?search=red"`. + """ + if url: + parsed = urlparse(enforce_bytes(url, name="url")) + self.scheme = parsed.scheme + self.host = parsed.hostname or b"" + self.port = parsed.port + self.target = (parsed.path or b"/") + ( + b"?" + parsed.query if parsed.query else b"" + ) + else: + self.scheme = enforce_bytes(scheme, name="scheme") + self.host = enforce_bytes(host, name="host") + self.port = port + self.target = enforce_bytes(target, name="target") + + @property + def origin(self) -> Origin: + default_port = {b"http": 80, b"https": 443}[self.scheme] + return Origin( + scheme=self.scheme, host=self.host, port=self.port or default_port + ) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, URL) + and other.scheme == self.scheme + and other.host == self.host + and other.port == self.port + and other.target == self.target + ) + + def __bytes__(self) -> bytes: + if self.port is None: + return b"%b://%b%b" % (self.scheme, self.host, self.target) + return b"%b://%b:%d%b" % (self.scheme, self.host, self.port, self.target) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(scheme={self.scheme!r}, " + f"host={self.host!r}, port={self.port!r}, target={self.target!r})" + ) + + +class Request: + """ + An HTTP request. + """ + + def __init__( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: Union[dict, list] = None, + content: Union[bytes, Iterable[bytes], AsyncIterable[bytes]] = None, + extensions: dict = None, + ) -> None: + """ + Parameters: + method: The HTTP request method, either as a string or bytes. + For example: `GET`. + url: The request URL, either as a `URL` instance, or as a string or bytes. + For example: `"https://www.example.com".` + headers: The HTTP request headers. + content: The content of the response body. + extensions: A dictionary of optional extra information included on + the request. Possible keys include `"timeout"`, and `"trace"`. + """ + self.method: bytes = enforce_bytes(method, name="method") + self.url: URL = enforce_url(url, name="url") + self.headers: List[Tuple[bytes, bytes]] = enforce_headers( + headers, name="headers" + ) + self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream( + content, name="content" + ) + self.extensions = {} if extensions is None else extensions + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.method!r}]>" + + +class Response: + """ + An HTTP response. + """ + + def __init__( + self, + status: int, + *, + headers: Union[dict, list] = None, + content: Union[bytes, Iterable[bytes], AsyncIterable[bytes]] = None, + extensions: dict = None, + ) -> None: + """ + Parameters: + status: The HTTP status code of the response. For example `200`. + headers: The HTTP response headers. + content: The content of the response body. + extensions: A dictionary of optional extra information included on + the responseself.Possible keys include `"http_version"`, + `"reason_phrase"`, and `"network_stream"`. + """ + self.status: int = status + self.headers: List[Tuple[bytes, bytes]] = enforce_headers( + headers, name="headers" + ) + self.stream: Union[Iterable[bytes], AsyncIterable[bytes]] = enforce_stream( + content, name="content" + ) + self.extensions: dict = {} if extensions is None else extensions + + self._stream_consumed = False + + @property + def content(self) -> bytes: + if not hasattr(self, "_content"): + if isinstance(self.stream, Iterable): + raise RuntimeError( + "Attempted to access 'response.content' on a streaming response. " + "Call 'response.read()' first." + ) + else: + raise RuntimeError( + "Attempted to access 'response.content' on a streaming response. " + "Call 'await response.aread()' first." + ) + return self._content + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.status}]>" + + # Sync interface... + + def read(self) -> bytes: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to read an asynchronous response using 'response.read()'. " + "You should use 'await response.aread()' instead." + ) + if not hasattr(self, "_content"): + self._content = b"".join([part for part in self.iter_stream()]) + return self._content + + def iter_stream(self) -> Iterator[bytes]: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to stream an asynchronous response using 'for ... in " + "response.iter_stream()'. " + "You should use 'async for ... in response.aiter_stream()' instead." + ) + if self._stream_consumed: + raise RuntimeError( + "Attempted to call 'for ... in response.iter_stream()' more than once." + ) + self._stream_consumed = True + for chunk in self.stream: + yield chunk + + def close(self) -> None: + if not isinstance(self.stream, Iterable): # pragma: nocover + raise RuntimeError( + "Attempted to close an asynchronous response using 'response.close()'. " + "You should use 'await response.aclose()' instead." + ) + if hasattr(self.stream, "close"): + self.stream.close() # type: ignore + + # Async interface... + + async def aread(self) -> bytes: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to read an synchronous response using " + "'await response.aread()'. " + "You should use 'response.read()' instead." + ) + if not hasattr(self, "_content"): + self._content = b"".join([part async for part in self.aiter_stream()]) + return self._content + + async def aiter_stream(self) -> AsyncIterator[bytes]: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to stream an synchronous response using 'async for ... in " + "response.aiter_stream()'. " + "You should use 'for ... in response.iter_stream()' instead." + ) + if self._stream_consumed: + raise RuntimeError( + "Attempted to call 'async for ... in response.aiter_stream()' " + "more than once." + ) + self._stream_consumed = True + async for chunk in self.stream: + yield chunk + + async def aclose(self) -> None: + if not isinstance(self.stream, AsyncIterable): # pragma: nocover + raise RuntimeError( + "Attempted to close a synchronous response using " + "'await response.aclose()'. " + "You should use 'response.close()' instead." + ) + if hasattr(self.stream, "aclose"): + await self.stream.aclose() # type: ignore diff --git a/httpcore/_ssl.py b/httpcore/_ssl.py new file mode 100644 index 00000000..c99c5a67 --- /dev/null +++ b/httpcore/_ssl.py @@ -0,0 +1,9 @@ +import ssl + +import certifi + + +def default_ssl_context() -> ssl.SSLContext: + context = ssl.create_default_context() + context.load_verify_locations(certifi.where()) + return context diff --git a/httpcore/_sync/__init__.py b/httpcore/_sync/__init__.py index e69de29b..6469b1a9 100644 --- a/httpcore/_sync/__init__.py +++ b/httpcore/_sync/__init__.py @@ -0,0 +1,20 @@ +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .http_proxy import HTTPProxy +from .interfaces import ConnectionInterface + +try: + from .http2 import HTTP2Connection +except ImportError: # pragma: nocover + pass + + +__all__ = [ + "HTTPConnection", + "ConnectionPool", + "HTTPProxy", + "HTTP11Connection", + "HTTP2Connection", + "ConnectionInterface", +] diff --git a/httpcore/_sync/base.py b/httpcore/_sync/base.py deleted file mode 100644 index 45ef4abf..00000000 --- a/httpcore/_sync/base.py +++ /dev/null @@ -1,122 +0,0 @@ -import enum -from types import TracebackType -from typing import Iterator, Tuple, Type - -from .._types import URL, Headers, T - - -class NewConnectionRequired(Exception): - pass - - -class ConnectionState(enum.IntEnum): - """ - PENDING READY - | | ^ - v V | - ACTIVE | - | | | - | V | - V IDLE-+ - FULL | - | | - V V - CLOSED - """ - - PENDING = 0 # Connection not yet acquired. - READY = 1 # Re-acquired from pool, about to send a request. - ACTIVE = 2 # Active requests. - FULL = 3 # Active requests, no more stream IDs available. - IDLE = 4 # No active requests. - CLOSED = 5 # Connection closed. - - -class SyncByteStream: - """ - The base interface for request and response bodies. - - Concrete implementations should subclass this class, and implement - the :meth:`__iter__` method, and optionally the :meth:`close` method. - """ - - def __iter__(self) -> Iterator[bytes]: - """ - Yield bytes representing the request or response body. - """ - yield b"" # pragma: nocover - - def close(self) -> None: - """ - Must be called by the client to indicate that the stream has been closed. - """ - pass # pragma: nocover - - def read(self) -> bytes: - try: - return b"".join([part for part in self]) - finally: - self.close() - - -class SyncHTTPTransport: - """ - The base interface for sending HTTP requests. - - Concrete implementations should subclass this class, and implement - the :meth:`handle_request` method, and optionally the :meth:`close` method. - """ - - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - """ - The interface for sending a single HTTP request, and returning a response. - - Parameters - ---------- - method: - The HTTP method, such as ``b'GET'``. - url: - The URL as a 4-tuple of (scheme, host, port, path). - headers: - Any HTTP headers to send with the request. - stream: - The body of the HTTP request. - extensions: - A dictionary of optional extensions. - - Returns - ------- - status_code: - The HTTP status code, such as ``200``. - headers: - Any HTTP headers included on the response. - stream: - The body of the HTTP response. - extensions: - A dictionary of optional extensions. - """ - raise NotImplementedError() # pragma: nocover - - def close(self) -> None: - """ - Close the implementation, which should close any outstanding response streams, - and any keep alive connections. - """ - - def __enter__(self: T) -> T: - return self - - def __exit__( - self, - exc_type: Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - self.close() diff --git a/httpcore/_sync/connection.py b/httpcore/_sync/connection.py index 382a4f9f..bd62885e 100644 --- a/httpcore/_sync/connection.py +++ b/httpcore/_sync/connection.py @@ -1,158 +1,96 @@ -from ssl import SSLContext -from typing import List, Optional, Tuple, cast +import itertools +import ssl +from types import TracebackType +from typing import Iterator, Optional, Type + +from .._exceptions import ConnectError, ConnectionNotAvailable, ConnectTimeout +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import Lock +from .._trace import Trace +from ..backends.sync import SyncBackend +from ..backends.base import NetworkBackend, NetworkStream +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface -from .._backends.sync import SyncBackend, SyncLock, SyncSocketStream, SyncBackend -from .._exceptions import ConnectError, ConnectTimeout -from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import exponential_backoff, get_logger, url_to_origin -from .base import SyncByteStream, SyncHTTPTransport, NewConnectionRequired -from .http import SyncBaseHTTPConnection -from .http11 import SyncHTTP11Connection +RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. -logger = get_logger(__name__) -RETRIES_BACKOFF_FACTOR = 0.5 # 0s, 0.5s, 1s, 2s, 4s, etc. +def exponential_backoff(factor: float) -> Iterator[float]: + yield 0 + for n in itertools.count(2): + yield factor * (2 ** (n - 2)) -class SyncHTTPConnection(SyncHTTPTransport): +class HTTPConnection(ConnectionInterface): def __init__( self, origin: Origin, + ssl_context: ssl.SSLContext = None, + keepalive_expiry: float = None, http1: bool = True, http2: bool = False, - keepalive_expiry: float = None, - uds: str = None, - ssl_context: SSLContext = None, - socket: SyncSocketStream = None, - local_address: str = None, retries: int = 0, - backend: SyncBackend = None, - ): - self.origin = origin - self._http1_enabled = http1 - self._http2_enabled = http2 + local_address: str = None, + uds: str = None, + network_backend: NetworkBackend = None, + ) -> None: + ssl_context = default_ssl_context() if ssl_context is None else ssl_context + alpn_protocols = ["http/1.1", "h2"] if http2 else ["http/1.1"] + ssl_context.set_alpn_protocols(alpn_protocols) + + self._origin = origin + self._ssl_context = ssl_context self._keepalive_expiry = keepalive_expiry - self._uds = uds - self._ssl_context = SSLContext() if ssl_context is None else ssl_context - self.socket = socket - self._local_address = local_address + self._http1 = http1 + self._http2 = http2 self._retries = retries + self._local_address = local_address + self._uds = uds - alpn_protocols: List[str] = [] - if http1: - alpn_protocols.append("http/1.1") - if http2: - alpn_protocols.append("h2") - - self._ssl_context.set_alpn_protocols(alpn_protocols) - - self.connection: Optional[SyncBaseHTTPConnection] = None - self._is_http11 = False - self._is_http2 = False - self._connect_failed = False - self._expires_at: Optional[float] = None - self._backend = SyncBackend() if backend is None else backend - - def __repr__(self) -> str: - return f"" - - def info(self) -> str: - if self.connection is None: - return "Connection failed" if self._connect_failed else "Connecting" - return self.connection.info() - - def should_close(self) -> bool: - """ - Return `True` if the connection is in a state where it should be closed. - This occurs when any of the following occur: - - * There are no active requests on an HTTP/1.1 connection, and the underlying - socket is readable. The only valid state the socket can be readable in - if this occurs is when the b"" EOF marker is about to be returned, - indicating a server disconnect. - * There are no active requests being made and the keepalive timeout has passed. - """ - if self.connection is None: - return False - return self.connection.should_close() - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - if self.connection is None: - return False - return self.connection.is_idle() + self._network_backend: NetworkBackend = ( + SyncBackend() if network_backend is None else network_backend + ) + self._connection: Optional[ConnectionInterface] = None + self._request_lock = Lock() - def is_closed(self) -> bool: - if self.connection is None: - return self._connect_failed - return self.connection.is_closed() + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection to {self._origin}" + ) - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - This occurs when any of the following occur: - - * The connection has not yet been opened, and HTTP/2 support is enabled. - We don't *know* at this point if we'll end up on an HTTP/2 connection or - not, but we *might* do, so we indicate availability. - * The connection has been opened, and is currently idle. - * The connection is open, and is an HTTP/2 connection. The connection must - also not currently be exceeding the maximum number of allowable concurrent - streams and must not have exhausted the maximum total number of stream IDs. - """ - if self.connection is None: - return self._http2_enabled and not self.is_closed - return self.connection.is_available() - - @property - def request_lock(self) -> SyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_request_lock"): - self._request_lock = self._backend.create_lock() - return self._request_lock - - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - assert url_to_origin(url) == self.origin - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - with self.request_lock: - if self.connection is None: - if self._connect_failed: - raise NewConnectionRequired() - if not self.socket: - logger.trace( - "open_socket origin=%r timeout=%r", self.origin, timeout + with self._request_lock: + if self._connection is None: + stream = self._connect(request) + + ssl_object = stream.get_extra_info("ssl_object") + http2_negotiated = ( + ssl_object is not None + and ssl_object.selected_alpn_protocol() == "h2" + ) + if http2_negotiated or (self._http2 and not self._http1): + from .http2 import HTTP2Connection + + self._connection = HTTP2Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, ) - self.socket = self._open_socket(timeout) - self._create_connection(self.socket) - elif not self.connection.is_available(): - raise NewConnectionRequired() - - assert self.connection is not None - logger.trace( - "connection.handle_request method=%r url=%r headers=%r", - method, - url, - headers, - ) - return self.connection.handle_request( - method, url, headers, stream, extensions - ) + else: + self._connection = HTTP11Connection( + origin=self._origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + elif not self._connection.is_available(): + raise ConnectionNotAvailable() + + return self._connection.handle_request(request) - def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: - scheme, hostname, port = self.origin - timeout = {} if timeout is None else timeout - ssl_context = self._ssl_context if scheme == b"https" else None + def _connect(self, request: Request) -> NetworkStream: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) retries_left = self._retries delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR) @@ -160,61 +98,98 @@ def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream: while True: try: if self._uds is None: - return self._backend.open_tcp_stream( - hostname, - port, - ssl_context, - timeout, - local_address=self._local_address, - ) + kwargs = { + "host": self._origin.host.decode("ascii"), + "port": self._origin.port, + "local_address": self._local_address, + "timeout": timeout, + } + with Trace( + "connection.connect_tcp", request, kwargs + ) as trace: + stream = self._network_backend.connect_tcp(**kwargs) + trace.return_value = stream else: - return self._backend.open_uds_stream( - self._uds, hostname, ssl_context, timeout - ) + kwargs = { + "path": self._uds, + "timeout": timeout, + } + with Trace( + "connection.connect_unix_socket", request, kwargs + ) as trace: + stream = self._network_backend.connect_unix_socket( + **kwargs + ) + trace.return_value = stream except (ConnectError, ConnectTimeout): if retries_left <= 0: - self._connect_failed = True raise retries_left -= 1 delay = next(delays) - self._backend.sleep(delay) - except Exception: # noqa: PIE786 - self._connect_failed = True - raise - - def _create_connection(self, socket: SyncSocketStream) -> None: - http_version = socket.get_http_version() - logger.trace( - "create_connection socket=%r http_version=%r", socket, http_version - ) - if http_version == "HTTP/2" or ( - self._http2_enabled and not self._http1_enabled - ): - from .http2 import SyncHTTP2Connection - - self._is_http2 = True - self.connection = SyncHTTP2Connection( - socket=socket, - keepalive_expiry=self._keepalive_expiry, - backend=self._backend, - ) - else: - self._is_http11 = True - self.connection = SyncHTTP11Connection( - socket=socket, keepalive_expiry=self._keepalive_expiry - ) - - def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> None: - if self.connection is not None: - logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout) - self.socket = self.connection.start_tls( - hostname, ssl_context, timeout - ) - logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout) + # TRACE 'retry' + self._network_backend.sleep(delay) + else: + break + + if self._origin.scheme == b"https": + kwargs = { + "ssl_context": self._ssl_context, + "server_hostname": self._origin.host.decode("ascii"), + "timeout": timeout, + } + with Trace("connection.start_tls", request, kwargs) as trace: + stream = stream.start_tls(**kwargs) + trace.return_value = stream + return stream + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin def close(self) -> None: - with self.request_lock: - if self.connection is not None: - self.connection.close() + if self._connection is not None: + self._connection.close() + + def is_available(self) -> bool: + if self._connection is None: + # If HTTP/2 support is enabled, and the resulting connection could + # end up as HTTP/2 then we should indicate the connection as being + # available to service multiple requests. + return self._http2 and (self._origin.scheme == b"https" or not self._http1) + return self._connection.is_available() + + def has_expired(self) -> bool: + if self._connection is None: + return False + return self._connection.has_expired() + + def is_idle(self) -> bool: + if self._connection is None: + return False + return self._connection.is_idle() + + def is_closed(self) -> bool: + if self._connection is None: + return False + return self._connection.is_closed() + + def info(self) -> str: + if self._connection is None: + return "CONNECTING" + return self._connection.info() + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTPConnection": + return self + + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.close() diff --git a/httpcore/_sync/connection_pool.py b/httpcore/_sync/connection_pool.py index 0bd759db..11280ac1 100644 --- a/httpcore/_sync/connection_pool.py +++ b/httpcore/_sync/connection_pool.py @@ -1,362 +1,335 @@ -import warnings -from ssl import SSLContext -from typing import ( - Iterator, - Callable, - Dict, - List, - Optional, - Set, - Tuple, - Union, - cast, -) - -from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore -from .._backends.base import lookup_sync_backend -from .._exceptions import LocalProtocolError, PoolTimeout, UnsupportedProtocol -from .._threadlock import ThreadLock -from .._types import URL, Headers, Origin, TimeoutDict -from .._utils import get_logger, origin_to_url_string, url_to_origin -from .base import SyncByteStream, SyncHTTPTransport, NewConnectionRequired -from .connection import SyncHTTPConnection - -logger = get_logger(__name__) - - -class NullSemaphore(SyncSemaphore): - def __init__(self) -> None: - pass - - def acquire(self, timeout: float = None) -> None: - return - - def release(self) -> None: - return - - -class ResponseByteStream(SyncByteStream): - def __init__( - self, - stream: SyncByteStream, - connection: SyncHTTPConnection, - callback: Callable, - ) -> None: - """ - A wrapper around the response stream that we return from - `.handle_request()`. - - Ensures that when `stream.close()` is called, the connection pool - is notified via a callback. - """ - self.stream = stream +import ssl +from types import TracebackType +from typing import Iterable, Iterator, List, Optional, Type + +from .._exceptions import ConnectionNotAvailable, UnsupportedProtocol +from .._models import Origin, Request, Response +from .._ssl import default_ssl_context +from .._synchronization import Event, Lock +from ..backends.sync import SyncBackend +from ..backends.base import NetworkBackend +from .connection import HTTPConnection +from .interfaces import ConnectionInterface, RequestInterface + + +class RequestStatus: + def __init__(self, request: Request): + self.request = request + self.connection: Optional[ConnectionInterface] = None + self._connection_acquired = Event() + + def set_connection(self, connection: ConnectionInterface) -> None: + assert self.connection is None self.connection = connection - self.callback = callback + self._connection_acquired.set() - def __iter__(self) -> Iterator[bytes]: - for chunk in self.stream: - yield chunk + def unset_connection(self) -> None: + assert self.connection is not None + self.connection = None + self._connection_acquired = Event() - def close(self) -> None: - try: - # Call the underlying stream close callback. - # This will be a call to `SyncHTTP11Connection._response_closed()` - # or `SyncHTTP2Stream._response_closed()`. - self.stream.close() - finally: - # Call the connection pool close callback. - # This will be a call to `SyncConnectionPool._response_closed()`. - self.callback(self.connection) + def wait_for_connection( + self, timeout: float = None + ) -> ConnectionInterface: + self._connection_acquired.wait(timeout=timeout) + assert self.connection is not None + return self.connection -class SyncConnectionPool(SyncHTTPTransport): +class ConnectionPool(RequestInterface): """ A connection pool for making HTTP requests. - - Parameters - ---------- - ssl_context: - An SSL context to use for verifying connections. - max_connections: - The maximum number of concurrent connections to allow. - max_keepalive_connections: - The maximum number of connections to allow before closing keep-alive - connections. - keepalive_expiry: - The maximum time to allow before closing a keep-alive connection. - http1: - Enable/Disable HTTP/1.1 support. Defaults to True. - http2: - Enable/Disable HTTP/2 support. Defaults to False. - uds: - Path to a Unix Domain Socket to use instead of TCP sockets. - local_address: - Local address to connect from. Can also be used to connect using a particular - address family. Using ``local_address="0.0.0.0"`` will connect using an - ``AF_INET`` address (IPv4), while using ``local_address="::"`` will connect - using an ``AF_INET6`` address (IPv6). - retries: - The maximum number of retries when trying to establish a connection. - backend: - A name indicating which concurrency backend to use. """ def __init__( self, - ssl_context: SSLContext = None, - max_connections: int = None, + ssl_context: ssl.SSLContext = None, + max_connections: int = 10, max_keepalive_connections: int = None, keepalive_expiry: float = None, http1: bool = True, http2: bool = False, - uds: str = None, - local_address: str = None, retries: int = 0, - max_keepalive: int = None, - backend: Union[SyncBackend, str] = "sync", - ): - if max_keepalive is not None: - warnings.warn( - "'max_keepalive' is deprecated. Use 'max_keepalive_connections'.", - DeprecationWarning, - ) - max_keepalive_connections = max_keepalive + local_address: str = None, + uds: str = None, + network_backend: NetworkBackend = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + http1: A boolean indicating if HTTP/1.1 requests should be supported + by the connection pool. Defaults to True. + http2: A boolean indicating if HTTP/2 requests should be supported by + the connection pool. Defaults to False. + retries: The maximum number of retries when trying to establish a + connection. + local_address: Local address to connect from. Can also be used to connect + using a particular address family. Using `local_address="0.0.0.0"` + will connect using an `AF_INET` address (IPv4), while using + `local_address="::"` will connect using an `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + if max_keepalive_connections is None: + max_keepalive_connections = max_connections - if isinstance(backend, str): - backend = lookup_sync_backend(backend) + if ssl_context is None: + ssl_context = default_ssl_context() + + self._ssl_context = ssl_context - self._ssl_context = SSLContext() if ssl_context is None else ssl_context self._max_connections = max_connections - self._max_keepalive_connections = max_keepalive_connections + self._max_keepalive_connections = min( + max_keepalive_connections, max_connections + ) + self._keepalive_expiry = keepalive_expiry self._http1 = http1 self._http2 = http2 - self._uds = uds - self._local_address = local_address self._retries = retries - self._connections: Dict[Origin, Set[SyncHTTPConnection]] = {} - self._thread_lock = ThreadLock() - self._backend = backend - self._next_keepalive_check = 0.0 - - if not (http1 or http2): - raise ValueError("Either http1 or http2 must be True.") - - if http2: - try: - import h2 # noqa: F401 - except ImportError: - raise ImportError( - "Attempted to use http2=True, but the 'h2' " - "package is not installed. Use 'pip install httpcore[http2]'." - ) - - @property - def _connection_semaphore(self) -> SyncSemaphore: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_internal_semaphore"): - if self._max_connections is not None: - self._internal_semaphore = self._backend.create_semaphore( - self._max_connections, exc_class=PoolTimeout - ) - else: - self._internal_semaphore = NullSemaphore() - - return self._internal_semaphore + self._local_address = local_address + self._uds = uds - @property - def _connection_acquiry_lock(self) -> SyncLock: - if not hasattr(self, "_internal_connection_acquiry_lock"): - self._internal_connection_acquiry_lock = self._backend.create_lock() - return self._internal_connection_acquiry_lock + self._pool: List[ConnectionInterface] = [] + self._requests: List[RequestStatus] = [] + self._pool_lock = Lock() + self._network_backend = ( + SyncBackend() if network_backend is None else network_backend + ) - def _create_connection( - self, - origin: Tuple[bytes, bytes, int], - ) -> SyncHTTPConnection: - return SyncHTTPConnection( + def create_connection(self, origin: Origin) -> ConnectionInterface: + return HTTPConnection( origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, http1=self._http1, http2=self._http2, - keepalive_expiry=self._keepalive_expiry, - uds=self._uds, - ssl_context=self._ssl_context, - local_address=self._local_address, retries=self._retries, - backend=self._backend, + local_address=self._local_address, + uds=self._uds, + network_backend=self._network_backend, ) - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - if not url[0]: + @property + def connections(self) -> List[ConnectionInterface]: + """ + Return a list of the connections currently in the pool. + + For example: + + ```python + >>> pool.connections + [ + , + , + , + ] + ``` + """ + return list(self._pool) + + def _attempt_to_acquire_connection(self, status: RequestStatus) -> bool: + """ + Attempt to provide a connection that can handle the given origin. + """ + origin = status.request.url.origin + + # If there are queued requests in front of us, then don't acquire a + # connection. We handle requests strictly in order. + waiting = [s for s in self._requests if s.connection is None] + if waiting and waiting[0] is not status: + return False + + # Reuse an existing connection if one is currently available. + for idx, connection in enumerate(self._pool): + if connection.can_handle_request(origin) and connection.is_available(): + self._pool.pop(idx) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + # If the pool is currently full, attempt to close one idle connection. + if len(self._pool) >= self._max_connections: + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle(): + connection.close() + self._pool.pop(idx) + break + + # If the pool is still full, then we cannot acquire a connection. + if len(self._pool) >= self._max_connections: + return False + + # Otherwise create a new connection. + connection = self.create_connection(origin) + self._pool.insert(0, connection) + status.set_connection(connection) + return True + + def _close_expired_connections(self) -> None: + """ + Clean up the connection pool by closing off any connections that have expired. + """ + # Close any connections that have expired their keep-alive time. + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.has_expired(): + connection.close() + self._pool.pop(idx) + + # If the pool size exceeds the maximum number of allowed keep-alive connections, + # then close off idle connections as required. + pool_size = len(self._pool) + for idx, connection in reversed(list(enumerate(self._pool))): + if connection.is_idle() and pool_size > self._max_keepalive_connections: + connection.close() + self._pool.pop(idx) + pool_size -= 1 + + def handle_request(self, request: Request) -> Response: + """ + Send an HTTP request, and return an HTTP response. + + This is the core implementation that is called into by `.request()` or `.stream()`. + """ + scheme = request.url.scheme.decode() + if scheme == "": raise UnsupportedProtocol( - "Request URL missing either an 'http://' or 'https://' protocol." + "Request URL is missing an 'http://' or 'https://' protocol." ) - - if url[0] not in (b"http", b"https"): - protocol = url[0].decode("ascii") + if scheme not in ("http", "https"): raise UnsupportedProtocol( - f"Request URL has an unsupported protocol '{protocol}://'." + "Request URL has an unsupported protocol '{scheme}://'." ) - if not url[1]: - raise LocalProtocolError("Missing hostname in URL.") + status = RequestStatus(request) - origin = url_to_origin(url) - timeout = cast(TimeoutDict, extensions.get("timeout", {})) + with self._pool_lock: + self._requests.append(status) + self._close_expired_connections() + self._attempt_to_acquire_connection(status) - self._keepalive_sweep() + while True: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("pool", None) + connection = status.wait_for_connection(timeout=timeout) + try: + response = connection.handle_request(request) + except ConnectionNotAvailable: + # The ConnectionNotAvailable exception is a special case, that + # indicates we need to retry the request on a new connection. + # + # The most common case where this can occur is when multiple + # requests are queued waiting for a single connection, which + # might end up as an HTTP/2 connection, but which actually ends + # up as HTTP/1.1. + with self._pool_lock: + # Maintain our position in the request queue, but reset the + # status so that the request becomes queued again. + status.unset_connection() + self._attempt_to_acquire_connection(status) + except Exception as exc: + self.response_closed(status) + raise exc + else: + break + + # When we return the response, we wrap the stream in a special class + # that handles notifying the connection pool once the response + # has been released. + assert isinstance(response.stream, Iterable) + return Response( + status=response.status, + headers=response.headers, + content=ConnectionPoolByteStream(response.stream, self, status), + extensions=response.extensions, + ) - connection: Optional[SyncHTTPConnection] = None - while connection is None: - with self._connection_acquiry_lock: - # We get-or-create a connection as an atomic operation, to ensure - # that HTTP/2 requests issued in close concurrency will end up - # on the same connection. - logger.trace("get_connection_from_pool=%r", origin) - connection = self._get_connection_from_pool(origin) + def response_closed(self, status: RequestStatus) -> None: + """ + This method acts as a callback once the request/response cycle is complete. - if connection is None: - connection = self._create_connection(origin=origin) - logger.trace("created connection=%r", connection) - self._add_to_pool(connection, timeout=timeout) - else: - logger.trace("reuse connection=%r", connection) + It is called into from the `ConnectionPoolByteStream.close()` method. + """ + assert status.connection is not None + connection = status.connection + + with self._pool_lock: + # Update the state of the connection pool. + self._requests.remove(status) + + if connection.is_closed(): + self._pool.remove(connection) + + # Since we've had a response closed, it's possible we'll now be able + # to service one or more requests that are currently pending. + for status in self._requests: + if status.connection is None: + acquired = self._attempt_to_acquire_connection(status) + # If we could not acquire a connection for a queued request + # then we don't need to check anymore requests that are + # queued later behind it. + if not acquired: + break + + # Housekeeping. + self._close_expired_connections() - try: - response = connection.handle_request( - method, url, headers=headers, stream=stream, extensions=extensions - ) - except NewConnectionRequired: - connection = None - except BaseException: # noqa: PIE786 - # See https://github.com/encode/httpcore/pull/305 for motivation - # behind catching 'BaseException' rather than 'Exception' here. - logger.trace("remove from pool connection=%r", connection) - self._remove_from_pool(connection) - raise - - status_code, headers, stream, extensions = response - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) - return status_code, headers, wrapped_stream, extensions - - def _get_connection_from_pool( - self, origin: Origin - ) -> Optional[SyncHTTPConnection]: - # Determine expired keep alive connections on this origin. - reuse_connection = None - connections_to_close = set() - - for connection in self._connections_for_origin(origin): - if connection.should_close(): - connections_to_close.add(connection) - self._remove_from_pool(connection) - elif connection.is_available(): - reuse_connection = connection - - # Close any dropped connections. - for connection in connections_to_close: - connection.close() - - return reuse_connection - - def _response_closed(self, connection: SyncHTTPConnection) -> None: - remove_from_pool = False - close_connection = False - - if connection.is_closed(): - remove_from_pool = True - elif connection.is_idle(): - num_connections = len(self._get_all_connections()) - if ( - self._max_keepalive_connections is not None - and num_connections > self._max_keepalive_connections - ): - remove_from_pool = True - close_connection = True - - if remove_from_pool: - self._remove_from_pool(connection) - - if close_connection: - connection.close() - - def _keepalive_sweep(self) -> None: + def close(self) -> None: """ - Remove any IDLE connections that have expired past their keep-alive time. + Close any connections in the pool. """ - if self._keepalive_expiry is None: - return + with self._pool_lock: + for connection in self._pool: + connection.close() + self._pool = [] + self._requests = [] - now = self._backend.time() - if now < self._next_keepalive_check: - return + def __enter__(self) -> "ConnectionPool": + return self - self._next_keepalive_check = now + min(1.0, self._keepalive_expiry) - connections_to_close = set() + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.close() - for connection in self._get_all_connections(): - if connection.should_close(): - connections_to_close.add(connection) - self._remove_from_pool(connection) - for connection in connections_to_close: - connection.close() +class ConnectionPoolByteStream: + """ + A wrapper around the response byte stream, that additionally handles + notifying the connection pool when the response has been closed. + """ - def _add_to_pool( - self, connection: SyncHTTPConnection, timeout: TimeoutDict + def __init__( + self, + stream: Iterable[bytes], + pool: ConnectionPool, + status: RequestStatus, ) -> None: - logger.trace("adding connection to pool=%r", connection) - self._connection_semaphore.acquire(timeout=timeout.get("pool", None)) - with self._thread_lock: - self._connections.setdefault(connection.origin, set()) - self._connections[connection.origin].add(connection) - - def _remove_from_pool(self, connection: SyncHTTPConnection) -> None: - logger.trace("removing connection from pool=%r", connection) - with self._thread_lock: - if connection in self._connections.get(connection.origin, set()): - self._connection_semaphore.release() - self._connections[connection.origin].remove(connection) - if not self._connections[connection.origin]: - del self._connections[connection.origin] - - def _connections_for_origin(self, origin: Origin) -> Set[SyncHTTPConnection]: - return set(self._connections.get(origin, set())) - - def _get_all_connections(self) -> Set[SyncHTTPConnection]: - connections: Set[SyncHTTPConnection] = set() - for connection_set in self._connections.values(): - connections |= connection_set - return connections + self._stream = stream + self._pool = pool + self._status = status - def close(self) -> None: - connections = self._get_all_connections() - for connection in connections: - self._remove_from_pool(connection) - - # Close all connections - for connection in connections: - connection.close() - - def get_connection_info(self) -> Dict[str, List[str]]: - """ - Returns a dict of origin URLs to a list of summary strings for each connection. - """ - self._keepalive_sweep() + def __iter__(self) -> Iterator[bytes]: + for part in self._stream: + yield part - stats = {} - for origin, connections in self._connections.items(): - stats[origin_to_url_string(origin)] = sorted( - [connection.info() for connection in connections] - ) - return stats + def close(self) -> None: + try: + if hasattr(self._stream, "close"): + self._stream.close() # type: ignore + finally: + self._pool.response_closed(self._status) diff --git a/httpcore/_sync/http.py b/httpcore/_sync/http.py deleted file mode 100644 index c128a96b..00000000 --- a/httpcore/_sync/http.py +++ /dev/null @@ -1,42 +0,0 @@ -from ssl import SSLContext - -from .._backends.sync import SyncSocketStream -from .._types import TimeoutDict -from .base import SyncHTTPTransport - - -class SyncBaseHTTPConnection(SyncHTTPTransport): - def info(self) -> str: - raise NotImplementedError() # pragma: nocover - - def should_close(self) -> bool: - """ - Return `True` if the connection is in a state where it should be closed. - """ - raise NotImplementedError() # pragma: nocover - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - raise NotImplementedError() # pragma: nocover - - def is_closed(self) -> bool: - """ - Return `True` if the connection has been closed. - """ - raise NotImplementedError() # pragma: nocover - - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - """ - raise NotImplementedError() # pragma: nocover - - def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> SyncSocketStream: - """ - Upgrade the underlying socket to TLS. - """ - raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_sync/http11.py b/httpcore/_sync/http11.py index 5dbb42e0..1b6498a9 100644 --- a/httpcore/_sync/http11.py +++ b/httpcore/_sync/http11.py @@ -1,17 +1,21 @@ import enum import time -from ssl import SSLContext -from typing import Iterator, List, Optional, Tuple, Union, cast +from types import TracebackType +from typing import Iterable, Iterator, List, Optional, Tuple, Type, Union import h11 -from .._backends.sync import SyncSocketStream -from .._bytestreams import IteratorByteStream -from .._exceptions import LocalProtocolError, RemoteProtocolError, map_exceptions -from .._types import URL, Headers, TimeoutDict -from .._utils import get_logger -from .base import SyncByteStream, NewConnectionRequired -from .http import SyncBaseHTTPConnection +from .._exceptions import ( + ConnectionNotAvailable, + LocalProtocolError, + RemoteProtocolError, + map_exceptions, +) +from .._models import Origin, Request, Response +from .._synchronization import Lock +from .._trace import Trace +from ..backends.base import NetworkStream +from .interfaces import ConnectionInterface H11Event = Union[ h11.Request, @@ -23,170 +27,120 @@ ] -class ConnectionState(enum.IntEnum): +class HTTPConnectionState(enum.IntEnum): NEW = 0 ACTIVE = 1 IDLE = 2 CLOSED = 3 -logger = get_logger(__name__) - - -class SyncHTTP11Connection(SyncBaseHTTPConnection): +class HTTP11Connection(ConnectionInterface): READ_NUM_BYTES = 64 * 1024 - def __init__(self, socket: SyncSocketStream, keepalive_expiry: float = None): - self.socket = socket - + def __init__( + self, origin: Origin, stream: NetworkStream, keepalive_expiry: float = None + ) -> None: + self._origin = origin + self._network_stream = stream self._keepalive_expiry: Optional[float] = keepalive_expiry - self._should_expire_at: Optional[float] = None + self._expire_at: Optional[float] = None + self._state = HTTPConnectionState.NEW + self._state_lock = Lock() + self._request_count = 0 self._h11_state = h11.Connection(our_role=h11.CLIENT) - self._state = ConnectionState.NEW - - def __repr__(self) -> str: - return f"" - - def _now(self) -> float: - return time.monotonic() - - def _server_disconnected(self) -> bool: - """ - Return True if the connection is idle, and the underlying socket is readable. - The only valid state the socket can be readable here is when the b"" - EOF marker is about to be returned, indicating a server disconnect. - """ - return self._state == ConnectionState.IDLE and self.socket.is_readable() - - def _keepalive_expired(self) -> bool: - """ - Return True if the connection is idle, and has passed it's keepalive - expiry time. - """ - return ( - self._state == ConnectionState.IDLE - and self._should_expire_at is not None - and self._now() >= self._should_expire_at - ) - - def info(self) -> str: - return f"HTTP/1.1, {self._state.name}" - - def should_close(self) -> bool: - """ - Return `True` if the connection is in a state where it should be closed. - """ - return self._server_disconnected() or self._keepalive_expired() - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - return self._state == ConnectionState.IDLE - - def is_closed(self) -> bool: - """ - Return `True` if the connection has been closed. - """ - return self._state == ConnectionState.CLOSED - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - """ - return self._state == ConnectionState.IDLE + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise RuntimeError( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" + ) + + with self._state_lock: + if self._state in (HTTPConnectionState.NEW, HTTPConnectionState.IDLE): + self._request_count += 1 + self._state = HTTPConnectionState.ACTIVE + self._expire_at = None + else: + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request} + with Trace("http11.send_request_headers", request, kwargs) as trace: + self._send_request_headers(**kwargs) + with Trace("http11.send_request_body", request, kwargs) as trace: + self._send_request_body(**kwargs) + with Trace( + "http11.receive_response_headers", request, kwargs + ) as trace: + ( + http_version, + status, + reason_phrase, + headers, + ) = self._receive_response_headers(**kwargs) + trace.return_value = ( + http_version, + status, + reason_phrase, + headers, + ) + + return Response( + status=status, + headers=headers, + content=HTTP11ConnectionByteStream(self, request), + extensions={ + "http_version": http_version, + "reason_phrase": reason_phrase, + "network_stream": self._network_stream, + }, + ) + except BaseException as exc: + with Trace("http11.response_closed", request) as trace: + self._response_closed() + raise exc + + # Sending the request... + + def _send_request_headers(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - """ - Send a single HTTP/1.1 request. - - Note that there is no kind of task/thread locking at this layer of interface. - Dealing with locking for concurrency is handled by the `SyncHTTPConnection`. - """ - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - if self._state in (ConnectionState.NEW, ConnectionState.IDLE): - self._state = ConnectionState.ACTIVE - self._should_expire_at = None - else: - raise NewConnectionRequired() - - self._send_request(method, url, headers, timeout) - self._send_request_body(stream, timeout) - ( - http_version, - status_code, - reason_phrase, - headers, - ) = self._receive_response(timeout) - response_stream = IteratorByteStream( - iterator=self._receive_response_data(timeout), - close_func=self._response_closed, - ) - extensions = { - "http_version": http_version, - "reason_phrase": reason_phrase, - } - return (status_code, headers, response_stream, extensions) - - def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> SyncSocketStream: - timeout = {} if timeout is None else timeout - self.socket = self.socket.start_tls(hostname, ssl_context, timeout) - return self.socket - - def _send_request( - self, method: bytes, url: URL, headers: Headers, timeout: TimeoutDict - ) -> None: - """ - Send the request line and headers. - """ - logger.trace("send_request method=%r url=%r headers=%s", method, url, headers) - _scheme, _host, _port, target = url with map_exceptions({h11.LocalProtocolError: LocalProtocolError}): - event = h11.Request(method=method, target=target, headers=headers) - self._send_event(event, timeout) - - def _send_request_body( - self, stream: SyncByteStream, timeout: TimeoutDict - ) -> None: - """ - Send the request body. - """ - # Send the request body. - for chunk in stream: - logger.trace("send_data=Data(<%d bytes>)", len(chunk)) + event = h11.Request( + method=request.method, + target=request.url.target, + headers=request.headers, + ) + self._send_event(event, timeout=timeout) + + def _send_request_body(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + assert isinstance(request.stream, Iterable) + for chunk in request.stream: event = h11.Data(data=chunk) - self._send_event(event, timeout) + self._send_event(event, timeout=timeout) - # Finalize sending the request. event = h11.EndOfMessage() - self._send_event(event, timeout) + self._send_event(event, timeout=timeout) - def _send_event(self, event: H11Event, timeout: TimeoutDict) -> None: - """ - Send a single `h11` event to the network, waiting for the data to - drain before returning. - """ + def _send_event(self, event: H11Event, timeout: float = None) -> None: bytes_to_send = self._h11_state.send(event) - self.socket.write(bytes_to_send, timeout) + self._network_stream.write(bytes_to_send, timeout=timeout) + + # Receiving the response... - def _receive_response( - self, timeout: TimeoutDict + def _receive_response_headers( + self, request: Request ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]]]: - """ - Read the response status and headers from the network. - """ + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + while True: - event = self._receive_event(timeout) + event = self._receive_event(timeout=timeout) if isinstance(event, h11.Response): break @@ -198,72 +152,127 @@ def _receive_response( return http_version, event.status_code, event.reason, headers - def _receive_response_data( - self, timeout: TimeoutDict - ) -> Iterator[bytes]: - """ - Read the response data from the network. - """ + def _receive_response_body(self, request: Request) -> Iterator[bytes]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + while True: - event = self._receive_event(timeout) + event = self._receive_event(timeout=timeout) if isinstance(event, h11.Data): - logger.trace("receive_event=Data(<%d bytes>)", len(event.data)) yield bytes(event.data) elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): - logger.trace("receive_event=%r", event) break - def _receive_event(self, timeout: TimeoutDict) -> H11Event: - """ - Read a single `h11` event, reading more data from the network if needed. - """ + def _receive_event(self, timeout: float = None) -> H11Event: while True: with map_exceptions({h11.RemoteProtocolError: RemoteProtocolError}): event = self._h11_state.next_event() if event is h11.NEED_DATA: - data = self.socket.read(self.READ_NUM_BYTES, timeout) - - # If we feed this case through h11 we'll raise an exception like: - # - # httpcore.RemoteProtocolError: can't handle event type - # ConnectionClosed when role=SERVER and state=SEND_RESPONSE - # - # Which is accurate, but not very informative from an end-user - # perspective. Instead we handle messaging for this case distinctly. - if data == b"" and self._h11_state.their_state == h11.SEND_RESPONSE: - msg = "Server disconnected without sending a response." - raise RemoteProtocolError(msg) - + data = self._network_stream.read( + self.READ_NUM_BYTES, timeout=timeout + ) self._h11_state.receive_data(data) else: - assert event is not h11.NEED_DATA - break - return event + return event def _response_closed(self) -> None: - logger.trace( - "response_closed our_state=%r their_state=%r", - self._h11_state.our_state, - self._h11_state.their_state, - ) - if ( - self._h11_state.our_state is h11.DONE - and self._h11_state.their_state is h11.DONE - ): - self._h11_state.start_next_cycle() - self._state = ConnectionState.IDLE - if self._keepalive_expiry is not None: - self._should_expire_at = self._now() + self._keepalive_expiry - else: - self.close() + with self._state_lock: + if ( + self._h11_state.our_state is h11.DONE + and self._h11_state.their_state is h11.DONE + ): + self._state = HTTPConnectionState.IDLE + self._h11_state.start_next_cycle() + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + else: + self.close() + + # Once the connection is no longer required... def close(self) -> None: - if self._state != ConnectionState.CLOSED: - self._state = ConnectionState.CLOSED + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # The ConnectionInterface methods provide information about the state of + # the connection, allowing for a connection pooling implementation to + # determine when to reuse and when to close the connection... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + # Note that HTTP/1.1 connections in the "NEW" state are not treated as + # being "available". The control flow which created the connection will + # be able to send an outgoing request, but the connection will not be + # acquired from the connection pool for any other request. + return self._state == HTTPConnectionState.IDLE + + def has_expired(self) -> bool: + now = time.monotonic() + keepalive_expired = self._expire_at is not None and now > self._expire_at + + # If the HTTP connection is idle but the socket is readable, then the + # only valid state is that the socket is about to return b"", indicating + # a server-initiated disconnect. + server_disconnected = ( + self._state == HTTPConnectionState.IDLE + and self._network_stream.get_extra_info("is_readable") + ) + + return keepalive_expired or server_disconnected + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/1.1, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. - if self._h11_state.our_state is h11.MUST_CLOSE: - event = h11.ConnectionClosed() - self._h11_state.send(event) + def __enter__(self) -> "HTTP11Connection": + return self - self.socket.close() + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self.close() + + +class HTTP11ConnectionByteStream: + def __init__(self, connection: HTTP11Connection, request: Request) -> None: + self._connection = connection + self._request = request + + def __iter__(self) -> Iterator[bytes]: + kwargs = {"request": self._request} + with Trace("http11.receive_response_body", self._request, kwargs): + for chunk in self._connection._receive_response_body(**kwargs): + yield chunk + + def close(self) -> None: + with Trace("http11.response_closed", self._request): + self._connection._response_closed() diff --git a/httpcore/_sync/http2.py b/httpcore/_sync/http2.py index 90caf5fa..ff66bb35 100644 --- a/httpcore/_sync/http2.py +++ b/httpcore/_sync/http2.py @@ -1,175 +1,119 @@ import enum import time -from ssl import SSLContext -from typing import Iterator, Dict, List, Optional, Tuple, cast +import types +import typing +import h2.config import h2.connection import h2.events -from h2.config import H2Configuration -from h2.exceptions import NoAvailableStreamIDError -from h2.settings import SettingCodes, Settings - -from .._backends.sync import SyncBackend, SyncLock, SyncSemaphore, SyncSocketStream -from .._bytestreams import IteratorByteStream -from .._exceptions import LocalProtocolError, PoolTimeout, RemoteProtocolError -from .._types import URL, Headers, TimeoutDict -from .._utils import get_logger -from .base import SyncByteStream, NewConnectionRequired -from .http import SyncBaseHTTPConnection - -logger = get_logger(__name__) +import h2.exceptions +import h2.settings + +from .._exceptions import ConnectionNotAvailable, RemoteProtocolError +from .._models import Origin, Request, Response +from .._synchronization import Lock, Semaphore +from .._trace import Trace +from ..backends.base import NetworkStream +from .interfaces import ConnectionInterface + + +def has_body_headers(request: Request) -> bool: + return any( + [ + k.lower() == b"content-length" or k.lower() == b"transfer-encoding" + for k, v in request.headers + ] + ) -class ConnectionState(enum.IntEnum): - IDLE = 0 +class HTTPConnectionState(enum.IntEnum): ACTIVE = 1 - CLOSED = 2 + IDLE = 2 + CLOSED = 3 -class SyncHTTP2Connection(SyncBaseHTTPConnection): +class HTTP2Connection(ConnectionInterface): READ_NUM_BYTES = 64 * 1024 - CONFIG = H2Configuration(validate_inbound_headers=False) + CONFIG = h2.config.H2Configuration(validate_inbound_headers=False) def __init__( - self, - socket: SyncSocketStream, - backend: SyncBackend, - keepalive_expiry: float = None, + self, origin: Origin, stream: NetworkStream, keepalive_expiry: float = None ): - self.socket = socket - - self._backend = backend + self._origin = origin + self._network_stream = stream + self._keepalive_expiry: typing.Optional[float] = keepalive_expiry self._h2_state = h2.connection.H2Connection(config=self.CONFIG) - + self._state = HTTPConnectionState.IDLE + self._expire_at: typing.Optional[float] = None + self._request_count = 0 + self._init_lock = Lock() + self._state_lock = Lock() + self._read_lock = Lock() + self._write_lock = Lock() self._sent_connection_init = False - self._streams: Dict[int, SyncHTTP2Stream] = {} - self._events: Dict[int, List[h2.events.Event]] = {} - - self._keepalive_expiry: Optional[float] = keepalive_expiry - self._should_expire_at: Optional[float] = None - self._state = ConnectionState.ACTIVE - self._exhausted_available_stream_ids = False - - def __repr__(self) -> str: - return f"" - - def info(self) -> str: - return f"HTTP/2, {self._state.name}, {len(self._streams)} streams" - - def _now(self) -> float: - return time.monotonic() - - def should_close(self) -> bool: - """ - Return `True` if the connection is currently idle, and the keepalive - timeout has passed. - """ - return ( - self._state == ConnectionState.IDLE - and self._should_expire_at is not None - and self._now() >= self._should_expire_at - ) - - def is_idle(self) -> bool: - """ - Return `True` if the connection is currently idle. - """ - return self._state == ConnectionState.IDLE - - def is_closed(self) -> bool: - """ - Return `True` if the connection has been closed. - """ - return self._state == ConnectionState.CLOSED - - def is_available(self) -> bool: - """ - Return `True` if the connection is currently able to accept an outgoing request. - This occurs when any of the following occur: - - * The connection has not yet been opened, and HTTP/2 support is enabled. - We don't *know* at this point if we'll end up on an HTTP/2 connection or - not, but we *might* do, so we indicate availability. - * The connection has been opened, and is currently idle. - * The connection is open, and is an HTTP/2 connection. The connection must - also not have exhausted the maximum total number of stream IDs. - """ - return ( - self._state != ConnectionState.CLOSED - and not self._exhausted_available_stream_ids - ) - - @property - def init_lock(self) -> SyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_initialization_lock"): - self._initialization_lock = self._backend.create_lock() - return self._initialization_lock - - @property - def read_lock(self) -> SyncLock: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_read_lock"): - self._read_lock = self._backend.create_lock() - return self._read_lock - - @property - def max_streams_semaphore(self) -> SyncSemaphore: - # We do this lazily, to make sure backend autodetection always - # runs within an async context. - if not hasattr(self, "_max_streams_semaphore"): - max_streams = self._h2_state.local_settings.max_concurrent_streams - self._max_streams_semaphore = self._backend.create_semaphore( - max_streams, exc_class=PoolTimeout + self._used_all_stream_ids = False + self._events: typing.Dict[int, h2.events.Event] = {} + + def handle_request(self, request: Request) -> Response: + if not self.can_handle_request(request.url.origin): + raise ConnectionNotAvailable( + f"Attempted to send request to {request.url.origin} on connection " + f"to {self._origin}" ) - return self._max_streams_semaphore - def start_tls( - self, hostname: bytes, ssl_context: SSLContext, timeout: TimeoutDict = None - ) -> SyncSocketStream: - raise NotImplementedError("TLS upgrade not supported on HTTP/2 connections.") + with self._state_lock: + if self._state in (HTTPConnectionState.ACTIVE, HTTPConnectionState.IDLE): + self._request_count += 1 + self._expire_at = None + self._state = HTTPConnectionState.ACTIVE + else: + raise ConnectionNotAvailable() - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - with self.init_lock: + with self._init_lock: if not self._sent_connection_init: - # The very first stream is responsible for initiating the connection. - self._state = ConnectionState.ACTIVE - self.send_connection_init(timeout) + kwargs = {"request": request} + with Trace("http2.send_connection_init", request, kwargs): + self._send_connection_init(**kwargs) self._sent_connection_init = True + max_streams = self._h2_state.local_settings.max_concurrent_streams + self._max_streams_semaphore = Semaphore(max_streams) - self.max_streams_semaphore.acquire() - try: - try: - stream_id = self._h2_state.get_next_available_stream_id() - except NoAvailableStreamIDError: - self._exhausted_available_stream_ids = True - raise NewConnectionRequired() - else: - self._state = ConnectionState.ACTIVE - self._should_expire_at = None + self._max_streams_semaphore.acquire() - h2_stream = SyncHTTP2Stream(stream_id=stream_id, connection=self) - self._streams[stream_id] = h2_stream + try: + stream_id = self._h2_state.get_next_available_stream_id() self._events[stream_id] = [] - return h2_stream.handle_request( - method, url, headers, stream, extensions + except h2.exceptions.NoAvailableStreamIDError: # pragma: nocover + self._used_all_stream_ids = True + raise ConnectionNotAvailable() + + try: + kwargs = {"request": request, "stream_id": stream_id} + with Trace("http2.send_request_headers", request, kwargs): + self._send_request_headers(request=request, stream_id=stream_id) + with Trace("http2.send_request_body", request, kwargs): + self._send_request_body(request=request, stream_id=stream_id) + with Trace( + "http2.receive_response_headers", request, kwargs + ) as trace: + status, headers = self._receive_response( + request=request, stream_id=stream_id + ) + trace.return_value = (status, headers) + + return Response( + status=status, + headers=headers, + content=HTTP2ConnectionByteStream(self, request, stream_id=stream_id), + extensions={"stream_id": stream_id, "http_version": b"HTTP/2"}, ) except Exception: # noqa: PIE786 - self.max_streams_semaphore.release() + kwargs = {"stream_id": stream_id} + with Trace("http2.response_closed", request, kwargs): + self._response_closed(stream_id=stream_id) raise - def send_connection_init(self, timeout: TimeoutDict) -> None: + def _send_connection_init(self, request: Request) -> None: """ The HTTP/2 connection requires some initial setup before we can start using individual request/response streams on it. @@ -177,15 +121,15 @@ def send_connection_init(self, timeout: TimeoutDict) -> None: # Need to set these manually here instead of manipulating via # __setitem__() otherwise the H2Connection will emit SettingsUpdate # frames in addition to sending the undesired defaults. - self._h2_state.local_settings = Settings( + self._h2_state.local_settings = h2.settings.Settings( client=True, initial_values={ # Disable PUSH_PROMISE frames from the server since we don't do anything # with them for now. Maybe when we support caching? - SettingCodes.ENABLE_PUSH: 0, + h2.settings.SettingCodes.ENABLE_PUSH: 0, # These two are taken from h2 for safe defaults - SettingCodes.MAX_CONCURRENT_STREAMS: 100, - SettingCodes.MAX_HEADER_LIST_SIZE: 65536, + h2.settings.SettingCodes.MAX_CONCURRENT_STREAMS: 100, + h2.settings.SettingCodes.MAX_HEADER_LIST_SIZE: 65536, }, ) @@ -196,227 +140,63 @@ def send_connection_init(self, timeout: TimeoutDict) -> None: h2.settings.SettingCodes.ENABLE_CONNECT_PROTOCOL ] - logger.trace("initiate_connection=%r", self) self._h2_state.initiate_connection() self._h2_state.increment_flow_control_window(2 ** 24) - data_to_send = self._h2_state.data_to_send() - self.socket.write(data_to_send, timeout) - - def is_socket_readable(self) -> bool: - return self.socket.is_readable() - - def close(self) -> None: - logger.trace("close_connection=%r", self) - if self._state != ConnectionState.CLOSED: - self._state = ConnectionState.CLOSED - - self.socket.close() - - def wait_for_outgoing_flow(self, stream_id: int, timeout: TimeoutDict) -> int: - """ - Returns the maximum allowable outgoing flow for a given stream. - If the allowable flow is zero, then waits on the network until - WindowUpdated frames have increased the flow rate. - https://tools.ietf.org/html/rfc7540#section-6.9 - """ - local_flow = self._h2_state.local_flow_control_window(stream_id) - connection_flow = self._h2_state.max_outbound_frame_size - flow = min(local_flow, connection_flow) - while flow == 0: - self.receive_events(timeout) - local_flow = self._h2_state.local_flow_control_window(stream_id) - connection_flow = self._h2_state.max_outbound_frame_size - flow = min(local_flow, connection_flow) - return flow - - def wait_for_event( - self, stream_id: int, timeout: TimeoutDict - ) -> h2.events.Event: - """ - Returns the next event for a given stream. - If no events are available yet, then waits on the network until - an event is available. - """ - with self.read_lock: - while not self._events[stream_id]: - self.receive_events(timeout) - return self._events[stream_id].pop(0) - - def receive_events(self, timeout: TimeoutDict) -> None: - """ - Read some data from the network, and update the H2 state. - """ - data = self.socket.read(self.READ_NUM_BYTES, timeout) - if data == b"": - raise RemoteProtocolError("Server disconnected") - - events = self._h2_state.receive_data(data) - for event in events: - event_stream_id = getattr(event, "stream_id", 0) - logger.trace("receive_event stream_id=%r event=%s", event_stream_id, event) - - if hasattr(event, "error_code"): - raise RemoteProtocolError(event) - - if event_stream_id in self._events: - self._events[event_stream_id].append(event) - - data_to_send = self._h2_state.data_to_send() - self.socket.write(data_to_send, timeout) - - def send_headers( - self, stream_id: int, headers: Headers, end_stream: bool, timeout: TimeoutDict - ) -> None: - logger.trace("send_headers stream_id=%r headers=%r", stream_id, headers) - self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) - self._h2_state.increment_flow_control_window(2 ** 24, stream_id=stream_id) - data_to_send = self._h2_state.data_to_send() - self.socket.write(data_to_send, timeout) - - def send_data( - self, stream_id: int, chunk: bytes, timeout: TimeoutDict - ) -> None: - logger.trace("send_data stream_id=%r chunk=%r", stream_id, chunk) - self._h2_state.send_data(stream_id, chunk) - data_to_send = self._h2_state.data_to_send() - self.socket.write(data_to_send, timeout) - - def end_stream(self, stream_id: int, timeout: TimeoutDict) -> None: - logger.trace("end_stream stream_id=%r", stream_id) - self._h2_state.end_stream(stream_id) - data_to_send = self._h2_state.data_to_send() - self.socket.write(data_to_send, timeout) - - def acknowledge_received_data( - self, stream_id: int, amount: int, timeout: TimeoutDict - ) -> None: - self._h2_state.acknowledge_received_data(amount, stream_id) - data_to_send = self._h2_state.data_to_send() - self.socket.write(data_to_send, timeout) + self._write_outgoing_data(request) - def close_stream(self, stream_id: int) -> None: - try: - logger.trace("close_stream stream_id=%r", stream_id) - del self._streams[stream_id] - del self._events[stream_id] - - if not self._streams: - if self._state == ConnectionState.ACTIVE: - if self._exhausted_available_stream_ids: - self.close() - else: - self._state = ConnectionState.IDLE - if self._keepalive_expiry is not None: - self._should_expire_at = ( - self._now() + self._keepalive_expiry - ) - finally: - self.max_streams_semaphore.release() - - -class SyncHTTP2Stream: - def __init__(self, stream_id: int, connection: SyncHTTP2Connection) -> None: - self.stream_id = stream_id - self.connection = connection - - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - headers = [(k.lower(), v) for (k, v) in headers] - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - - # Send the request. - seen_headers = set(key for key, value in headers) - has_body = ( - b"content-length" in seen_headers or b"transfer-encoding" in seen_headers - ) - - self.send_headers(method, url, headers, has_body, timeout) - if has_body: - self.send_body(stream, timeout) - - # Receive the response. - status_code, headers = self.receive_response(timeout) - response_stream = IteratorByteStream( - iterator=self.body_iter(timeout), close_func=self._response_closed - ) + # Sending the request... - extensions = { - "http_version": b"HTTP/2", - } - return (status_code, headers, response_stream, extensions) - - def send_headers( - self, - method: bytes, - url: URL, - headers: Headers, - has_body: bool, - timeout: TimeoutDict, - ) -> None: - scheme, hostname, port, path = url + def _send_request_headers(self, request: Request, stream_id: int) -> None: + end_stream = not has_body_headers(request) # In HTTP/2 the ':authority' pseudo-header is used instead of 'Host'. # In order to gracefully handle HTTP/1.1 and HTTP/2 we always require # HTTP/1.1 style headers, and map them appropriately if we end up on # an HTTP/2 connection. - authority = None - - for k, v in headers: - if k == b"host": - authority = v - break - - if authority is None: - # Mirror the same error we'd see with `h11`, so that the behaviour - # is consistent. Although we're dealing with an `:authority` - # pseudo-header by this point, from an end-user perspective the issue - # is that the outgoing request needed to include a `host` header. - raise LocalProtocolError("Missing mandatory Host: header") + authority = [v for k, v in request.headers if k.lower() == b"host"][0] headers = [ - (b":method", method), + (b":method", request.method), (b":authority", authority), - (b":scheme", scheme), - (b":path", path), + (b":scheme", request.url.scheme), + (b":path", request.url.target), ] + [ - (k, v) - for k, v in headers - if k + (k.lower(), v) + for k, v in request.headers + if k.lower() not in ( b"host", b"transfer-encoding", ) ] - end_stream = not has_body - self.connection.send_headers(self.stream_id, headers, end_stream, timeout) + self._h2_state.send_headers(stream_id, headers, end_stream=end_stream) + self._h2_state.increment_flow_control_window(2 ** 24, stream_id=stream_id) + self._write_outgoing_data(request) - def send_body(self, stream: SyncByteStream, timeout: TimeoutDict) -> None: - for data in stream: + def _send_request_body(self, request: Request, stream_id: int) -> None: + if not has_body_headers(request): + return + + assert isinstance(request.stream, typing.Iterable) + for data in request.stream: while data: - max_flow = self.connection.wait_for_outgoing_flow( - self.stream_id, timeout - ) + max_flow = self._wait_for_outgoing_flow(request, stream_id) chunk_size = min(len(data), max_flow) chunk, data = data[:chunk_size], data[chunk_size:] - self.connection.send_data(self.stream_id, chunk, timeout) + self._h2_state.send_data(stream_id, chunk) + self._write_outgoing_data(request) - self.connection.end_stream(self.stream_id, timeout) + self._h2_state.end_stream(stream_id) + self._write_outgoing_data(request) - def receive_response( - self, timeout: TimeoutDict - ) -> Tuple[int, List[Tuple[bytes, bytes]]]: - """ - Read the response status and headers from the network. - """ + # Receiving the response... + + def _receive_response( + self, request: Request, stream_id: int + ) -> typing.Tuple[int, typing.List[typing.Tuple[bytes, bytes]]]: while True: - event = self.connection.wait_for_event(self.stream_id, timeout) + event = self._receive_stream_event(request, stream_id) if isinstance(event, h2.events.ResponseReceived): break @@ -430,17 +210,167 @@ def receive_response( return (status_code, headers) - def body_iter(self, timeout: TimeoutDict) -> Iterator[bytes]: + def _receive_response_body( + self, request: Request, stream_id: int + ) -> typing.Iterator[bytes]: while True: - event = self.connection.wait_for_event(self.stream_id, timeout) + event = self._receive_stream_event(request, stream_id) if isinstance(event, h2.events.DataReceived): amount = event.flow_controlled_length - self.connection.acknowledge_received_data( - self.stream_id, amount, timeout - ) + self._h2_state.acknowledge_received_data(amount, stream_id) + self._write_outgoing_data(request) yield event.data elif isinstance(event, (h2.events.StreamEnded, h2.events.StreamReset)): break - def _response_closed(self) -> None: - self.connection.close_stream(self.stream_id) + def _receive_stream_event( + self, request: Request, stream_id: int + ) -> h2.events.Event: + while not self._events.get(stream_id): + self._receive_events(request) + return self._events[stream_id].pop(0) + + def _receive_events(self, request: Request) -> None: + events = self._read_incoming_data(request) + for event in events: + event_stream_id = getattr(event, "stream_id", 0) + + if hasattr(event, "error_code"): + raise RemoteProtocolError(event) + + if event_stream_id in self._events: + self._events[event_stream_id].append(event) + + self._write_outgoing_data(request) + + def _response_closed(self, stream_id: int) -> None: + self._max_streams_semaphore.release() + del self._events[stream_id] + with self._state_lock: + if self._state == HTTPConnectionState.ACTIVE and not self._events: + self._state = HTTPConnectionState.IDLE + if self._keepalive_expiry is not None: + now = time.monotonic() + self._expire_at = now + self._keepalive_expiry + if self._used_all_stream_ids: # pragma: nocover + self.close() + + def close(self) -> None: + # Note that this method unilaterally closes the connection, and does + # not have any kind of locking in place around it. + # For task-safe/thread-safe operations call into 'attempt_close' instead. + self._state = HTTPConnectionState.CLOSED + self._network_stream.close() + + # Wrappers around network read/write operations... + + def _read_incoming_data( + self, request: Request + ) -> typing.List[h2.events.Event]: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("read", None) + + with self._read_lock: + data = self._network_stream.read(self.READ_NUM_BYTES, timeout) + if data == b"": + raise RemoteProtocolError("Server disconnected") + return self._h2_state.receive_data(data) + + def _write_outgoing_data(self, request: Request) -> None: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("write", None) + + with self._write_lock: + data_to_send = self._h2_state.data_to_send() + self._network_stream.write(data_to_send, timeout) + + # Flow control... + + def _wait_for_outgoing_flow(self, request: Request, stream_id: int) -> int: + """ + Returns the maximum allowable outgoing flow for a given stream. + + If the allowable flow is zero, then waits on the network until + WindowUpdated frames have increased the flow rate. + https://tools.ietf.org/html/rfc7540#section-6.9 + """ + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + while flow == 0: + self._receive_events(request) + local_flow = self._h2_state.local_flow_control_window(stream_id) + max_frame_size = self._h2_state.max_outbound_frame_size + flow = min(local_flow, max_frame_size) + return flow + + # Interface for connection pooling... + + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._origin + + def is_available(self) -> bool: + return ( + self._state != HTTPConnectionState.CLOSED and not self._used_all_stream_ids + ) + + def has_expired(self) -> bool: + now = time.monotonic() + return self._expire_at is not None and now > self._expire_at + + def is_idle(self) -> bool: + return self._state == HTTPConnectionState.IDLE + + def is_closed(self) -> bool: + return self._state == HTTPConnectionState.CLOSED + + def info(self) -> str: + origin = str(self._origin) + return ( + f"{origin!r}, HTTP/2, {self._state.name}, " + f"Request Count: {self._request_count}" + ) + + def __repr__(self) -> str: + class_name = self.__class__.__name__ + origin = str(self._origin) + return ( + f"<{class_name} [{origin!r}, {self._state.name}, " + f"Request Count: {self._request_count}]>" + ) + + # These context managers are not used in the standard flow, but are + # useful for testing or working with connection instances directly. + + def __enter__(self) -> "HTTP2Connection": + return self + + def __exit__( + self, + exc_type: typing.Type[BaseException] = None, + exc_value: BaseException = None, + traceback: types.TracebackType = None, + ) -> None: + self.close() + + +class HTTP2ConnectionByteStream: + def __init__( + self, connection: HTTP2Connection, request: Request, stream_id: int + ) -> None: + self._connection = connection + self._request = request + self._stream_id = stream_id + + def __iter__(self) -> typing.Iterator[bytes]: + kwargs = {"request": self._request, "stream_id": self._stream_id} + with Trace("http2.receive_response_body", self._request, kwargs): + for chunk in self._connection._receive_response_body( + request=self._request, stream_id=self._stream_id + ): + yield chunk + + def close(self) -> None: + kwargs = {"stream_id": self._stream_id} + with Trace("http2.response_closed", self._request, kwargs): + self._connection._response_closed(stream_id=self._stream_id) diff --git a/httpcore/_sync/http_proxy.py b/httpcore/_sync/http_proxy.py index 78c02e29..ed273ac4 100644 --- a/httpcore/_sync/http_proxy.py +++ b/httpcore/_sync/http_proxy.py @@ -1,31 +1,27 @@ -from http import HTTPStatus -from ssl import SSLContext -from typing import Tuple, cast +import ssl +from typing import Dict, List, Tuple, Union -from .._bytestreams import ByteStream from .._exceptions import ProxyError -from .._types import URL, Headers, TimeoutDict -from .._utils import get_logger, url_to_origin -from .base import SyncByteStream -from .connection import SyncHTTPConnection -from .connection_pool import SyncConnectionPool, ResponseByteStream +from .._models import URL, Origin, Request, Response, enforce_headers, enforce_url +from .._ssl import default_ssl_context +from .._synchronization import Lock +from ..backends.base import NetworkBackend +from .connection import HTTPConnection +from .connection_pool import ConnectionPool +from .http11 import HTTP11Connection +from .interfaces import ConnectionInterface -logger = get_logger(__name__) - - -def get_reason_phrase(status_code: int) -> str: - try: - return HTTPStatus(status_code).phrase - except ValueError: - return "" +HeadersAsList = List[Tuple[Union[bytes, str], Union[bytes, str]]] +HeadersAsDict = Dict[Union[bytes, str], Union[bytes, str]] def merge_headers( - default_headers: Headers = None, override_headers: Headers = None -) -> Headers: + default_headers: List[Tuple[bytes, bytes]] = None, + override_headers: List[Tuple[bytes, bytes]] = None, +) -> List[Tuple[bytes, bytes]]: """ - Append default_headers and override_headers, de-duplicating if a key existing in - both cases. + Append default_headers and override_headers, de-duplicating if a key exists + in both cases. """ default_headers = [] if default_headers is None else default_headers override_headers = [] if override_headers is None else override_headers @@ -38,253 +34,229 @@ def merge_headers( return default_headers + override_headers -class SyncHTTPProxy(SyncConnectionPool): +class HTTPProxy(ConnectionPool): """ - A connection pool for making HTTP requests via an HTTP proxy. - - Parameters - ---------- - proxy_url: - The URL of the proxy service as a 4-tuple of (scheme, host, port, path). - proxy_headers: - A list of proxy headers to include. - proxy_mode: - A proxy mode to operate in. May be "DEFAULT", "FORWARD_ONLY", or "TUNNEL_ONLY". - ssl_context: - An SSL context to use for verifying connections. - max_connections: - The maximum number of concurrent connections to allow. - max_keepalive_connections: - The maximum number of connections to allow before closing keep-alive - connections. - http2: - Enable HTTP/2 support. + A connection pool that sends requests via an HTTP proxy. """ def __init__( self, - proxy_url: URL, - proxy_headers: Headers = None, - proxy_mode: str = "DEFAULT", - ssl_context: SSLContext = None, - max_connections: int = None, + proxy_url: Union[URL, bytes, str], + proxy_headers: Union[HeadersAsDict, HeadersAsList] = None, + ssl_context: ssl.SSLContext = None, + max_connections: int = 10, max_keepalive_connections: int = None, keepalive_expiry: float = None, - http2: bool = False, - backend: str = "sync", - # Deprecated argument style: - max_keepalive: int = None, - ): - assert proxy_mode in ("DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY") - - self.proxy_origin = url_to_origin(proxy_url) - self.proxy_headers = [] if proxy_headers is None else proxy_headers - self.proxy_mode = proxy_mode + retries: int = 0, + local_address: str = None, + uds: str = None, + network_backend: NetworkBackend = None, + ) -> None: + """ + A connection pool for making HTTP requests. + + Parameters: + proxy_url: The URL to use when connecting to the proxy server. + For example `"http://127.0.0.1:8080/"`. + proxy_headers: Any HTTP headers to use for the proxy requests. + For example `{"Proxy-Authorization": "Basic :"}`. + ssl_context: An SSL context to use for verifying connections. + If not specified, the default `httpcore.default_ssl_context()` + will be used. + max_connections: The maximum number of concurrent HTTP connections that + the pool should allow. Any attempt to send a request on a pool that + would exceed this amount will block until a connection is available. + max_keepalive_connections: The maximum number of idle HTTP connections + that will be maintained in the pool. + keepalive_expiry: The duration in seconds that an idle HTTP connection + may be maintained for before being expired from the pool. + retries: The maximum number of retries when trying to establish + a connection. + local_address: Local address to connect from. Can also be used to + connect using a particular address family. Using + `local_address="0.0.0.0"` will connect using an `AF_INET` address + (IPv4), while using `local_address="::"` will connect using an + `AF_INET6` address (IPv6). + uds: Path to a Unix Domain Socket to use instead of TCP sockets. + network_backend: A backend instance to use for handling network I/O. + """ + if ssl_context is None: + ssl_context = default_ssl_context() + super().__init__( ssl_context=ssl_context, max_connections=max_connections, max_keepalive_connections=max_keepalive_connections, keepalive_expiry=keepalive_expiry, - http2=http2, - backend=backend, - max_keepalive=max_keepalive, + network_backend=network_backend, + retries=retries, + local_address=local_address, + uds=uds, ) - - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - if self._keepalive_expiry is not None: - self._keepalive_sweep() - - if ( - self.proxy_mode == "DEFAULT" and url[0] == b"http" - ) or self.proxy_mode == "FORWARD_ONLY": - # By default HTTP requests should be forwarded. - logger.trace( - "forward_request proxy_origin=%r proxy_headers=%r method=%r url=%r", - self.proxy_origin, - self.proxy_headers, - method, - url, - ) - return self._forward_request( - method, url, headers=headers, stream=stream, extensions=extensions - ) - else: - # By default HTTPS should be tunnelled. - logger.trace( - "tunnel_request proxy_origin=%r proxy_headers=%r method=%r url=%r", - self.proxy_origin, - self.proxy_headers, - method, - url, - ) - return self._tunnel_request( - method, url, headers=headers, stream=stream, extensions=extensions - ) - - def _forward_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - """ - Forwarded proxy requests include the entire URL as the HTTP target, - rather than just the path. - """ - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - origin = self.proxy_origin - connection = self._get_connection_from_pool(origin) - - if connection is None: - connection = SyncHTTPConnection( - origin=origin, - http2=self._http2, + self._ssl_context = ssl_context + self._proxy_url = enforce_url(proxy_url, name="proxy_url") + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + + def create_connection(self, origin: Origin) -> ConnectionInterface: + if origin.scheme == b"http": + return ForwardHTTPConnection( + proxy_origin=self._proxy_url.origin, keepalive_expiry=self._keepalive_expiry, - ssl_context=self._ssl_context, + network_backend=self._network_backend, ) - self._add_to_pool(connection, timeout) - - # Issue a forwarded proxy request... - - # GET https://www.example.org/path HTTP/1.1 - # [proxy headers] - # [headers] - scheme, host, port, path = url - if port is None: - target = b"%b://%b%b" % (scheme, host, path) - else: - target = b"%b://%b:%d%b" % (scheme, host, port, path) - - url = self.proxy_origin + (target,) - headers = merge_headers(self.proxy_headers, headers) - - ( - status_code, - headers, - stream, - extensions, - ) = connection.handle_request( - method, url, headers=headers, stream=stream, extensions=extensions + return TunnelHTTPConnection( + proxy_origin=self._proxy_url.origin, + remote_origin=origin, + ssl_context=self._ssl_context, + keepalive_expiry=self._keepalive_expiry, + network_backend=self._network_backend, ) - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed + +class ForwardHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + proxy_headers: Union[HeadersAsDict, HeadersAsList] = None, + keepalive_expiry: float = None, + network_backend: NetworkBackend = None, + ) -> None: + self._connection = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + ) + self._proxy_origin = proxy_origin + self._proxy_headers = enforce_headers(proxy_headers, name="proxy_headers") + + def handle_request(self, request: Request) -> Response: + headers = merge_headers(self._proxy_headers, request.headers) + url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=bytes(request.url), ) + proxy_request = Request( + method=request.method, + url=url, + headers=headers, + content=request.stream, + extensions=request.extensions, + ) + return self._connection.handle_request(proxy_request) - return status_code, headers, wrapped_stream, extensions + def can_handle_request(self, origin: Origin) -> bool: + return origin.scheme == b"http" - def _tunnel_request( - self, - method: bytes, - url: URL, - headers: Headers, - stream: SyncByteStream, - extensions: dict, - ) -> Tuple[int, Headers, SyncByteStream, dict]: - """ - Tunnelled proxy requests require an initial CONNECT request to - establish the connection, and then send regular requests. - """ - timeout = cast(TimeoutDict, extensions.get("timeout", {})) - origin = url_to_origin(url) - connection = self._get_connection_from_pool(origin) + def close(self) -> None: + self._connection.close() - if connection is None: - scheme, host, port = origin + def info(self) -> str: + return self._connection.info() - # First, create a connection to the proxy server - proxy_connection = SyncHTTPConnection( - origin=self.proxy_origin, - http2=self._http2, - keepalive_expiry=self._keepalive_expiry, - ssl_context=self._ssl_context, - ) + def is_available(self) -> bool: + return self._connection.is_available() - # Issue a CONNECT request... - - # CONNECT www.example.org:80 HTTP/1.1 - # [proxy-headers] - target = b"%b:%d" % (host, port) - connect_url = self.proxy_origin + (target,) - connect_headers = [(b"Host", target), (b"Accept", b"*/*")] - connect_headers = merge_headers(connect_headers, self.proxy_headers) - - try: - ( - proxy_status_code, - _, - proxy_stream, - _, - ) = proxy_connection.handle_request( - b"CONNECT", - connect_url, - headers=connect_headers, - stream=ByteStream(b""), - extensions=extensions, - ) + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() - proxy_reason = get_reason_phrase(proxy_status_code) - logger.trace( - "tunnel_response proxy_status_code=%r proxy_reason=%r ", - proxy_status_code, - proxy_reason, + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" + + +class TunnelHTTPConnection(ConnectionInterface): + def __init__( + self, + proxy_origin: Origin, + remote_origin: Origin, + ssl_context: ssl.SSLContext, + proxy_headers: List[Tuple[bytes, bytes]] = None, + keepalive_expiry: float = None, + network_backend: NetworkBackend = None, + ) -> None: + self._connection: ConnectionInterface = HTTPConnection( + origin=proxy_origin, + keepalive_expiry=keepalive_expiry, + network_backend=network_backend, + ) + self._proxy_origin = proxy_origin + self._remote_origin = remote_origin + self._ssl_context = ssl_context + self._proxy_headers = [] if proxy_headers is None else proxy_headers + self._keepalive_expiry = keepalive_expiry + self._connect_lock = Lock() + self._connected = False + + def handle_request(self, request: Request) -> Response: + timeouts = request.extensions.get("timeout", {}) + timeout = timeouts.get("connect", None) + + with self._connect_lock: + if not self._connected: + target = b"%b:%d" % (self._remote_origin.host, self._remote_origin.port) + + connect_url = URL( + scheme=self._proxy_origin.scheme, + host=self._proxy_origin.host, + port=self._proxy_origin.port, + target=target, + ) + connect_headers = [(b"Host", target), (b"Accept", b"*/*")] + connect_request = Request( + method=b"CONNECT", url=connect_url, headers=connect_headers + ) + connect_response = self._connection.handle_request( + connect_request ) - # Read the response data without closing the socket - for _ in proxy_stream: - pass - # See if the tunnel was successfully established. - if proxy_status_code < 200 or proxy_status_code > 299: - msg = "%d %s" % (proxy_status_code, proxy_reason) + if connect_response.status < 200 or connect_response.status > 299: + reason_bytes = connect_response.extensions.get("reason_phrase", b"") + reason_str = reason_bytes.decode("ascii", errors="ignore") + msg = "%d %s" % (connect_response.status, reason_str) + self._connection.close() raise ProxyError(msg) - # Upgrade to TLS if required - # We assume the target speaks TLS on the specified port - if scheme == b"https": - proxy_connection.start_tls(host, self._ssl_context, timeout) - except Exception as exc: - proxy_connection.close() - raise ProxyError(exc) - - # The CONNECT request is successful, so we have now SWITCHED PROTOCOLS. - # This means the proxy connection is now unusable, and we must create - # a new one for regular requests, making sure to use the same socket to - # retain the tunnel. - connection = SyncHTTPConnection( - origin=origin, - http2=self._http2, - keepalive_expiry=self._keepalive_expiry, - ssl_context=self._ssl_context, - socket=proxy_connection.socket, - ) - self._add_to_pool(connection, timeout) - - # Once the connection has been established we can send requests on - # it as normal. - ( - status_code, - headers, - stream, - extensions, - ) = connection.handle_request( - method, - url, - headers=headers, - stream=stream, - extensions=extensions, - ) + stream = connect_response.extensions["network_stream"] + stream = stream.start_tls( + ssl_context=self._ssl_context, + server_hostname=self._remote_origin.host.decode("ascii"), + timeout=timeout, + ) + self._connection = HTTP11Connection( + origin=self._remote_origin, + stream=stream, + keepalive_expiry=self._keepalive_expiry, + ) + self._connected = True + return self._connection.handle_request(request) - wrapped_stream = ResponseByteStream( - stream, connection=connection, callback=self._response_closed - ) + def can_handle_request(self, origin: Origin) -> bool: + return origin == self._remote_origin + + def close(self) -> None: + self._connection.close() + + def info(self) -> str: + return self._connection.info() + + def is_available(self) -> bool: + return self._connection.is_available() + + def has_expired(self) -> bool: + return self._connection.has_expired() + + def is_idle(self) -> bool: + return self._connection.is_idle() + + def is_closed(self) -> bool: + return self._connection.is_closed() - return status_code, headers, wrapped_stream, extensions + def __repr__(self) -> str: + return f"<{self.__class__.__name__} [{self.info()}]>" diff --git a/httpcore/_sync/interfaces.py b/httpcore/_sync/interfaces.py new file mode 100644 index 00000000..831a07c5 --- /dev/null +++ b/httpcore/_sync/interfaces.py @@ -0,0 +1,133 @@ +from typing import Iterator, Union + +from contextlib import contextmanager +from .._models import ( + URL, + Origin, + Request, + Response, + enforce_bytes, + enforce_headers, + enforce_url, + include_request_headers, +) + + +class RequestInterface: + def request( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: Union[dict, list] = None, + content: Union[bytes, Iterator[bytes]] = None, + extensions: dict = None, + ) -> Response: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + response.read() + finally: + response.close() + return response + + @contextmanager + def stream( + self, + method: Union[bytes, str], + url: Union[URL, bytes, str], + *, + headers: Union[dict, list] = None, + content: Union[bytes, Iterator[bytes]] = None, + extensions: dict = None, + ) -> Iterator[Response]: + # Strict type checking on our parameters. + method = enforce_bytes(method, name="method") + url = enforce_url(url, name="url") + headers = enforce_headers(headers, name="headers") + + # Include Host header, and optionally Content-Length or Transfer-Encoding. + headers = include_request_headers(headers, url=url, content=content) + + request = Request( + method=method, + url=url, + headers=headers, + content=content, + extensions=extensions, + ) + response = self.handle_request(request) + try: + yield response + finally: + response.close() + + def handle_request(self, request: Request) -> Response: + raise NotImplementedError() # pragma: nocover + + +class ConnectionInterface(RequestInterface): + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def info(self) -> str: + raise NotImplementedError() # pragma: nocover + + def can_handle_request(self, origin: Origin) -> bool: + raise NotImplementedError() # pragma: nocover + + def is_available(self) -> bool: + """ + Return `True` if the connection is currently able to accept an + outgoing request. + + An HTTP/1.1 connection will only be available if it is currently idle. + + An HTTP/2 connection will be available so long as the stream ID space is + not yet exhausted, and the connection is not in an error state. + + While the connection is being established we may not yet know if it is going + to result in an HTTP/1.1 or HTTP/2 connection. The connection should be + treated as being available, but might ultimately raise `NewConnectionRequired` + required exceptions if multiple requests are attempted over a connection + that ends up being established as HTTP/1.1. + """ + raise NotImplementedError() # pragma: nocover + + def has_expired(self) -> bool: + """ + Return `True` if the connection is in a state where it should be closed. + + This either means that the connection is idle and it has passed the + expiry time on its keep-alive, or that server has sent an EOF. + """ + raise NotImplementedError() # pragma: nocover + + def is_idle(self) -> bool: + """ + Return `True` if the connection is currently idle. + """ + raise NotImplementedError() # pragma: nocover + + def is_closed(self) -> bool: + """ + Return `True` if the connection has been closed. + + Used when a response is closed to determine if the connection may be + returned to the connection pool or not. + """ + raise NotImplementedError() # pragma: nocover diff --git a/httpcore/_synchronization.py b/httpcore/_synchronization.py new file mode 100644 index 00000000..e7cb1d60 --- /dev/null +++ b/httpcore/_synchronization.py @@ -0,0 +1,89 @@ +import threading +from types import TracebackType +from typing import Type + +import anyio + +from ._exceptions import PoolTimeout, map_exceptions + + +class AsyncLock: + def __init__(self) -> None: + self._lock = anyio.Lock() + + async def __aenter__(self) -> "AsyncLock": + await self._lock.acquire() + return self + + async def __aexit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self._lock.release() + + +class AsyncEvent: + def __init__(self) -> None: + self._event = anyio.Event() + + def set(self) -> None: + self._event.set() + + async def wait(self, timeout: float = None) -> None: + exc_map: dict = {TimeoutError: PoolTimeout} + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + await self._event.wait() + + +class AsyncSemaphore: + def __init__(self, bound: int) -> None: + self._semaphore = anyio.Semaphore(initial_value=bound, max_value=bound) + + async def acquire(self) -> None: + await self._semaphore.acquire() + + async def release(self) -> None: + self._semaphore.release() + + +class Lock: + def __init__(self) -> None: + self._lock = threading.Lock() + + def __enter__(self) -> "Lock": + self._lock.acquire() + return self + + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + self._lock.release() + + +class Event: + def __init__(self) -> None: + self._event = threading.Event() + + def set(self) -> None: + self._event.set() + + def wait(self, timeout: float = None) -> None: + if not self._event.wait(timeout=timeout): + raise PoolTimeout() # pragma: nocover + + +class Semaphore: + def __init__(self, bound: int) -> None: + self._semaphore = threading.Semaphore(value=bound) + + def acquire(self) -> None: + self._semaphore.acquire() + + def release(self) -> None: + self._semaphore.release() diff --git a/httpcore/_threadlock.py b/httpcore/_threadlock.py deleted file mode 100644 index 2ff2bc37..00000000 --- a/httpcore/_threadlock.py +++ /dev/null @@ -1,35 +0,0 @@ -import threading -from types import TracebackType -from typing import Type - - -class ThreadLock: - """ - Provides thread safety when used as a sync context manager, or a - no-op when used as an async context manager. - """ - - def __init__(self) -> None: - self.lock = threading.Lock() - - def __enter__(self) -> None: - self.lock.acquire() - - def __exit__( - self, - exc_type: Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - self.lock.release() - - async def __aenter__(self) -> None: - pass - - async def __aexit__( - self, - exc_type: Type[BaseException] = None, - exc_value: BaseException = None, - traceback: TracebackType = None, - ) -> None: - pass diff --git a/httpcore/_trace.py b/httpcore/_trace.py new file mode 100644 index 00000000..b6fb64ac --- /dev/null +++ b/httpcore/_trace.py @@ -0,0 +1,52 @@ +from types import TracebackType +from typing import Any, Type + +from ._models import Request + + +class Trace: + def __init__(self, name: str, request: Request, kwargs: dict = None) -> None: + self.name = name + self.trace = request.extensions.get("trace") + self.kwargs = kwargs or {} + self.return_value: Any = None + + def __enter__(self) -> "Trace": + if self.trace is not None: + info = self.kwargs + self.trace(f"{self.name}.started", info) + return self + + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + if self.trace is not None: + if exc_value is None: + info: dict = {"return_value": self.return_value} + self.trace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + self.trace(f"{self.name}.failed", info) + + async def __aenter__(self) -> "Trace": + if self.trace is not None: + info = self.kwargs + await self.trace(f"{self.name}.started", info) + return self + + async def __aexit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + if self.trace is not None: + if exc_value is None: + info: dict = {"return_value": self.return_value} + await self.trace(f"{self.name}.complete", info) + else: + info = {"exception": exc_value} + await self.trace(f"{self.name}.failed", info) diff --git a/httpcore/_types.py b/httpcore/_types.py deleted file mode 100644 index 2f9eeba7..00000000 --- a/httpcore/_types.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Type definitions for type checking purposes. -""" - -from typing import List, Mapping, Optional, Tuple, TypeVar, Union - -T = TypeVar("T") -StrOrBytes = Union[str, bytes] -Origin = Tuple[bytes, bytes, int] -URL = Tuple[bytes, bytes, Optional[int], bytes] -Headers = List[Tuple[bytes, bytes]] -TimeoutDict = Mapping[str, Optional[float]] diff --git a/httpcore/_utils.py b/httpcore/_utils.py index 978b87a2..df5dea8f 100644 --- a/httpcore/_utils.py +++ b/httpcore/_utils.py @@ -1,83 +1,12 @@ -import itertools -import logging -import os import select import socket import sys import typing -from ._types import URL, Origin - -_LOGGER_INITIALIZED = False -TRACE_LOG_LEVEL = 5 -DEFAULT_PORTS = {b"http": 80, b"https": 443} - - -class Logger(logging.Logger): - # Stub for type checkers. - def trace(self, message: str, *args: typing.Any, **kwargs: typing.Any) -> None: - ... # pragma: nocover - - -def get_logger(name: str) -> Logger: - """ - Get a `logging.Logger` instance, and optionally - set up debug logging based on the HTTPCORE_LOG_LEVEL or HTTPX_LOG_LEVEL - environment variables. - """ - global _LOGGER_INITIALIZED - if not _LOGGER_INITIALIZED: - _LOGGER_INITIALIZED = True - logging.addLevelName(TRACE_LOG_LEVEL, "TRACE") - - log_level = os.environ.get( - "HTTPCORE_LOG_LEVEL", os.environ.get("HTTPX_LOG_LEVEL", "") - ).upper() - if log_level in ("DEBUG", "TRACE"): - logger = logging.getLogger("httpcore") - logger.setLevel(logging.DEBUG if log_level == "DEBUG" else TRACE_LOG_LEVEL) - handler = logging.StreamHandler(sys.stderr) - handler.setFormatter( - logging.Formatter( - fmt="%(levelname)s [%(asctime)s] %(name)s - %(message)s", - datefmt="%Y-%m-%d %H:%M:%S", - ) - ) - logger.addHandler(handler) - - logger = logging.getLogger(name) - - def trace(message: str, *args: typing.Any, **kwargs: typing.Any) -> None: - logger.log(TRACE_LOG_LEVEL, message, *args, **kwargs) - - logger.trace = trace # type: ignore - - return typing.cast(Logger, logger) - - -def url_to_origin(url: URL) -> Origin: - scheme, host, explicit_port = url[:3] - default_port = DEFAULT_PORTS[scheme] - port = default_port if explicit_port is None else explicit_port - return scheme, host, port - - -def origin_to_url_string(origin: Origin) -> str: - scheme, host, explicit_port = origin - port = f":{explicit_port}" if explicit_port != DEFAULT_PORTS[scheme] else "" - return f"{scheme.decode('ascii')}://{host.decode('ascii')}{port}" - - -def exponential_backoff(factor: float) -> typing.Iterator[float]: - yield 0 - for n in itertools.count(2): - yield factor * (2 ** (n - 2)) - def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool: """ Return whether a socket, as identifed by its file descriptor, is readable. - "A socket is readable" means that the read buffer isn't empty, i.e. that calling .recv() on it would immediately return some data. """ @@ -88,7 +17,7 @@ def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool: # descriptor, we treat it as being readable, as if it the next read operation # on it is ready to return the terminating `b""`. sock_fd = None if sock is None else sock.fileno() - if sock_fd is None or sock_fd < 0: + if sock_fd is None or sock_fd < 0: # pragma: nocover return True # The implementation below was stolen from: @@ -97,7 +26,9 @@ def is_socket_readable(sock: typing.Optional[socket.socket]) -> bool: # Use select.select on Windows, and when poll is unavailable and select.poll # everywhere else. (E.g. When eventlet is in use. See #327) - if sys.platform == "win32" or getattr(select, "poll", None) is None: + if ( + sys.platform == "win32" or getattr(select, "poll", None) is None + ): # pragma: nocover rready, _, _ = select.select([sock_fd], [], [], 0) return bool(rready) p = select.poll() diff --git a/httpcore/_backends/__init__.py b/httpcore/backends/__init__.py similarity index 100% rename from httpcore/_backends/__init__.py rename to httpcore/backends/__init__.py diff --git a/httpcore/backends/asyncio.py b/httpcore/backends/asyncio.py new file mode 100644 index 00000000..568a9400 --- /dev/null +++ b/httpcore/backends/asyncio.py @@ -0,0 +1,118 @@ +import ssl +import typing + +import anyio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import AsyncNetworkBackend, AsyncNetworkStream + + +class AsyncIOStream(AsyncNetworkStream): + def __init__(self, stream: anyio.abc.ByteStream) -> None: + self._stream = stream + + async def read(self, max_bytes: int, timeout: float = None) -> bytes: + exc_map = { + TimeoutError: ReadTimeout, + anyio.BrokenResourceError: ReadError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + try: + return await self._stream.receive(max_bytes=max_bytes) + except anyio.EndOfStream: # pragma: nocover + return b"" + + async def write(self, buffer: bytes, timeout: float = None) -> None: + if not buffer: + return + + exc_map = { + TimeoutError: WriteTimeout, + anyio.BrokenResourceError: WriteError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + await self._stream.send(item=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str = None, + timeout: float = None, + ) -> AsyncNetworkStream: + exc_map = { + TimeoutError: ConnectTimeout, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + ssl_stream = await anyio.streams.tls.TLSStream.wrap( + self._stream, + ssl_context=ssl_context, + hostname=server_hostname, + standard_compatible=False, + server_side=False, + ) + return AsyncIOStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object": + return self._stream.extra(anyio.streams.tls.TLSAttribute.ssl_object, None) + if info == "client_addr": + return self._stream.extra(anyio.abc.SocketAttribute.local_address, None) + if info == "server_addr": + return self._stream.extra(anyio.abc.SocketAttribute.remote_address, None) + if info == "socket": + return self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + if info == "is_readable": + sock = self._stream.extra(anyio.abc.SocketAttribute.raw_socket, None) + return is_socket_readable(sock) + return None + + +class AsyncIOBackend(AsyncNetworkBackend): + async def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> AsyncNetworkStream: + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_tcp( + remote_host=host, + remote_port=port, + local_host=local_address, + ) + return AsyncIOStream(stream) + + async def connect_unix_socket( + self, path: str, timeout: float = None + ) -> AsyncNetworkStream: # pragma: nocover + exc_map = { + TimeoutError: ConnectTimeout, + OSError: ConnectError, + anyio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with anyio.fail_after(timeout): + stream: anyio.abc.ByteStream = await anyio.connect_unix(path) + return AsyncIOStream(stream) + + async def sleep(self, seconds: float) -> None: + await anyio.sleep(seconds) # pragma: nocover diff --git a/httpcore/backends/auto.py b/httpcore/backends/auto.py new file mode 100644 index 00000000..d2a92de6 --- /dev/null +++ b/httpcore/backends/auto.py @@ -0,0 +1,35 @@ +import sniffio + +from .base import AsyncNetworkBackend, AsyncNetworkStream + + +class AutoBackend(AsyncNetworkBackend): + async def _init_backend(self) -> None: + if not (hasattr(self, "_backend")): + backend = sniffio.current_async_library() + if backend == "trio": + from .trio import TrioBackend + + self._backend: AsyncNetworkBackend = TrioBackend() + else: + from .asyncio import AsyncIOBackend + + self._backend = AsyncIOBackend() + + async def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> AsyncNetworkStream: + await self._init_backend() + return await self._backend.connect_tcp( + host, port, timeout=timeout, local_address=local_address + ) + + async def connect_unix_socket( + self, path: str, timeout: float = None + ) -> AsyncNetworkStream: # pragma: nocover + await self._init_backend() + return await self._backend.connect_unix_socket(path, timeout=timeout) + + async def sleep(self, seconds: float) -> None: # pragma: nocover + await self._init_backend() + return await self._backend.sleep(seconds) diff --git a/httpcore/backends/base.py b/httpcore/backends/base.py new file mode 100644 index 00000000..45136ffa --- /dev/null +++ b/httpcore/backends/base.py @@ -0,0 +1,75 @@ +import ssl +import time +import typing + + +class NetworkStream: + def read(self, max_bytes: int, timeout: float = None) -> bytes: + raise NotImplementedError() # pragma: nocover + + def write(self, buffer: bytes, timeout: float = None) -> None: + raise NotImplementedError() # pragma: nocover + + def close(self) -> None: + raise NotImplementedError() # pragma: nocover + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str = None, + timeout: float = None, + ) -> "NetworkStream": + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class NetworkBackend: + def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def connect_unix_socket(self, path: str, timeout: float = None) -> NetworkStream: + raise NotImplementedError() # pragma: nocover + + def sleep(self, seconds: float) -> None: + time.sleep(seconds) # pragma: nocover + + +class AsyncNetworkStream: + async def read(self, max_bytes: int, timeout: float = None) -> bytes: + raise NotImplementedError() # pragma: nocover + + async def write(self, buffer: bytes, timeout: float = None) -> None: + raise NotImplementedError() # pragma: nocover + + async def aclose(self) -> None: + raise NotImplementedError() # pragma: nocover + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str = None, + timeout: float = None, + ) -> "AsyncNetworkStream": + raise NotImplementedError() # pragma: nocover + + def get_extra_info(self, info: str) -> typing.Any: + return None # pragma: nocover + + +class AsyncNetworkBackend: + async def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def connect_unix_socket( + self, path: str, timeout: float = None + ) -> AsyncNetworkStream: + raise NotImplementedError() # pragma: nocover + + async def sleep(self, seconds: float) -> None: + raise NotImplementedError() # pragma: nocover diff --git a/httpcore/backends/mock.py b/httpcore/backends/mock.py new file mode 100644 index 00000000..06d9694a --- /dev/null +++ b/httpcore/backends/mock.py @@ -0,0 +1,105 @@ +import ssl +import typing + +from .base import AsyncNetworkBackend, AsyncNetworkStream, NetworkBackend, NetworkStream + + +class MockSSLObject: + def __init__(self, http2: bool): + self._http2 = http2 + + def selected_alpn_protocol(self) -> str: + return "h2" if self._http2 else "http/1.1" + + +class MockStream(NetworkStream): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + def read(self, max_bytes: int, timeout: float = None) -> bytes: + if not self._buffer: + return b"" + return self._buffer.pop(0) + + def write(self, buffer: bytes, timeout: float = None) -> None: + pass + + def close(self) -> None: + pass + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str = None, + timeout: float = None, + ) -> NetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + +class MockBackend(NetworkBackend): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def connect_unix_socket(self, path: str, timeout: float = None) -> NetworkStream: + return MockStream(list(self._buffer), http2=self._http2) + + def sleep(self, seconds: float) -> None: + pass + + +class AsyncMockStream(AsyncNetworkStream): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._original_buffer = buffer + self._current_buffer = list(self._original_buffer) + self._http2 = http2 + + async def read(self, max_bytes: int, timeout: float = None) -> bytes: + if not self._current_buffer: + self._current_buffer = list(self._original_buffer) + return self._current_buffer.pop(0) + + async def write(self, buffer: bytes, timeout: float = None) -> None: + pass + + async def aclose(self) -> None: + pass + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str = None, + timeout: float = None, + ) -> AsyncNetworkStream: + return self + + def get_extra_info(self, info: str) -> typing.Any: + return MockSSLObject(http2=self._http2) if info == "ssl_object" else None + + +class AsyncMockBackend(AsyncNetworkBackend): + def __init__(self, buffer: typing.List[bytes], http2: bool = False) -> None: + self._buffer = buffer + self._http2 = http2 + + async def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def connect_unix_socket( + self, path: str, timeout: float = None + ) -> AsyncNetworkStream: + return AsyncMockStream(list(self._buffer), http2=self._http2) + + async def sleep(self, seconds: float) -> None: + pass diff --git a/httpcore/backends/sync.py b/httpcore/backends/sync.py new file mode 100644 index 00000000..0c17aef7 --- /dev/null +++ b/httpcore/backends/sync.py @@ -0,0 +1,89 @@ +import socket +import ssl +import typing + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .._utils import is_socket_readable +from .base import NetworkBackend, NetworkStream + + +class SyncStream(NetworkStream): + def __init__(self, sock: socket.socket) -> None: + self._sock = sock + + def read(self, max_bytes: int, timeout: float = None) -> bytes: + exc_map = {socket.timeout: ReadTimeout, socket.error: ReadError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + return self._sock.recv(max_bytes) + + def write(self, buffer: bytes, timeout: float = None) -> None: + if not buffer: + return + + exc_map = {socket.timeout: WriteTimeout, socket.error: WriteError} + with map_exceptions(exc_map): + while buffer: + self._sock.settimeout(timeout) + n = self._sock.send(buffer) + buffer = buffer[n:] + + def close(self) -> None: + self._sock.close() + + def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str = None, + timeout: float = None, + ) -> NetworkStream: + exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError} + with map_exceptions(exc_map): + self._sock.settimeout(timeout) + sock = ssl_context.wrap_socket(self._sock, server_hostname=server_hostname) + return SyncStream(sock) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._sock, ssl.SSLSocket): + return self._sock._sslobj # type: ignore + if info == "client_addr": + return self._sock.getsockname() + if info == "server_addr": + return self._sock.getpeername() + if info == "socket": + return self._sock + if info == "is_readable": + return is_socket_readable(self._sock) + return None + + +class SyncBackend(NetworkBackend): + def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> NetworkStream: + address = (host, port) + source_address = None if local_address is None else (local_address, 0) + exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError} + with map_exceptions(exc_map): + sock = socket.create_connection( + address, timeout, source_address=source_address + ) + return SyncStream(sock) + + def connect_unix_socket( + self, path: str, timeout: float = None + ) -> NetworkStream: # pragma: nocover + exc_map = {socket.timeout: ConnectTimeout, socket.error: ConnectError} + with map_exceptions(exc_map): + sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + sock.settimeout(timeout) + sock.connect(path) + return SyncStream(sock) diff --git a/httpcore/backends/trio.py b/httpcore/backends/trio.py new file mode 100644 index 00000000..c2afd5cd --- /dev/null +++ b/httpcore/backends/trio.py @@ -0,0 +1,128 @@ +import ssl +import typing + +import trio + +from .._exceptions import ( + ConnectError, + ConnectTimeout, + ReadError, + ReadTimeout, + WriteError, + WriteTimeout, + map_exceptions, +) +from .base import AsyncNetworkBackend, AsyncNetworkStream + + +class TrioStream(AsyncNetworkStream): + def __init__(self, stream: trio.abc.Stream) -> None: + self._stream = stream + + async def read(self, max_bytes: int, timeout: float = None) -> bytes: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map = {trio.TooSlowError: ReadTimeout, trio.BrokenResourceError: ReadError} + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + return await self._stream.receive_some(max_bytes=max_bytes) + + async def write(self, buffer: bytes, timeout: float = None) -> None: + if not buffer: + return + + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map = { + trio.TooSlowError: WriteTimeout, + trio.BrokenResourceError: WriteError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + await self._stream.send_all(data=buffer) + + async def aclose(self) -> None: + await self._stream.aclose() + + async def start_tls( + self, + ssl_context: ssl.SSLContext, + server_hostname: str = None, + timeout: float = None, + ) -> AsyncNetworkStream: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + } + ssl_stream = trio.SSLStream( + self._stream, + ssl_context=ssl_context, + server_hostname=server_hostname, + https_compatible=True, + server_side=False, + ) + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + await ssl_stream.do_handshake() + return TrioStream(ssl_stream) + + def get_extra_info(self, info: str) -> typing.Any: + if info == "ssl_object" and isinstance(self._stream, trio.SSLStream): + return self._stream._ssl_object # type: ignore + if info == "client_addr": + return self._get_socket_stream().socket.getsockname() + if info == "server_addr": + return self._get_socket_stream().socket.getpeername() + if info == "socket": + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream.socket + if info == "is_readable": + socket = self.get_extra_info("socket") + return socket.is_readable() + return None + + def _get_socket_stream(self) -> trio.SocketStream: + stream = self._stream + while isinstance(stream, trio.SSLStream): + stream = stream.transport_stream + assert isinstance(stream, trio.SocketStream) + return stream + + +class TrioBackend(AsyncNetworkBackend): + async def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> AsyncNetworkStream: + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + } + # Trio supports 'local_address' from 0.16.1 onwards. + # We only include the keyword argument if a local_address + # argument has been passed. + kwargs: dict = {} if local_address is None else {"local_address": local_address} + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_tcp_stream( + host=host, port=port, **kwargs + ) + return TrioStream(stream) + + async def connect_unix_socket( + self, path: str, timeout: float = None + ) -> AsyncNetworkStream: # pragma: nocover + timeout_or_inf = float("inf") if timeout is None else timeout + exc_map = { + trio.TooSlowError: ConnectTimeout, + trio.BrokenResourceError: ConnectError, + } + with map_exceptions(exc_map): + with trio.fail_after(timeout_or_inf): + stream: trio.abc.Stream = await trio.open_unix_socket(path) + return TrioStream(stream) + + async def sleep(self, seconds: float) -> None: + await trio.sleep(seconds) # pragma: nocover diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 00000000..bd3e933f --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,34 @@ +site_name: HTTPCore +site_description: A minimal HTTP client for Python. +site_url: https://www.encode.io/httpcore/ + +repo_name: encode/httpcore +repo_url: https://github.com/encode/httpcore/ + +nav: + - Introduction: 'index.md' + - Quickstart: 'quickstart.md' + - Connection Pools: 'connection-pools.md' + - Proxies: 'proxies.md' + - HTTP/2: 'http2.md' + - Async Support: 'async.md' + - Extensions: 'extensions.md' + - Exceptions: 'exceptions.md' + +theme: + name: "material" + +plugins: + - search + - mkdocstrings: + default_handler: python + watch: + - httpcore + handlers: + python: + members_order: + - "source" + +markdown_extensions: + - codehilite: + css_class: highlight diff --git a/requirements.txt b/requirements.txt index d007703d..144dd248 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,15 +7,11 @@ curio==1.4; python_version < '3.7' curio==1.5; python_version >= '3.7' # Docs -# https://github.com/sphinx-doc/sphinx/issues/9505 -sphinx @ https://github.com/sphinx-doc/sphinx/archive/03bf83365eb7a5180e93a3ccd5c050f5da36c489.tar.gz; python_version >= '3.10' -sphinx==4.1.2; python_version < '3.10' -sphinx-autobuild==2021.3.14 -myst-parser==0.15.2 -furo==2021.10.9 -ghp-import==2.0.2 -# myst-parser + docutils==0.17 has a bug: https://github.com/executablebooks/MyST-Parser/issues/343 -docutils==0.17.1 +mkdocs==1.2.2 +mkdocs-autorefs==0.3.0 +mkdocs-material==7.3.0 +mkdocs-material-extensions==1.0.3 +mkdocstrings==0.16.1 # Packaging twine==3.4.2 @@ -34,7 +30,9 @@ isort==5.9.3 mypy==0.910 pproxy==2.7.8 pytest==6.2.5 +pytest-httpbin==1.0.0 pytest-trio==0.7.0 pytest-asyncio==0.15.1 trustme==0.9.0 +types-certifi==2021.10.8.0 uvicorn==0.12.1; python_version < '3.7' diff --git a/scripts/build b/scripts/build index 2192b90b..602eab0a 100755 --- a/scripts/build +++ b/scripts/build @@ -12,4 +12,4 @@ set -x ${PREFIX}python setup.py sdist bdist_wheel ${PREFIX}twine check dist/* -scripts/docs build +${PREFIX}mkdocs build diff --git a/scripts/docs b/scripts/docs deleted file mode 100755 index b4c9f971..00000000 --- a/scripts/docs +++ /dev/null @@ -1,23 +0,0 @@ -#!/bin/bash -e - -export PREFIX="" -if [ -d 'venv' ] ; then - export PREFIX="venv/bin/" -fi - -SOURCE_DIR="docs" -OUT_DIR="build/html" - -COMMAND="$1" -ARGS="${@:2}" - -set -x - -if [ "$COMMAND" = "build" ]; then - ${PREFIX}sphinx-build $SOURCE_DIR $OUT_DIR -elif [ "$COMMAND" = "gh-deploy" ]; then - scripts/docs build - ${PREFIX}ghp-import $OUT_DIR -np -m "Deployed $(git rev-parse --short HEAD)" $ARGS -else - ${PREFIX}sphinx-autobuild $SOURCE_DIR $OUT_DIR --watch httpcore/ $ARGS -fi diff --git a/setup.cfg b/setup.cfg index 0e23b062..110a6eec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,7 +1,7 @@ [flake8] ignore = W503, E203, B305 -max-line-length = 88 -exclude = httpcore/_sync,tests/sync_tests +max-line-length = 120 +exclude = httpcore/_sync,tests/_sync [mypy] disallow_untyped_defs = True @@ -17,7 +17,7 @@ profile = black combine_as_imports = True known_first_party = httpcore,tests known_third_party = brotli,certifi,chardet,cryptography,h11,h2,hstspreload,pytest,rfc3986,setuptools,sniffio,trio,trustme,urllib3,uvicorn -skip = httpcore/_sync/,tests/sync_tests/ +skip = httpcore/_sync/,tests/_sync [tool:pytest] addopts = -rxXs @@ -25,5 +25,5 @@ markers = copied_from(source, changes=None): mark test as copied from somewhere else, along with a description of changes made to accodomate e.g. our test setup [coverage:run] -omit = venv/* +omit = venv/*, httpcore/_sync/*, httpcore/_compat.py include = httpcore/*, tests/* diff --git a/tests/async_tests/__init__.py b/tests/_async/__init__.py similarity index 100% rename from tests/async_tests/__init__.py rename to tests/_async/__init__.py diff --git a/tests/_async/test_connection.py b/tests/_async/test_connection.py new file mode 100644 index 00000000..89ebb2a4 --- /dev/null +++ b/tests/_async/test_connection.py @@ -0,0 +1,187 @@ +import hpack +import hyperframe.frame +import pytest + +from httpcore import AsyncHTTPConnection, ConnectError, ConnectionNotAvailable, Origin +from httpcore.backends.base import AsyncNetworkStream +from httpcore.backends.mock import AsyncMockBackend + + +@pytest.mark.anyio +async def test_http_connection(): + origin = Origin(b"https", b"example.com", 443) + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend, keepalive_expiry=5.0 + ) as conn: + assert not conn.is_idle() + assert not conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert repr(conn) == "" + + async with conn.stream("GET", "https://example.com/") as response: + assert ( + repr(conn) + == "" + ) + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert not conn.is_closed() + assert conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + +@pytest.mark.anyio +async def test_concurrent_requests_not_available_on_http11_connections(): + """ + Attempting to issue a request against an already active HTTP/1.1 connection + will raise a `ConnectionNotAvailable` exception. + """ + origin = Origin(b"https", b"example.com", 443) + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend, keepalive_expiry=5.0 + ) as conn: + async with conn.stream("GET", "https://example.com/"): + with pytest.raises(ConnectionNotAvailable): + await conn.request("GET", "https://example.com/") + + +@pytest.mark.anyio +async def test_http2_connection(): + origin = Origin(b"https", b"example.com", 443) + network_backend = AsyncMockBackend( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ], + http2=True, + ) + + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend, http2=True + ) as conn: + response = await conn.request("GET", "https://example.com/") + + assert response.status == 200 + assert response.content == b"Hello, world!" + assert response.extensions["http_version"] == b"HTTP/2" + + +@pytest.mark.anyio +async def test_request_to_incorrect_origin(): + """ + A connection can only send requests whichever origin it is connected to. + """ + origin = Origin(b"https", b"example.com", 443) + network_backend = AsyncMockBackend([]) + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend + ) as conn: + with pytest.raises(RuntimeError): + await conn.request("GET", "https://other.com/") + + +class NeedsRetryBackend(AsyncMockBackend): + def __init__(self, *args, **kwargs) -> None: + self._retry = 2 + super().__init__(*args, **kwargs) + + async def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> AsyncNetworkStream: + if self._retry > 0: + self._retry -= 1 + raise ConnectError() + + return await super().connect_tcp( + host, port, timeout=timeout, local_address=local_address + ) + + +@pytest.mark.anyio +async def test_connection_retries(): + origin = Origin(b"https", b"example.com", 443) + content = [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + + network_backend = NeedsRetryBackend(content) + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend, retries=3 + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + + network_backend = NeedsRetryBackend(content) + async with AsyncHTTPConnection( + origin=origin, + network_backend=network_backend, + ) as conn: + with pytest.raises(ConnectError): + await conn.request("GET", "https://example.com/") + + +@pytest.mark.anyio +async def test_uds_connections(): + # We're not actually testing Unix Domain Sockets here, because we're just + # using a mock backend, but at least we're covering the UDS codepath + # in `connection.py` which we may as well do. + origin = Origin(b"https", b"example.com", 443) + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with AsyncHTTPConnection( + origin=origin, network_backend=network_backend, uds="/mock/example" + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 diff --git a/tests/_async/test_connection_pool.py b/tests/_async/test_connection_pool.py new file mode 100644 index 00000000..ca4091cd --- /dev/null +++ b/tests/_async/test_connection_pool.py @@ -0,0 +1,399 @@ +from typing import List + +import pytest +import trio as concurrency + +from httpcore import AsyncConnectionPool, UnsupportedProtocol +from httpcore.backends.mock import AsyncMockBackend + + +@pytest.mark.anyio +async def test_connection_pool_with_keepalive(): + """ + By default HTTP/1.1 requests should be returned to the connection pool. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncConnectionPool( + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will return to the pool, IDLE. + async with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a second request to the same origin will reuse the existing IDLE connection. + async with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a request to a different origin will not reuse the existing IDLE connection. + async with pool.stream("GET", "http://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "", + "", + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [ + "", + "", + ] + + +@pytest.mark.anyio +async def test_connection_pool_with_close(): + """ + HTTP/1.1 requests that include a 'Connection: Close' header should + not be returned to the connection pool. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncConnectionPool(network_backend=network_backend) as pool: + # Sending an intial request, which once complete will not return to the pool. + async with pool.stream( + "GET", "https://example.com/", headers={"Connection": "close"} + ) as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [] + + +@pytest.mark.anyio +async def test_trace_request(): + """ + The 'trace' request extension allows for a callback function to inspect the + internal events that occur while sending a request. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + called = [] + + async def trace(name, kwargs): + called.append(name) + + async with AsyncConnectionPool(network_backend=network_backend) as pool: + await pool.request("GET", "https://example.com/", extensions={"trace": trace}) + + assert called == [ + "connection.connect_tcp.started", + "connection.connect_tcp.complete", + "connection.start_tls.started", + "connection.start_tls.complete", + "http11.send_request_headers.started", + "http11.send_request_headers.complete", + "http11.send_request_body.started", + "http11.send_request_body.complete", + "http11.receive_response_headers.started", + "http11.receive_response_headers.complete", + "http11.receive_response_body.started", + "http11.receive_response_body.complete", + "http11.response_closed.started", + "http11.response_closed.complete", + ] + + +@pytest.mark.anyio +async def test_connection_pool_with_exception(): + """ + HTTP/1.1 requests that result in an exception should not be returned to the + connection pool. + """ + network_backend = AsyncMockBackend([b"Wait, this isn't valid HTTP!"]) + + called = [] + + async def trace(name, kwargs): + called.append(name) + + async with AsyncConnectionPool(network_backend=network_backend) as pool: + # Sending an initial request, which once complete will not return to the pool. + with pytest.raises(Exception): + await pool.request( + "GET", "https://example.com/", extensions={"trace": trace} + ) + + info = [repr(c) for c in pool.connections] + assert info == [] + + assert called == [ + "connection.connect_tcp.started", + "connection.connect_tcp.complete", + "connection.start_tls.started", + "connection.start_tls.complete", + "http11.send_request_headers.started", + "http11.send_request_headers.complete", + "http11.send_request_body.started", + "http11.send_request_body.complete", + "http11.receive_response_headers.started", + "http11.receive_response_headers.failed", + "http11.response_closed.started", + "http11.response_closed.complete", + ] + + +@pytest.mark.anyio +async def test_connection_pool_with_immediate_expiry(): + """ + Connection pools with keepalive_expiry=0.0 should immediately expire + keep alive connections. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncConnectionPool( + keepalive_expiry=0.0, + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will not return to the pool. + async with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [] + + +@pytest.mark.anyio +async def test_connection_pool_with_no_keepalive_connections_allowed(): + """ + When 'max_keepalive_connections=0' is used, IDLE connections should not + be returned to the pool. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncConnectionPool( + max_keepalive_connections=0, network_backend=network_backend + ) as pool: + # Sending an intial request, which once complete will not return to the pool. + async with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [] + + +@pytest.mark.trio +async def test_connection_pool_concurrency(): + """ + HTTP/1.1 requests made in concurrency must not ever exceed the maximum number + of allowable connection in the pool. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async def fetch(pool, domain, info_list): + async with pool.stream("GET", f"http://{domain}/") as response: + info = [repr(c) for c in pool.connections] + info_list.append(info) + await response.aread() + + async with AsyncConnectionPool( + max_connections=1, network_backend=network_backend + ) as pool: + info_list: List[str] = [] + async with concurrency.open_nursery() as nursery: + for domain in ["a.com", "b.com", "c.com", "d.com", "e.com"]: + nursery.start_soon(fetch, pool, domain, info_list) + + # Check that each time we inspect the connection pool, only a + # single connection was established. + for item in info_list: + assert len(item) == 1 + assert item[0] in [ + "", + "", + "", + "", + "", + ] + + +@pytest.mark.trio +async def test_connection_pool_concurrency_same_domain_closing(): + """ + HTTP/1.1 requests made in concurrency must not ever exceed the maximum number + of allowable connection in the pool. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async def fetch(pool, domain, info_list): + async with pool.stream("GET", f"https://{domain}/") as response: + info = [repr(c) for c in pool.connections] + info_list.append(info) + await response.aread() + + async with AsyncConnectionPool( + max_connections=1, network_backend=network_backend, http2=True + ) as pool: + info_list: List[str] = [] + async with concurrency.open_nursery() as nursery: + for domain in ["a.com", "a.com", "a.com", "a.com", "a.com"]: + nursery.start_soon(fetch, pool, domain, info_list) + + # Check that each time we inspect the connection pool, only a + # single connection was established. + for item in info_list: + assert len(item) == 1 + assert item[0] in [ + "", + "", + "", + "", + "", + ] + + +@pytest.mark.trio +async def test_connection_pool_concurrency_same_domain_keepalive(): + """ + HTTP/1.1 requests made in concurrency must not ever exceed the maximum number + of allowable connection in the pool. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + * 5 + ) + + async def fetch(pool, domain, info_list): + async with pool.stream("GET", f"https://{domain}/") as response: + info = [repr(c) for c in pool.connections] + info_list.append(info) + await response.aread() + + async with AsyncConnectionPool( + max_connections=1, network_backend=network_backend, http2=True + ) as pool: + info_list: List[str] = [] + async with concurrency.open_nursery() as nursery: + for domain in ["a.com", "a.com", "a.com", "a.com", "a.com"]: + nursery.start_soon(fetch, pool, domain, info_list) + + # Check that each time we inspect the connection pool, only a + # single connection was established. + for item in info_list: + assert len(item) == 1 + assert item[0] in [ + "", + "", + "", + "", + "", + ] + + +@pytest.mark.anyio +async def test_unsupported_protocol(): + async with AsyncConnectionPool() as pool: + with pytest.raises(UnsupportedProtocol): + await pool.request("GET", "ftp://www.example.com/") + + with pytest.raises(UnsupportedProtocol): + await pool.request("GET", "://www.example.com/") diff --git a/tests/_async/test_http11.py b/tests/_async/test_http11.py new file mode 100644 index 00000000..6798262d --- /dev/null +++ b/tests/_async/test_http11.py @@ -0,0 +1,179 @@ +import pytest + +from httpcore import ( + AsyncHTTP11Connection, + ConnectionNotAvailable, + LocalProtocolError, + Origin, + RemoteProtocolError, +) +from httpcore.backends.mock import AsyncMockStream + + +@pytest.mark.anyio +async def test_http11_connection(): + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with AsyncHTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert not conn.is_closed() + assert conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + +@pytest.mark.anyio +async def test_http11_connection_unread_response(): + """ + If the client releases the response without reading it to termination, + then the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/") as response: + assert response.status == 200 + + assert not conn.is_idle() + assert conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + +@pytest.mark.anyio +async def test_http11_connection_with_remote_protocol_error(): + """ + If a remote protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream([b"Wait, this isn't valid HTTP!", b""]) + async with AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RemoteProtocolError): + await conn.request("GET", "https://example.com/") + + assert not conn.is_idle() + assert conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + +@pytest.mark.anyio +async def test_http11_connection_with_local_protocol_error(): + """ + If a local protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + with pytest.raises(LocalProtocolError) as exc_info: + await conn.request("GET", "https://example.com/", headers={"Host": "\0"}) + + assert str(exc_info.value) == "Illegal header value b'\\x00'" + + assert not conn.is_idle() + assert conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + +@pytest.mark.anyio +async def test_http11_connection_handles_one_active_request(): + """ + Attempting to send a request while one is already in-flight will raise + a ConnectionNotAvailable exception. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/"): + with pytest.raises(ConnectionNotAvailable): + await conn.request("GET", "https://example.com/") + + +@pytest.mark.anyio +async def test_http11_connection_attempt_close(): + """ + A connection can only be closed when it is idle. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + async with AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/") as response: + await response.aread() + assert response.status == 200 + assert response.content == b"Hello, world!" + + +@pytest.mark.anyio +async def test_http11_request_to_incorrect_origin(): + """ + A connection can only send requests to whichever origin it is connected to. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream([]) + async with AsyncHTTP11Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RuntimeError): + await conn.request("GET", "https://other.com/") diff --git a/tests/_async/test_http2.py b/tests/_async/test_http2.py new file mode 100644 index 00000000..b40ee742 --- /dev/null +++ b/tests/_async/test_http2.py @@ -0,0 +1,233 @@ +import hpack +import hyperframe.frame +import pytest + +from httpcore import ( + AsyncHTTP2Connection, + ConnectionNotAvailable, + Origin, + RemoteProtocolError, +) +from httpcore.backends.mock import AsyncMockStream + + +@pytest.mark.anyio +async def test_http2_connection(): + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + async with AsyncHTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = await conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert conn.is_available() + assert not conn.is_closed() + assert not conn.has_expired() + assert ( + conn.info() == "'https://example.com:443', HTTP/2, IDLE, Request Count: 1" + ) + assert ( + repr(conn) + == "" + ) + + +@pytest.mark.anyio +async def test_http2_connection_post_request(): + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + response = await conn.request( + "POST", + "https://example.com/", + headers={b"content-length": b"17"}, + content=b'{"data": "upload"}', + ) + assert response.status == 200 + assert response.content == b"Hello, world!" + + +@pytest.mark.anyio +async def test_http11_connection_with_remote_protocol_error(): + """ + If a remote protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream([b"Wait, this isn't valid HTTP!", b""]) + async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RemoteProtocolError): + await conn.request("GET", "https://example.com/") + + +@pytest.mark.anyio +async def test_http11_connection_with_stream_cancelled(): + """ + If a remote protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.RstStreamFrame(stream_id=1, error_code=8).serialize(), + b"", + ] + ) + async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RemoteProtocolError): + await conn.request("GET", "https://example.com/") + + +@pytest.mark.anyio +async def test_http2_connection_with_flow_control(): + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + # Available flow: 65,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 75,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 85,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 95,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 105,535 + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"100,000 bytes received", flags=["END_STREAM"] + ).serialize(), + ] + ) + async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + response = await conn.request( + "POST", + "https://example.com/", + content=b"x" * 100_000, + ) + assert response.status == 200 + assert response.content == b"100,000 bytes received" + + +@pytest.mark.anyio +async def test_http2_connection_attempt_close(): + """ + A connection can only be closed when it is idle. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + async with conn.stream("GET", "https://example.com/") as response: + await response.aread() + assert response.status == 200 + assert response.content == b"Hello, world!" + + await conn.aclose() + with pytest.raises(ConnectionNotAvailable): + await conn.request("GET", "https://example.com/") + + +@pytest.mark.anyio +async def test_http2_request_to_incorrect_origin(): + """ + A connection can only send requests to whichever origin it is connected to. + """ + origin = Origin(b"https", b"example.com", 443) + stream = AsyncMockStream([]) + async with AsyncHTTP2Connection(origin=origin, stream=stream) as conn: + with pytest.raises(ConnectionNotAvailable): + await conn.request("GET", "https://other.com/") diff --git a/tests/_async/test_http_proxy.py b/tests/_async/test_http_proxy.py new file mode 100644 index 00000000..bdd72891 --- /dev/null +++ b/tests/_async/test_http_proxy.py @@ -0,0 +1,133 @@ +import pytest + +from httpcore import AsyncHTTPProxy, Origin, ProxyError +from httpcore.backends.mock import AsyncMockBackend + + +@pytest.mark.anyio +async def test_proxy_forwarding(): + """ + Send an HTTP request via a proxy. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncHTTPProxy( + proxy_url="http://localhost:8080/", + max_connections=10, + network_backend=network_backend, + ) as proxy: + # Sending an intial request, which once complete will return to the pool, IDLE. + async with proxy.stream("GET", "http://example.com/") as response: + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + assert proxy.connections[0].is_idle() + assert proxy.connections[0].is_available() + assert not proxy.connections[0].is_closed() + + # A connection on a forwarding proxy can handle HTTP requests to any host. + assert proxy.connections[0].can_handle_request( + Origin(b"http", b"example.com", 80) + ) + assert proxy.connections[0].can_handle_request( + Origin(b"http", b"other.com", 80) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"https", b"example.com", 443) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"https", b"other.com", 443) + ) + + +@pytest.mark.anyio +async def test_proxy_tunneling(): + """ + Send an HTTPS request via a proxy. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 200 OK\r\n" b"\r\n", + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + async with AsyncHTTPProxy( + proxy_url="http://localhost:8080/", + max_connections=10, + network_backend=network_backend, + ) as proxy: + # Sending an intial request, which once complete will return to the pool, IDLE. + async with proxy.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + await response.aread() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + assert proxy.connections[0].is_idle() + assert proxy.connections[0].is_available() + assert not proxy.connections[0].is_closed() + + # A connection on a tunneled proxy can only handle HTTPS requests to the same origin. + assert not proxy.connections[0].can_handle_request( + Origin(b"http", b"example.com", 80) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"http", b"other.com", 80) + ) + assert proxy.connections[0].can_handle_request( + Origin(b"https", b"example.com", 443) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"https", b"other.com", 443) + ) + + +@pytest.mark.anyio +async def test_proxy_tunneling_with_403(): + """ + Send an HTTPS request via a proxy. + """ + network_backend = AsyncMockBackend( + [ + b"HTTP/1.1 403 Permission Denied\r\n" b"\r\n", + ] + ) + + async with AsyncHTTPProxy( + proxy_url="http://localhost:8080/", + max_connections=10, + network_backend=network_backend, + ) as proxy: + with pytest.raises(ProxyError) as exc_info: + await proxy.request("GET", "https://example.com/") + assert str(exc_info.value) == "403 Permission Denied" + assert not proxy.connections diff --git a/tests/_async/test_integration.py b/tests/_async/test_integration.py new file mode 100644 index 00000000..ec7cb61b --- /dev/null +++ b/tests/_async/test_integration.py @@ -0,0 +1,51 @@ +import ssl + +import pytest + +from httpcore import AsyncConnectionPool + + +@pytest.mark.anyio +async def test_request(httpbin): + async with AsyncConnectionPool() as pool: + response = await pool.request("GET", httpbin.url) + assert response.status == 200 + + +@pytest.mark.anyio +async def test_ssl_request(httpbin_secure): + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + async with AsyncConnectionPool(ssl_context=ssl_context) as pool: + response = await pool.request("GET", httpbin_secure.url) + assert response.status == 200 + + +@pytest.mark.anyio +async def test_extra_info(httpbin_secure): + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + async with AsyncConnectionPool(ssl_context=ssl_context) as pool: + async with pool.stream("GET", httpbin_secure.url) as response: + assert response.status == 200 + stream = response.extensions["network_stream"] + + ssl_object = stream.get_extra_info("ssl_object") + assert ssl_object.version() == "TLSv1.3" + + local_addr = stream.get_extra_info("client_addr") + assert local_addr[0] == "127.0.0.1" + + remote_addr = stream.get_extra_info("server_addr") + assert "https://%s:%d" % remote_addr == httpbin_secure.url + + sock = stream.get_extra_info("socket") + assert hasattr(sock, "family") + assert hasattr(sock, "type") + + invalid = stream.get_extra_info("invalid") + assert invalid is None + + stream.get_extra_info("is_readable") diff --git a/tests/backend_tests/__init__.py b/tests/_sync/__init__.py similarity index 100% rename from tests/backend_tests/__init__.py rename to tests/_sync/__init__.py diff --git a/tests/_sync/test_connection.py b/tests/_sync/test_connection.py new file mode 100644 index 00000000..be91ebce --- /dev/null +++ b/tests/_sync/test_connection.py @@ -0,0 +1,187 @@ +import hpack +import hyperframe.frame +import pytest + +from httpcore import HTTPConnection, ConnectError, ConnectionNotAvailable, Origin +from httpcore.backends.base import NetworkStream +from httpcore.backends.mock import MockBackend + + + +def test_http_connection(): + origin = Origin(b"https", b"example.com", 443) + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with HTTPConnection( + origin=origin, network_backend=network_backend, keepalive_expiry=5.0 + ) as conn: + assert not conn.is_idle() + assert not conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert repr(conn) == "" + + with conn.stream("GET", "https://example.com/") as response: + assert ( + repr(conn) + == "" + ) + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert not conn.is_closed() + assert conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + + +def test_concurrent_requests_not_available_on_http11_connections(): + """ + Attempting to issue a request against an already active HTTP/1.1 connection + will raise a `ConnectionNotAvailable` exception. + """ + origin = Origin(b"https", b"example.com", 443) + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with HTTPConnection( + origin=origin, network_backend=network_backend, keepalive_expiry=5.0 + ) as conn: + with conn.stream("GET", "https://example.com/"): + with pytest.raises(ConnectionNotAvailable): + conn.request("GET", "https://example.com/") + + + +def test_http2_connection(): + origin = Origin(b"https", b"example.com", 443) + network_backend = MockBackend( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ], + http2=True, + ) + + with HTTPConnection( + origin=origin, network_backend=network_backend, http2=True + ) as conn: + response = conn.request("GET", "https://example.com/") + + assert response.status == 200 + assert response.content == b"Hello, world!" + assert response.extensions["http_version"] == b"HTTP/2" + + + +def test_request_to_incorrect_origin(): + """ + A connection can only send requests whichever origin it is connected to. + """ + origin = Origin(b"https", b"example.com", 443) + network_backend = MockBackend([]) + with HTTPConnection( + origin=origin, network_backend=network_backend + ) as conn: + with pytest.raises(RuntimeError): + conn.request("GET", "https://other.com/") + + +class NeedsRetryBackend(MockBackend): + def __init__(self, *args, **kwargs) -> None: + self._retry = 2 + super().__init__(*args, **kwargs) + + def connect_tcp( + self, host: str, port: int, timeout: float = None, local_address: str = None + ) -> NetworkStream: + if self._retry > 0: + self._retry -= 1 + raise ConnectError() + + return super().connect_tcp( + host, port, timeout=timeout, local_address=local_address + ) + + + +def test_connection_retries(): + origin = Origin(b"https", b"example.com", 443) + content = [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + + network_backend = NeedsRetryBackend(content) + with HTTPConnection( + origin=origin, network_backend=network_backend, retries=3 + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + + network_backend = NeedsRetryBackend(content) + with HTTPConnection( + origin=origin, + network_backend=network_backend, + ) as conn: + with pytest.raises(ConnectError): + conn.request("GET", "https://example.com/") + + + +def test_uds_connections(): + # We're not actually testing Unix Domain Sockets here, because we're just + # using a mock backend, but at least we're covering the UDS codepath + # in `connection.py` which we may as well do. + origin = Origin(b"https", b"example.com", 443) + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with HTTPConnection( + origin=origin, network_backend=network_backend, uds="/mock/example" + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 diff --git a/tests/_sync/test_connection_pool.py b/tests/_sync/test_connection_pool.py new file mode 100644 index 00000000..0d2e595d --- /dev/null +++ b/tests/_sync/test_connection_pool.py @@ -0,0 +1,399 @@ +from typing import List + +import pytest +from tests import concurrency + +from httpcore import ConnectionPool, UnsupportedProtocol +from httpcore.backends.mock import MockBackend + + + +def test_connection_pool_with_keepalive(): + """ + By default HTTP/1.1 requests should be returned to the connection pool. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with ConnectionPool( + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will return to the pool, IDLE. + with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a second request to the same origin will reuse the existing IDLE connection. + with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + + # Sending a request to a different origin will not reuse the existing IDLE connection. + with pool.stream("GET", "http://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "", + "", + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [ + "", + "", + ] + + + +def test_connection_pool_with_close(): + """ + HTTP/1.1 requests that include a 'Connection: Close' header should + not be returned to the connection pool. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with ConnectionPool(network_backend=network_backend) as pool: + # Sending an intial request, which once complete will not return to the pool. + with pool.stream( + "GET", "https://example.com/", headers={"Connection": "close"} + ) as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [] + + + +def test_trace_request(): + """ + The 'trace' request extension allows for a callback function to inspect the + internal events that occur while sending a request. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + called = [] + + def trace(name, kwargs): + called.append(name) + + with ConnectionPool(network_backend=network_backend) as pool: + pool.request("GET", "https://example.com/", extensions={"trace": trace}) + + assert called == [ + "connection.connect_tcp.started", + "connection.connect_tcp.complete", + "connection.start_tls.started", + "connection.start_tls.complete", + "http11.send_request_headers.started", + "http11.send_request_headers.complete", + "http11.send_request_body.started", + "http11.send_request_body.complete", + "http11.receive_response_headers.started", + "http11.receive_response_headers.complete", + "http11.receive_response_body.started", + "http11.receive_response_body.complete", + "http11.response_closed.started", + "http11.response_closed.complete", + ] + + + +def test_connection_pool_with_exception(): + """ + HTTP/1.1 requests that result in an exception should not be returned to the + connection pool. + """ + network_backend = MockBackend([b"Wait, this isn't valid HTTP!"]) + + called = [] + + def trace(name, kwargs): + called.append(name) + + with ConnectionPool(network_backend=network_backend) as pool: + # Sending an initial request, which once complete will not return to the pool. + with pytest.raises(Exception): + pool.request( + "GET", "https://example.com/", extensions={"trace": trace} + ) + + info = [repr(c) for c in pool.connections] + assert info == [] + + assert called == [ + "connection.connect_tcp.started", + "connection.connect_tcp.complete", + "connection.start_tls.started", + "connection.start_tls.complete", + "http11.send_request_headers.started", + "http11.send_request_headers.complete", + "http11.send_request_body.started", + "http11.send_request_body.complete", + "http11.receive_response_headers.started", + "http11.receive_response_headers.failed", + "http11.response_closed.started", + "http11.response_closed.complete", + ] + + + +def test_connection_pool_with_immediate_expiry(): + """ + Connection pools with keepalive_expiry=0.0 should immediately expire + keep alive connections. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with ConnectionPool( + keepalive_expiry=0.0, + network_backend=network_backend, + ) as pool: + # Sending an intial request, which once complete will not return to the pool. + with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [] + + + +def test_connection_pool_with_no_keepalive_connections_allowed(): + """ + When 'max_keepalive_connections=0' is used, IDLE connections should not + be returned to the pool. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with ConnectionPool( + max_keepalive_connections=0, network_backend=network_backend + ) as pool: + # Sending an intial request, which once complete will not return to the pool. + with pool.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in pool.connections] + assert info == [ + "" + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in pool.connections] + assert info == [] + + + +def test_connection_pool_concurrency(): + """ + HTTP/1.1 requests made in concurrency must not ever exceed the maximum number + of allowable connection in the pool. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + def fetch(pool, domain, info_list): + with pool.stream("GET", f"http://{domain}/") as response: + info = [repr(c) for c in pool.connections] + info_list.append(info) + response.read() + + with ConnectionPool( + max_connections=1, network_backend=network_backend + ) as pool: + info_list: List[str] = [] + with concurrency.open_nursery() as nursery: + for domain in ["a.com", "b.com", "c.com", "d.com", "e.com"]: + nursery.start_soon(fetch, pool, domain, info_list) + + # Check that each time we inspect the connection pool, only a + # single connection was established. + for item in info_list: + assert len(item) == 1 + assert item[0] in [ + "", + "", + "", + "", + "", + ] + + + +def test_connection_pool_concurrency_same_domain_closing(): + """ + HTTP/1.1 requests made in concurrency must not ever exceed the maximum number + of allowable connection in the pool. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + def fetch(pool, domain, info_list): + with pool.stream("GET", f"https://{domain}/") as response: + info = [repr(c) for c in pool.connections] + info_list.append(info) + response.read() + + with ConnectionPool( + max_connections=1, network_backend=network_backend, http2=True + ) as pool: + info_list: List[str] = [] + with concurrency.open_nursery() as nursery: + for domain in ["a.com", "a.com", "a.com", "a.com", "a.com"]: + nursery.start_soon(fetch, pool, domain, info_list) + + # Check that each time we inspect the connection pool, only a + # single connection was established. + for item in info_list: + assert len(item) == 1 + assert item[0] in [ + "", + "", + "", + "", + "", + ] + + + +def test_connection_pool_concurrency_same_domain_keepalive(): + """ + HTTP/1.1 requests made in concurrency must not ever exceed the maximum number + of allowable connection in the pool. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + * 5 + ) + + def fetch(pool, domain, info_list): + with pool.stream("GET", f"https://{domain}/") as response: + info = [repr(c) for c in pool.connections] + info_list.append(info) + response.read() + + with ConnectionPool( + max_connections=1, network_backend=network_backend, http2=True + ) as pool: + info_list: List[str] = [] + with concurrency.open_nursery() as nursery: + for domain in ["a.com", "a.com", "a.com", "a.com", "a.com"]: + nursery.start_soon(fetch, pool, domain, info_list) + + # Check that each time we inspect the connection pool, only a + # single connection was established. + for item in info_list: + assert len(item) == 1 + assert item[0] in [ + "", + "", + "", + "", + "", + ] + + + +def test_unsupported_protocol(): + with ConnectionPool() as pool: + with pytest.raises(UnsupportedProtocol): + pool.request("GET", "ftp://www.example.com/") + + with pytest.raises(UnsupportedProtocol): + pool.request("GET", "://www.example.com/") diff --git a/tests/_sync/test_http11.py b/tests/_sync/test_http11.py new file mode 100644 index 00000000..6ad7814c --- /dev/null +++ b/tests/_sync/test_http11.py @@ -0,0 +1,179 @@ +import pytest + +from httpcore import ( + HTTP11Connection, + ConnectionNotAvailable, + LocalProtocolError, + Origin, + RemoteProtocolError, +) +from httpcore.backends.mock import MockStream + + + +def test_http11_connection(): + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with HTTP11Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert not conn.is_closed() + assert conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + + +def test_http11_connection_unread_response(): + """ + If the client releases the response without reading it to termination, + then the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with HTTP11Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/") as response: + assert response.status == 200 + + assert not conn.is_idle() + assert conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + + +def test_http11_connection_with_remote_protocol_error(): + """ + If a remote protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream([b"Wait, this isn't valid HTTP!", b""]) + with HTTP11Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RemoteProtocolError): + conn.request("GET", "https://example.com/") + + assert not conn.is_idle() + assert conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + + +def test_http11_connection_with_local_protocol_error(): + """ + If a local protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with HTTP11Connection(origin=origin, stream=stream) as conn: + with pytest.raises(LocalProtocolError) as exc_info: + conn.request("GET", "https://example.com/", headers={"Host": "\0"}) + + assert str(exc_info.value) == "Illegal header value b'\\x00'" + + assert not conn.is_idle() + assert conn.is_closed() + assert not conn.is_available() + assert not conn.has_expired() + assert ( + repr(conn) + == "" + ) + + + +def test_http11_connection_handles_one_active_request(): + """ + Attempting to send a request while one is already in-flight will raise + a ConnectionNotAvailable exception. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with HTTP11Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/"): + with pytest.raises(ConnectionNotAvailable): + conn.request("GET", "https://example.com/") + + + +def test_http11_connection_attempt_close(): + """ + A connection can only be closed when it is idle. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + with HTTP11Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/") as response: + response.read() + assert response.status == 200 + assert response.content == b"Hello, world!" + + + +def test_http11_request_to_incorrect_origin(): + """ + A connection can only send requests to whichever origin it is connected to. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream([]) + with HTTP11Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RuntimeError): + conn.request("GET", "https://other.com/") diff --git a/tests/_sync/test_http2.py b/tests/_sync/test_http2.py new file mode 100644 index 00000000..062e68d7 --- /dev/null +++ b/tests/_sync/test_http2.py @@ -0,0 +1,233 @@ +import hpack +import hyperframe.frame +import pytest + +from httpcore import ( + HTTP2Connection, + ConnectionNotAvailable, + Origin, + RemoteProtocolError, +) +from httpcore.backends.mock import MockStream + + + +def test_http2_connection(): + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with HTTP2Connection( + origin=origin, stream=stream, keepalive_expiry=5.0 + ) as conn: + response = conn.request("GET", "https://example.com/") + assert response.status == 200 + assert response.content == b"Hello, world!" + + assert conn.is_idle() + assert conn.is_available() + assert not conn.is_closed() + assert not conn.has_expired() + assert ( + conn.info() == "'https://example.com:443', HTTP/2, IDLE, Request Count: 1" + ) + assert ( + repr(conn) + == "" + ) + + + +def test_http2_connection_post_request(): + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with HTTP2Connection(origin=origin, stream=stream) as conn: + response = conn.request( + "POST", + "https://example.com/", + headers={b"content-length": b"17"}, + content=b'{"data": "upload"}', + ) + assert response.status == 200 + assert response.content == b"Hello, world!" + + + +def test_http11_connection_with_remote_protocol_error(): + """ + If a remote protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream([b"Wait, this isn't valid HTTP!", b""]) + with HTTP2Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RemoteProtocolError): + conn.request("GET", "https://example.com/") + + + +def test_http11_connection_with_stream_cancelled(): + """ + If a remote protocol error occurs, then no response will be returned, + and the connection will not be reusable. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.RstStreamFrame(stream_id=1, error_code=8).serialize(), + b"", + ] + ) + with HTTP2Connection(origin=origin, stream=stream) as conn: + with pytest.raises(RemoteProtocolError): + conn.request("GET", "https://example.com/") + + + +def test_http2_connection_with_flow_control(): + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + # Available flow: 65,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 75,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 85,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 95,535 + hyperframe.frame.WindowUpdateFrame( + stream_id=0, window_increment=10_000 + ).serialize(), + hyperframe.frame.WindowUpdateFrame( + stream_id=1, window_increment=10_000 + ).serialize(), + # Available flow: 105,535 + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"100,000 bytes received", flags=["END_STREAM"] + ).serialize(), + ] + ) + with HTTP2Connection(origin=origin, stream=stream) as conn: + response = conn.request( + "POST", + "https://example.com/", + content=b"x" * 100_000, + ) + assert response.status == 200 + assert response.content == b"100,000 bytes received" + + + +def test_http2_connection_attempt_close(): + """ + A connection can only be closed when it is idle. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream( + [ + hyperframe.frame.SettingsFrame().serialize(), + hyperframe.frame.HeadersFrame( + stream_id=1, + data=hpack.Encoder().encode( + [ + (b":status", b"200"), + (b"content-type", b"plain/text"), + ] + ), + flags=["END_HEADERS"], + ).serialize(), + hyperframe.frame.DataFrame( + stream_id=1, data=b"Hello, world!", flags=["END_STREAM"] + ).serialize(), + ] + ) + with HTTP2Connection(origin=origin, stream=stream) as conn: + with conn.stream("GET", "https://example.com/") as response: + response.read() + assert response.status == 200 + assert response.content == b"Hello, world!" + + conn.close() + with pytest.raises(ConnectionNotAvailable): + conn.request("GET", "https://example.com/") + + + +def test_http2_request_to_incorrect_origin(): + """ + A connection can only send requests to whichever origin it is connected to. + """ + origin = Origin(b"https", b"example.com", 443) + stream = MockStream([]) + with HTTP2Connection(origin=origin, stream=stream) as conn: + with pytest.raises(ConnectionNotAvailable): + conn.request("GET", "https://other.com/") diff --git a/tests/_sync/test_http_proxy.py b/tests/_sync/test_http_proxy.py new file mode 100644 index 00000000..28ea9a21 --- /dev/null +++ b/tests/_sync/test_http_proxy.py @@ -0,0 +1,133 @@ +import pytest + +from httpcore import HTTPProxy, Origin, ProxyError +from httpcore.backends.mock import MockBackend + + + +def test_proxy_forwarding(): + """ + Send an HTTP request via a proxy. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with HTTPProxy( + proxy_url="http://localhost:8080/", + max_connections=10, + network_backend=network_backend, + ) as proxy: + # Sending an intial request, which once complete will return to the pool, IDLE. + with proxy.stream("GET", "http://example.com/") as response: + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + assert proxy.connections[0].is_idle() + assert proxy.connections[0].is_available() + assert not proxy.connections[0].is_closed() + + # A connection on a forwarding proxy can handle HTTP requests to any host. + assert proxy.connections[0].can_handle_request( + Origin(b"http", b"example.com", 80) + ) + assert proxy.connections[0].can_handle_request( + Origin(b"http", b"other.com", 80) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"https", b"example.com", 443) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"https", b"other.com", 443) + ) + + + +def test_proxy_tunneling(): + """ + Send an HTTPS request via a proxy. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 200 OK\r\n" b"\r\n", + b"HTTP/1.1 200 OK\r\n", + b"Content-Type: plain/text\r\n", + b"Content-Length: 13\r\n", + b"\r\n", + b"Hello, world!", + ] + ) + + with HTTPProxy( + proxy_url="http://localhost:8080/", + max_connections=10, + network_backend=network_backend, + ) as proxy: + # Sending an intial request, which once complete will return to the pool, IDLE. + with proxy.stream("GET", "https://example.com/") as response: + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + response.read() + + assert response.status == 200 + assert response.content == b"Hello, world!" + info = [repr(c) for c in proxy.connections] + assert info == [ + "" + ] + assert proxy.connections[0].is_idle() + assert proxy.connections[0].is_available() + assert not proxy.connections[0].is_closed() + + # A connection on a tunneled proxy can only handle HTTPS requests to the same origin. + assert not proxy.connections[0].can_handle_request( + Origin(b"http", b"example.com", 80) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"http", b"other.com", 80) + ) + assert proxy.connections[0].can_handle_request( + Origin(b"https", b"example.com", 443) + ) + assert not proxy.connections[0].can_handle_request( + Origin(b"https", b"other.com", 443) + ) + + + +def test_proxy_tunneling_with_403(): + """ + Send an HTTPS request via a proxy. + """ + network_backend = MockBackend( + [ + b"HTTP/1.1 403 Permission Denied\r\n" b"\r\n", + ] + ) + + with HTTPProxy( + proxy_url="http://localhost:8080/", + max_connections=10, + network_backend=network_backend, + ) as proxy: + with pytest.raises(ProxyError) as exc_info: + proxy.request("GET", "https://example.com/") + assert str(exc_info.value) == "403 Permission Denied" + assert not proxy.connections diff --git a/tests/_sync/test_integration.py b/tests/_sync/test_integration.py new file mode 100644 index 00000000..42bf70b0 --- /dev/null +++ b/tests/_sync/test_integration.py @@ -0,0 +1,51 @@ +import ssl + +import pytest + +from httpcore import ConnectionPool + + + +def test_request(httpbin): + with ConnectionPool() as pool: + response = pool.request("GET", httpbin.url) + assert response.status == 200 + + + +def test_ssl_request(httpbin_secure): + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + with ConnectionPool(ssl_context=ssl_context) as pool: + response = pool.request("GET", httpbin_secure.url) + assert response.status == 200 + + + +def test_extra_info(httpbin_secure): + ssl_context = ssl.create_default_context() + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + with ConnectionPool(ssl_context=ssl_context) as pool: + with pool.stream("GET", httpbin_secure.url) as response: + assert response.status == 200 + stream = response.extensions["network_stream"] + + ssl_object = stream.get_extra_info("ssl_object") + assert ssl_object.version() == "TLSv1.3" + + local_addr = stream.get_extra_info("client_addr") + assert local_addr[0] == "127.0.0.1" + + remote_addr = stream.get_extra_info("server_addr") + assert "https://%s:%d" % remote_addr == httpbin_secure.url + + sock = stream.get_extra_info("socket") + assert hasattr(sock, "family") + assert hasattr(sock, "type") + + invalid = stream.get_extra_info("invalid") + assert invalid is None + + stream.get_extra_info("is_readable") diff --git a/tests/async_tests/test_connection_pool.py b/tests/async_tests/test_connection_pool.py deleted file mode 100644 index bb4693bb..00000000 --- a/tests/async_tests/test_connection_pool.py +++ /dev/null @@ -1,194 +0,0 @@ -from typing import AsyncIterator, Tuple - -import pytest - -import httpcore -from httpcore._async.base import ConnectionState -from httpcore._types import URL, Headers - - -class MockConnection(object): - def __init__(self, http_version): - self.origin = (b"http", b"example.org", 80) - self.state = ConnectionState.PENDING - self.is_http11 = http_version == "HTTP/1.1" - self.is_http2 = http_version == "HTTP/2" - self.stream_count = 0 - - async def handle_async_request( - self, - method: bytes, - url: URL, - headers: Headers = None, - stream: httpcore.AsyncByteStream = None, - extensions: dict = None, - ) -> Tuple[int, Headers, httpcore.AsyncByteStream, dict]: - self.state = ConnectionState.ACTIVE - self.stream_count += 1 - - async def on_close(): - self.stream_count -= 1 - if self.stream_count == 0: - self.state = ConnectionState.IDLE - - async def aiterator() -> AsyncIterator[bytes]: - yield b"" - - stream = httpcore.AsyncIteratorByteStream( - aiterator=aiterator(), aclose_func=on_close - ) - - return 200, [], stream, {} - - async def aclose(self): - pass - - def info(self) -> str: - return self.state.name - - def is_available(self): - if self.is_http11: - return self.state == ConnectionState.IDLE - else: - return self.state != ConnectionState.CLOSED - - def should_close(self): - return False - - def is_idle(self): - return self.state == ConnectionState.IDLE - - def is_closed(self): - return False - - -class ConnectionPool(httpcore.AsyncConnectionPool): - def __init__(self, http_version: str): - super().__init__() - self.http_version = http_version - assert http_version in ("HTTP/1.1", "HTTP/2") - - def _create_connection(self, **kwargs): - return MockConnection(self.http_version) - - -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - body = [] - async for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - await stream.aclose() - - -@pytest.mark.trio -@pytest.mark.parametrize("http_version", ["HTTP/1.1", "HTTP/2"]) -async def test_sequential_requests(http_version) -> None: - async with ConnectionPool(http_version=http_version) as http: - info = await http.get_connection_info() - assert info == {} - - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - await read_body(stream) - info = await http.get_connection_info() - assert info == {"http://example.org": ["IDLE"]} - - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - await read_body(stream) - info = await http.get_connection_info() - assert info == {"http://example.org": ["IDLE"]} - - -@pytest.mark.trio -async def test_concurrent_requests_h11() -> None: - async with ConnectionPool(http_version="HTTP/1.1") as http: - info = await http.get_connection_info() - assert info == {} - - response_1 = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - response_2 = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE", "ACTIVE"]} - - await read_body(stream_1) - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE", "IDLE"]} - - await read_body(stream_2) - info = await http.get_connection_info() - assert info == {"http://example.org": ["IDLE", "IDLE"]} - - -@pytest.mark.trio -async def test_concurrent_requests_h2() -> None: - async with ConnectionPool(http_version="HTTP/2") as http: - info = await http.get_connection_info() - assert info == {} - - response_1 = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - response_2 = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - await read_body(stream_1) - info = await http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - await read_body(stream_2) - info = await http.get_connection_info() - assert info == {"http://example.org": ["IDLE"]} diff --git a/tests/async_tests/test_http11.py b/tests/async_tests/test_http11.py deleted file mode 100644 index 0bcd1821..00000000 --- a/tests/async_tests/test_http11.py +++ /dev/null @@ -1,317 +0,0 @@ -import collections - -import pytest - -import httpcore -from httpcore._backends.auto import AsyncBackend, AsyncLock, AsyncSocketStream - - -class MockStream(AsyncSocketStream): - def __init__(self, http_buffer, disconnect): - self.read_buffer = collections.deque(http_buffer) - self.disconnect = disconnect - - def get_http_version(self) -> str: - return "HTTP/1.1" - - async def write(self, data, timeout): - pass - - async def read(self, n, timeout): - return self.read_buffer.popleft() - - async def aclose(self): - pass - - def is_readable(self): - return self.disconnect - - -class MockLock(AsyncLock): - async def release(self) -> None: - pass - - async def acquire(self) -> None: - pass - - -class MockBackend(AsyncBackend): - def __init__(self, http_buffer, disconnect=False): - self.http_buffer = http_buffer - self.disconnect = disconnect - - async def open_tcp_stream( - self, hostname, port, ssl_context, timeout, *, local_address - ): - return MockStream(self.http_buffer, self.disconnect) - - def create_lock(self): - return MockLock() - - -@pytest.mark.trio -async def test_get_request_with_connection_keepalive() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - ] - ) - - async with httpcore.AsyncConnectionPool(backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - # This second request will go out over the same connection. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - -@pytest.mark.trio -async def test_get_request_with_connection_close_header() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - b"", # Terminate the connection. - ] - ) - - async with httpcore.AsyncConnectionPool(backend=backend) as http: - # We're sending a request with 'Connection: close', so the connection - # does not remain in the pool once we've sent the request. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org"), (b"Connection", b"close")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == {} - - # The second request will go out over a new connection. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org"), (b"Connection", b"close")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == {} - - -@pytest.mark.trio -async def test_get_request_with_socket_disconnect_between_requests() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - ], - disconnect=True, - ) - - async with httpcore.AsyncConnectionPool(backend=backend) as http: - # Send an initial request. We're using a standard keep-alive - # connection, so the connection remains in the pool after completion. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - # On sending this second request, at the point of pool re-acquiry the - # socket indicates that it has disconnected, and we'll send the request - # over a new connection. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - -@pytest.mark.trio -async def test_get_request_with_unclean_close_after_first_request() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - b"", # Terminate the connection. - ], - ) - - async with httpcore.AsyncConnectionPool(backend=backend) as http: - # Send an initial request. We're using a standard keep-alive - # connection, so the connection remains in the pool after completion. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - # At this point we successfully write another request, but the socket - # read returns `b""`, indicating a premature close. - with pytest.raises(httpcore.RemoteProtocolError) as excinfo: - await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert str(excinfo.value) == "Server disconnected without sending a response." - - -@pytest.mark.trio -async def test_request_with_missing_host_header() -> None: - backend = MockBackend(http_buffer=[]) - - async with httpcore.AsyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.LocalProtocolError) as excinfo: - await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert str(excinfo.value) == "Missing mandatory Host: header" - - -@pytest.mark.trio -async def test_concurrent_get_requests() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - ] - ) - - async with httpcore.AsyncConnectionPool(backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response_1 = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream_1, extensions = response_1 - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, ACTIVE"] - } - - response_2 = await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream_2, extensions = response_2 - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, ACTIVE", "HTTP/1.1, ACTIVE"] - } - - await stream_1.aread() - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, ACTIVE", "HTTP/1.1, IDLE"] - } - - await stream_2.aread() - assert await http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE", "HTTP/1.1, IDLE"] - } diff --git a/tests/async_tests/test_http2.py b/tests/async_tests/test_http2.py deleted file mode 100644 index e5990067..00000000 --- a/tests/async_tests/test_http2.py +++ /dev/null @@ -1,249 +0,0 @@ -import collections - -import h2.config -import h2.connection -import pytest - -import httpcore -from httpcore._backends.auto import ( - AsyncBackend, - AsyncLock, - AsyncSemaphore, - AsyncSocketStream, -) - - -class MockStream(AsyncSocketStream): - def __init__(self, http_buffer, disconnect): - self.read_buffer = collections.deque(http_buffer) - self.disconnect = disconnect - - def get_http_version(self) -> str: - return "HTTP/2" - - async def write(self, data, timeout): - pass - - async def read(self, n, timeout): - return self.read_buffer.popleft() - - async def aclose(self): - pass - - def is_readable(self): - return self.disconnect - - -class MockLock(AsyncLock): - async def release(self): - pass - - async def acquire(self): - pass - - -class MockSemaphore(AsyncSemaphore): - def __init__(self): - pass - - async def acquire(self, timeout=None): - pass - - async def release(self): - pass - - -class MockBackend(AsyncBackend): - def __init__(self, http_buffer, disconnect=False): - self.http_buffer = http_buffer - self.disconnect = disconnect - - async def open_tcp_stream( - self, hostname, port, ssl_context, timeout, *, local_address - ): - return MockStream(self.http_buffer, self.disconnect) - - def create_lock(self): - return MockLock() - - def create_semaphore(self, max_value, exc_class): - return MockSemaphore() - - -class HTTP2BytesGenerator: - def __init__(self): - self.client_config = h2.config.H2Configuration(client_side=True) - self.client_conn = h2.connection.H2Connection(config=self.client_config) - self.server_config = h2.config.H2Configuration(client_side=False) - self.server_conn = h2.connection.H2Connection(config=self.server_config) - self.initialized = False - - def get_server_bytes( - self, request_headers, request_data, response_headers, response_data - ): - if not self.initialized: - self.client_conn.initiate_connection() - self.server_conn.initiate_connection() - self.initialized = True - - # Feed the request events to the client-side state machine - client_stream_id = self.client_conn.get_next_available_stream_id() - self.client_conn.send_headers(client_stream_id, headers=request_headers) - self.client_conn.send_data(client_stream_id, data=request_data, end_stream=True) - - # Determine the bytes that're sent out the client side, and feed them - # into the server-side state machine to get it into the correct state. - client_bytes = self.client_conn.data_to_send() - events = self.server_conn.receive_data(client_bytes) - server_stream_id = [ - event.stream_id - for event in events - if isinstance(event, h2.events.RequestReceived) - ][0] - - # Feed the response events to the server-side state machine - self.server_conn.send_headers(server_stream_id, headers=response_headers) - self.server_conn.send_data( - server_stream_id, data=response_data, end_stream=True - ) - - return self.server_conn.data_to_send() - - -@pytest.mark.trio -async def test_get_request() -> None: - bytes_generator = HTTP2BytesGenerator() - http_buffer = [ - bytes_generator.get_server_bytes( - request_headers=[ - (b":method", b"GET"), - (b":authority", b"www.example.com"), - (b":scheme", b"https"), - (b":path", "/"), - ], - request_data=b"", - response_headers=[ - (b":status", b"200"), - (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"), - (b"server", b"Apache"), - (b"content-length", b"13"), - (b"content-type", b"text/plain"), - ], - response_data=b"Hello, world.", - ), - bytes_generator.get_server_bytes( - request_headers=[ - (b":method", b"GET"), - (b":authority", b"www.example.com"), - (b":scheme", b"https"), - (b":path", "/"), - ], - request_data=b"", - response_headers=[ - (b":status", b"200"), - (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"), - (b"server", b"Apache"), - (b"content-length", b"13"), - (b"content-type", b"text/plain"), - ], - response_data=b"Hello, world.", - ), - ] - backend = MockBackend(http_buffer=http_buffer) - - async with httpcore.AsyncConnectionPool(http2=True, backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response = await http.handle_async_request( - method=b"GET", - url=(b"https", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "https://example.org": ["HTTP/2, IDLE, 0 streams"] - } - - # The second HTTP request will go out over the same connection. - response = await http.handle_async_request( - method=b"GET", - url=(b"https", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "https://example.org": ["HTTP/2, IDLE, 0 streams"] - } - - -@pytest.mark.trio -async def test_post_request() -> None: - bytes_generator = HTTP2BytesGenerator() - bytes_to_send = bytes_generator.get_server_bytes( - request_headers=[ - (b":method", b"POST"), - (b":authority", b"www.example.com"), - (b":scheme", b"https"), - (b":path", "/"), - (b"content-length", b"13"), - ], - request_data=b"Hello, world.", - response_headers=[ - (b":status", b"200"), - (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"), - (b"server", b"Apache"), - (b"content-length", b"13"), - (b"content-type", b"text/plain"), - ], - response_data=b"Hello, world.", - ) - backend = MockBackend(http_buffer=[bytes_to_send]) - - async with httpcore.AsyncConnectionPool(http2=True, backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response = await http.handle_async_request( - method=b"POST", - url=(b"https", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org"), (b"Content-length", b"13")], - stream=httpcore.ByteStream(b"Hello, world."), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = await stream.aread() - assert status_code == 200 - assert body == b"Hello, world." - assert await http.get_connection_info() == { - "https://example.org": ["HTTP/2, IDLE, 0 streams"] - } - - -@pytest.mark.trio -async def test_request_with_missing_host_header() -> None: - backend = MockBackend(http_buffer=[]) - - server_config = h2.config.H2Configuration(client_side=False) - server_conn = h2.connection.H2Connection(config=server_config) - server_conn.initiate_connection() - backend = MockBackend(http_buffer=[server_conn.data_to_send()]) - - async with httpcore.AsyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.LocalProtocolError) as excinfo: - await http.handle_async_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert str(excinfo.value) == "Missing mandatory Host: header" diff --git a/tests/async_tests/test_interfaces.py b/tests/async_tests/test_interfaces.py deleted file mode 100644 index ce547f82..00000000 --- a/tests/async_tests/test_interfaces.py +++ /dev/null @@ -1,605 +0,0 @@ -import platform -from typing import Tuple - -import pytest - -import httpcore -from httpcore._types import URL -from tests.conftest import HTTPS_SERVER_URL -from tests.utils import Server, lookup_async_backend - - -@pytest.fixture(params=["auto", "anyio"]) -def backend(request): - return request.param - - -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - body = [] - async for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - await stream.aclose() - - -def test_must_configure_either_http1_or_http2() -> None: - with pytest.raises(ValueError): - httpcore.AsyncConnectionPool(http1=False, http2=False) - - -@pytest.mark.anyio -async def test_http_request(backend: str, server: Server) -> None: - async with httpcore.AsyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.anyio -async def test_https_request(backend: str, https_server: Server) -> None: - async with httpcore.AsyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if https_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.anyio -@pytest.mark.parametrize( - "url", [(b"ftp", b"example.org", 443, b"/"), (b"", b"coolsite.org", 443, b"/")] -) -async def test_request_unsupported_protocol( - backend: str, url: Tuple[bytes, bytes, int, bytes] -) -> None: - async with httpcore.AsyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.UnsupportedProtocol): - await http.handle_async_request( - method=b"GET", - url=url, - headers=[(b"host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - -@pytest.mark.anyio -async def test_http2_request(backend: str, https_server: Server) -> None: - async with httpcore.AsyncConnectionPool(backend=backend, http2=True) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - assert extensions == {"http_version": b"HTTP/2"} - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.anyio -async def test_closing_http_request(backend: str, server: Server) -> None: - async with httpcore.AsyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header, (b"connection", b"close")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert origin not in http._connections # type: ignore - - -@pytest.mark.anyio -async def test_http_request_reuse_connection(backend: str, server: Server) -> None: - async with httpcore.AsyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.anyio -async def test_https_request_reuse_connection( - backend: str, https_server: Server -) -> None: - async with httpcore.AsyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if https_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if https_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.anyio -async def test_http_request_cannot_reuse_dropped_connection( - backend: str, server: Server -) -> None: - async with httpcore.AsyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - # Mock the connection as having been dropped. - connection = list(http._connections[origin])[0] # type: ignore - connection.is_socket_readable = lambda: True # type: ignore - - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"]) -@pytest.mark.anyio -async def test_http_proxy( - proxy_server: URL, proxy_mode: str, backend: str, server: Server -) -> None: - max_connections = 1 - async with httpcore.AsyncHTTPProxy( - proxy_server, - proxy_mode=proxy_mode, - max_connections=max_connections, - backend=backend, - ) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - - -@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"]) -@pytest.mark.parametrize("protocol,port", [(b"http", 80), (b"https", 443)]) -@pytest.mark.trio -# Filter out ssl module deprecation warnings and asyncio module resource warning, -# convert other warnings to errors. -@pytest.mark.filterwarnings("ignore:.*(SSLContext|PROTOCOL_TLS):DeprecationWarning") -@pytest.mark.filterwarnings("ignore::ResourceWarning:asyncio") -@pytest.mark.filterwarnings("error") -async def test_proxy_socket_does_not_leak_when_the_connection_hasnt_been_added_to_pool( - proxy_server: URL, - server: Server, - proxy_mode: str, - protocol: bytes, - port: int, -): - async with httpcore.AsyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http: - for _ in range(100): - try: - _ = await http.handle_async_request( - method=b"GET", - url=(protocol, b"blockedhost.example.com", port, b"/"), - headers=[(b"host", b"blockedhost.example.com")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - except (httpcore.ProxyError, httpcore.RemoteProtocolError): - pass - - -@pytest.mark.anyio -async def test_http_request_local_address(backend: str, server: Server) -> None: - if backend == "auto" and lookup_async_backend() == "trio": - pytest.skip("The trio backend does not support local_address") - - async with httpcore.AsyncConnectionPool( - backend=backend, local_address="0.0.0.0" - ) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -# mitmproxy does not support forwarding HTTPS requests -@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "TUNNEL_ONLY"]) -@pytest.mark.parametrize("http2", [False, True]) -@pytest.mark.anyio -async def test_proxy_https_requests( - proxy_server: URL, - proxy_mode: str, - http2: bool, - https_server: Server, -) -> None: - max_connections = 1 - async with httpcore.AsyncHTTPProxy( - proxy_server, - proxy_mode=proxy_mode, - max_connections=max_connections, - http2=http2, - ) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - _ = await read_body(stream) - - assert status_code == 200 - assert extensions["http_version"] == b"HTTP/2" if http2 else b"HTTP/1.1" - assert extensions.get("reason_phrase", b"") == b"" if http2 else b"OK" - - -@pytest.mark.parametrize( - "http2,keepalive_expiry,expected_during_active,expected_during_idle", - [ - ( - False, - 60.0, - {HTTPS_SERVER_URL: ["HTTP/1.1, ACTIVE", "HTTP/1.1, ACTIVE"]}, - {HTTPS_SERVER_URL: ["HTTP/1.1, IDLE", "HTTP/1.1, IDLE"]}, - ), - ( - True, - 60.0, - {HTTPS_SERVER_URL: ["HTTP/2, ACTIVE, 2 streams"]}, - {HTTPS_SERVER_URL: ["HTTP/2, IDLE, 0 streams"]}, - ), - ( - False, - 0.0, - {HTTPS_SERVER_URL: ["HTTP/1.1, ACTIVE", "HTTP/1.1, ACTIVE"]}, - {}, - ), - ( - True, - 0.0, - {HTTPS_SERVER_URL: ["HTTP/2, ACTIVE, 2 streams"]}, - {}, - ), - ], -) -@pytest.mark.anyio -async def test_connection_pool_get_connection_info( - http2: bool, - keepalive_expiry: float, - expected_during_active: dict, - expected_during_idle: dict, - backend: str, - https_server: Server, -) -> None: - async with httpcore.AsyncConnectionPool( - http2=http2, keepalive_expiry=keepalive_expiry, backend=backend - ) as http: - _, _, stream_1, _ = await http.handle_async_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - _, _, stream_2, _ = await http.handle_async_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - try: - stats = await http.get_connection_info() - assert stats == expected_during_active - finally: - await read_body(stream_1) - await read_body(stream_2) - - stats = await http.get_connection_info() - assert stats == expected_during_idle - - stats = await http.get_connection_info() - assert stats == {} - - -@pytest.mark.skipif( - platform.system() not in ("Linux", "Darwin"), - reason="Unix Domain Sockets only exist on Unix", -) -@pytest.mark.anyio -async def test_http_request_unix_domain_socket( - uds_server: Server, backend: str -) -> None: - uds = uds_server.uds - async with httpcore.AsyncConnectionPool(uds=uds, backend=backend) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", b"localhost", None, b"/"), - headers=[(b"host", b"localhost")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert status_code == 200 - reason_phrase = b"OK" if uds_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - body = await read_body(stream) - assert body == b"Hello, world!" - - -@pytest.mark.parametrize("max_keepalive", [1, 3, 5]) -@pytest.mark.parametrize("connections_number", [4]) -@pytest.mark.anyio -async def test_max_keepalive_connections_handled_correctly( - max_keepalive: int, connections_number: int, backend: str, server: Server -) -> None: - async with httpcore.AsyncConnectionPool( - max_keepalive_connections=max_keepalive, keepalive_expiry=60, backend=backend - ) as http: - connections_streams = [] - for _ in range(connections_number): - _, _, stream, _ = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - connections_streams.append(stream) - - try: - for i in range(len(connections_streams)): - await read_body(connections_streams[i]) - finally: - stats = await http.get_connection_info() - - connections_in_pool = next(iter(stats.values())) - assert len(connections_in_pool) == min(connections_number, max_keepalive) - - -@pytest.mark.anyio -async def test_explicit_backend_name(server: Server) -> None: - async with httpcore.AsyncConnectionPool(backend=lookup_async_backend()) as http: - status_code, headers, stream, extensions = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.anyio -@pytest.mark.usefixtures("too_many_open_files_minus_one") -@pytest.mark.skipif(platform.system() != "Linux", reason="Only a problem on Linux") -async def test_broken_socket_detection_many_open_files( - backend: str, server: Server -) -> None: - """ - Regression test for: https://github.com/encode/httpcore/issues/182 - """ - async with httpcore.AsyncConnectionPool(backend=backend) as http: - # * First attempt will be successful because it will grab the last - # available fd before what select() supports on the platform. - # * Second attempt would have failed without a fix, due to a "filedescriptor - # out of range in select()" exception. - for _ in range(2): - ( - status_code, - response_headers, - stream, - extensions, - ) = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - await read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.anyio -@pytest.mark.parametrize( - "url", - [ - pytest.param((b"http", b"localhost", 12345, b"/"), id="connection-refused"), - pytest.param( - (b"http", b"doesnotexistatall.org", None, b"/"), id="dns-resolution-failed" - ), - ], -) -async def test_cannot_connect_tcp(backend: str, url) -> None: - """ - A properly wrapped error is raised when connecting to the server fails. - """ - async with httpcore.AsyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.ConnectError): - await http.handle_async_request( - method=b"GET", - url=url, - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - -@pytest.mark.anyio -async def test_cannot_connect_uds(backend: str) -> None: - """ - A properly wrapped error is raised when connecting to the UDS server fails. - """ - uds = "/tmp/doesnotexist.sock" - async with httpcore.AsyncConnectionPool(backend=backend, uds=uds) as http: - with pytest.raises(httpcore.ConnectError): - await http.handle_async_request( - method=b"GET", - url=(b"http", b"localhost", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) diff --git a/tests/async_tests/test_retries.py b/tests/async_tests/test_retries.py deleted file mode 100644 index dc479890..00000000 --- a/tests/async_tests/test_retries.py +++ /dev/null @@ -1,200 +0,0 @@ -import queue -import time -from typing import Any, List, Optional - -import pytest - -import httpcore -from httpcore._backends.auto import AsyncSocketStream, AutoBackend -from tests.utils import Server - - -class AsyncMockBackend(AutoBackend): - def __init__(self) -> None: - super().__init__() - self._exceptions: queue.Queue[Optional[Exception]] = queue.Queue() - self._timestamps: List[float] = [] - - def push(self, *exceptions: Optional[Exception]) -> None: - for exc in exceptions: - self._exceptions.put(exc) - - def pop_open_tcp_stream_intervals(self) -> list: - intervals = [b - a for a, b in zip(self._timestamps, self._timestamps[1:])] - self._timestamps.clear() - return intervals - - async def open_tcp_stream(self, *args: Any, **kwargs: Any) -> AsyncSocketStream: - self._timestamps.append(time.time()) - exc = None if self._exceptions.empty() else self._exceptions.get_nowait() - if exc is not None: - raise exc - return await super().open_tcp_stream(*args, **kwargs) - - -async def read_body(stream: httpcore.AsyncByteStream) -> bytes: - try: - return b"".join([chunk async for chunk in stream]) - finally: - await stream.aclose() - - -@pytest.mark.anyio -async def test_no_retries(server: Server) -> None: - """ - By default, connection failures are not retried on. - """ - backend = AsyncMockBackend() - - async with httpcore.AsyncConnectionPool( - max_keepalive_connections=0, backend=backend - ) as http: - response = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) - - backend.push(httpcore.ConnectTimeout(), httpcore.ConnectError()) - - with pytest.raises(httpcore.ConnectTimeout): - await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - with pytest.raises(httpcore.ConnectError): - await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - -@pytest.mark.anyio -async def test_retries_enabled(server: Server) -> None: - """ - When retries are enabled, connection failures are retried on with - a fixed exponential backoff. - """ - backend = AsyncMockBackend() - retries = 10 # Large enough to not run out of retries within this test. - - async with httpcore.AsyncConnectionPool( - retries=retries, max_keepalive_connections=0, backend=backend - ) as http: - # Standard case, no failures. - response = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert backend.pop_open_tcp_stream_intervals() == [] - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) - - # One failure, then success. - backend.push(httpcore.ConnectError(), None) - response = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - ] - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) - - # Three failures, then success. - backend.push( - httpcore.ConnectError(), - httpcore.ConnectTimeout(), - httpcore.ConnectTimeout(), - None, - ) - response = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - pytest.approx(0.5, rel=0.1), # First backoff. - pytest.approx(1.0, rel=0.1), # Second (increased) backoff. - ] - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) - - # Non-connect exceptions are not retried on. - backend.push(httpcore.ReadTimeout(), httpcore.NetworkError()) - with pytest.raises(httpcore.ReadTimeout): - await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - with pytest.raises(httpcore.NetworkError): - await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - -@pytest.mark.anyio -async def test_retries_exceeded(server: Server) -> None: - """ - When retries are enabled and connecting failures more than the configured number - of retries, connect exceptions are raised. - """ - backend = AsyncMockBackend() - retries = 1 - - async with httpcore.AsyncConnectionPool( - retries=retries, max_keepalive_connections=0, backend=backend - ) as http: - response = await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, _, stream, _ = response - assert status_code == 200 - await read_body(stream) - - # First failure is retried on, second one isn't. - backend.push(httpcore.ConnectError(), httpcore.ConnectTimeout()) - with pytest.raises(httpcore.ConnectTimeout): - await http.handle_async_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) diff --git a/tests/backend_tests/test_asyncio.py b/tests/backend_tests/test_asyncio.py deleted file mode 100644 index 37c5232c..00000000 --- a/tests/backend_tests/test_asyncio.py +++ /dev/null @@ -1,32 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from httpcore._backends.asyncio import SocketStream - - -class MockSocket: - def fileno(self): - return 1 - - -class TestSocketStream: - class TestIsReadable: - @pytest.mark.asyncio - async def test_returns_true_when_transport_has_no_socket(self): - stream_reader = MagicMock() - stream_reader._transport.get_extra_info.return_value = None - sock_stream = SocketStream(stream_reader, MagicMock()) - - assert sock_stream.is_readable() - - @pytest.mark.asyncio - async def test_returns_true_when_socket_is_readable(self): - stream_reader = MagicMock() - stream_reader._transport.get_extra_info.return_value = MockSocket() - sock_stream = SocketStream(stream_reader, MagicMock()) - - with patch( - "httpcore._utils.is_socket_readable", MagicMock(return_value=True) - ): - assert sock_stream.is_readable() diff --git a/tests/concurrency.py b/tests/concurrency.py new file mode 100644 index 00000000..aa80b3a8 --- /dev/null +++ b/tests/concurrency.py @@ -0,0 +1,40 @@ +""" +Some of our tests require branching of flow control. + +We'd like to have the same kind of test for both async and sync environments, +and so we have functionality here that replicate's Trio's `open_nursery` API, +but in a plain old multi-threaded context. + +We don't do any smarts around cancellations, or managing exceptions from +childen, because we don't need that for our use-case. +""" +import threading +from types import TracebackType +from typing import List, Type + + +class Nursery: + def __init__(self) -> None: + self._threads: List[threading.Thread] = [] + + def __enter__(self) -> "Nursery": + return self + + def __exit__( + self, + exc_type: Type[BaseException] = None, + exc_value: BaseException = None, + traceback: TracebackType = None, + ) -> None: + for thread in self._threads: + thread.start() + for thread in self._threads: + thread.join() + + def start_soon(self, func, *args): + thread = threading.Thread(target=func, args=args) + self._threads.append(thread) + + +def open_nursery() -> Nursery: + return Nursery() diff --git a/tests/conftest.py b/tests/conftest.py deleted file mode 100644 index 0ba96fdc..00000000 --- a/tests/conftest.py +++ /dev/null @@ -1,187 +0,0 @@ -import contextlib -import os -import threading -import time -import typing - -import pytest -import trustme - -from httpcore._types import URL - -from .utils import HypercornServer, LiveServer, Server, http_proxy_server - -try: - import hypercorn -except ImportError: # pragma: no cover # Python 3.6 - hypercorn = None # type: ignore - SERVER_HOST = "example.org" - SERVER_HTTP_PORT = 80 - SERVER_HTTPS_PORT = 443 - HTTPS_SERVER_URL = "https://example.org" -else: - SERVER_HOST = "localhost" - SERVER_HTTP_PORT = 8002 - SERVER_HTTPS_PORT = 8003 - HTTPS_SERVER_URL = f"https://localhost:{SERVER_HTTPS_PORT}" - - -@pytest.fixture(scope="session") -def proxy_server() -> typing.Iterator[URL]: - proxy_host = "127.0.0.1" - proxy_port = 8080 - - with http_proxy_server(proxy_host, proxy_port) as proxy_url: - yield proxy_url - - -async def app(scope: dict, receive: typing.Callable, send: typing.Callable) -> None: - assert scope["type"] == "http" - await send( - { - "type": "http.response.start", - "status": 200, - "headers": [[b"content-type", b"text/plain"]], - } - ) - await send({"type": "http.response.body", "body": b"Hello, world!"}) - - -@pytest.fixture(scope="session") -def uds() -> typing.Iterator[str]: - uds = "test_server.sock" - try: - yield uds - finally: - os.remove(uds) - - -@pytest.fixture(scope="session") -def uds_server(uds: str) -> typing.Iterator[Server]: - if hypercorn is not None: - server = HypercornServer(app=app, bind=f"unix:{uds}") - with server.serve_in_thread(): - yield server - else: - # On Python 3.6, use Uvicorn as a fallback. - import uvicorn - - class UvicornServer(Server, uvicorn.Server): - sends_reason = True - - @property - def uds(self) -> str: - uds = self.config.uds - assert uds is not None - return uds - - def install_signal_handlers(self) -> None: - pass - - @contextlib.contextmanager - def serve_in_thread(self) -> typing.Iterator[None]: - thread = threading.Thread(target=self.run) - thread.start() - try: - while not self.started: - time.sleep(1e-3) - yield - finally: - self.should_exit = True - thread.join() - - config = uvicorn.Config(app=app, lifespan="off", loop="asyncio", uds=uds) - server = UvicornServer(config=config) - with server.serve_in_thread(): - yield server - - -@pytest.fixture(scope="session") -def server() -> typing.Iterator[Server]: # pragma: no cover - server: Server # Please mypy. - - if hypercorn is None: - server = LiveServer(host=SERVER_HOST, port=SERVER_HTTP_PORT) - yield server - return - - server = HypercornServer(app=app, bind=f"{SERVER_HOST}:{SERVER_HTTP_PORT}") - with server.serve_in_thread(): - yield server - - -@pytest.fixture(scope="session") -def cert_authority() -> trustme.CA: - return trustme.CA() - - -@pytest.fixture(scope="session") -def localhost_cert(cert_authority: trustme.CA) -> trustme.LeafCert: - return cert_authority.issue_cert("localhost") - - -@pytest.fixture(scope="session") -def localhost_cert_path(localhost_cert: trustme.LeafCert) -> typing.Iterator[str]: - with localhost_cert.private_key_and_cert_chain_pem.tempfile() as tmp: - yield tmp - - -@pytest.fixture(scope="session") -def localhost_cert_pem_file(localhost_cert: trustme.LeafCert) -> typing.Iterator[str]: - with localhost_cert.cert_chain_pems[0].tempfile() as tmp: - yield tmp - - -@pytest.fixture(scope="session") -def localhost_cert_private_key_file( - localhost_cert: trustme.LeafCert, -) -> typing.Iterator[str]: - with localhost_cert.private_key_pem.tempfile() as tmp: - yield tmp - - -@pytest.fixture(scope="session") -def https_server( - localhost_cert_pem_file: str, localhost_cert_private_key_file: str -) -> typing.Iterator[Server]: # pragma: no cover - server: Server # Please mypy. - - if hypercorn is None: - server = LiveServer(host=SERVER_HOST, port=SERVER_HTTPS_PORT) - yield server - return - - server = HypercornServer( - app=app, - bind=f"{SERVER_HOST}:{SERVER_HTTPS_PORT}", - certfile=localhost_cert_pem_file, - keyfile=localhost_cert_private_key_file, - ) - with server.serve_in_thread(): - yield server - - -@pytest.fixture(scope="function") -def too_many_open_files_minus_one() -> typing.Iterator[None]: - # Fixture for test regression on https://github.com/encode/httpcore/issues/182 - # Max number of descriptors chosen according to: - # See: https://man7.org/linux/man-pages/man2/select.2.html#top_of_page - # "To monitor file descriptors greater than 1023, use poll or epoll instead." - max_num_descriptors = 1023 - - files = [] - - while True: - f = open("/dev/null") - # Leave one file descriptor available for a transport to perform - # a successful request. - if f.fileno() > max_num_descriptors - 1: - f.close() - break - files.append(f) - - try: - yield - finally: - for f in files: - f.close() diff --git a/tests/sync_tests/__init__.py b/tests/sync_tests/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/sync_tests/test_connection_pool.py b/tests/sync_tests/test_connection_pool.py deleted file mode 100644 index 7e49f712..00000000 --- a/tests/sync_tests/test_connection_pool.py +++ /dev/null @@ -1,194 +0,0 @@ -from typing import Iterator, Tuple - -import pytest - -import httpcore -from httpcore._async.base import ConnectionState -from httpcore._types import URL, Headers - - -class MockConnection(object): - def __init__(self, http_version): - self.origin = (b"http", b"example.org", 80) - self.state = ConnectionState.PENDING - self.is_http11 = http_version == "HTTP/1.1" - self.is_http2 = http_version == "HTTP/2" - self.stream_count = 0 - - def handle_request( - self, - method: bytes, - url: URL, - headers: Headers = None, - stream: httpcore.SyncByteStream = None, - extensions: dict = None, - ) -> Tuple[int, Headers, httpcore.SyncByteStream, dict]: - self.state = ConnectionState.ACTIVE - self.stream_count += 1 - - def on_close(): - self.stream_count -= 1 - if self.stream_count == 0: - self.state = ConnectionState.IDLE - - def iterator() -> Iterator[bytes]: - yield b"" - - stream = httpcore.IteratorByteStream( - iterator=iterator(), close_func=on_close - ) - - return 200, [], stream, {} - - def close(self): - pass - - def info(self) -> str: - return self.state.name - - def is_available(self): - if self.is_http11: - return self.state == ConnectionState.IDLE - else: - return self.state != ConnectionState.CLOSED - - def should_close(self): - return False - - def is_idle(self): - return self.state == ConnectionState.IDLE - - def is_closed(self): - return False - - -class ConnectionPool(httpcore.SyncConnectionPool): - def __init__(self, http_version: str): - super().__init__() - self.http_version = http_version - assert http_version in ("HTTP/1.1", "HTTP/2") - - def _create_connection(self, **kwargs): - return MockConnection(self.http_version) - - -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - body = [] - for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - stream.close() - - - -@pytest.mark.parametrize("http_version", ["HTTP/1.1", "HTTP/2"]) -def test_sequential_requests(http_version) -> None: - with ConnectionPool(http_version=http_version) as http: - info = http.get_connection_info() - assert info == {} - - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - read_body(stream) - info = http.get_connection_info() - assert info == {"http://example.org": ["IDLE"]} - - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - read_body(stream) - info = http.get_connection_info() - assert info == {"http://example.org": ["IDLE"]} - - - -def test_concurrent_requests_h11() -> None: - with ConnectionPool(http_version="HTTP/1.1") as http: - info = http.get_connection_info() - assert info == {} - - response_1 = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - response_2 = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE", "ACTIVE"]} - - read_body(stream_1) - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE", "IDLE"]} - - read_body(stream_2) - info = http.get_connection_info() - assert info == {"http://example.org": ["IDLE", "IDLE"]} - - - -def test_concurrent_requests_h2() -> None: - with ConnectionPool(http_version="HTTP/2") as http: - info = http.get_connection_info() - assert info == {} - - response_1 = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_1, headers_1, stream_1, ext_1 = response_1 - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - response_2 = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code_2, headers_2, stream_2, ext_2 = response_2 - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - read_body(stream_1) - info = http.get_connection_info() - assert info == {"http://example.org": ["ACTIVE"]} - - read_body(stream_2) - info = http.get_connection_info() - assert info == {"http://example.org": ["IDLE"]} diff --git a/tests/sync_tests/test_http11.py b/tests/sync_tests/test_http11.py deleted file mode 100644 index 5e00f692..00000000 --- a/tests/sync_tests/test_http11.py +++ /dev/null @@ -1,317 +0,0 @@ -import collections - -import pytest - -import httpcore -from httpcore._backends.sync import SyncBackend, SyncLock, SyncSocketStream - - -class MockStream(SyncSocketStream): - def __init__(self, http_buffer, disconnect): - self.read_buffer = collections.deque(http_buffer) - self.disconnect = disconnect - - def get_http_version(self) -> str: - return "HTTP/1.1" - - def write(self, data, timeout): - pass - - def read(self, n, timeout): - return self.read_buffer.popleft() - - def close(self): - pass - - def is_readable(self): - return self.disconnect - - -class MockLock(SyncLock): - def release(self) -> None: - pass - - def acquire(self) -> None: - pass - - -class MockBackend(SyncBackend): - def __init__(self, http_buffer, disconnect=False): - self.http_buffer = http_buffer - self.disconnect = disconnect - - def open_tcp_stream( - self, hostname, port, ssl_context, timeout, *, local_address - ): - return MockStream(self.http_buffer, self.disconnect) - - def create_lock(self): - return MockLock() - - - -def test_get_request_with_connection_keepalive() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - ] - ) - - with httpcore.SyncConnectionPool(backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - # This second request will go out over the same connection. - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - - -def test_get_request_with_connection_close_header() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - b"", # Terminate the connection. - ] - ) - - with httpcore.SyncConnectionPool(backend=backend) as http: - # We're sending a request with 'Connection: close', so the connection - # does not remain in the pool once we've sent the request. - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org"), (b"Connection", b"close")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == {} - - # The second request will go out over a new connection. - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org"), (b"Connection", b"close")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == {} - - - -def test_get_request_with_socket_disconnect_between_requests() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - ], - disconnect=True, - ) - - with httpcore.SyncConnectionPool(backend=backend) as http: - # Send an initial request. We're using a standard keep-alive - # connection, so the connection remains in the pool after completion. - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - # On sending this second request, at the point of pool re-acquiry the - # socket indicates that it has disconnected, and we'll send the request - # over a new connection. - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - - -def test_get_request_with_unclean_close_after_first_request() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - b"", # Terminate the connection. - ], - ) - - with httpcore.SyncConnectionPool(backend=backend) as http: - # Send an initial request. We're using a standard keep-alive - # connection, so the connection remains in the pool after completion. - response = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE"] - } - - # At this point we successfully write another request, but the socket - # read returns `b""`, indicating a premature close. - with pytest.raises(httpcore.RemoteProtocolError) as excinfo: - http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert str(excinfo.value) == "Server disconnected without sending a response." - - - -def test_request_with_missing_host_header() -> None: - backend = MockBackend(http_buffer=[]) - - with httpcore.SyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.LocalProtocolError) as excinfo: - http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert str(excinfo.value) == "Missing mandatory Host: header" - - - -def test_concurrent_get_requests() -> None: - backend = MockBackend( - http_buffer=[ - b"HTTP/1.1 200 OK\r\n", - b"Date: Sat, 06 Oct 2049 12:34:56 GMT\r\n", - b"Server: Apache\r\n", - b"Content-Length: 13\r\n", - b"Content-Type: text/plain\r\n", - b"\r\n", - b"Hello, world.", - ] - ) - - with httpcore.SyncConnectionPool(backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response_1 = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream_1, extensions = response_1 - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, ACTIVE"] - } - - response_2 = http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream_2, extensions = response_2 - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, ACTIVE", "HTTP/1.1, ACTIVE"] - } - - stream_1.read() - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, ACTIVE", "HTTP/1.1, IDLE"] - } - - stream_2.read() - assert http.get_connection_info() == { - "http://example.org": ["HTTP/1.1, IDLE", "HTTP/1.1, IDLE"] - } diff --git a/tests/sync_tests/test_http2.py b/tests/sync_tests/test_http2.py deleted file mode 100644 index 10b9badd..00000000 --- a/tests/sync_tests/test_http2.py +++ /dev/null @@ -1,249 +0,0 @@ -import collections - -import h2.config -import h2.connection -import pytest - -import httpcore -from httpcore._backends.sync import ( - SyncBackend, - SyncLock, - SyncSemaphore, - SyncSocketStream, -) - - -class MockStream(SyncSocketStream): - def __init__(self, http_buffer, disconnect): - self.read_buffer = collections.deque(http_buffer) - self.disconnect = disconnect - - def get_http_version(self) -> str: - return "HTTP/2" - - def write(self, data, timeout): - pass - - def read(self, n, timeout): - return self.read_buffer.popleft() - - def close(self): - pass - - def is_readable(self): - return self.disconnect - - -class MockLock(SyncLock): - def release(self): - pass - - def acquire(self): - pass - - -class MockSemaphore(SyncSemaphore): - def __init__(self): - pass - - def acquire(self, timeout=None): - pass - - def release(self): - pass - - -class MockBackend(SyncBackend): - def __init__(self, http_buffer, disconnect=False): - self.http_buffer = http_buffer - self.disconnect = disconnect - - def open_tcp_stream( - self, hostname, port, ssl_context, timeout, *, local_address - ): - return MockStream(self.http_buffer, self.disconnect) - - def create_lock(self): - return MockLock() - - def create_semaphore(self, max_value, exc_class): - return MockSemaphore() - - -class HTTP2BytesGenerator: - def __init__(self): - self.client_config = h2.config.H2Configuration(client_side=True) - self.client_conn = h2.connection.H2Connection(config=self.client_config) - self.server_config = h2.config.H2Configuration(client_side=False) - self.server_conn = h2.connection.H2Connection(config=self.server_config) - self.initialized = False - - def get_server_bytes( - self, request_headers, request_data, response_headers, response_data - ): - if not self.initialized: - self.client_conn.initiate_connection() - self.server_conn.initiate_connection() - self.initialized = True - - # Feed the request events to the client-side state machine - client_stream_id = self.client_conn.get_next_available_stream_id() - self.client_conn.send_headers(client_stream_id, headers=request_headers) - self.client_conn.send_data(client_stream_id, data=request_data, end_stream=True) - - # Determine the bytes that're sent out the client side, and feed them - # into the server-side state machine to get it into the correct state. - client_bytes = self.client_conn.data_to_send() - events = self.server_conn.receive_data(client_bytes) - server_stream_id = [ - event.stream_id - for event in events - if isinstance(event, h2.events.RequestReceived) - ][0] - - # Feed the response events to the server-side state machine - self.server_conn.send_headers(server_stream_id, headers=response_headers) - self.server_conn.send_data( - server_stream_id, data=response_data, end_stream=True - ) - - return self.server_conn.data_to_send() - - - -def test_get_request() -> None: - bytes_generator = HTTP2BytesGenerator() - http_buffer = [ - bytes_generator.get_server_bytes( - request_headers=[ - (b":method", b"GET"), - (b":authority", b"www.example.com"), - (b":scheme", b"https"), - (b":path", "/"), - ], - request_data=b"", - response_headers=[ - (b":status", b"200"), - (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"), - (b"server", b"Apache"), - (b"content-length", b"13"), - (b"content-type", b"text/plain"), - ], - response_data=b"Hello, world.", - ), - bytes_generator.get_server_bytes( - request_headers=[ - (b":method", b"GET"), - (b":authority", b"www.example.com"), - (b":scheme", b"https"), - (b":path", "/"), - ], - request_data=b"", - response_headers=[ - (b":status", b"200"), - (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"), - (b"server", b"Apache"), - (b"content-length", b"13"), - (b"content-type", b"text/plain"), - ], - response_data=b"Hello, world.", - ), - ] - backend = MockBackend(http_buffer=http_buffer) - - with httpcore.SyncConnectionPool(http2=True, backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response = http.handle_request( - method=b"GET", - url=(b"https", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "https://example.org": ["HTTP/2, IDLE, 0 streams"] - } - - # The second HTTP request will go out over the same connection. - response = http.handle_request( - method=b"GET", - url=(b"https", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "https://example.org": ["HTTP/2, IDLE, 0 streams"] - } - - - -def test_post_request() -> None: - bytes_generator = HTTP2BytesGenerator() - bytes_to_send = bytes_generator.get_server_bytes( - request_headers=[ - (b":method", b"POST"), - (b":authority", b"www.example.com"), - (b":scheme", b"https"), - (b":path", "/"), - (b"content-length", b"13"), - ], - request_data=b"Hello, world.", - response_headers=[ - (b":status", b"200"), - (b"date", b"Sat, 06 Oct 2049 12:34:56 GMT"), - (b"server", b"Apache"), - (b"content-length", b"13"), - (b"content-type", b"text/plain"), - ], - response_data=b"Hello, world.", - ) - backend = MockBackend(http_buffer=[bytes_to_send]) - - with httpcore.SyncConnectionPool(http2=True, backend=backend) as http: - # We're sending a request with a standard keep-alive connection, so - # it will remain in the pool once we've sent the request. - response = http.handle_request( - method=b"POST", - url=(b"https", b"example.org", None, b"/"), - headers=[(b"Host", b"example.org"), (b"Content-length", b"13")], - stream=httpcore.ByteStream(b"Hello, world."), - extensions={}, - ) - status_code, headers, stream, extensions = response - body = stream.read() - assert status_code == 200 - assert body == b"Hello, world." - assert http.get_connection_info() == { - "https://example.org": ["HTTP/2, IDLE, 0 streams"] - } - - - -def test_request_with_missing_host_header() -> None: - backend = MockBackend(http_buffer=[]) - - server_config = h2.config.H2Configuration(client_side=False) - server_conn = h2.connection.H2Connection(config=server_config) - server_conn.initiate_connection() - backend = MockBackend(http_buffer=[server_conn.data_to_send()]) - - with httpcore.SyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.LocalProtocolError) as excinfo: - http.handle_request( - method=b"GET", - url=(b"http", b"example.org", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert str(excinfo.value) == "Missing mandatory Host: header" diff --git a/tests/sync_tests/test_interfaces.py b/tests/sync_tests/test_interfaces.py deleted file mode 100644 index cc31ab34..00000000 --- a/tests/sync_tests/test_interfaces.py +++ /dev/null @@ -1,605 +0,0 @@ -import platform -from typing import Tuple - -import pytest - -import httpcore -from httpcore._types import URL -from tests.conftest import HTTPS_SERVER_URL -from tests.utils import Server, lookup_sync_backend - - -@pytest.fixture(params=["sync"]) -def backend(request): - return request.param - - -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - body = [] - for chunk in stream: - body.append(chunk) - return b"".join(body) - finally: - stream.close() - - -def test_must_configure_either_http1_or_http2() -> None: - with pytest.raises(ValueError): - httpcore.SyncConnectionPool(http1=False, http2=False) - - - -def test_http_request(backend: str, server: Server) -> None: - with httpcore.SyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - - -def test_https_request(backend: str, https_server: Server) -> None: - with httpcore.SyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if https_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - - -@pytest.mark.parametrize( - "url", [(b"ftp", b"example.org", 443, b"/"), (b"", b"coolsite.org", 443, b"/")] -) -def test_request_unsupported_protocol( - backend: str, url: Tuple[bytes, bytes, int, bytes] -) -> None: - with httpcore.SyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.UnsupportedProtocol): - http.handle_request( - method=b"GET", - url=url, - headers=[(b"host", b"example.org")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - - -def test_http2_request(backend: str, https_server: Server) -> None: - with httpcore.SyncConnectionPool(backend=backend, http2=True) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - assert extensions == {"http_version": b"HTTP/2"} - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - - -def test_closing_http_request(backend: str, server: Server) -> None: - with httpcore.SyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header, (b"connection", b"close")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert origin not in http._connections # type: ignore - - - -def test_http_request_reuse_connection(backend: str, server: Server) -> None: - with httpcore.SyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - - -def test_https_request_reuse_connection( - backend: str, https_server: Server -) -> None: - with httpcore.SyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if https_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if https_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"https", *https_server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - - -def test_http_request_cannot_reuse_dropped_connection( - backend: str, server: Server -) -> None: - with httpcore.SyncConnectionPool(backend=backend) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - # Mock the connection as having been dropped. - connection = list(http._connections[origin])[0] # type: ignore - connection.is_socket_readable = lambda: True # type: ignore - - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"]) - -def test_http_proxy( - proxy_server: URL, proxy_mode: str, backend: str, server: Server -) -> None: - max_connections = 1 - with httpcore.SyncHTTPProxy( - proxy_server, - proxy_mode=proxy_mode, - max_connections=max_connections, - backend=backend, - ) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - - -@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "FORWARD_ONLY", "TUNNEL_ONLY"]) -@pytest.mark.parametrize("protocol,port", [(b"http", 80), (b"https", 443)]) - -# Filter out ssl module deprecation warnings and asyncio module resource warning, -# convert other warnings to errors. -@pytest.mark.filterwarnings("ignore:.*(SSLContext|PROTOCOL_TLS):DeprecationWarning") -@pytest.mark.filterwarnings("ignore::ResourceWarning:asyncio") -@pytest.mark.filterwarnings("error") -def test_proxy_socket_does_not_leak_when_the_connection_hasnt_been_added_to_pool( - proxy_server: URL, - server: Server, - proxy_mode: str, - protocol: bytes, - port: int, -): - with httpcore.SyncHTTPProxy(proxy_server, proxy_mode=proxy_mode) as http: - for _ in range(100): - try: - _ = http.handle_request( - method=b"GET", - url=(protocol, b"blockedhost.example.com", port, b"/"), - headers=[(b"host", b"blockedhost.example.com")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - except (httpcore.ProxyError, httpcore.RemoteProtocolError): - pass - - - -def test_http_request_local_address(backend: str, server: Server) -> None: - if backend == "sync" and lookup_sync_backend() == "trio": - pytest.skip("The trio backend does not support local_address") - - with httpcore.SyncConnectionPool( - backend=backend, local_address="0.0.0.0" - ) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - -# mitmproxy does not support forwarding HTTPS requests -@pytest.mark.parametrize("proxy_mode", ["DEFAULT", "TUNNEL_ONLY"]) -@pytest.mark.parametrize("http2", [False, True]) - -def test_proxy_https_requests( - proxy_server: URL, - proxy_mode: str, - http2: bool, - https_server: Server, -) -> None: - max_connections = 1 - with httpcore.SyncHTTPProxy( - proxy_server, - proxy_mode=proxy_mode, - max_connections=max_connections, - http2=http2, - ) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - _ = read_body(stream) - - assert status_code == 200 - assert extensions["http_version"] == b"HTTP/2" if http2 else b"HTTP/1.1" - assert extensions.get("reason_phrase", b"") == b"" if http2 else b"OK" - - -@pytest.mark.parametrize( - "http2,keepalive_expiry,expected_during_active,expected_during_idle", - [ - ( - False, - 60.0, - {HTTPS_SERVER_URL: ["HTTP/1.1, ACTIVE", "HTTP/1.1, ACTIVE"]}, - {HTTPS_SERVER_URL: ["HTTP/1.1, IDLE", "HTTP/1.1, IDLE"]}, - ), - ( - True, - 60.0, - {HTTPS_SERVER_URL: ["HTTP/2, ACTIVE, 2 streams"]}, - {HTTPS_SERVER_URL: ["HTTP/2, IDLE, 0 streams"]}, - ), - ( - False, - 0.0, - {HTTPS_SERVER_URL: ["HTTP/1.1, ACTIVE", "HTTP/1.1, ACTIVE"]}, - {}, - ), - ( - True, - 0.0, - {HTTPS_SERVER_URL: ["HTTP/2, ACTIVE, 2 streams"]}, - {}, - ), - ], -) - -def test_connection_pool_get_connection_info( - http2: bool, - keepalive_expiry: float, - expected_during_active: dict, - expected_during_idle: dict, - backend: str, - https_server: Server, -) -> None: - with httpcore.SyncConnectionPool( - http2=http2, keepalive_expiry=keepalive_expiry, backend=backend - ) as http: - _, _, stream_1, _ = http.handle_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - _, _, stream_2, _ = http.handle_request( - method=b"GET", - url=(b"https", *https_server.netloc, b"/"), - headers=[https_server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - try: - stats = http.get_connection_info() - assert stats == expected_during_active - finally: - read_body(stream_1) - read_body(stream_2) - - stats = http.get_connection_info() - assert stats == expected_during_idle - - stats = http.get_connection_info() - assert stats == {} - - -@pytest.mark.skipif( - platform.system() not in ("Linux", "Darwin"), - reason="Unix Domain Sockets only exist on Unix", -) - -def test_http_request_unix_domain_socket( - uds_server: Server, backend: str -) -> None: - uds = uds_server.uds - with httpcore.SyncConnectionPool(uds=uds, backend=backend) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", b"localhost", None, b"/"), - headers=[(b"host", b"localhost")], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert status_code == 200 - reason_phrase = b"OK" if uds_server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - body = read_body(stream) - assert body == b"Hello, world!" - - -@pytest.mark.parametrize("max_keepalive", [1, 3, 5]) -@pytest.mark.parametrize("connections_number", [4]) - -def test_max_keepalive_connections_handled_correctly( - max_keepalive: int, connections_number: int, backend: str, server: Server -) -> None: - with httpcore.SyncConnectionPool( - max_keepalive_connections=max_keepalive, keepalive_expiry=60, backend=backend - ) as http: - connections_streams = [] - for _ in range(connections_number): - _, _, stream, _ = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - connections_streams.append(stream) - - try: - for i in range(len(connections_streams)): - read_body(connections_streams[i]) - finally: - stats = http.get_connection_info() - - connections_in_pool = next(iter(stats.values())) - assert len(connections_in_pool) == min(connections_number, max_keepalive) - - - -def test_explicit_backend_name(server: Server) -> None: - with httpcore.SyncConnectionPool(backend=lookup_sync_backend()) as http: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - - -@pytest.mark.usefixtures("too_many_open_files_minus_one") -@pytest.mark.skipif(platform.system() != "Linux", reason="Only a problem on Linux") -def test_broken_socket_detection_many_open_files( - backend: str, server: Server -) -> None: - """ - Regression test for: https://github.com/encode/httpcore/issues/182 - """ - with httpcore.SyncConnectionPool(backend=backend) as http: - # * First attempt will be successful because it will grab the last - # available fd before what select() supports on the platform. - # * Second attempt would have failed without a fix, due to a "filedescriptor - # out of range in select()" exception. - for _ in range(2): - ( - status_code, - response_headers, - stream, - extensions, - ) = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - - assert status_code == 200 - reason_phrase = b"OK" if server.sends_reason else b"" - assert extensions == { - "http_version": b"HTTP/1.1", - "reason_phrase": reason_phrase, - } - origin = (b"http", *server.netloc) - assert len(http._connections[origin]) == 1 # type: ignore - - - -@pytest.mark.parametrize( - "url", - [ - pytest.param((b"http", b"localhost", 12345, b"/"), id="connection-refused"), - pytest.param( - (b"http", b"doesnotexistatall.org", None, b"/"), id="dns-resolution-failed" - ), - ], -) -def test_cannot_connect_tcp(backend: str, url) -> None: - """ - A properly wrapped error is raised when connecting to the server fails. - """ - with httpcore.SyncConnectionPool(backend=backend) as http: - with pytest.raises(httpcore.ConnectError): - http.handle_request( - method=b"GET", - url=url, - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - - -def test_cannot_connect_uds(backend: str) -> None: - """ - A properly wrapped error is raised when connecting to the UDS server fails. - """ - uds = "/tmp/doesnotexist.sock" - with httpcore.SyncConnectionPool(backend=backend, uds=uds) as http: - with pytest.raises(httpcore.ConnectError): - http.handle_request( - method=b"GET", - url=(b"http", b"localhost", None, b"/"), - headers=[], - stream=httpcore.ByteStream(b""), - extensions={}, - ) diff --git a/tests/sync_tests/test_retries.py b/tests/sync_tests/test_retries.py deleted file mode 100644 index 1e266270..00000000 --- a/tests/sync_tests/test_retries.py +++ /dev/null @@ -1,200 +0,0 @@ -import queue -import time -from typing import Any, List, Optional - -import pytest - -import httpcore -from httpcore._backends.sync import SyncSocketStream, SyncBackend -from tests.utils import Server - - -class SyncMockBackend(SyncBackend): - def __init__(self) -> None: - super().__init__() - self._exceptions: queue.Queue[Optional[Exception]] = queue.Queue() - self._timestamps: List[float] = [] - - def push(self, *exceptions: Optional[Exception]) -> None: - for exc in exceptions: - self._exceptions.put(exc) - - def pop_open_tcp_stream_intervals(self) -> list: - intervals = [b - a for a, b in zip(self._timestamps, self._timestamps[1:])] - self._timestamps.clear() - return intervals - - def open_tcp_stream(self, *args: Any, **kwargs: Any) -> SyncSocketStream: - self._timestamps.append(time.time()) - exc = None if self._exceptions.empty() else self._exceptions.get_nowait() - if exc is not None: - raise exc - return super().open_tcp_stream(*args, **kwargs) - - -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - return b"".join([chunk for chunk in stream]) - finally: - stream.close() - - - -def test_no_retries(server: Server) -> None: - """ - By default, connection failures are not retried on. - """ - backend = SyncMockBackend() - - with httpcore.SyncConnectionPool( - max_keepalive_connections=0, backend=backend - ) as http: - response = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) - - backend.push(httpcore.ConnectTimeout(), httpcore.ConnectError()) - - with pytest.raises(httpcore.ConnectTimeout): - http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - with pytest.raises(httpcore.ConnectError): - http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - - -def test_retries_enabled(server: Server) -> None: - """ - When retries are enabled, connection failures are retried on with - a fixed exponential backoff. - """ - backend = SyncMockBackend() - retries = 10 # Large enough to not run out of retries within this test. - - with httpcore.SyncConnectionPool( - retries=retries, max_keepalive_connections=0, backend=backend - ) as http: - # Standard case, no failures. - response = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert backend.pop_open_tcp_stream_intervals() == [] - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) - - # One failure, then success. - backend.push(httpcore.ConnectError(), None) - response = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - ] - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) - - # Three failures, then success. - backend.push( - httpcore.ConnectError(), - httpcore.ConnectTimeout(), - httpcore.ConnectTimeout(), - None, - ) - response = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - assert backend.pop_open_tcp_stream_intervals() == [ - pytest.approx(0, abs=5e-3), # Retry immediately. - pytest.approx(0.5, rel=0.1), # First backoff. - pytest.approx(1.0, rel=0.1), # Second (increased) backoff. - ] - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) - - # Non-connect exceptions are not retried on. - backend.push(httpcore.ReadTimeout(), httpcore.NetworkError()) - with pytest.raises(httpcore.ReadTimeout): - http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - with pytest.raises(httpcore.NetworkError): - http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - - - -def test_retries_exceeded(server: Server) -> None: - """ - When retries are enabled and connecting failures more than the configured number - of retries, connect exceptions are raised. - """ - backend = SyncMockBackend() - retries = 1 - - with httpcore.SyncConnectionPool( - retries=retries, max_keepalive_connections=0, backend=backend - ) as http: - response = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - status_code, _, stream, _ = response - assert status_code == 200 - read_body(stream) - - # First failure is retried on, second one isn't. - backend.push(httpcore.ConnectError(), httpcore.ConnectTimeout()) - with pytest.raises(httpcore.ConnectTimeout): - http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) diff --git a/tests/test_api.py b/tests/test_api.py new file mode 100644 index 00000000..18dfb280 --- /dev/null +++ b/tests/test_api.py @@ -0,0 +1,20 @@ +import json + +import httpcore + + +def test_request(httpbin): + response = httpcore.request("GET", httpbin.url) + assert response.status == 200 + + +def test_stream(httpbin): + with httpcore.stream("GET", httpbin.url) as response: + assert response.status == 200 + + +def test_request_with_content(httpbin): + url = f"{httpbin.url}/post" + response = httpcore.request("POST", url, content=b'{"hello":"world"}') + assert response.status == 200 + assert json.loads(response.content)["json"] == {"hello": "world"} diff --git a/tests/test_exported_members.py b/tests/test_exported_members.py deleted file mode 100644 index 8ff0b458..00000000 --- a/tests/test_exported_members.py +++ /dev/null @@ -1,8 +0,0 @@ -import httpcore -from httpcore import __all__ as exported_members - - -def test_all_imports_are_exported() -> None: - assert exported_members == sorted( - member for member in vars(httpcore).keys() if not member.startswith("_") - ) diff --git a/tests/test_map_exceptions.py b/tests/test_map_exceptions.py deleted file mode 100644 index 22ada95f..00000000 --- a/tests/test_map_exceptions.py +++ /dev/null @@ -1,21 +0,0 @@ -import pytest - -from httpcore._exceptions import map_exceptions - - -def test_map_single_exception() -> None: - with pytest.raises(TypeError): - with map_exceptions({ValueError: TypeError}): - raise ValueError("nope") - - -def test_map_multiple_exceptions() -> None: - with pytest.raises(ValueError): - with map_exceptions({IndexError: ValueError, KeyError: ValueError}): - raise KeyError("nope") - - -def test_unhandled_map_exception() -> None: - with pytest.raises(TypeError): - with map_exceptions({IndexError: ValueError, KeyError: ValueError}): - raise TypeError("nope") diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 00000000..cb61f460 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,160 @@ +from typing import AsyncIterator, Iterator, List + +import pytest + +import httpcore + +# URL + + +def test_url(): + url = httpcore.URL("https://www.example.com/") + assert url == httpcore.URL( + scheme="https", host="www.example.com", port=None, target="/" + ) + assert bytes(url) == b"https://www.example.com/" + + +def test_url_with_port(): + url = httpcore.URL("https://www.example.com:443/") + assert url == httpcore.URL( + scheme="https", host="www.example.com", port=443, target="/" + ) + assert bytes(url) == b"https://www.example.com:443/" + + +def test_url_with_invalid_argument(): + with pytest.raises(TypeError) as exc_info: + httpcore.URL(123) # type: ignore + assert str(exc_info.value) == "url must be bytes or str, but got int." + + +def test_url_cannot_include_unicode_strings(): + """ + URLs instantiated with strings outside of the plain ASCII range are disallowed, + but the explicit style allows for these ambiguous cases to be precisely expressed. + """ + with pytest.raises(TypeError) as exc_info: + httpcore.URL("https://www.example.com/☺") + assert str(exc_info.value) == "url strings may not include unicode characters." + + httpcore.URL(scheme=b"https", host=b"www.example.com", target="/☺".encode("utf-8")) + + +# Request + + +def test_request(): + request = httpcore.Request("GET", "https://www.example.com/") + assert request.method == b"GET" + assert request.url == httpcore.URL("https://www.example.com/") + assert request.headers == [] + assert request.extensions == {} + assert repr(request) == "" + assert ( + repr(request.url) + == "URL(scheme=b'https', host=b'www.example.com', port=None, target=b'/')" + ) + assert repr(request.stream) == "" + + +def test_request_with_invalid_method(): + with pytest.raises(TypeError) as exc_info: + httpcore.Request(123, "https://www.example.com/") # type: ignore + assert str(exc_info.value) == "method must be bytes or str, but got int." + + +def test_request_with_invalid_url(): + with pytest.raises(TypeError) as exc_info: + httpcore.Request("GET", 123) # type: ignore + assert str(exc_info.value) == "url must be a URL, bytes, or str, but got int." + + +def test_request_with_invalid_headers(): + with pytest.raises(TypeError) as exc_info: + httpcore.Request("GET", "https://www.example.com/", headers=123) # type: ignore + assert str(exc_info.value) == "headers must be a list, but got int." + + +# Response + + +def test_response(): + response = httpcore.Response(200) + assert response.status == 200 + assert response.headers == [] + assert response.extensions == {} + assert repr(response) == "" + assert repr(response.stream) == "" + + +# Tests for reading and streaming sync byte streams... + + +class ByteIterator: + def __init__(self, chunks: List[bytes]) -> None: + self._chunks = chunks + + def __iter__(self) -> Iterator[bytes]: + for chunk in self._chunks: + yield chunk + + +def test_response_sync_read(): + stream = ByteIterator([b"Hello, ", b"world!"]) + response = httpcore.Response(200, content=stream) + assert response.read() == b"Hello, world!" + assert response.content == b"Hello, world!" + + +def test_response_sync_streaming(): + stream = ByteIterator([b"Hello, ", b"world!"]) + response = httpcore.Response(200, content=stream) + content = b"".join([chunk for chunk in response.iter_stream()]) + assert content == b"Hello, world!" + + # We streamed the response rather than reading it, so .content is not available. + with pytest.raises(RuntimeError): + response.content + + # Once we've streamed the response, we can't access the stream again. + with pytest.raises(RuntimeError): + for _chunk in response.iter_stream(): + pass # pragma: nocover + + +# Tests for reading and streaming async byte streams... + + +class AsyncByteIterator: + def __init__(self, chunks: List[bytes]) -> None: + self._chunks = chunks + + async def __aiter__(self) -> AsyncIterator[bytes]: + for chunk in self._chunks: + yield chunk + + +@pytest.mark.trio +async def test_response_async_read(): + stream = AsyncByteIterator([b"Hello, ", b"world!"]) + response = httpcore.Response(200, content=stream) + assert await response.aread() == b"Hello, world!" + assert response.content == b"Hello, world!" + + +@pytest.mark.trio +async def test_response_async_streaming(): + stream = AsyncByteIterator([b"Hello, ", b"world!"]) + response = httpcore.Response(200, content=stream) + content = b"".join([chunk async for chunk in response.aiter_stream()]) + assert content == b"Hello, world!" + + # We streamed the response rather than reading it, so .content is not available. + with pytest.raises(RuntimeError): + response.content + + # Once we've streamed the response, we can't access the stream again. + with pytest.raises(RuntimeError): + async for chunk in response.aiter_stream(): + pass # pragma: nocover diff --git a/tests/test_threadsafety.py b/tests/test_threadsafety.py deleted file mode 100644 index ae44cb85..00000000 --- a/tests/test_threadsafety.py +++ /dev/null @@ -1,49 +0,0 @@ -import concurrent.futures - -import pytest - -import httpcore - -from .utils import Server - - -def read_body(stream: httpcore.SyncByteStream) -> bytes: - try: - return b"".join(chunk for chunk in stream) - finally: - stream.close() - - -@pytest.mark.parametrize( - "http2", [pytest.param(False, id="h11"), pytest.param(True, id="h2")] -) -def test_threadsafe_basic(server: Server, http2: bool) -> None: - """ - The sync connection pool can be used to perform requests concurrently using - threads. - - Also a regression test for: https://github.com/encode/httpx/issues/1393 - """ - with httpcore.SyncConnectionPool(http2=http2) as http: - - def request(http: httpcore.SyncHTTPTransport) -> int: - status_code, headers, stream, extensions = http.handle_request( - method=b"GET", - url=(b"http", *server.netloc, b"/"), - headers=[server.host_header], - stream=httpcore.ByteStream(b""), - extensions={}, - ) - read_body(stream) - return status_code - - with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: - futures = [executor.submit(request, http) for _ in range(10)] - num_results = 0 - - for future in concurrent.futures.as_completed(futures): - status_code = future.result() - assert status_code == 200 - num_results += 1 - - assert num_results == 10 diff --git a/tests/test_utils.py b/tests/test_utils.py deleted file mode 100644 index 57f23bf1..00000000 --- a/tests/test_utils.py +++ /dev/null @@ -1,19 +0,0 @@ -import itertools -from typing import List - -import pytest - -from httpcore._utils import exponential_backoff - - -@pytest.mark.parametrize( - "factor, expected", - [ - (0.1, [0, 0.1, 0.2, 0.4, 0.8]), - (0.2, [0, 0.2, 0.4, 0.8, 1.6]), - (0.5, [0, 0.5, 1.0, 2.0, 4.0]), - ], -) -def test_exponential_backoff(factor: float, expected: List[int]) -> None: - delays = list(itertools.islice(exponential_backoff(factor), 5)) - assert delays == expected diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index de309125..00000000 --- a/tests/utils.py +++ /dev/null @@ -1,199 +0,0 @@ -import contextlib -import functools -import socket -import subprocess -import tempfile -import threading -import time -from typing import Callable, Iterator, List, Tuple - -import sniffio -import trio - -try: - from hypercorn import config as hypercorn_config, trio as hypercorn_trio -except ImportError: # pragma: no cover # Python 3.6 - hypercorn_config = None # type: ignore - hypercorn_trio = None # type: ignore - - -def lookup_async_backend(): - return sniffio.current_async_library() - - -def lookup_sync_backend(): - return "sync" - - -def _wait_can_connect(host: str, port: int): - while True: - try: - sock = socket.create_connection((host, port)) - except ConnectionRefusedError: - time.sleep(0.25) - else: - sock.close() - break - - -class Server: - """ - Base interface for servers we can test against. - """ - - @property - def sends_reason(self) -> bool: - raise NotImplementedError # pragma: no cover - - @property - def netloc(self) -> Tuple[bytes, int]: - raise NotImplementedError # pragma: no cover - - @property - def uds(self) -> str: - raise NotImplementedError # pragma: no cover - - @property - def host_header(self) -> Tuple[bytes, bytes]: - raise NotImplementedError # pragma: no cover - - -class LiveServer(Server): # pragma: no cover # Python 3.6 only - """ - A test server running on a live location. - """ - - sends_reason = True - - def __init__(self, host: str, port: int) -> None: - self._host = host - self._port = port - - @property - def netloc(self) -> Tuple[bytes, int]: - return (self._host.encode("ascii"), self._port) - - @property - def host_header(self) -> Tuple[bytes, bytes]: - return (b"host", self._host.encode("ascii")) - - -class HypercornServer(Server): # pragma: no cover # Python 3.7+ only - """ - A test server running in-process, powered by Hypercorn. - """ - - sends_reason = False - - def __init__( - self, - app: Callable, - bind: str, - certfile: str = None, - keyfile: str = None, - ) -> None: - assert hypercorn_config is not None - self._app = app - self._config = hypercorn_config.Config() - self._config.bind = [bind] - self._config.certfile = certfile - self._config.keyfile = keyfile - self._config.worker_class = "asyncio" - self._started = False - self._should_exit = False - - @property - def netloc(self) -> Tuple[bytes, int]: - bind = self._config.bind[0] - host, port = bind.split(":") - return (host.encode("ascii"), int(port)) - - @property - def host_header(self) -> Tuple[bytes, bytes]: - return (b"host", self.netloc[0]) - - @property - def uds(self) -> str: - bind = self._config.bind[0] - scheme, _, uds = bind.partition(":") - assert scheme == "unix" - return uds - - def _run(self) -> None: - async def shutdown_trigger() -> None: - while not self._should_exit: - await trio.sleep(0.01) - - serve = functools.partial( - hypercorn_trio.serve, shutdown_trigger=shutdown_trigger - ) - - async def main() -> None: - async with trio.open_nursery() as nursery: - await nursery.start(serve, self._app, self._config) - self._started = True - - trio.run(main) - - @contextlib.contextmanager - def serve_in_thread(self) -> Iterator[None]: - thread = threading.Thread(target=self._run) - thread.start() - try: - while not self._started: - time.sleep(1e-3) - yield - finally: - self._should_exit = True - thread.join() - - -@contextlib.contextmanager -def http_proxy_server(proxy_host: str, proxy_port: int): - """ - This function launches pproxy process like this: - $ pproxy -b -l http://127.0.0.1:8080 - What does it mean? - It runs HTTP proxy on 127.0.0.1:8080 and blocks access to some external hosts, - specified in blocked_hosts_file - - Relevant pproxy docs could be found in their github repo: - https://github.com/qwj/python-proxy - """ - proc = None - - with create_proxy_block_file(["blockedhost.example.com"]) as block_file_name: - try: - command = [ - "pproxy", - "-b", - block_file_name, - "-l", - f"http://{proxy_host}:{proxy_port}/", - ] - proc = subprocess.Popen(command) - - _wait_can_connect(proxy_host, proxy_port) - - yield b"http", proxy_host.encode(), proxy_port, b"/" - finally: - if proc is not None: - proc.kill() - - -@contextlib.contextmanager -def create_proxy_block_file(blocked_domains: List[str]): - """ - The context manager yields pproxy block file. - This file should contain line delimited hostnames. We use it in the following test: - test_proxy_socket_does_not_leak_when_the_connection_hasnt_been_added_to_pool - """ - with tempfile.NamedTemporaryFile(delete=True, mode="w+") as file: - - for domain in blocked_domains: - file.write(domain) - file.write("\n") - - file.flush() - - yield file.name diff --git a/unasync.py b/unasync.py index fb106ed6..a7b2e46f 100755 --- a/unasync.py +++ b/unasync.py @@ -4,10 +4,13 @@ import sys SUBS = [ - ('AsyncIteratorByteStream', 'IteratorByteStream'), + ('from .._compat import asynccontextmanager', 'from contextlib import contextmanager'), + ('from ..backends.auto import AutoBackend', 'from ..backends.sync import SyncBackend'), + ('import trio as concurrency', 'from tests import concurrency'), + ('AsyncByteStream', 'SyncByteStream'), ('AsyncIterator', 'Iterator'), ('AutoBackend', 'SyncBackend'), - ('Async([A-Z][A-Za-z0-9_]*)', r'Sync\2'), + ('Async([A-Z][A-Za-z0-9_]*)', r'\2'), ('async def', 'def'), ('async with', 'with'), ('async for', 'for'), @@ -17,15 +20,13 @@ ('aclose_func', 'close_func'), ('aiterator', 'iterator'), ('aread', 'read'), + ('asynccontextmanager', 'contextmanager'), ('__aenter__', '__enter__'), ('__aexit__', '__exit__'), ('__aiter__', '__iter__'), ('@pytest.mark.anyio', ''), ('@pytest.mark.trio', ''), - (r'@pytest.fixture\(params=\["auto", "anyio"\]\)', - '@pytest.fixture(params=["sync"])'), - ('lookup_async_backend', "lookup_sync_backend"), - ('auto', 'sync'), + ('AutoBackend', 'SyncBackend'), ] COMPILED_SUBS = [ (re.compile(r'(^|\b)' + regex + r'($|\b)'), repl) @@ -78,7 +79,7 @@ def unasync_dir(in_dir, out_dir, check_only=False): def main(): check_only = '--check' in sys.argv unasync_dir("httpcore/_async", "httpcore/_sync", check_only=check_only) - unasync_dir("tests/async_tests", "tests/sync_tests", check_only=check_only) + unasync_dir("tests/_async", "tests/_sync", check_only=check_only) if __name__ == '__main__':