Skip to content

Commit 4b79ec3

Browse files
committed
Merge pull request #898 from vaskalas/master
added headers to ClientSession.ws_connnect #785
2 parents 649dc17 + 05dc1d9 commit 4b79ec3

File tree

2 files changed

+59
-5
lines changed

2 files changed

+59
-5
lines changed

aiohttp/client.py

+15-4
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,8 @@ def ws_connect(self, url, *,
250250
autoclose=True,
251251
autoping=True,
252252
auth=None,
253-
origin=None):
253+
origin=None,
254+
headers=None):
254255
"""Initiate websocket connection."""
255256
return _WSRequestContextManager(
256257
self._ws_connect(url,
@@ -259,7 +260,8 @@ def ws_connect(self, url, *,
259260
autoclose=autoclose,
260261
autoping=autoping,
261262
auth=auth,
262-
origin=origin))
263+
origin=origin,
264+
headers=headers))
263265

264266
@asyncio.coroutine
265267
def _ws_connect(self, url, *,
@@ -268,16 +270,25 @@ def _ws_connect(self, url, *,
268270
autoclose=True,
269271
autoping=True,
270272
auth=None,
271-
origin=None):
273+
origin=None,
274+
headers=None):
272275

273276
sec_key = base64.b64encode(os.urandom(16))
274277

275-
headers = {
278+
if headers is None:
279+
headers = CIMultiDict()
280+
281+
default_headers = {
276282
hdrs.UPGRADE: hdrs.WEBSOCKET,
277283
hdrs.CONNECTION: hdrs.UPGRADE,
278284
hdrs.SEC_WEBSOCKET_VERSION: '13',
279285
hdrs.SEC_WEBSOCKET_KEY: sec_key.decode(),
280286
}
287+
288+
for key, value in default_headers.items():
289+
if key not in headers:
290+
headers[key] = value
291+
281292
if protocols:
282293
headers[hdrs.SEC_WEBSOCKET_PROTOCOL] = ','.join(protocols)
283294
if origin is not None:

tests/test_websocket_client_functional.py

+44-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import aiohttp
22
import asyncio
33
import pytest
4-
from aiohttp import helpers, web
4+
from aiohttp import helpers, hdrs, web
55

66

77
@pytest.mark.run_loop
@@ -283,3 +283,46 @@ def handler(request):
283283
yield from asyncio.sleep(0.1, loop=loop)
284284
assert resp.closed
285285
assert resp.exception() is None
286+
287+
288+
@pytest.mark.run_loop
289+
def test_override_default_headers(create_app_and_client, loop):
290+
291+
@asyncio.coroutine
292+
def handler(request):
293+
assert request.headers[hdrs.SEC_WEBSOCKET_VERSION] == '8'
294+
ws = web.WebSocketResponse()
295+
yield from ws.prepare(request)
296+
297+
ws.send_str('answer')
298+
yield from ws.close()
299+
return ws
300+
301+
app, client = yield from create_app_and_client()
302+
app.router.add_route('GET', '/', handler)
303+
headers = {hdrs.SEC_WEBSOCKET_VERSION: '8'}
304+
resp = yield from client.ws_connect('/', headers=headers)
305+
msg = yield from resp.receive()
306+
assert msg.data == 'answer'
307+
yield from resp.close()
308+
309+
310+
@pytest.mark.run_loop
311+
def test_additional_headers(create_app_and_client, loop):
312+
313+
@asyncio.coroutine
314+
def handler(request):
315+
assert request.headers['x-hdr'] == 'xtra'
316+
ws = web.WebSocketResponse()
317+
yield from ws.prepare(request)
318+
319+
ws.send_str('answer')
320+
yield from ws.close()
321+
return ws
322+
323+
app, client = yield from create_app_and_client()
324+
app.router.add_route('GET', '/', handler)
325+
resp = yield from client.ws_connect('/', headers={'x-hdr': 'xtra'})
326+
msg = yield from resp.receive()
327+
assert msg.data == 'answer'
328+
yield from resp.close()

0 commit comments

Comments
 (0)