From 3491974380605fe7890bfeb87f1e59458f2d0bed Mon Sep 17 00:00:00 2001 From: AntonKueltz Date: Mon, 13 Apr 2020 23:56:13 -0700 Subject: [PATCH] Fix error parsing length in DER signature decoding --- fastecdsa/encoding/asn1.py | 46 ++++++++++++++++++++ fastecdsa/encoding/der.py | 63 ++++++++++++++-------------- fastecdsa/tests/encoding/test_der.py | 8 ++-- 3 files changed, 83 insertions(+), 34 deletions(-) diff --git a/fastecdsa/encoding/asn1.py b/fastecdsa/encoding/asn1.py index da1913a..a865f81 100644 --- a/fastecdsa/encoding/asn1.py +++ b/fastecdsa/encoding/asn1.py @@ -13,6 +13,10 @@ PUBLIC_KEY = b'\xa1' +class ASN1EncodingError(Exception): + pass + + def _asn1_len(data: bytes) -> bytes: # https://www.itu.int/ITU-T/studygroups/com17/languages/X.690-0207.pdf # section 8.1.3.3 @@ -72,6 +76,16 @@ def asn1_public_key(Q: Point) -> bytes: def parse_asn1_length(data: bytes) -> (int, bytes, bytes): + """ + Parse an ASN.1 encoded structure. + + Args: + data (bytes): A sequence of bytes representing an ASN.1 encoded structure + + Returns: + (int, bytes, bytes): A tuple of the integer length in bytes, the byte representation of the integer, + and the remaining bytes after the integer bytes in the sequence + """ (initial_byte,) = unpack('=B', data[:1]) data = data[1:] @@ -84,4 +98,36 @@ def parse_asn1_length(data: bytes) -> (int, bytes, bytes): (length,) = unpack(fmt[count], data[:count]) data = data[count:] + if length > len(data): + raise ASN1EncodingError(f"Parsed length of ASN.1 structure to be {length} bytes but only {len(data)} bytes" + f"remain in the provided data") + return length, data[:length], data[length:] + + +def parse_asn1_int(data: bytes) -> (int, bytes, bytes): + """ + Parse an ASN.1 encoded integer. + + Args: + data (bytes): A sequence of bytes whose start is an ASN.1 integer encoding + + Returns: + (int, bytes, bytes): A tuple of the integer length in bytes, the byte representation of the integer, + and the remaining bytes after the integer bytes in the sequence + """ + + # encoding needs at least the type, length and data + if len(data) < 3: + raise ASN1EncodingError("ASN.1 encoded integer must be at least 3 bytes long") + # integer should be identified as ASN.1 integer + if data[0] != ord(INTEGER): + raise ASN1EncodingError("Value should be a ASN.1 INTEGER") + + length, data, remaining = parse_asn1_length(data[1:]) + + # integer length should match length indicated + if length != len(data): + raise ASN1EncodingError(f"Expected ASN.1 INTEGER to be {length} bytes, got {len(data)} bytes") + + return length, data, remaining diff --git a/fastecdsa/encoding/der.py b/fastecdsa/encoding/der.py index 70777b0..ac0ba3d 100644 --- a/fastecdsa/encoding/der.py +++ b/fastecdsa/encoding/der.py @@ -1,7 +1,7 @@ from struct import pack from . import SigEncoder -from .asn1 import INTEGER, SEQUENCE +from .asn1 import ASN1EncodingError, INTEGER, SEQUENCE, parse_asn1_int, parse_asn1_length from .util import bytes_to_int, int_to_bytes @@ -40,33 +40,34 @@ def decode_signature(sig: bytes) -> (int, int): Returns (r,s) """ - if len(sig) < 8: - raise InvalidDerSignature("bytestring too small") - if sig[0] != ord(SEQUENCE): - raise InvalidDerSignature("missing SEQUENCE marker") - if sig[1] != len(sig) - 2: - raise InvalidDerSignature("invalid length") - length_r = sig[3] - if 5 + length_r >= len(sig): - raise InvalidDerSignature("invalid length") - length_s = sig[5 + length_r] - if length_r + length_s + 6 != len(sig): - raise InvalidDerSignature("invalid length") - if sig[2] != ord(INTEGER): - raise InvalidDerSignature("invalid r marker") - if length_r == 0: - raise InvalidDerSignature("invalid r value") - if sig[4] & 0x80: - raise InvalidDerSignature("invalid r value") - if length_r > 1 and (sig[4] == 0x00) and not (sig[5] & 0x80): - raise InvalidDerSignature("invalid r value") - if sig[length_r + 4] != ord(INTEGER): - raise InvalidDerSignature("invalid s marker") - if length_s == 0: - raise InvalidDerSignature("invalid s value") - if sig[length_r + 6] & 0x80: - raise InvalidDerSignature("invalid s value") - if length_s > 1 and (sig[length_r + 6] == 0x00) and not (sig[length_r + 7] & 0x80): - raise InvalidDerSignature("invalid s value") - r_data, s_data = sig[4:4 + length_r], sig[6 + length_r:] - return bytes_to_int(r_data), bytes_to_int(s_data) + def _validate_int_bytes(data: bytes): + # check for negative values, indicated by leading 1 bit + if data[0] & 0x80: + raise InvalidDerSignature("Signature contains a negative value") + + # check for leading 0x00s that aren't there to disambiguate possible negative values + if data[0] == 0x00 and not data[1] & 0x80: + raise InvalidDerSignature("Invalid leading 0x00 byte in ASN.1 integer") + + # overarching structure must be a sequence + if not sig or sig[0] != ord(SEQUENCE): + raise InvalidDerSignature("First byte should be ASN.1 SEQUENCE") + + try: + seqlen, sequence, leftover = parse_asn1_length(sig[1:]) + except ASN1EncodingError as asn1_error: + raise InvalidDerSignature(asn1_error) + + # sequence should be entirety remaining data + if leftover: + raise InvalidDerSignature(f"Expected a sequence of {seqlen} bytes, got {len(sequence + leftover)}") + + try: + rlen, r, sdata = parse_asn1_int(sequence) + slen, s, _ = parse_asn1_int(sdata) + except ASN1EncodingError as asn1_error: + raise InvalidDerSignature(asn1_error) + + _validate_int_bytes(r) + _validate_int_bytes(s) + return bytes_to_int(r), bytes_to_int(s) diff --git a/fastecdsa/tests/encoding/test_der.py b/fastecdsa/tests/encoding/test_der.py index 2d4f2dc..93ac04c 100644 --- a/fastecdsa/tests/encoding/test_der.py +++ b/fastecdsa/tests/encoding/test_der.py @@ -59,15 +59,17 @@ def test_encode_signature(self): def test_decode_signature(self): with self.assertRaises(InvalidDerSignature): - DEREncoder.decode_signature(b"") # length to shot + DEREncoder.decode_signature(b"") # length too short with self.assertRaises(InvalidDerSignature): DEREncoder.decode_signature(b"\x31\x06\x02\x01\x01\x02\x01\x02") # invalid SEQUENCE marker with self.assertRaises(InvalidDerSignature): - DEREncoder.decode_signature(b"\x30\x07\x02\x01\x01\x02\x01\x02") # invalid length + DEREncoder.decode_signature(b"\x30\x07\x02\x01\x01\x02\x01\x02") # invalid length (too short) + with self.assertRaises(InvalidDerSignature): + DEREncoder.decode_signature(b"\x30\x05\x02\x01\x01\x02\x01\x02") # invalid length (too long) with self.assertRaises(InvalidDerSignature): DEREncoder.decode_signature(b"\x30\x06\x02\x03\x01\x02\x01\x02") # invalid length of r with self.assertRaises(InvalidDerSignature): - DEREncoder.decode_signature(b"\x30\x06\x02\x01\x01\x03\x01\x02") # invalid length of s + DEREncoder.decode_signature(b"\x30\x06\x02\x01\x01\x02\x03\x02") # invalid length of s with self.assertRaises(InvalidDerSignature): DEREncoder.decode_signature(b"\x30\x06\x03\x01\x01\x02\x01\x02") # invalid INTEGER marker for r with self.assertRaises(InvalidDerSignature):