Skip to content

Commit 806c227

Browse files
authored
Fix x_forwarded_proto for websockets (#2043)
1 parent 57c6d57 commit 806c227

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

tests/middleware/test_proxy_headers.py

+44
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,22 @@
33
import httpx
44
import pytest
55

6+
from tests.protocols.test_http import HTTP_PROTOCOLS
67
from tests.response import Response
8+
from tests.utils import run_server
79
from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope
10+
from uvicorn.config import Config
811
from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware
12+
from uvicorn.protocols.websockets.wsproto_impl import WSProtocol
13+
14+
try:
15+
import websockets.client
16+
17+
from uvicorn.protocols.websockets.websockets_impl import WebSocketProtocol
18+
19+
WS_PROTOCOLS = [WSProtocol, WebSocketProtocol]
20+
except ImportError: # pragma: nocover
21+
WS_PROTOCOLS = []
922

1023

1124
async def app(
@@ -103,3 +116,34 @@ async def test_proxy_headers_invalid_x_forwarded_for() -> None:
103116
response = await client.get("/", headers=headers)
104117
assert response.status_code == 200
105118
assert response.text == "Remote: https://1.2.3.4:0"
119+
120+
121+
@pytest.mark.anyio
122+
@pytest.mark.parametrize("ws_protocol_cls", WS_PROTOCOLS)
123+
@pytest.mark.parametrize("http_protocol_cls", HTTP_PROTOCOLS)
124+
@pytest.mark.skipif(not WS_PROTOCOLS, reason="websockets module not installed.")
125+
async def test_proxy_headers_websocket_x_forwarded_proto(
126+
ws_protocol_cls, http_protocol_cls, unused_tcp_port: int
127+
) -> None:
128+
async def websocket_app(scope, receive, send):
129+
scheme = scope["scheme"]
130+
host, port = scope["client"]
131+
addr = "%s://%s:%d" % (scheme, host, port)
132+
await send({"type": "websocket.accept"})
133+
await send({"type": "websocket.send", "text": addr})
134+
135+
app_with_middleware = ProxyHeadersMiddleware(websocket_app, trusted_hosts="*")
136+
config = Config(
137+
app=app_with_middleware,
138+
ws=ws_protocol_cls,
139+
http=http_protocol_cls,
140+
lifespan="off",
141+
port=unused_tcp_port,
142+
)
143+
144+
async with run_server(config):
145+
url = f"ws://127.0.0.1:{unused_tcp_port}"
146+
headers = {"X-Forwarded-Proto": "https", "X-Forwarded-For": "1.2.3.4"}
147+
async with websockets.client.connect(url, extra_headers=headers) as websocket:
148+
data = await websocket.recv()
149+
assert data == "wss://1.2.3.4:0"

uvicorn/middleware/proxy_headers.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,15 @@ async def __call__(
5959
if b"x-forwarded-proto" in headers:
6060
# Determine if the incoming request was http or https based on
6161
# the X-Forwarded-Proto header.
62-
x_forwarded_proto = headers[b"x-forwarded-proto"].decode("latin1")
63-
scope["scheme"] = x_forwarded_proto.strip()
62+
x_forwarded_proto = (
63+
headers[b"x-forwarded-proto"].decode("latin1").strip()
64+
)
65+
if scope["type"] == "websocket":
66+
scope["scheme"] = (
67+
"wss" if x_forwarded_proto == "https" else "ws"
68+
)
69+
else:
70+
scope["scheme"] = x_forwarded_proto
6471

6572
if b"x-forwarded-for" in headers:
6673
# Determine the client address from the last trusted IP in the

0 commit comments

Comments
 (0)