|
3 | 3 | import httpx
|
4 | 4 | import pytest
|
5 | 5 |
|
| 6 | +from tests.protocols.test_http import HTTP_PROTOCOLS |
6 | 7 | from tests.response import Response
|
| 8 | +from tests.utils import run_server |
7 | 9 | from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
|
| 10 | +from uvicorn.config import Config |
8 | 11 | from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
|
| 12 | +from uvicorn.protocols.websockets.wsproto_impl import WSProtocol |
| 13 | + |
| 14 | +try: |
| 15 | + import websockets.client |
| 16 | + |
| 17 | + from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol |
| 18 | + |
| 19 | + WS_PROTOCOLS = [WSProtocol, WebSocketProtocol] |
| 20 | +except ImportError: # pragma: nocover |
| 21 | + WS_PROTOCOLS = [] |
9 | 22 |
|
10 | 23 |
|
11 | 24 | async def app(
|
@@ -103,3 +116,34 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:
|
103 | 116 | response = await client.get("/", headers=headers)
|
104 | 117 | assert response.status_code == 200
|
105 | 118 | assert response.text == "Remote: https://1.2.3.4:0"
|
| 119 | + |
| 120 | + |
| 121 | +@pytest.mark.anyio |
| 122 | +@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS) |
| 123 | +@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS) |
| 124 | +@pytest.mark.skipif(not WS_PROTOCOLS, reason="websockets module not installed.") |
| 125 | +async def test_proxy_headers_websocket_x_forwarded_proto( |
| 126 | + ws_protocol_cls, http_protocol_cls, unused_tcp_port: int |
| 127 | +) -> None: |
| 128 | + async def websocket_app(scope, receive, send): |
| 129 | + scheme = scope["scheme"] |
| 130 | + host, port = scope["client"] |
| 131 | + addr = "%s://%s:%d" % (scheme, host, port) |
| 132 | + await send({"type": "websocket.accept"}) |
| 133 | + await send({"type": "websocket.send", "text": addr}) |
| 134 | + |
| 135 | + app_with_middleware = ProxyHeadersMiddleware(websocket_app, trusted_hosts="*") |
| 136 | + config = Config( |
| 137 | + app=app_with_middleware, |
| 138 | + ws=ws_protocol_cls, |
| 139 | + http=http_protocol_cls, |
| 140 | + lifespan="off", |
| 141 | + port=unused_tcp_port, |
| 142 | + ) |
| 143 | + |
| 144 | + async with run_server(config): |
| 145 | + url = f"ws://127.0.0.1:{unused_tcp_port}" |
| 146 | + headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"} |
| 147 | + async with websockets.client.connect(url, extra_headers=headers) as websocket: |
| 148 | + data = await websocket.recv() |
| 149 | + assert data == "wss://1.2.3.4:0" |
0 commit comments