From 40f1999de74a3bf44f000486a0ce1a58c82827e6 Mon Sep 17 00:00:00 2001 From: Marti Date: Fri, 26 Aug 2016 04:26:31 +0300 Subject: Allow passing iterators where collections are expected (#3078) Iterators can only be enumerated once, breaking code like this in Python 3 for example: san = SubjectAlternativeName(map(DNSName, lst)) This is also a slight behavior change if the caller modifies the list after passing it to the constructor, because input lists are now copied. Which seems like a good thing. Also: * Name now checks that attributes elements are of type NameAttribute * NoticeReference now allows notice_numbers to be any iterable --- src/cryptography/x509/extensions.py | 72 +++++++++++++----------- src/cryptography/x509/name.py | 4 ++ tests/test_x509.py | 12 ++++ tests/test_x509_ext.py | 108 ++++++++++++++++++++++++++++++++---- 4 files changed, 154 insertions(+), 42 deletions(-) diff --git a/src/cryptography/x509/extensions.py b/src/cryptography/x509/extensions.py index b7ea72cd..c0705a3a 100644 --- a/src/cryptography/x509/extensions.py +++ b/src/cryptography/x509/extensions.py @@ -174,13 +174,15 @@ class AuthorityKeyIdentifier(object): "must both be present or both None" ) - if authority_cert_issuer is not None and not all( - isinstance(x, GeneralName) for x in authority_cert_issuer - ): - raise TypeError( - "authority_cert_issuer must be a list of GeneralName " - "objects" - ) + if authority_cert_issuer is not None: + authority_cert_issuer = list(authority_cert_issuer) + if not all( + isinstance(x, GeneralName) for x in authority_cert_issuer + ): + raise TypeError( + "authority_cert_issuer must be a list of GeneralName " + "objects" + ) if authority_cert_serial_number is not None and not isinstance( authority_cert_serial_number, six.integer_types @@ -273,6 +275,7 @@ class AuthorityInformationAccess(object): oid = ExtensionOID.AUTHORITY_INFORMATION_ACCESS def __init__(self, descriptions): + descriptions = list(descriptions) if not all(isinstance(x, AccessDescription) for x in descriptions): raise TypeError( "Every item in the descriptions list must be an " @@ -386,6 +389,7 @@ class CRLDistributionPoints(object): oid = ExtensionOID.CRL_DISTRIBUTION_POINTS def __init__(self, distribution_points): + distribution_points = list(distribution_points) if not all( isinstance(x, DistributionPoint) for x in distribution_points ): @@ -426,22 +430,22 @@ class DistributionPoint(object): "least one must be None." ) - if full_name and not all( - isinstance(x, GeneralName) for x in full_name - ): - raise TypeError( - "full_name must be a list of GeneralName objects" - ) + if full_name: + full_name = list(full_name) + if not all(isinstance(x, GeneralName) for x in full_name): + raise TypeError( + "full_name must be a list of GeneralName objects" + ) if relative_name and not isinstance(relative_name, Name): raise TypeError("relative_name must be a Name") - if crl_issuer and not all( - isinstance(x, GeneralName) for x in crl_issuer - ): - raise TypeError( - "crl_issuer must be None or a list of general names" - ) + if crl_issuer: + crl_issuer = list(crl_issuer) + if not all(isinstance(x, GeneralName) for x in crl_issuer): + raise TypeError( + "crl_issuer must be None or a list of general names" + ) if reasons and (not isinstance(reasons, frozenset) or not all( isinstance(x, ReasonFlags) for x in reasons @@ -569,6 +573,7 @@ class CertificatePolicies(object): oid = ExtensionOID.CERTIFICATE_POLICIES def __init__(self, policies): + policies = list(policies) if not all(isinstance(x, PolicyInformation) for x in policies): raise TypeError( "Every item in the policies list must be a " @@ -605,15 +610,17 @@ class PolicyInformation(object): raise TypeError("policy_identifier must be an ObjectIdentifier") self._policy_identifier = policy_identifier - if policy_qualifiers and not all( - isinstance( - x, (six.text_type, UserNotice) - ) for x in policy_qualifiers - ): - raise TypeError( - "policy_qualifiers must be a list of strings and/or UserNotice" - " objects or None" - ) + + if policy_qualifiers: + policy_qualifiers = list(policy_qualifiers) + if not all( + isinstance(x, (six.text_type, UserNotice)) + for x in policy_qualifiers + ): + raise TypeError( + "policy_qualifiers must be a list of strings and/or " + "UserNotice objects or None" + ) self._policy_qualifiers = policy_qualifiers @@ -676,9 +683,8 @@ class UserNotice(object): class NoticeReference(object): def __init__(self, organization, notice_numbers): self._organization = organization - if not isinstance(notice_numbers, list) or not all( - isinstance(x, int) for x in notice_numbers - ): + notice_numbers = list(notice_numbers) + if not all(isinstance(x, int) for x in notice_numbers): raise TypeError( "notice_numbers must be a list of integers" ) @@ -712,6 +718,7 @@ class ExtendedKeyUsage(object): oid = ExtensionOID.EXTENDED_KEY_USAGE def __init__(self, usages): + usages = list(usages) if not all(isinstance(x, ObjectIdentifier) for x in usages): raise TypeError( "Every item in the usages list must be an ObjectIdentifier" @@ -866,6 +873,7 @@ class NameConstraints(object): def __init__(self, permitted_subtrees, excluded_subtrees): if permitted_subtrees is not None: + permitted_subtrees = list(permitted_subtrees) if not all( isinstance(x, GeneralName) for x in permitted_subtrees ): @@ -877,6 +885,7 @@ class NameConstraints(object): self._validate_ip_name(permitted_subtrees) if excluded_subtrees is not None: + excluded_subtrees = list(excluded_subtrees) if not all( isinstance(x, GeneralName) for x in excluded_subtrees ): @@ -965,6 +974,7 @@ class Extension(object): class GeneralNames(object): def __init__(self, general_names): + general_names = list(general_names) if not all(isinstance(x, GeneralName) for x in general_names): raise TypeError( "Every item in the general_names list must be an " diff --git a/src/cryptography/x509/name.py b/src/cryptography/x509/name.py index d62341d7..7e55f6e3 100644 --- a/src/cryptography/x509/name.py +++ b/src/cryptography/x509/name.py @@ -54,6 +54,10 @@ class NameAttribute(object): class Name(object): def __init__(self, attributes): + attributes = list(attributes) + if not all(isinstance(x, NameAttribute) for x in attributes): + raise TypeError("attributes must be a list of NameAttribute") + self._attributes = attributes def get_attributes_for_oid(self, oid): diff --git a/tests/test_x509.py b/tests/test_x509.py index b1d627c3..47b81cb5 100644 --- a/tests/test_x509.py +++ b/tests/test_x509.py @@ -3680,6 +3680,14 @@ class TestName(object): assert hash(name1) == hash(name2) assert hash(name1) != hash(name3) + def test_iter_input(self): + attrs = [ + x509.NameAttribute(x509.ObjectIdentifier('2.999.1'), u'value1') + ] + name = x509.Name(iter(attrs)) + assert list(name) == attrs + assert list(name) == attrs + def test_repr(self): name = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, u'cryptography.io'), @@ -3700,3 +3708,7 @@ class TestName(object): "=, val" "ue=u'PyCA')>])>" ) + + def test_not_nameattribute(self): + with pytest.raises(TypeError): + x509.Name(["not-a-NameAttribute"]) diff --git a/tests/test_x509_ext.py b/tests/test_x509_ext.py index f3677d11..749e52f1 100644 --- a/tests/test_x509_ext.py +++ b/tests/test_x509_ext.py @@ -268,6 +268,11 @@ class TestNoticeReference(object): with pytest.raises(TypeError): x509.NoticeReference("org", None) + def test_iter_input(self): + numbers = [1, 3, 4] + nr = x509.NoticeReference(u"org", iter(numbers)) + assert list(nr.notice_numbers) == numbers + def test_repr(self): nr = x509.NoticeReference(u"org", [1, 3, 4]) @@ -357,6 +362,11 @@ class TestPolicyInformation(object): with pytest.raises(TypeError): x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), [1, 2]) + def test_iter_input(self): + qual = [u"foo", u"bar"] + pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), iter(qual)) + assert list(pi.policy_qualifiers) == qual + def test_repr(self): pq = [u"string", x509.UserNotice(None, u"hi")] pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), pq) @@ -414,6 +424,13 @@ class TestCertificatePolicies(object): for policyinfo in cp: assert policyinfo == pi + def test_iter_input(self): + policies = [ + x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), [u"string"]) + ] + cp = x509.CertificatePolicies(iter(policies)) + assert list(cp) == policies + def test_repr(self): pq = [u"string"] pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), pq) @@ -859,6 +876,15 @@ class TestAuthorityKeyIdentifier(object): assert aki.authority_cert_issuer == [dns] assert aki.authority_cert_serial_number == 0 + def test_iter_input(self): + dirnames = [ + x509.DirectoryName( + x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u'myCN')]) + ) + ] + aki = x509.AuthorityKeyIdentifier(b"digest", iter(dirnames), 1234) + assert list(aki.authority_cert_issuer) == dirnames + def test_repr(self): dirname = x509.DirectoryName( x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u'myCN')]) @@ -973,6 +999,14 @@ class TestExtendedKeyUsage(object): ExtendedKeyUsageOID.CLIENT_AUTH ] + def test_iter_input(self): + usages = [ + x509.ObjectIdentifier("1.3.6.1.5.5.7.3.1"), + x509.ObjectIdentifier("1.3.6.1.5.5.7.3.2"), + ] + aia = x509.ExtendedKeyUsage(iter(usages)) + assert list(aia) == usages + def test_repr(self): eku = x509.ExtendedKeyUsage([ x509.ObjectIdentifier("1.3.6.1.5.5.7.3.1"), @@ -1417,18 +1451,16 @@ class TestDirectoryName(object): def test_repr(self): name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u'value1')]) - gn = x509.DirectoryName(x509.Name([name])) + gn = x509.DirectoryName(name) if six.PY3: assert repr(gn) == ( - ", value='value1')>])" - ">])>)>" + ", value='value1')>])>)>" ) else: assert repr(gn) == ( - ", value=u'value1')>]" - ")>])>)>" + ", value=u'value1')>])>)>" ) def test_eq(self): @@ -1438,8 +1470,8 @@ class TestDirectoryName(object): name2 = x509.Name([ x509.NameAttribute(x509.ObjectIdentifier('2.999.1'), u'value1') ]) - gn = x509.DirectoryName(x509.Name([name])) - gn2 = x509.DirectoryName(x509.Name([name2])) + gn = x509.DirectoryName(name) + gn2 = x509.DirectoryName(name2) assert gn == gn2 def test_ne(self): @@ -1449,8 +1481,8 @@ class TestDirectoryName(object): name2 = x509.Name([ x509.NameAttribute(x509.ObjectIdentifier('2.999.2'), u'value2') ]) - gn = x509.DirectoryName(x509.Name([name])) - gn2 = x509.DirectoryName(x509.Name([name2])) + gn = x509.DirectoryName(name) + gn2 = x509.DirectoryName(name2) assert gn != gn2 assert gn != object() @@ -1649,6 +1681,14 @@ class TestGeneralNames(object): x509.DNSName(u"crypto.local"), ] + def test_iter_input(self): + names = [ + x509.DNSName(u"cryptography.io"), + x509.DNSName(u"crypto.local"), + ] + gns = x509.GeneralNames(iter(names)) + assert list(gns) == names + def test_indexing(self): gn = x509.GeneralNames([ x509.DNSName(u"cryptography.io"), @@ -2371,6 +2411,16 @@ class TestAuthorityInformationAccess(object): ) ] + def test_iter_input(self): + desc = [ + x509.AccessDescription( + AuthorityInformationAccessOID.OCSP, + x509.UniformResourceIdentifier(u"http://ocsp.domain.com") + ) + ] + aia = x509.AuthorityInformationAccess(iter(desc)) + assert list(aia) == desc + def test_repr(self): aia = x509.AuthorityInformationAccess([ x509.AccessDescription( @@ -2743,6 +2793,12 @@ class TestNameConstraints(object): assert nc.permitted_subtrees is not None assert nc.excluded_subtrees is None + def test_iter_input(self): + subtrees = [x509.IPAddress(ipaddress.IPv4Network(u"192.168.0.0/24"))] + nc = x509.NameConstraints(iter(subtrees), iter(subtrees)) + assert list(nc.permitted_subtrees) == subtrees + assert list(nc.excluded_subtrees) == subtrees + def test_repr(self): permitted = [x509.DNSName(u"name.local"), x509.DNSName(u"name2.local")] nc = x509.NameConstraints( @@ -3050,6 +3106,24 @@ class TestDistributionPoint(object): assert dp != dp2 assert dp != object() + def test_iter_input(self): + name = [x509.UniformResourceIdentifier(u"http://crypt.og/crl")] + issuer = [ + x509.DirectoryName( + x509.Name([ + x509.NameAttribute(NameOID.COMMON_NAME, u"Important CA") + ]) + ) + ] + dp = x509.DistributionPoint( + iter(name), + None, + frozenset([x509.ReasonFlags.ca_compromise]), + iter(issuer), + ) + assert list(dp.full_name) == name + assert list(dp.crl_issuer) == issuer + def test_repr(self): dp = x509.DistributionPoint( None, @@ -3129,6 +3203,18 @@ class TestCRLDistributionPoints(object): ), ] + def test_iter_input(self): + points = [ + x509.DistributionPoint( + [x509.UniformResourceIdentifier(u"http://domain")], + None, + None, + None + ), + ] + cdp = x509.CRLDistributionPoints(iter(points)) + assert list(cdp) == points + def test_repr(self): cdp = x509.CRLDistributionPoints([ x509.DistributionPoint( -- cgit v1.2.3