diff options
author | Alex Gaynor <alex.gaynor@gmail.com> | 2019-07-28 22:58:04 -0400 |
---|---|---|
committer | Paul Kehrer <paul.l.kehrer@gmail.com> | 2019-07-28 21:58:04 -0500 |
commit | 9cd41ac714d9bff819ece6d8cdcde064d403c671 (patch) | |
tree | 4ed2502ced1db85417fbf6fe214f59a1525893ff | |
parent | 2c83570f6310cb36553af274eb41dd8e2b96b58e (diff) | |
download | cryptography-9cd41ac714d9bff819ece6d8cdcde064d403c671.tar.gz cryptography-9cd41ac714d9bff819ece6d8cdcde064d403c671.tar.bz2 cryptography-9cd41ac714d9bff819ece6d8cdcde064d403c671.zip |
Make DER reader into a context manager (#4957)
* Make DER reader into a context manager
* Added another test case
* flake8
-rw-r--r-- | src/cryptography/hazmat/_der.py | 12 | ||||
-rw-r--r-- | src/cryptography/hazmat/primitives/asymmetric/utils.py | 9 | ||||
-rw-r--r-- | src/cryptography/x509/extensions.py | 18 | ||||
-rw-r--r-- | tests/hazmat/test_der.py | 8 | ||||
-rw-r--r-- | tests/x509/test_x509.py | 72 |
5 files changed, 65 insertions, 54 deletions
diff --git a/src/cryptography/hazmat/_der.py b/src/cryptography/hazmat/_der.py index 3a121a85..51518d64 100644 --- a/src/cryptography/hazmat/_der.py +++ b/src/cryptography/hazmat/_der.py @@ -36,6 +36,13 @@ class DERReader(object): def __init__(self, data): self.data = memoryview(data) + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, tb): + if exc_value is None: + self.check_empty() + def is_empty(self): return len(self.data) == 0 @@ -100,9 +107,8 @@ class DERReader(object): return body def read_single_element(self, expected_tag): - ret = self.read_element(expected_tag) - self.check_empty() - return ret + with self: + return self.read_element(expected_tag) def read_optional_element(self, expected_tag): if len(self.data) > 0 and six.indexbytes(self.data, 0) == expected_tag: diff --git a/src/cryptography/hazmat/primitives/asymmetric/utils.py b/src/cryptography/hazmat/primitives/asymmetric/utils.py index 43d5b9bf..14d2abee 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/utils.py +++ b/src/cryptography/hazmat/primitives/asymmetric/utils.py @@ -12,11 +12,10 @@ from cryptography.hazmat.primitives import hashes def decode_dss_signature(signature): - seq = DERReader(signature).read_single_element(SEQUENCE) - r = seq.read_element(INTEGER).as_integer() - s = seq.read_element(INTEGER).as_integer() - seq.check_empty() - return r, s + with DERReader(signature).read_single_element(SEQUENCE) as seq: + r = seq.read_element(INTEGER).as_integer() + s = seq.read_element(INTEGER).as_integer() + return r, s def encode_dss_signature(r, s): diff --git a/src/cryptography/x509/extensions.py b/src/cryptography/x509/extensions.py index c78c76c2..5bef9945 100644 --- a/src/cryptography/x509/extensions.py +++ b/src/cryptography/x509/extensions.py @@ -48,17 +48,17 @@ def _key_identifier_from_public_key(public_key): serialization.PublicFormat.SubjectPublicKeyInfo ) - public_key_info = DERReader(serialized).read_single_element(SEQUENCE) - algorithm = public_key_info.read_element(SEQUENCE) - public_key = public_key_info.read_element(BIT_STRING) - public_key_info.check_empty() + reader = DERReader(serialized) + with reader.read_single_element(SEQUENCE) as public_key_info: + algorithm = public_key_info.read_element(SEQUENCE) + public_key = public_key_info.read_element(BIT_STRING) # Double-check the algorithm structure. - algorithm.read_element(OBJECT_IDENTIFIER) - if not algorithm.is_empty(): - # Skip the optional parameters field. - algorithm.read_any_element() - algorithm.check_empty() + with algorithm: + algorithm.read_element(OBJECT_IDENTIFIER) + if not algorithm.is_empty(): + # Skip the optional parameters field. + algorithm.read_any_element() # BIT STRING contents begin with the number of padding bytes added. It # must be zero for SubjectPublicKeyInfo structures. diff --git a/tests/hazmat/test_der.py b/tests/hazmat/test_der.py index d81c0d3e..d052802c 100644 --- a/tests/hazmat/test_der.py +++ b/tests/hazmat/test_der.py @@ -46,6 +46,14 @@ def test_der(): with pytest.raises(ValueError): reader.check_empty() + with pytest.raises(ValueError): + with reader: + pass + + with pytest.raises(ZeroDivisionError): + with DERReader(der): + raise ZeroDivisionError + # Parse the outer element. outer = reader.read_element(SEQUENCE) reader.check_empty() diff --git a/tests/x509/test_x509.py b/tests/x509/test_x509.py index bb0ad022..540a814a 100644 --- a/tests/x509/test_x509.py +++ b/tests/x509/test_x509.py @@ -76,36 +76,35 @@ ParsedCertificate = collections.namedtuple( def _parse_cert(der): # See the Certificate structured, defined in RFC 5280. - cert = DERReader(der).read_single_element(SEQUENCE) - tbs_cert = cert.read_element(SEQUENCE) - # Skip outer signature algorithm - _ = cert.read_element(SEQUENCE) - # Skip signature - _ = cert.read_element(BIT_STRING) - cert.check_empty() - - # Skip version - _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 0) - # Skip serialNumber - _ = tbs_cert.read_element(INTEGER) - # Skip inner signature algorithm - _ = tbs_cert.read_element(SEQUENCE) - issuer = tbs_cert.read_element(SEQUENCE) - validity = tbs_cert.read_element(SEQUENCE) - subject = tbs_cert.read_element(SEQUENCE) - # Skip subjectPublicKeyInfo - _ = tbs_cert.read_element(SEQUENCE) - # Skip issuerUniqueID - _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 1) - # Skip subjectUniqueID - _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 2) - # Skip extensions - _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 3) - tbs_cert.check_empty() - - not_before_tag, _ = validity.read_any_element() - not_after_tag, _ = validity.read_any_element() - validity.check_empty() + with DERReader(der).read_single_element(SEQUENCE) as cert: + tbs_cert = cert.read_element(SEQUENCE) + # Skip outer signature algorithm + _ = cert.read_element(SEQUENCE) + # Skip signature + _ = cert.read_element(BIT_STRING) + + with tbs_cert: + # Skip version + _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 0) + # Skip serialNumber + _ = tbs_cert.read_element(INTEGER) + # Skip inner signature algorithm + _ = tbs_cert.read_element(SEQUENCE) + issuer = tbs_cert.read_element(SEQUENCE) + validity = tbs_cert.read_element(SEQUENCE) + subject = tbs_cert.read_element(SEQUENCE) + # Skip subjectPublicKeyInfo + _ = tbs_cert.read_element(SEQUENCE) + # Skip issuerUniqueID + _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 1) + # Skip subjectUniqueID + _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 2) + # Skip extensions + _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 3) + + with validity: + not_before_tag, _ = validity.read_any_element() + not_after_tag, _ = validity.read_any_element() return ParsedCertificate( not_before_tag=not_before_tag, @@ -1642,15 +1641,14 @@ class TestRSACertificateRequest(object): issuer = parsed.issuer def read_next_rdn_value_tag(reader): - rdn = reader.read_element(SET) - attribute = rdn.read_element(SEQUENCE) # Assume each RDN has a single attribute. - rdn.check_empty() + with reader.read_element(SET) as rdn: + attribute = rdn.read_element(SEQUENCE) - _ = attribute.read_element(OBJECT_IDENTIFIER) - tag, value = attribute.read_any_element() - attribute.check_empty() - return tag + with attribute: + _ = attribute.read_element(OBJECT_IDENTIFIER) + tag, value = attribute.read_any_element() + return tag # Check that each value was encoded as an ASN.1 PRINTABLESTRING. assert read_next_rdn_value_tag(subject) == PRINTABLE_STRING |