aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMarti <marti@juffo.org>2016-08-26 04:26:31 +0300
committerPaul Kehrer <paul.l.kehrer@gmail.com>2016-08-26 09:26:31 +0800
commit40f1999de74a3bf44f000486a0ce1a58c82827e6 (patch)
treed7c3cb6ea4f0b3846cc8685669c75d963f43db64
parenteafc4ee77f92d4e6e208351fd17e9cb1ae045677 (diff)
downloadcryptography-40f1999de74a3bf44f000486a0ce1a58c82827e6.tar.gz
cryptography-40f1999de74a3bf44f000486a0ce1a58c82827e6.tar.bz2
cryptography-40f1999de74a3bf44f000486a0ce1a58c82827e6.zip
Allow passing iterators where collections are expected (#3078)
Iterators can only be enumerated once, breaking code like this in Python 3 for example: san = SubjectAlternativeName(map(DNSName, lst)) This is also a slight behavior change if the caller modifies the list after passing it to the constructor, because input lists are now copied. Which seems like a good thing. Also: * Name now checks that attributes elements are of type NameAttribute * NoticeReference now allows notice_numbers to be any iterable
-rw-r--r--src/cryptography/x509/extensions.py72
-rw-r--r--src/cryptography/x509/name.py4
-rw-r--r--tests/test_x509.py12
-rw-r--r--tests/test_x509_ext.py108
4 files changed, 154 insertions, 42 deletions
diff --git a/src/cryptography/x509/extensions.py b/src/cryptography/x509/extensions.py
index b7ea72cd..c0705a3a 100644
--- a/src/cryptography/x509/extensions.py
+++ b/src/cryptography/x509/extensions.py
@@ -174,13 +174,15 @@ class AuthorityKeyIdentifier(object):
"must both be present or both None"
)
- if authority_cert_issuer is not None and not all(
- isinstance(x, GeneralName) for x in authority_cert_issuer
- ):
- raise TypeError(
- "authority_cert_issuer must be a list of GeneralName "
- "objects"
- )
+ if authority_cert_issuer is not None:
+ authority_cert_issuer = list(authority_cert_issuer)
+ if not all(
+ isinstance(x, GeneralName) for x in authority_cert_issuer
+ ):
+ raise TypeError(
+ "authority_cert_issuer must be a list of GeneralName "
+ "objects"
+ )
if authority_cert_serial_number is not None and not isinstance(
authority_cert_serial_number, six.integer_types
@@ -273,6 +275,7 @@ class AuthorityInformationAccess(object):
oid = ExtensionOID.AUTHORITY_INFORMATION_ACCESS
def __init__(self, descriptions):
+ descriptions = list(descriptions)
if not all(isinstance(x, AccessDescription) for x in descriptions):
raise TypeError(
"Every item in the descriptions list must be an "
@@ -386,6 +389,7 @@ class CRLDistributionPoints(object):
oid = ExtensionOID.CRL_DISTRIBUTION_POINTS
def __init__(self, distribution_points):
+ distribution_points = list(distribution_points)
if not all(
isinstance(x, DistributionPoint) for x in distribution_points
):
@@ -426,22 +430,22 @@ class DistributionPoint(object):
"least one must be None."
)
- if full_name and not all(
- isinstance(x, GeneralName) for x in full_name
- ):
- raise TypeError(
- "full_name must be a list of GeneralName objects"
- )
+ if full_name:
+ full_name = list(full_name)
+ if not all(isinstance(x, GeneralName) for x in full_name):
+ raise TypeError(
+ "full_name must be a list of GeneralName objects"
+ )
if relative_name and not isinstance(relative_name, Name):
raise TypeError("relative_name must be a Name")
- if crl_issuer and not all(
- isinstance(x, GeneralName) for x in crl_issuer
- ):
- raise TypeError(
- "crl_issuer must be None or a list of general names"
- )
+ if crl_issuer:
+ crl_issuer = list(crl_issuer)
+ if not all(isinstance(x, GeneralName) for x in crl_issuer):
+ raise TypeError(
+ "crl_issuer must be None or a list of general names"
+ )
if reasons and (not isinstance(reasons, frozenset) or not all(
isinstance(x, ReasonFlags) for x in reasons
@@ -569,6 +573,7 @@ class CertificatePolicies(object):
oid = ExtensionOID.CERTIFICATE_POLICIES
def __init__(self, policies):
+ policies = list(policies)
if not all(isinstance(x, PolicyInformation) for x in policies):
raise TypeError(
"Every item in the policies list must be a "
@@ -605,15 +610,17 @@ class PolicyInformation(object):
raise TypeError("policy_identifier must be an ObjectIdentifier")
self._policy_identifier = policy_identifier
- if policy_qualifiers and not all(
- isinstance(
- x, (six.text_type, UserNotice)
- ) for x in policy_qualifiers
- ):
- raise TypeError(
- "policy_qualifiers must be a list of strings and/or UserNotice"
- " objects or None"
- )
+
+ if policy_qualifiers:
+ policy_qualifiers = list(policy_qualifiers)
+ if not all(
+ isinstance(x, (six.text_type, UserNotice))
+ for x in policy_qualifiers
+ ):
+ raise TypeError(
+ "policy_qualifiers must be a list of strings and/or "
+ "UserNotice objects or None"
+ )
self._policy_qualifiers = policy_qualifiers
@@ -676,9 +683,8 @@ class UserNotice(object):
class NoticeReference(object):
def __init__(self, organization, notice_numbers):
self._organization = organization
- if not isinstance(notice_numbers, list) or not all(
- isinstance(x, int) for x in notice_numbers
- ):
+ notice_numbers = list(notice_numbers)
+ if not all(isinstance(x, int) for x in notice_numbers):
raise TypeError(
"notice_numbers must be a list of integers"
)
@@ -712,6 +718,7 @@ class ExtendedKeyUsage(object):
oid = ExtensionOID.EXTENDED_KEY_USAGE
def __init__(self, usages):
+ usages = list(usages)
if not all(isinstance(x, ObjectIdentifier) for x in usages):
raise TypeError(
"Every item in the usages list must be an ObjectIdentifier"
@@ -866,6 +873,7 @@ class NameConstraints(object):
def __init__(self, permitted_subtrees, excluded_subtrees):
if permitted_subtrees is not None:
+ permitted_subtrees = list(permitted_subtrees)
if not all(
isinstance(x, GeneralName) for x in permitted_subtrees
):
@@ -877,6 +885,7 @@ class NameConstraints(object):
self._validate_ip_name(permitted_subtrees)
if excluded_subtrees is not None:
+ excluded_subtrees = list(excluded_subtrees)
if not all(
isinstance(x, GeneralName) for x in excluded_subtrees
):
@@ -965,6 +974,7 @@ class Extension(object):
class GeneralNames(object):
def __init__(self, general_names):
+ general_names = list(general_names)
if not all(isinstance(x, GeneralName) for x in general_names):
raise TypeError(
"Every item in the general_names list must be an "
diff --git a/src/cryptography/x509/name.py b/src/cryptography/x509/name.py
index d62341d7..7e55f6e3 100644
--- a/src/cryptography/x509/name.py
+++ b/src/cryptography/x509/name.py
@@ -54,6 +54,10 @@ class NameAttribute(object):
class Name(object):
def __init__(self, attributes):
+ attributes = list(attributes)
+ if not all(isinstance(x, NameAttribute) for x in attributes):
+ raise TypeError("attributes must be a list of NameAttribute")
+
self._attributes = attributes
def get_attributes_for_oid(self, oid):
diff --git a/tests/test_x509.py b/tests/test_x509.py
index b1d627c3..47b81cb5 100644
--- a/tests/test_x509.py
+++ b/tests/test_x509.py
@@ -3680,6 +3680,14 @@ class TestName(object):
assert hash(name1) == hash(name2)
assert hash(name1) != hash(name3)
+ def test_iter_input(self):
+ attrs = [
+ x509.NameAttribute(x509.ObjectIdentifier('2.999.1'), u'value1')
+ ]
+ name = x509.Name(iter(attrs))
+ assert list(name) == attrs
+ assert list(name) == attrs
+
def test_repr(self):
name = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, u'cryptography.io'),
@@ -3700,3 +3708,7 @@ class TestName(object):
"=<ObjectIdentifier(oid=2.5.4.10, name=organizationName)>, val"
"ue=u'PyCA')>])>"
)
+
+ def test_not_nameattribute(self):
+ with pytest.raises(TypeError):
+ x509.Name(["not-a-NameAttribute"])
diff --git a/tests/test_x509_ext.py b/tests/test_x509_ext.py
index f3677d11..749e52f1 100644
--- a/tests/test_x509_ext.py
+++ b/tests/test_x509_ext.py
@@ -268,6 +268,11 @@ class TestNoticeReference(object):
with pytest.raises(TypeError):
x509.NoticeReference("org", None)
+ def test_iter_input(self):
+ numbers = [1, 3, 4]
+ nr = x509.NoticeReference(u"org", iter(numbers))
+ assert list(nr.notice_numbers) == numbers
+
def test_repr(self):
nr = x509.NoticeReference(u"org", [1, 3, 4])
@@ -357,6 +362,11 @@ class TestPolicyInformation(object):
with pytest.raises(TypeError):
x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), [1, 2])
+ def test_iter_input(self):
+ qual = [u"foo", u"bar"]
+ pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), iter(qual))
+ assert list(pi.policy_qualifiers) == qual
+
def test_repr(self):
pq = [u"string", x509.UserNotice(None, u"hi")]
pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), pq)
@@ -414,6 +424,13 @@ class TestCertificatePolicies(object):
for policyinfo in cp:
assert policyinfo == pi
+ def test_iter_input(self):
+ policies = [
+ x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), [u"string"])
+ ]
+ cp = x509.CertificatePolicies(iter(policies))
+ assert list(cp) == policies
+
def test_repr(self):
pq = [u"string"]
pi = x509.PolicyInformation(x509.ObjectIdentifier("1.2.3"), pq)
@@ -859,6 +876,15 @@ class TestAuthorityKeyIdentifier(object):
assert aki.authority_cert_issuer == [dns]
assert aki.authority_cert_serial_number == 0
+ def test_iter_input(self):
+ dirnames = [
+ x509.DirectoryName(
+ x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u'myCN')])
+ )
+ ]
+ aki = x509.AuthorityKeyIdentifier(b"digest", iter(dirnames), 1234)
+ assert list(aki.authority_cert_issuer) == dirnames
+
def test_repr(self):
dirname = x509.DirectoryName(
x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u'myCN')])
@@ -973,6 +999,14 @@ class TestExtendedKeyUsage(object):
ExtendedKeyUsageOID.CLIENT_AUTH
]
+ def test_iter_input(self):
+ usages = [
+ x509.ObjectIdentifier("1.3.6.1.5.5.7.3.1"),
+ x509.ObjectIdentifier("1.3.6.1.5.5.7.3.2"),
+ ]
+ aia = x509.ExtendedKeyUsage(iter(usages))
+ assert list(aia) == usages
+
def test_repr(self):
eku = x509.ExtendedKeyUsage([
x509.ObjectIdentifier("1.3.6.1.5.5.7.3.1"),
@@ -1417,18 +1451,16 @@ class TestDirectoryName(object):
def test_repr(self):
name = x509.Name([x509.NameAttribute(NameOID.COMMON_NAME, u'value1')])
- gn = x509.DirectoryName(x509.Name([name]))
+ gn = x509.DirectoryName(name)
if six.PY3:
assert repr(gn) == (
- "<DirectoryName(value=<Name([<Name([<NameAttribute(oid=<Object"
- "Identifier(oid=2.5.4.3, name=commonName)>, value='value1')>])"
- ">])>)>"
+ "<DirectoryName(value=<Name([<NameAttribute(oid=<ObjectIdentif"
+ "ier(oid=2.5.4.3, name=commonName)>, value='value1')>])>)>"
)
else:
assert repr(gn) == (
- "<DirectoryName(value=<Name([<Name([<NameAttribute(oid=<Object"
- "Identifier(oid=2.5.4.3, name=commonName)>, value=u'value1')>]"
- ")>])>)>"
+ "<DirectoryName(value=<Name([<NameAttribute(oid=<ObjectIdentif"
+ "ier(oid=2.5.4.3, name=commonName)>, value=u'value1')>])>)>"
)
def test_eq(self):
@@ -1438,8 +1470,8 @@ class TestDirectoryName(object):
name2 = x509.Name([
x509.NameAttribute(x509.ObjectIdentifier('2.999.1'), u'value1')
])
- gn = x509.DirectoryName(x509.Name([name]))
- gn2 = x509.DirectoryName(x509.Name([name2]))
+ gn = x509.DirectoryName(name)
+ gn2 = x509.DirectoryName(name2)
assert gn == gn2
def test_ne(self):
@@ -1449,8 +1481,8 @@ class TestDirectoryName(object):
name2 = x509.Name([
x509.NameAttribute(x509.ObjectIdentifier('2.999.2'), u'value2')
])
- gn = x509.DirectoryName(x509.Name([name]))
- gn2 = x509.DirectoryName(x509.Name([name2]))
+ gn = x509.DirectoryName(name)
+ gn2 = x509.DirectoryName(name2)
assert gn != gn2
assert gn != object()
@@ -1649,6 +1681,14 @@ class TestGeneralNames(object):
x509.DNSName(u"crypto.local"),
]
+ def test_iter_input(self):
+ names = [
+ x509.DNSName(u"cryptography.io"),
+ x509.DNSName(u"crypto.local"),
+ ]
+ gns = x509.GeneralNames(iter(names))
+ assert list(gns) == names
+
def test_indexing(self):
gn = x509.GeneralNames([
x509.DNSName(u"cryptography.io"),
@@ -2371,6 +2411,16 @@ class TestAuthorityInformationAccess(object):
)
]
+ def test_iter_input(self):
+ desc = [
+ x509.AccessDescription(
+ AuthorityInformationAccessOID.OCSP,
+ x509.UniformResourceIdentifier(u"http://ocsp.domain.com")
+ )
+ ]
+ aia = x509.AuthorityInformationAccess(iter(desc))
+ assert list(aia) == desc
+
def test_repr(self):
aia = x509.AuthorityInformationAccess([
x509.AccessDescription(
@@ -2743,6 +2793,12 @@ class TestNameConstraints(object):
assert nc.permitted_subtrees is not None
assert nc.excluded_subtrees is None
+ def test_iter_input(self):
+ subtrees = [x509.IPAddress(ipaddress.IPv4Network(u"192.168.0.0/24"))]
+ nc = x509.NameConstraints(iter(subtrees), iter(subtrees))
+ assert list(nc.permitted_subtrees) == subtrees
+ assert list(nc.excluded_subtrees) == subtrees
+
def test_repr(self):
permitted = [x509.DNSName(u"name.local"), x509.DNSName(u"name2.local")]
nc = x509.NameConstraints(
@@ -3050,6 +3106,24 @@ class TestDistributionPoint(object):
assert dp != dp2
assert dp != object()
+ def test_iter_input(self):
+ name = [x509.UniformResourceIdentifier(u"http://crypt.og/crl")]
+ issuer = [
+ x509.DirectoryName(
+ x509.Name([
+ x509.NameAttribute(NameOID.COMMON_NAME, u"Important CA")
+ ])
+ )
+ ]
+ dp = x509.DistributionPoint(
+ iter(name),
+ None,
+ frozenset([x509.ReasonFlags.ca_compromise]),
+ iter(issuer),
+ )
+ assert list(dp.full_name) == name
+ assert list(dp.crl_issuer) == issuer
+
def test_repr(self):
dp = x509.DistributionPoint(
None,
@@ -3129,6 +3203,18 @@ class TestCRLDistributionPoints(object):
),
]
+ def test_iter_input(self):
+ points = [
+ x509.DistributionPoint(
+ [x509.UniformResourceIdentifier(u"http://domain")],
+ None,
+ None,
+ None
+ ),
+ ]
+ cdp = x509.CRLDistributionPoints(iter(points))
+ assert list(cdp) == points
+
def test_repr(self):
cdp = x509.CRLDistributionPoints([
x509.DistributionPoint(