diff options
Diffstat (limited to 'tests/hazmat/primitives/utils.py')
| -rw-r--r-- | tests/hazmat/primitives/utils.py | 101 |
1 files changed, 75 insertions, 26 deletions
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index e148bc63..4aa5ce71 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -18,6 +18,9 @@ from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand +from cryptography.hazmat.primitives.kdf.kbkdf import ( + CounterLocation, KBKDFHMAC, Mode +) from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from ...utils import load_vectors_from_file @@ -44,6 +47,10 @@ def generate_encrypt_test(param_loader, path, file_names, cipher_factory, def encrypt_test(backend, cipher_factory, mode_factory, params): + assert backend.cipher_supported( + cipher_factory(**params), mode_factory(**params) + ) + plaintext = params["plaintext"] ciphertext = params["ciphertext"] cipher = Cipher( @@ -161,16 +168,15 @@ def hash_test(backend, algorithm, params): assert m.finalize() == binascii.unhexlify(expected_md) -def generate_base_hash_test(algorithm, digest_size, block_size): +def generate_base_hash_test(algorithm, digest_size): def test_base_hash(self, backend): - base_hash_test(backend, algorithm, digest_size, block_size) + base_hash_test(backend, algorithm, digest_size) return test_base_hash -def base_hash_test(backend, algorithm, digest_size, block_size): +def base_hash_test(backend, algorithm, digest_size): m = hashes.Hash(algorithm, backend=backend) assert m.algorithm.digest_size == digest_size - assert m.algorithm.block_size == block_size m_copy = m.copy() assert m != m_copy assert m._ctx != m_copy._ctx @@ -182,18 +188,6 @@ def base_hash_test(backend, algorithm, digest_size, block_size): assert copy.finalize() == m.finalize() -def generate_long_string_hash_test(hash_factory, md): - def test_long_string_hash(self, backend): - long_string_hash_test(backend, hash_factory, md) - return test_long_string_hash - - -def long_string_hash_test(backend, algorithm, md): - m = hashes.Hash(algorithm, backend=backend) - m.update(b"a" * 1000000) - assert m.finalize() == binascii.unhexlify(md.lower().encode("ascii")) - - def generate_base_hmac_test(hash_cls): def test_base_hmac(self, backend): base_hmac_test(backend, hash_cls) @@ -296,8 +290,6 @@ def aead_tag_exception_test(backend, cipher_factory, mode_factory): mode_factory(binascii.unhexlify(b"0" * 24)), backend ) - with pytest.raises(ValueError): - cipher.decryptor() with pytest.raises(ValueError): mode_factory(binascii.unhexlify(b"0" * 24), b"000") @@ -370,6 +362,57 @@ def generate_hkdf_test(param_loader, path, file_names, algorithm): return test_hkdf +def generate_kbkdf_counter_mode_test(param_loader, path, file_names): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_kbkdf(self, backend, params): + kbkdf_counter_mode_test(backend, params) + return test_kbkdf + + +def kbkdf_counter_mode_test(backend, params): + supported_algorithms = { + 'hmac_sha1': hashes.SHA1, + 'hmac_sha224': hashes.SHA224, + 'hmac_sha256': hashes.SHA256, + 'hmac_sha384': hashes.SHA384, + 'hmac_sha512': hashes.SHA512, + } + + supported_counter_locations = { + "before_fixed": CounterLocation.BeforeFixed, + "after_fixed": CounterLocation.AfterFixed, + } + + algorithm = supported_algorithms.get(params.get('prf')) + if algorithm is None or not backend.hmac_supported(algorithm()): + pytest.skip("KBKDF does not support algorithm: {}".format( + params.get('prf') + )) + + ctr_loc = supported_counter_locations.get(params.get("ctrlocation")) + if ctr_loc is None or not isinstance(ctr_loc, CounterLocation): + pytest.skip("Does not support counter location: {}".format( + params.get('ctrlocation') + )) + + ctrkdf = KBKDFHMAC( + algorithm(), + Mode.CounterMode, + params['l'] // 8, + params['rlen'] // 8, + None, + ctr_loc, + None, + None, + binascii.unhexlify(params['fixedinputdata']), + backend=backend) + + ko = ctrkdf.derive(binascii.unhexlify(params['ki'])) + assert binascii.hexlify(ko) == params["ko"] + + def generate_rsa_verification_test(param_loader, path, file_names, hash_alg, pad_factory): all_params = _load_all_params(path, file_names, param_loader) @@ -390,17 +433,23 @@ def rsa_verification_test(backend, params, hash_alg, pad_factory): ) public_key = public_numbers.public_key(backend) pad = pad_factory(params, hash_alg) - verifier = public_key.verifier( - binascii.unhexlify(params["s"]), - pad, - hash_alg - ) - verifier.update(binascii.unhexlify(params["msg"])) + signature = binascii.unhexlify(params["s"]) + msg = binascii.unhexlify(params["msg"]) if params["fail"]: with pytest.raises(InvalidSignature): - verifier.verify() + public_key.verify( + signature, + msg, + pad, + hash_alg + ) else: - verifier.verify() + public_key.verify( + signature, + msg, + pad, + hash_alg + ) def _check_rsa_private_numbers(skey): |
