From 0ed7822a5fde4724a56259e986edc43bd8f599c7 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Wed, 10 Dec 2014 08:18:02 -0600 Subject: add __ne__ and __eq__ methods to RSA, DSA, and EC numbers classes fixes #1449 --- .../hazmat/primitives/asymmetric/dsa.py | 23 ++++++++ .../hazmat/primitives/asymmetric/ec.py | 20 +++++++ .../hazmat/primitives/asymmetric/rsa.py | 20 +++++++ tests/hazmat/primitives/test_dsa.py | 61 ++++++++++++++++++++++ tests/hazmat/primitives/test_ec.py | 35 +++++++++++++ tests/hazmat/primitives/test_rsa.py | 51 +++++++++++++++++- 6 files changed, 209 insertions(+), 1 deletion(-) diff --git a/src/cryptography/hazmat/primitives/asymmetric/dsa.py b/src/cryptography/hazmat/primitives/asymmetric/dsa.py index 5d942e04..03e49ba5 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/dsa.py +++ b/src/cryptography/hazmat/primitives/asymmetric/dsa.py @@ -59,6 +59,12 @@ class DSAParameterNumbers(object): def parameters(self, backend): return backend.load_dsa_parameter_numbers(self) + def __eq__(self, other): + return self.p == other.p and self.q == other.q and self.g == other.g + + def __ne__(self, other): + return not self == other + class DSAPublicNumbers(object): def __init__(self, y, parameter_numbers): @@ -79,6 +85,15 @@ class DSAPublicNumbers(object): def public_key(self, backend): return backend.load_dsa_public_numbers(self) + def __eq__(self, other): + return ( + self.y == other.y and + self.parameter_numbers == other.parameter_numbers + ) + + def __ne__(self, other): + return not self == other + class DSAPrivateNumbers(object): def __init__(self, x, public_numbers): @@ -97,3 +112,11 @@ class DSAPrivateNumbers(object): def private_key(self, backend): return backend.load_dsa_private_numbers(self) + + def __eq__(self, other): + return ( + self.x == other.x and self.public_numbers == other.public_numbers + ) + + def __ne__(self, other): + return not self == other diff --git a/src/cryptography/hazmat/primitives/asymmetric/ec.py b/src/cryptography/hazmat/primitives/asymmetric/ec.py index d9c41c19..cebd351c 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/ec.py +++ b/src/cryptography/hazmat/primitives/asymmetric/ec.py @@ -161,6 +161,17 @@ class EllipticCurvePublicNumbers(object): x = utils.read_only_property("_x") y = utils.read_only_property("_y") + def __eq__(self, other): + return ( + self.x == other.x and + self.y == other.y and + self.curve.name == other.curve.name and + self.curve.key_size == other.curve.key_size + ) + + def __ne__(self, other): + return not self == other + class EllipticCurvePrivateNumbers(object): def __init__(self, private_value, public_numbers): @@ -184,3 +195,12 @@ class EllipticCurvePrivateNumbers(object): private_value = utils.read_only_property("_private_value") public_numbers = utils.read_only_property("_public_numbers") + + def __eq__(self, other): + return ( + self.private_value == other.private_value and + self.public_numbers == other.public_numbers + ) + + def __ne__(self, other): + return not self == other diff --git a/src/cryptography/hazmat/primitives/asymmetric/rsa.py b/src/cryptography/hazmat/primitives/asymmetric/rsa.py index 6aeed006..ff397225 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/rsa.py +++ b/src/cryptography/hazmat/primitives/asymmetric/rsa.py @@ -160,6 +160,20 @@ class RSAPrivateNumbers(object): def private_key(self, backend): return backend.load_rsa_private_numbers(self) + def __eq__(self, other): + return ( + self.p == other.p and + self.q == other.q and + self.d == other.d and + self.dmp1 == other.dmp1 and + self.dmq1 == other.dmq1 and + self.iqmp == other.iqmp and + self.public_numbers == other.public_numbers + ) + + def __ne__(self, other): + return not self == other + class RSAPublicNumbers(object): def __init__(self, e, n): @@ -180,3 +194,9 @@ class RSAPublicNumbers(object): def __repr__(self): return "".format(self._e, self._n) + + def __eq__(self, other): + return self.e == other.e and self.n == other.n + + def __ne__(self, other): + return not self == other diff --git a/tests/hazmat/primitives/test_dsa.py b/tests/hazmat/primitives/test_dsa.py index f818f73b..5edb6cd6 100644 --- a/tests/hazmat/primitives/test_dsa.py +++ b/tests/hazmat/primitives/test_dsa.py @@ -705,3 +705,64 @@ class TestDSANumbers(object): with pytest.raises(TypeError): dsa.DSAPrivateNumbers(x=None, public_numbers=public_numbers) + + +class TestDSANumberEquality(object): + def test_parameter_numbers_eq(self): + param = dsa.DSAParameterNumbers(1, 2, 3) + assert param == dsa.DSAParameterNumbers(1, 2, 3) + + def test_parameter_numbers_ne(self): + param = dsa.DSAParameterNumbers(1, 2, 3) + assert param != dsa.DSAParameterNumbers(1, 2, 4) + assert param != dsa.DSAParameterNumbers(1, 1, 3) + assert param != dsa.DSAParameterNumbers(2, 2, 3) + + def test_public_numbers_eq(self): + pub = dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 3)) + assert pub == dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 3)) + + def test_public_numbers_ne(self): + pub = dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 3)) + assert pub != dsa.DSAPublicNumbers(2, dsa.DSAParameterNumbers(1, 2, 3)) + assert pub != dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(2, 2, 3)) + assert pub != dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 3, 3)) + assert pub != dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 4)) + + def test_private_numbers_eq(self): + pub = dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 3)) + priv = dsa.DSAPrivateNumbers(1, pub) + assert priv == dsa.DSAPrivateNumbers( + 1, dsa.DSAPublicNumbers( + 1, dsa.DSAParameterNumbers(1, 2, 3) + ) + ) + + def test_private_numbers_ne(self): + pub = dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 3)) + priv = dsa.DSAPrivateNumbers(1, pub) + assert priv != dsa.DSAPrivateNumbers( + 2, dsa.DSAPublicNumbers( + 1, dsa.DSAParameterNumbers(1, 2, 3) + ) + ) + assert priv != dsa.DSAPrivateNumbers( + 1, dsa.DSAPublicNumbers( + 2, dsa.DSAParameterNumbers(1, 2, 3) + ) + ) + assert priv != dsa.DSAPrivateNumbers( + 1, dsa.DSAPublicNumbers( + 1, dsa.DSAParameterNumbers(2, 2, 3) + ) + ) + assert priv != dsa.DSAPrivateNumbers( + 1, dsa.DSAPublicNumbers( + 1, dsa.DSAParameterNumbers(1, 3, 3) + ) + ) + assert priv != dsa.DSAPrivateNumbers( + 1, dsa.DSAPublicNumbers( + 1, dsa.DSAParameterNumbers(1, 2, 4) + ) + ) diff --git a/tests/hazmat/primitives/test_ec.py b/tests/hazmat/primitives/test_ec.py index a006f01f..4c09ceac 100644 --- a/tests/hazmat/primitives/test_ec.py +++ b/tests/hazmat/primitives/test_ec.py @@ -360,3 +360,38 @@ class TestECDSAVectors(object): numbers = ec.EllipticCurvePrivateNumbers(1, pub_numbers) assert numbers.private_key(b) == b"private_key" assert pub_numbers.public_key(b) == b"public_key" + + +class TestECNumbersEquality(object): + def test_public_numbers_eq(self): + pub = ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) + assert pub == ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) + + def test_public_numbers_ne(self): + pub = ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) + assert pub != ec.EllipticCurvePublicNumbers(1, 2, ec.SECP384R1()) + assert pub != ec.EllipticCurvePublicNumbers(1, 3, ec.SECP192R1()) + assert pub != ec.EllipticCurvePublicNumbers(2, 2, ec.SECP192R1()) + + def test_private_numbers_eq(self): + pub = ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) + priv = ec.EllipticCurvePrivateNumbers(1, pub) + assert priv == ec.EllipticCurvePrivateNumbers( + 1, ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) + ) + + def test_private_numbers_ne(self): + pub = ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) + priv = ec.EllipticCurvePrivateNumbers(1, pub) + assert priv != ec.EllipticCurvePrivateNumbers( + 2, ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) + ) + assert priv != ec.EllipticCurvePrivateNumbers( + 1, ec.EllipticCurvePublicNumbers(2, 2, ec.SECP192R1()) + ) + assert priv != ec.EllipticCurvePrivateNumbers( + 1, ec.EllipticCurvePublicNumbers(1, 3, ec.SECP192R1()) + ) + assert priv != ec.EllipticCurvePrivateNumbers( + 1, ec.EllipticCurvePublicNumbers(1, 2, ec.SECP521R1()) + ) diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index 581976ae..c0a8aace 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -18,7 +18,9 @@ from cryptography.exceptions import ( from cryptography.hazmat.backends.interfaces import RSABackend from cryptography.hazmat.primitives import hashes, interfaces from cryptography.hazmat.primitives.asymmetric import padding, rsa -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicNumbers +from cryptography.hazmat.primitives.asymmetric.rsa import ( + RSAPrivateNumbers, RSAPublicNumbers +) from .fixtures_rsa import ( RSA_KEY_1024, RSA_KEY_1025, RSA_KEY_1026, RSA_KEY_1027, RSA_KEY_1028, @@ -1647,3 +1649,50 @@ class TestRSANumbers(object): def test_public_number_repr(self): num = RSAPublicNumbers(1, 1) assert repr(num) == "" + + +class TestRSANumbersEquality(object): + def test_public_numbers_eq(self): + num = RSAPublicNumbers(1, 2) + num2 = RSAPublicNumbers(1, 2) + assert num == num2 + + def test_public_numbers_ne(self): + num = RSAPublicNumbers(1, 2) + assert num != RSAPublicNumbers(2, 2) + assert num != RSAPublicNumbers(1, 3) + + def test_private_numbers_eq(self): + pub = RSAPublicNumbers(1, 2) + num = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, pub) + pub2 = RSAPublicNumbers(1, 2) + num2 = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, pub2) + assert num == num2 + + def test_private_numbers_ne(self): + pub = RSAPublicNumbers(1, 2) + num = RSAPrivateNumbers(1, 2, 3, 4, 5, 6, pub) + assert num != RSAPrivateNumbers( + 1, 2, 3, 4, 5, 7, RSAPublicNumbers(1, 2) + ) + assert num != RSAPrivateNumbers( + 1, 2, 3, 4, 4, 6, RSAPublicNumbers(1, 2) + ) + assert num != RSAPrivateNumbers( + 1, 2, 3, 5, 5, 6, RSAPublicNumbers(1, 2) + ) + assert num != RSAPrivateNumbers( + 1, 2, 4, 4, 5, 6, RSAPublicNumbers(1, 2) + ) + assert num != RSAPrivateNumbers( + 1, 3, 3, 4, 5, 6, RSAPublicNumbers(1, 2) + ) + assert num != RSAPrivateNumbers( + 2, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 2) + ) + assert num != RSAPrivateNumbers( + 1, 2, 3, 4, 5, 6, RSAPublicNumbers(2, 2) + ) + assert num != RSAPrivateNumbers( + 1, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 3) + ) -- cgit v1.2.3 From 285edf80abd3b1b59384e1021d10773150b1c3c3 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Wed, 10 Dec 2014 18:21:46 -0600 Subject: add NotImplemented handling --- src/cryptography/hazmat/primitives/asymmetric/dsa.py | 9 +++++++++ src/cryptography/hazmat/primitives/asymmetric/ec.py | 6 ++++++ src/cryptography/hazmat/primitives/asymmetric/rsa.py | 6 ++++++ tests/hazmat/primitives/test_dsa.py | 3 +++ tests/hazmat/primitives/test_ec.py | 2 ++ tests/hazmat/primitives/test_rsa.py | 2 ++ 6 files changed, 28 insertions(+) diff --git a/src/cryptography/hazmat/primitives/asymmetric/dsa.py b/src/cryptography/hazmat/primitives/asymmetric/dsa.py index 03e49ba5..9b06f3e6 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/dsa.py +++ b/src/cryptography/hazmat/primitives/asymmetric/dsa.py @@ -60,6 +60,9 @@ class DSAParameterNumbers(object): return backend.load_dsa_parameter_numbers(self) def __eq__(self, other): + if not isinstance(other, DSAParameterNumbers): + return NotImplemented + return self.p == other.p and self.q == other.q and self.g == other.g def __ne__(self, other): @@ -86,6 +89,9 @@ class DSAPublicNumbers(object): return backend.load_dsa_public_numbers(self) def __eq__(self, other): + if not isinstance(other, DSAPublicNumbers): + return NotImplemented + return ( self.y == other.y and self.parameter_numbers == other.parameter_numbers @@ -114,6 +120,9 @@ class DSAPrivateNumbers(object): return backend.load_dsa_private_numbers(self) def __eq__(self, other): + if not isinstance(other, DSAPrivateNumbers): + return NotImplemented + return ( self.x == other.x and self.public_numbers == other.public_numbers ) diff --git a/src/cryptography/hazmat/primitives/asymmetric/ec.py b/src/cryptography/hazmat/primitives/asymmetric/ec.py index cebd351c..202f1c97 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/ec.py +++ b/src/cryptography/hazmat/primitives/asymmetric/ec.py @@ -162,6 +162,9 @@ class EllipticCurvePublicNumbers(object): y = utils.read_only_property("_y") def __eq__(self, other): + if not isinstance(other, EllipticCurvePublicNumbers): + return NotImplemented + return ( self.x == other.x and self.y == other.y and @@ -197,6 +200,9 @@ class EllipticCurvePrivateNumbers(object): public_numbers = utils.read_only_property("_public_numbers") def __eq__(self, other): + if not isinstance(other, EllipticCurvePrivateNumbers): + return NotImplemented + return ( self.private_value == other.private_value and self.public_numbers == other.public_numbers diff --git a/src/cryptography/hazmat/primitives/asymmetric/rsa.py b/src/cryptography/hazmat/primitives/asymmetric/rsa.py index ff397225..0cc6b22b 100644 --- a/src/cryptography/hazmat/primitives/asymmetric/rsa.py +++ b/src/cryptography/hazmat/primitives/asymmetric/rsa.py @@ -161,6 +161,9 @@ class RSAPrivateNumbers(object): return backend.load_rsa_private_numbers(self) def __eq__(self, other): + if not isinstance(other, RSAPrivateNumbers): + return NotImplemented + return ( self.p == other.p and self.q == other.q and @@ -196,6 +199,9 @@ class RSAPublicNumbers(object): return "".format(self._e, self._n) def __eq__(self, other): + if not isinstance(other, RSAPublicNumbers): + return NotImplemented + return self.e == other.e and self.n == other.n def __ne__(self, other): diff --git a/tests/hazmat/primitives/test_dsa.py b/tests/hazmat/primitives/test_dsa.py index 5edb6cd6..8c0fb80c 100644 --- a/tests/hazmat/primitives/test_dsa.py +++ b/tests/hazmat/primitives/test_dsa.py @@ -717,6 +717,7 @@ class TestDSANumberEquality(object): assert param != dsa.DSAParameterNumbers(1, 2, 4) assert param != dsa.DSAParameterNumbers(1, 1, 3) assert param != dsa.DSAParameterNumbers(2, 2, 3) + assert param != object() def test_public_numbers_eq(self): pub = dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 3)) @@ -728,6 +729,7 @@ class TestDSANumberEquality(object): assert pub != dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(2, 2, 3)) assert pub != dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 3, 3)) assert pub != dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 4)) + assert pub != object() def test_private_numbers_eq(self): pub = dsa.DSAPublicNumbers(1, dsa.DSAParameterNumbers(1, 2, 3)) @@ -766,3 +768,4 @@ class TestDSANumberEquality(object): 1, dsa.DSAParameterNumbers(1, 2, 4) ) ) + assert priv != object() diff --git a/tests/hazmat/primitives/test_ec.py b/tests/hazmat/primitives/test_ec.py index 4c09ceac..84c447c1 100644 --- a/tests/hazmat/primitives/test_ec.py +++ b/tests/hazmat/primitives/test_ec.py @@ -372,6 +372,7 @@ class TestECNumbersEquality(object): assert pub != ec.EllipticCurvePublicNumbers(1, 2, ec.SECP384R1()) assert pub != ec.EllipticCurvePublicNumbers(1, 3, ec.SECP192R1()) assert pub != ec.EllipticCurvePublicNumbers(2, 2, ec.SECP192R1()) + assert pub != object() def test_private_numbers_eq(self): pub = ec.EllipticCurvePublicNumbers(1, 2, ec.SECP192R1()) @@ -395,3 +396,4 @@ class TestECNumbersEquality(object): assert priv != ec.EllipticCurvePrivateNumbers( 1, ec.EllipticCurvePublicNumbers(1, 2, ec.SECP521R1()) ) + assert priv != object() diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index c0a8aace..095ed037 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -1661,6 +1661,7 @@ class TestRSANumbersEquality(object): num = RSAPublicNumbers(1, 2) assert num != RSAPublicNumbers(2, 2) assert num != RSAPublicNumbers(1, 3) + assert num != object() def test_private_numbers_eq(self): pub = RSAPublicNumbers(1, 2) @@ -1696,3 +1697,4 @@ class TestRSANumbersEquality(object): assert num != RSAPrivateNumbers( 1, 2, 3, 4, 5, 6, RSAPublicNumbers(1, 3) ) + assert num != object() -- cgit v1.2.3