diff options
author | Alex Gaynor <alex.gaynor@gmail.com> | 2014-01-01 08:11:13 -0800 |
---|---|---|
committer | Alex Gaynor <alex.gaynor@gmail.com> | 2014-01-01 08:11:13 -0800 |
commit | 2a160d6a159817dd9d08a84e77d102e328f9af4f (patch) | |
tree | 461b7c607367f243bb46996ec16ad7424e48d440 /tests/hazmat/primitives/utils.py | |
parent | 62aefffb1396190930074bf04c91459d1536bd0e (diff) | |
parent | 522487e5a7dd3004747da85c9f6c53fc5dc4de06 (diff) | |
download | cryptography-2a160d6a159817dd9d08a84e77d102e328f9af4f.tar.gz cryptography-2a160d6a159817dd9d08a84e77d102e328f9af4f.tar.bz2 cryptography-2a160d6a159817dd9d08a84e77d102e328f9af4f.zip |
Merge branch 'master' into validate-iv
Conflicts:
tests/hazmat/backends/test_openssl.py
tests/hazmat/primitives/test_block.py
Diffstat (limited to 'tests/hazmat/primitives/utils.py')
-rw-r--r-- | tests/hazmat/primitives/utils.py | 292 |
1 files changed, 91 insertions, 201 deletions
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index b06f9b29..cdcf84cb 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -3,7 +3,6 @@ import os import pytest -from cryptography.hazmat.backends import _ALL_BACKENDS from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.exceptions import ( @@ -13,34 +12,29 @@ from cryptography.exceptions import ( from ...utils import load_vectors_from_file +def _load_all_params(path, file_names, param_loader): + all_params = [] + for file_name in file_names: + all_params.extend( + load_vectors_from_file(os.path.join(path, file_name), param_loader) + ) + return all_params + + def generate_encrypt_test(param_loader, path, file_names, cipher_factory, - mode_factory, only_if=lambda backend: True, - skip_message=None): - def test_encryption(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - encrypt_test, - backend, - cipher_factory, - mode_factory, - params, - only_if, - skip_message - ) + mode_factory): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_encryption(self, backend, params): + encrypt_test(backend, cipher_factory, mode_factory, params) + return test_encryption -def encrypt_test(backend, cipher_factory, mode_factory, params, only_if, - skip_message): - if not only_if(backend): - pytest.skip(skip_message) - plaintext = params.pop("plaintext") - ciphertext = params.pop("ciphertext") +def encrypt_test(backend, cipher_factory, mode_factory, params): + plaintext = params["plaintext"] + ciphertext = params["ciphertext"] cipher = Cipher( cipher_factory(**params), mode_factory(**params), @@ -57,34 +51,21 @@ def encrypt_test(backend, cipher_factory, mode_factory, params, only_if, def generate_aead_test(param_loader, path, file_names, cipher_factory, - mode_factory, only_if, skip_message): - def test_aead(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - aead_test, - backend, - cipher_factory, - mode_factory, - params, - only_if, - skip_message - ) + mode_factory): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_aead(self, backend, params): + aead_test(backend, cipher_factory, mode_factory, params) + return test_aead -def aead_test(backend, cipher_factory, mode_factory, params, only_if, - skip_message): - if not only_if(backend): - pytest.skip(skip_message) +def aead_test(backend, cipher_factory, mode_factory, params): if params.get("pt") is not None: - plaintext = params.pop("pt") - ciphertext = params.pop("ct") - aad = params.pop("aad") + plaintext = params["pt"] + ciphertext = params["ct"] + aad = params["aad"] if params.get("fail") is True: cipher = Cipher( cipher_factory(binascii.unhexlify(params["key"])), @@ -123,33 +104,19 @@ def aead_test(backend, cipher_factory, mode_factory, params, only_if, def generate_stream_encryption_test(param_loader, path, file_names, - cipher_factory, only_if=None, - skip_message=None): - def test_stream_encryption(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - stream_encryption_test, - backend, - cipher_factory, - params, - only_if, - skip_message - ) + cipher_factory): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_stream_encryption(self, backend, params): + stream_encryption_test(backend, cipher_factory, params) return test_stream_encryption -def stream_encryption_test(backend, cipher_factory, params, only_if, - skip_message): - if not only_if(backend): - pytest.skip(skip_message) - plaintext = params.pop("plaintext") - ciphertext = params.pop("ciphertext") - offset = params.pop("offset") +def stream_encryption_test(backend, cipher_factory, params): + plaintext = params["plaintext"] + ciphertext = params["ciphertext"] + offset = params["offset"] cipher = Cipher(cipher_factory(**params), None, backend=backend) encryptor = cipher.encryptor() # throw away offset bytes @@ -164,29 +131,16 @@ def stream_encryption_test(backend, cipher_factory, params, only_if, assert actual_plaintext == binascii.unhexlify(plaintext) -def generate_hash_test(param_loader, path, file_names, hash_cls, - only_if=None, skip_message=None): - def test_hash(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - hash_test, - backend, - hash_cls, - params, - only_if, - skip_message - ) +def generate_hash_test(param_loader, path, file_names, hash_cls): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_hash(self, backend, params): + hash_test(backend, hash_cls, params) return test_hash -def hash_test(backend, algorithm, params, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) +def hash_test(backend, algorithm, params): msg = params[0] md = params[1] m = hashes.Hash(algorithm, backend=backend) @@ -195,27 +149,13 @@ def hash_test(backend, algorithm, params, only_if, skip_message): assert m.finalize() == binascii.unhexlify(expected_md) -def generate_base_hash_test(algorithm, digest_size, block_size, - only_if=None, skip_message=None): - def test_base_hash(self): - for backend in _ALL_BACKENDS: - yield ( - base_hash_test, - backend, - algorithm, - digest_size, - block_size, - only_if, - skip_message, - ) +def generate_base_hash_test(algorithm, digest_size, block_size): + def test_base_hash(self, backend): + base_hash_test(backend, algorithm, digest_size, block_size) return test_base_hash -def base_hash_test(backend, algorithm, digest_size, block_size, only_if, - skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) - +def base_hash_test(backend, algorithm, digest_size, block_size): m = hashes.Hash(algorithm, backend=backend) assert m.algorithm.digest_size == digest_size assert m.algorithm.block_size == block_size @@ -230,52 +170,42 @@ def base_hash_test(backend, algorithm, digest_size, block_size, only_if, assert copy.finalize() == m.finalize() -def generate_long_string_hash_test(hash_factory, md, only_if=None, - skip_message=None): - def test_long_string_hash(self): - for backend in _ALL_BACKENDS: - yield( - long_string_hash_test, - backend, - hash_factory, - md, - only_if, - skip_message - ) +def generate_long_string_hash_test(hash_factory, md): + def test_long_string_hash(self, backend): + long_string_hash_test(backend, hash_factory, md) return test_long_string_hash -def long_string_hash_test(backend, algorithm, md, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) +def long_string_hash_test(backend, algorithm, md): m = hashes.Hash(algorithm, backend=backend) m.update(b"a" * 1000000) assert m.finalize() == binascii.unhexlify(md.lower().encode("ascii")) -def generate_hmac_test(param_loader, path, file_names, algorithm, - only_if=None, skip_message=None): - def test_hmac(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - hmac_test, - backend, - algorithm, - params, - only_if, - skip_message - ) +def generate_base_hmac_test(hash_cls): + def test_base_hmac(self, backend): + base_hmac_test(backend, hash_cls) + return test_base_hmac + + +def base_hmac_test(backend, algorithm): + key = b"ab" + h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) + h_copy = h.copy() + assert h != h_copy + assert h._ctx != h_copy._ctx + + +def generate_hmac_test(param_loader, path, file_names, algorithm): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_hmac(self, backend, params): + hmac_test(backend, algorithm, params) return test_hmac -def hmac_test(backend, algorithm, params, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) +def hmac_test(backend, algorithm, params): msg = params[0] md = params[1] key = params[2] @@ -284,48 +214,13 @@ def hmac_test(backend, algorithm, params, only_if, skip_message): assert h.finalize() == binascii.unhexlify(md.encode("ascii")) -def generate_base_hmac_test(hash_cls, only_if=None, skip_message=None): - def test_base_hmac(self): - for backend in _ALL_BACKENDS: - yield ( - base_hmac_test, - backend, - hash_cls, - only_if, - skip_message, - ) - return test_base_hmac - - -def base_hmac_test(backend, algorithm, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) - key = b"ab" - h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) - h_copy = h.copy() - assert h != h_copy - assert h._ctx != h_copy._ctx - - -def generate_aead_exception_test(cipher_factory, mode_factory, - only_if, skip_message): - def test_aead_exception(self): - for backend in _ALL_BACKENDS: - yield ( - aead_exception_test, - backend, - cipher_factory, - mode_factory, - only_if, - skip_message - ) +def generate_aead_exception_test(cipher_factory, mode_factory): + def test_aead_exception(self, backend): + aead_exception_test(backend, cipher_factory, mode_factory) return test_aead_exception -def aead_exception_test(backend, cipher_factory, mode_factory, - only_if, skip_message): - if not only_if(backend): - pytest.skip(skip_message) +def aead_exception_test(backend, cipher_factory, mode_factory): cipher = Cipher( cipher_factory(binascii.unhexlify(b"0" * 32)), mode_factory(binascii.unhexlify(b"0" * 24)), @@ -355,25 +250,13 @@ def aead_exception_test(backend, cipher_factory, mode_factory, decryptor.tag -def generate_aead_tag_exception_test(cipher_factory, mode_factory, - only_if, skip_message): - def test_aead_tag_exception(self): - for backend in _ALL_BACKENDS: - yield ( - aead_tag_exception_test, - backend, - cipher_factory, - mode_factory, - only_if, - skip_message - ) +def generate_aead_tag_exception_test(cipher_factory, mode_factory): + def test_aead_tag_exception(self, backend): + aead_tag_exception_test(backend, cipher_factory, mode_factory) return test_aead_tag_exception -def aead_tag_exception_test(backend, cipher_factory, mode_factory, - only_if, skip_message): - if not only_if(backend): - pytest.skip(skip_message) +def aead_tag_exception_test(backend, cipher_factory, mode_factory): cipher = Cipher( cipher_factory(binascii.unhexlify(b"0" * 32)), mode_factory(binascii.unhexlify(b"0" * 24)), @@ -383,6 +266,13 @@ def aead_tag_exception_test(backend, cipher_factory, mode_factory, cipher.decryptor() cipher = Cipher( cipher_factory(binascii.unhexlify(b"0" * 32)), + mode_factory(binascii.unhexlify(b"0" * 24), b"000"), + backend + ) + with pytest.raises(ValueError): + cipher.decryptor() + cipher = Cipher( + cipher_factory(binascii.unhexlify(b"0" * 32)), mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), backend ) |