diff options
Diffstat (limited to 'tests/hazmat')
| -rw-r--r-- | tests/hazmat/primitives/test_rsa.py | 145 | 
1 files changed, 103 insertions, 42 deletions
diff --git a/tests/hazmat/primitives/test_rsa.py b/tests/hazmat/primitives/test_rsa.py index c458a662..cc87d981 100644 --- a/tests/hazmat/primitives/test_rsa.py +++ b/tests/hazmat/primitives/test_rsa.py @@ -39,7 +39,7 @@ class DummyPadding(object):  class DummyMGF(object): -    pass +    _salt_length = 0  def _modinv(e, m): @@ -454,10 +454,8 @@ class TestRSASignature(object):          )          signer = private_key.signer(              padding.PSS( -                mgf=padding.MGF1( -                    algorithm=hashes.SHA1(), -                    salt_length=padding.MGF1.MAX_LENGTH -                ) +                mgf=padding.MGF1(algorithm=hashes.SHA1()), +                salt_length=padding.PSS.MAX_LENGTH              ),              hashes.SHA1(),              backend @@ -471,6 +469,23 @@ class TestRSASignature(object):          verifier = public_key.verifier(              signature,              padding.PSS( +                mgf=padding.MGF1(algorithm=hashes.SHA1()), +                salt_length=padding.PSS.MAX_LENGTH +            ), +            hashes.SHA1(), +            backend +        ) +        verifier.update(binascii.unhexlify(example["message"])) +        verifier.verify() + +    def test_deprecated_pss_mgf1_salt_length(self, backend): +        private_key = rsa.RSAPrivateKey.generate( +            public_exponent=65537, +            key_size=512, +            backend=backend +        ) +        signer = private_key.signer( +            padding.PSS(                  mgf=padding.MGF1(                      algorithm=hashes.SHA1(),                      salt_length=padding.MGF1.MAX_LENGTH @@ -479,7 +494,21 @@ class TestRSASignature(object):              hashes.SHA1(),              backend          ) -        verifier.update(binascii.unhexlify(example["message"])) +        signer.update(b"so deprecated") +        signature = signer.finalize() +        assert len(signature) == math.ceil(private_key.key_size / 8.0) +        verifier = private_key.public_key().verifier( +            signature, +            padding.PSS( +                mgf=padding.MGF1( +                    algorithm=hashes.SHA1(), +                    salt_length=padding.MGF1.MAX_LENGTH +                ) +            ), +            hashes.SHA1(), +            backend +        ) +        verifier.update(b"so deprecated")          verifier.verify()      @pytest.mark.parametrize( @@ -498,10 +527,8 @@ class TestRSASignature(object):          )          public_key = private_key.public_key()          pss = padding.PSS( -            mgf=padding.MGF1( -                algorithm=hash_alg, -                salt_length=padding.MGF1.MAX_LENGTH -            ) +            mgf=padding.MGF1(hash_alg), +            salt_length=padding.PSS.MAX_LENGTH          )          signer = private_key.signer(              pss, @@ -531,10 +558,8 @@ class TestRSASignature(object):          )          signer = private_key.signer(              padding.PSS( -                mgf=padding.MGF1( -                    algorithm=hashes.SHA1(), -                    salt_length=padding.MGF1.MAX_LENGTH -                ) +                mgf=padding.MGF1(hashes.SHA1()), +                salt_length=padding.PSS.MAX_LENGTH              ),              hashes.SHA512(),              backend @@ -555,10 +580,8 @@ class TestRSASignature(object):          with pytest.raises(ValueError):              private_key.signer(                  padding.PSS( -                    mgf=padding.MGF1( -                        algorithm=hashes.SHA1(), -                        salt_length=padding.MGF1.MAX_LENGTH -                    ) +                    mgf=padding.MGF1(hashes.SHA1()), +                    salt_length=padding.PSS.MAX_LENGTH                  ),                  hashes.SHA512(),                  backend @@ -572,10 +595,8 @@ class TestRSASignature(object):          )          signer = private_key.signer(              padding.PSS( -                mgf=padding.MGF1( -                    algorithm=hashes.SHA1(), -                    salt_length=1000000 -                ) +                mgf=padding.MGF1(hashes.SHA1()), +                salt_length=1000000              ),              hashes.SHA1(),              backend @@ -722,10 +743,8 @@ class TestRSAVerification(object):          verifier = public_key.verifier(              binascii.unhexlify(example["signature"]),              padding.PSS( -                mgf=padding.MGF1( -                    algorithm=hashes.SHA1(), -                    salt_length=20 -                ) +                mgf=padding.MGF1(algorithm=hashes.SHA1()), +                salt_length=20              ),              hashes.SHA1(),              backend @@ -749,10 +768,8 @@ class TestRSAVerification(object):          verifier = public_key.verifier(              signature,              padding.PSS( -                mgf=padding.MGF1( -                    algorithm=hashes.SHA1(), -                    salt_length=padding.MGF1.MAX_LENGTH -                ) +                mgf=padding.MGF1(algorithm=hashes.SHA1()), +                salt_length=padding.PSS.MAX_LENGTH              ),              hashes.SHA1(),              backend @@ -779,10 +796,8 @@ class TestRSAVerification(object):          verifier = public_key.verifier(              signature,              padding.PSS( -                mgf=padding.MGF1( -                    algorithm=hashes.SHA1(), -                    salt_length=padding.MGF1.MAX_LENGTH -                ) +                mgf=padding.MGF1(algorithm=hashes.SHA1()), +                salt_length=padding.PSS.MAX_LENGTH              ),              hashes.SHA1(),              backend @@ -809,10 +824,8 @@ class TestRSAVerification(object):          verifier = public_key.verifier(              signature,              padding.PSS( -                mgf=padding.MGF1( -                    algorithm=hashes.SHA1(), -                    salt_length=padding.MGF1.MAX_LENGTH -                ) +                mgf=padding.MGF1(algorithm=hashes.SHA1()), +                salt_length=padding.PSS.MAX_LENGTH              ),              hashes.SHA1(),              backend @@ -904,10 +917,8 @@ class TestRSAVerification(object):              public_key.verifier(                  signature,                  padding.PSS( -                    mgf=padding.MGF1( -                        algorithm=hashes.SHA1(), -                        salt_length=padding.MGF1.MAX_LENGTH -                    ) +                    mgf=padding.MGF1(algorithm=hashes.SHA1()), +                    salt_length=padding.PSS.MAX_LENGTH                  ),                  hashes.SHA512(),                  backend @@ -1113,7 +1124,57 @@ class TestRSAPKCS1Verification(object):      )) +class TestPSS(object): +    def test_deprecation_warning(self): +        pytest.deprecated_call( +            padding.PSS, +            mgf=padding.MGF1(hashes.SHA1(), 20) +        ) + +    def test_invalid_salt_length_not_integer(self): +        with pytest.raises(TypeError): +            padding.PSS( +                mgf=padding.MGF1( +                    hashes.SHA1() +                ), +                salt_length=b"not_a_length" +            ) + +    def test_invalid_salt_length_negative_integer(self): +        with pytest.raises(ValueError): +            padding.PSS( +                mgf=padding.MGF1( +                    hashes.SHA1() +                ), +                salt_length=-1 +            ) + +    def test_no_salt_length_supplied_pss_or_mgf1(self): +        with pytest.raises(ValueError): +            padding.PSS(mgf=padding.MGF1(hashes.SHA1())) + +    def test_valid_pss_parameters(self): +        algorithm = hashes.SHA1() +        salt_length = algorithm.digest_size +        mgf = padding.MGF1(algorithm) +        pss = padding.PSS(mgf=mgf, salt_length=salt_length) +        assert pss._mgf == mgf +        assert pss._salt_length == salt_length + +    def test_valid_pss_parameters_maximum(self): +        algorithm = hashes.SHA1() +        mgf = padding.MGF1(algorithm) +        pss = padding.PSS(mgf=mgf, salt_length=padding.PSS.MAX_LENGTH) +        assert pss._mgf == mgf +        assert pss._salt_length == padding.PSS.MAX_LENGTH + +  class TestMGF1(object): +    def test_deprecation_warning(self): +        pytest.deprecated_call( +            padding.MGF1, algorithm=hashes.SHA1(), salt_length=20 +        ) +      def test_invalid_hash_algorithm(self):          with pytest.raises(TypeError):              padding.MGF1(b"not_a_hash", 0)  | 
