Skip to content

Commit 8661202

Browse files
committed
support expect_fingerprint as bytes
1 parent a24960c commit 8661202

File tree

2 files changed

+52
-22
lines changed

2 files changed

+52
-22
lines changed

aiohttp/connector.py

+29-11
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,20 @@
2828
PY_34 = sys.version_info >= (3, 4)
2929
PY_343 = sys.version_info >= (3, 4, 3)
3030

31-
HASHFUNC_BY_DIGESTLEN = {
31+
HASHFUNC_BY_HEXDIGESTLEN = {
3232
32: md5,
3333
40: sha1,
3434
64: sha256,
3535
}
36+
HASHFUNC_BY_BINDIGESTLEN = {
37+
16: md5,
38+
20: sha1,
39+
32: sha256,
40+
}
41+
HASHFUNCMAP_BY_DIGEST_TYPE = {
42+
str: HASHFUNC_BY_HEXDIGESTLEN,
43+
bytes: HASHFUNC_BY_BINDIGESTLEN,
44+
}
3645

3746

3847
class Connection(object):
@@ -356,9 +365,10 @@ class TCPConnector(BaseConnector):
356365
"""TCP connector.
357366
358367
:param bool verify_ssl: Set to True to check ssl certifications.
359-
:param str expect_fingerprint: Set to the md5, sha1, or sha256 fingerprint
360-
(as a hexadecimal string) of the expected certificate (DER-encoded)
361-
to verify the cert matches. May be interspersed with colons.
368+
:param str expect_fingerprint: Pass the md5, sha1, or sha256 fingerprint
369+
as either a hexadecimal string or binary bytestring of the expected
370+
certificate (in DER format) to verify the cert matches.
371+
If passing a hex string, colons and case are ignored.
362372
:param bool resolve: Set to True to do DNS lookup for host name.
363373
:param family: socket address family
364374
:param args: see :class:`BaseConnector`
@@ -378,15 +388,23 @@ def __init__(self, *, verify_ssl=True, expect_fingerprint=None,
378388
self._verify_ssl = verify_ssl
379389

380390
if expect_fingerprint:
381-
expect_fingerprint = expect_fingerprint.replace(':', '').lower()
382-
digestlen = len(expect_fingerprint)
383-
hashfunc = HASHFUNC_BY_DIGESTLEN.get(digestlen)
391+
xfp = expect_fingerprint
392+
digest_type = type(xfp)
393+
hashfuncmap = HASHFUNCMAP_BY_DIGEST_TYPE.get(digest_type)
394+
if not hashfuncmap:
395+
raise TypeError('expect_fingerprint must be str or bytes')
396+
is_str = digest_type is str
397+
if is_str:
398+
xfp = xfp.replace(':', '').lower()
399+
digestlen = len(xfp)
400+
hashfunc = hashfuncmap.get(digestlen)
384401
if not hashfunc:
385-
raise ValueError('Fingerprint is of invalid length.')
402+
raise ValueError('expect_fingerprint has invalid length')
386403
self._hashfunc = hashfunc
387-
self._fingerprint_bytes = unhexlify(expect_fingerprint)
388-
389-
self._expect_fingerprint = expect_fingerprint
404+
self._fingerprint_bytes = unhexlify(xfp) if is_str else xfp
405+
self._expect_fingerprint = xfp if is_str else hexlify(xfp)
406+
else:
407+
self._expect_fingerprint = None
390408
self._ssl_context = ssl_context
391409
self._family = family
392410
self._resolve = resolve

tests/test_connector.py

+23-11
Original file line numberDiff line numberDiff line change
@@ -468,19 +468,33 @@ def test_tcp_connector_expect_fingerprint_invalid_len(self):
468468
with self.assertRaises(ValueError):
469469
aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=invalid)
470470

471+
def test_tcp_connector_expect_fingerprint_invalid_type(self):
472+
invalid = 123
473+
with self.assertRaises(TypeError):
474+
aiohttp.TCPConnector(loop=self.loop, expect_fingerprint=invalid)
475+
471476
def test_tcp_connector_expect_fingerprint(self):
472-
# the even-index fingerprints below are for sample.crt.der,
473-
# the certificate presented by test_utils.run_server
477+
# The even-index fingerprints below are "expect success" cases
478+
# for ./sample.crt.der, the cert presented by test_utils.run_server.
479+
# The odd-index fingerprints are "expect fail" cases.
474480
testcases = (
475481
# md5
476-
'a20647adaaf5d85c4a995e62793b063d', # good
477-
'ffffffffffffffffffffffffffffffff', # bad
482+
'a2:06:47:ad:aa:f5:d8:5c:4a:99:5e:62:79:3b:06:3d', # good
483+
'ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff:ff', # bad
484+
485+
'A20647ADAAF5D85C4A995E62793B063D', # colons and case ignored
486+
'FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF',
487+
488+
b'\xa2\x06G\xad\xaa\xf5\xd8\\J\x99^by;\x06=', # bytes ok too
489+
b'\xff' * 16,
490+
478491
# sha1
479-
'7393fd3aed081d6fa9ae71391ae3c57f89e76cf9', # good
480-
'ffffffffffffffffffffffffffffffffffffffff', # bad
492+
'7393fd3aed081d6fa9ae71391ae3c57f89e76cf9',
493+
'ffffffffffffffffffffffffffffffffffffffff',
494+
481495
# sha256
482-
'309ac94483dc9127889111a16497fdcb7e37551444404c11ab99a8aeb714ee8b', # good # flake8: noqa
483-
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', # bad # flake8: noqa
496+
'309ac94483dc9127889111a16497fdcb7e37551444404c11ab99a8aeb714ee8b', # flake8: noqa
497+
'ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff', # flake8: noqa
484498
)
485499
for i, fingerprint in enumerate(testcases):
486500
expect_fail = i % 2
@@ -490,10 +504,8 @@ def test_tcp_connector_expect_fingerprint(self):
490504
coro = client.request('get', httpd.url('method', 'get'),
491505
connector=conn, loop=self.loop)
492506
if expect_fail:
493-
with self.assertRaises(FingerprintMismatch) as cm:
507+
with self.assertRaises(FingerprintMismatch):
494508
self.loop.run_until_complete(coro)
495-
self.assertEqual(cm.exception.expected, fingerprint)
496-
self.assertEqual(cm.exception.got, testcases[i-1])
497509
else:
498510
# should not raise
499511
self.loop.run_until_complete(coro)

0 commit comments

Comments
 (0)