aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAlex Gaynor <alex.gaynor@gmail.com>2019-07-28 22:58:04 -0400
committerPaul Kehrer <paul.l.kehrer@gmail.com>2019-07-28 21:58:04 -0500
commit9cd41ac714d9bff819ece6d8cdcde064d403c671 (patch)
tree4ed2502ced1db85417fbf6fe214f59a1525893ff
parent2c83570f6310cb36553af274eb41dd8e2b96b58e (diff)
downloadcryptography-9cd41ac714d9bff819ece6d8cdcde064d403c671.tar.gz
cryptography-9cd41ac714d9bff819ece6d8cdcde064d403c671.tar.bz2
cryptography-9cd41ac714d9bff819ece6d8cdcde064d403c671.zip
Make DER reader into a context manager (#4957)
* Make DER reader into a context manager * Added another test case * flake8
-rw-r--r--src/cryptography/hazmat/_der.py12
-rw-r--r--src/cryptography/hazmat/primitives/asymmetric/utils.py9
-rw-r--r--src/cryptography/x509/extensions.py18
-rw-r--r--tests/hazmat/test_der.py8
-rw-r--r--tests/x509/test_x509.py72
5 files changed, 65 insertions, 54 deletions
diff --git a/src/cryptography/hazmat/_der.py b/src/cryptography/hazmat/_der.py
index 3a121a85..51518d64 100644
--- a/src/cryptography/hazmat/_der.py
+++ b/src/cryptography/hazmat/_der.py
@@ -36,6 +36,13 @@ class DERReader(object):
def __init__(self, data):
self.data = memoryview(data)
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_value is None:
+ self.check_empty()
+
def is_empty(self):
return len(self.data) == 0
@@ -100,9 +107,8 @@ class DERReader(object):
return body
def read_single_element(self, expected_tag):
- ret = self.read_element(expected_tag)
- self.check_empty()
- return ret
+ with self:
+ return self.read_element(expected_tag)
def read_optional_element(self, expected_tag):
if len(self.data) > 0 and six.indexbytes(self.data, 0) == expected_tag:
diff --git a/src/cryptography/hazmat/primitives/asymmetric/utils.py b/src/cryptography/hazmat/primitives/asymmetric/utils.py
index 43d5b9bf..14d2abee 100644
--- a/src/cryptography/hazmat/primitives/asymmetric/utils.py
+++ b/src/cryptography/hazmat/primitives/asymmetric/utils.py
@@ -12,11 +12,10 @@ from cryptography.hazmat.primitives import hashes
def decode_dss_signature(signature):
- seq = DERReader(signature).read_single_element(SEQUENCE)
- r = seq.read_element(INTEGER).as_integer()
- s = seq.read_element(INTEGER).as_integer()
- seq.check_empty()
- return r, s
+ with DERReader(signature).read_single_element(SEQUENCE) as seq:
+ r = seq.read_element(INTEGER).as_integer()
+ s = seq.read_element(INTEGER).as_integer()
+ return r, s
def encode_dss_signature(r, s):
diff --git a/src/cryptography/x509/extensions.py b/src/cryptography/x509/extensions.py
index c78c76c2..5bef9945 100644
--- a/src/cryptography/x509/extensions.py
+++ b/src/cryptography/x509/extensions.py
@@ -48,17 +48,17 @@ def _key_identifier_from_public_key(public_key):
serialization.PublicFormat.SubjectPublicKeyInfo
)
- public_key_info = DERReader(serialized).read_single_element(SEQUENCE)
- algorithm = public_key_info.read_element(SEQUENCE)
- public_key = public_key_info.read_element(BIT_STRING)
- public_key_info.check_empty()
+ reader = DERReader(serialized)
+ with reader.read_single_element(SEQUENCE) as public_key_info:
+ algorithm = public_key_info.read_element(SEQUENCE)
+ public_key = public_key_info.read_element(BIT_STRING)
# Double-check the algorithm structure.
- algorithm.read_element(OBJECT_IDENTIFIER)
- if not algorithm.is_empty():
- # Skip the optional parameters field.
- algorithm.read_any_element()
- algorithm.check_empty()
+ with algorithm:
+ algorithm.read_element(OBJECT_IDENTIFIER)
+ if not algorithm.is_empty():
+ # Skip the optional parameters field.
+ algorithm.read_any_element()
# BIT STRING contents begin with the number of padding bytes added. It
# must be zero for SubjectPublicKeyInfo structures.
diff --git a/tests/hazmat/test_der.py b/tests/hazmat/test_der.py
index d81c0d3e..d052802c 100644
--- a/tests/hazmat/test_der.py
+++ b/tests/hazmat/test_der.py
@@ -46,6 +46,14 @@ def test_der():
with pytest.raises(ValueError):
reader.check_empty()
+ with pytest.raises(ValueError):
+ with reader:
+ pass
+
+ with pytest.raises(ZeroDivisionError):
+ with DERReader(der):
+ raise ZeroDivisionError
+
# Parse the outer element.
outer = reader.read_element(SEQUENCE)
reader.check_empty()
diff --git a/tests/x509/test_x509.py b/tests/x509/test_x509.py
index bb0ad022..540a814a 100644
--- a/tests/x509/test_x509.py
+++ b/tests/x509/test_x509.py
@@ -76,36 +76,35 @@ ParsedCertificate = collections.namedtuple(
def _parse_cert(der):
# See the Certificate structured, defined in RFC 5280.
- cert = DERReader(der).read_single_element(SEQUENCE)
- tbs_cert = cert.read_element(SEQUENCE)
- # Skip outer signature algorithm
- _ = cert.read_element(SEQUENCE)
- # Skip signature
- _ = cert.read_element(BIT_STRING)
- cert.check_empty()
-
- # Skip version
- _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 0)
- # Skip serialNumber
- _ = tbs_cert.read_element(INTEGER)
- # Skip inner signature algorithm
- _ = tbs_cert.read_element(SEQUENCE)
- issuer = tbs_cert.read_element(SEQUENCE)
- validity = tbs_cert.read_element(SEQUENCE)
- subject = tbs_cert.read_element(SEQUENCE)
- # Skip subjectPublicKeyInfo
- _ = tbs_cert.read_element(SEQUENCE)
- # Skip issuerUniqueID
- _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 1)
- # Skip subjectUniqueID
- _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 2)
- # Skip extensions
- _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 3)
- tbs_cert.check_empty()
-
- not_before_tag, _ = validity.read_any_element()
- not_after_tag, _ = validity.read_any_element()
- validity.check_empty()
+ with DERReader(der).read_single_element(SEQUENCE) as cert:
+ tbs_cert = cert.read_element(SEQUENCE)
+ # Skip outer signature algorithm
+ _ = cert.read_element(SEQUENCE)
+ # Skip signature
+ _ = cert.read_element(BIT_STRING)
+
+ with tbs_cert:
+ # Skip version
+ _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 0)
+ # Skip serialNumber
+ _ = tbs_cert.read_element(INTEGER)
+ # Skip inner signature algorithm
+ _ = tbs_cert.read_element(SEQUENCE)
+ issuer = tbs_cert.read_element(SEQUENCE)
+ validity = tbs_cert.read_element(SEQUENCE)
+ subject = tbs_cert.read_element(SEQUENCE)
+ # Skip subjectPublicKeyInfo
+ _ = tbs_cert.read_element(SEQUENCE)
+ # Skip issuerUniqueID
+ _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 1)
+ # Skip subjectUniqueID
+ _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 2)
+ # Skip extensions
+ _ = tbs_cert.read_optional_element(CONTEXT_SPECIFIC | CONSTRUCTED | 3)
+
+ with validity:
+ not_before_tag, _ = validity.read_any_element()
+ not_after_tag, _ = validity.read_any_element()
return ParsedCertificate(
not_before_tag=not_before_tag,
@@ -1642,15 +1641,14 @@ class TestRSACertificateRequest(object):
issuer = parsed.issuer
def read_next_rdn_value_tag(reader):
- rdn = reader.read_element(SET)
- attribute = rdn.read_element(SEQUENCE)
# Assume each RDN has a single attribute.
- rdn.check_empty()
+ with reader.read_element(SET) as rdn:
+ attribute = rdn.read_element(SEQUENCE)
- _ = attribute.read_element(OBJECT_IDENTIFIER)
- tag, value = attribute.read_any_element()
- attribute.check_empty()
- return tag
+ with attribute:
+ _ = attribute.read_element(OBJECT_IDENTIFIER)
+ tag, value = attribute.read_any_element()
+ return tag
# Check that each value was encoded as an ASN.1 PRINTABLESTRING.
assert read_next_rdn_value_tag(subject) == PRINTABLE_STRING