Skip to content

Commit abe92fd

Browse files
author
Sergey Skripnick
committed
Check address family to fill wsgi env properly
1 parent 35b1a0a commit abe92fd

File tree

2 files changed

+46
-26
lines changed

2 files changed

+46
-26
lines changed

aiohttp/wsgi.py

+22-13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import inspect
1111
import io
1212
import os
13+
import socket
1314
import sys
1415
from urllib.parse import urlsplit
1516

@@ -90,20 +91,28 @@ def create_wsgi_environ(self, message, payload):
9091
# which this request is received from the client.
9192
# http://www.ietf.org/rfc/rfc3875
9293

93-
remote = self.transport.get_extra_info('peername')
94-
if remote:
95-
environ['REMOTE_ADDR'] = remote[0]
96-
environ['REMOTE_PORT'] = remote[1]
97-
_host, port = self.transport.get_extra_info('sockname')
98-
environ['SERVER_PORT'] = str(port)
99-
host = message.headers.get("HOST", None)
100-
# SERVER_NAME should be set to value of Host header, but this
101-
# header is not required. In this case we shoud set it to local
102-
# address of socket
103-
environ['SERVER_NAME'] = host.split(":")[0] if host else _host
94+
family = self.transport.get_extra_info('socket').family
95+
if family in (socket.AF_INET, socket.AF_INET6):
96+
peername = self.transport.get_extra_info('peername')
97+
environ['REMOTE_ADDR'] = peername[0]
98+
environ['REMOTE_PORT'] = str(peername[1])
99+
http_host = message.headers.get("HOST", None)
100+
if http_host:
101+
hostport = http_host.split(":")
102+
environ['SERVER_NAME'] = hostport[0]
103+
if len(hostport) > 1:
104+
environ['SERVER_PORT'] = str(hostport[1])
105+
else:
106+
environ['SERVER_PORT'] = '80'
107+
else:
108+
# SERVER_NAME should be set to value of Host header, but this
109+
# header is not required. In this case we shoud set it to local
110+
# address of socket
111+
sockname = self.transport.get_extra_info('sockname')
112+
environ['SERVER_NAME'] = sockname[0]
113+
environ['SERVER_PORT'] = str(sockname[1])
104114
else:
105-
# Dealing with unix socket, so request was received from client by
106-
# upstream server and this data may be found in the headers
115+
# We are behind reverse proxy, so get all vars from headers
107116
for header in ('REMOTE_ADDR', 'REMOTE_PORT',
108117
'SERVER_NAME', 'SERVER_PORT'):
109118
environ[header] = message.headers.get(header, '')

tests/test_wsgi.py

+24-13
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import io
44
import asyncio
5+
import socket
56
import unittest
67
import unittest.mock
78

@@ -22,8 +23,10 @@ def setUp(self):
2223
self.writer = unittest.mock.Mock()
2324
self.writer.drain.return_value = ()
2425
self.transport = unittest.mock.Mock()
25-
self.transport.get_extra_info.side_effect = [('1.2.3.4', 1234),
26-
('2.3.4.5', 80)]
26+
self.transport.get_extra_info.side_effect = [
27+
unittest.mock.Mock(family=socket.AF_INET),
28+
('1.2.3.4', 1234),
29+
('2.3.4.5', 80)]
2730

2831
self.headers = multidict.MultiDict({"HOST": "python.org"})
2932
self.message = protocol.RawRequestMessage(
@@ -78,26 +81,20 @@ def test_environ_headers(self):
7881
self.assertEqual(environ['SERVER_PORT'], '80')
7982
get_extra_info_calls = self.transport.get_extra_info.mock_calls
8083
expected_calls = [
84+
unittest.mock.call('socket'),
8185
unittest.mock.call('peername'),
82-
unittest.mock.call('sockname'),
8386
]
8487
self.assertEqual(expected_calls, get_extra_info_calls)
8588

8689
def test_environ_host_header_alternate_port(self):
87-
self.transport.get_extra_info = unittest.mock.Mock(
88-
side_effect=[('1.2.3.4', 1234), ('3.4.5.6', 82)]
89-
)
9090
self.headers.update({'HOST': 'example.com:9999'})
9191
environ = self._make_one()
92-
self.assertEqual(environ['SERVER_PORT'], '82')
92+
self.assertEqual(environ['SERVER_PORT'], '9999')
9393

9494
def test_environ_host_header_alternate_port_ssl(self):
95-
self.transport.get_extra_info = unittest.mock.Mock(
96-
side_effect=[('1.2.3.4', 1234), ('3.4.5.6', 82)]
97-
)
9895
self.headers.update({'HOST': 'example.com:9999'})
9996
environ = self._make_one(is_ssl=True)
100-
self.assertEqual(environ['SERVER_PORT'], '82')
97+
self.assertEqual(environ['SERVER_PORT'], '9999')
10198

10299
def test_wsgi_response(self):
103100
srv = self._make_srv()
@@ -276,8 +273,22 @@ def test_http_1_0_no_host(self):
276273
self.assertEqual(environ['SERVER_NAME'], '2.3.4.5')
277274
self.assertEqual(environ['SERVER_PORT'], '80')
278275

279-
def test_unix_socket(self):
280-
self.transport.get_extra_info = unittest.mock.Mock(return_value=None)
276+
def test_family_inet6(self):
277+
self.transport.get_extra_info.side_effect = [
278+
unittest.mock.Mock(family=socket.AF_INET6),
279+
("::", 1122, 0, 0),
280+
('2.3.4.5', 80)]
281+
self.message = protocol.RawRequestMessage(
282+
'GET', '/', (1, 0), self.headers, True, 'deflate')
283+
environ = self._make_one()
284+
self.assertEqual(environ['SERVER_NAME'], 'python.org')
285+
self.assertEqual(environ['SERVER_PORT'], '80')
286+
self.assertEqual(environ['REMOTE_ADDR'], '::')
287+
self.assertEqual(environ['REMOTE_PORT'], '1122')
288+
289+
def test_family_unix(self):
290+
self.transport.get_extra_info.side_effect = [
291+
unittest.mock.Mock(family=socket.AF_UNIX)]
281292
headers = multidict.MultiDict({
282293
'SERVER_NAME': '1.2.3.4', 'SERVER_PORT': '5678',
283294
'REMOTE_ADDR': '4.3.2.1', 'REMOTE_PORT': '8765'})

0 commit comments

Comments
 (0)