From e5c5eec8cc5b0a4a1faa47013c0b90d01483b125 Mon Sep 17 00:00:00 2001 From: Alex Gaynor Date: Fri, 13 Dec 2013 08:10:20 -0800 Subject: Clean up test generation to not use generators anymore and use parametrization --- tests/hazmat/primitives/utils.py | 541 +++++++++++++++------------------------ 1 file changed, 209 insertions(+), 332 deletions(-) (limited to 'tests/hazmat') diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index 705983a0..7f67ad40 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -3,7 +3,6 @@ import os import pytest -from cryptography.hazmat.bindings import _ALL_BACKENDS from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.exceptions import ( @@ -13,378 +12,256 @@ from cryptography.exceptions import ( from ...utils import load_vectors_from_file +def _load_all_params(path, file_names, param_loader): + all_params = [] + for file_name in file_names: + all_params.extend( + load_vectors_from_file(os.path.join(path, file_name), param_loader) + ) + return all_params + + def generate_encrypt_test(param_loader, path, file_names, cipher_factory, mode_factory, only_if=lambda backend: True, skip_message=None): - def test_encryption(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - encrypt_test, - backend, - cipher_factory, - mode_factory, - params, - only_if, - skip_message - ) - return test_encryption - + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_encryption(self, backend, params): + if not only_if(backend): + pytest.skip(skip_message) + plaintext = params.pop("plaintext") + ciphertext = params.pop("ciphertext") + cipher = Cipher( + cipher_factory(**params), + mode_factory(**params), + backend=backend + ) + encryptor = cipher.encryptor() + actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) + actual_ciphertext += encryptor.finalize() + assert actual_ciphertext == binascii.unhexlify(ciphertext) + decryptor = cipher.decryptor() + actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) + actual_plaintext += decryptor.finalize() + assert actual_plaintext == binascii.unhexlify(plaintext) -def encrypt_test(backend, cipher_factory, mode_factory, params, only_if, - skip_message): - if not only_if(backend): - pytest.skip(skip_message) - plaintext = params.pop("plaintext") - ciphertext = params.pop("ciphertext") - cipher = Cipher( - cipher_factory(**params), - mode_factory(**params), - backend=backend - ) - encryptor = cipher.encryptor() - actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) - actual_ciphertext += encryptor.finalize() - assert actual_ciphertext == binascii.unhexlify(ciphertext) - decryptor = cipher.decryptor() - actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) - actual_plaintext += decryptor.finalize() - assert actual_plaintext == binascii.unhexlify(plaintext) + return test_encryption def generate_aead_test(param_loader, path, file_names, cipher_factory, mode_factory, only_if, skip_message): - def test_aead(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - aead_test, - backend, - cipher_factory, - mode_factory, - params, - only_if, - skip_message - ) + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_aead(self, backend, params): + if not only_if(backend): + pytest.skip(skip_message) + if params.get("pt") is not None: + plaintext = params.pop("pt") + ciphertext = params.pop("ct") + aad = params.pop("aad") + if params.get("fail") is True: + cipher = Cipher( + cipher_factory(binascii.unhexlify(params["key"])), + mode_factory(binascii.unhexlify(params["iv"]), + binascii.unhexlify(params["tag"])), + backend + ) + decryptor = cipher.decryptor() + decryptor.authenticate_additional_data(binascii.unhexlify(aad)) + actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) + with pytest.raises(InvalidTag): + decryptor.finalize() + else: + cipher = Cipher( + cipher_factory(binascii.unhexlify(params["key"])), + mode_factory(binascii.unhexlify(params["iv"]), None), + backend + ) + encryptor = cipher.encryptor() + encryptor.authenticate_additional_data(binascii.unhexlify(aad)) + actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) + actual_ciphertext += encryptor.finalize() + tag_len = len(params["tag"]) + assert binascii.hexlify(encryptor.tag)[:tag_len] == params["tag"] + cipher = Cipher( + cipher_factory(binascii.unhexlify(params["key"])), + mode_factory(binascii.unhexlify(params["iv"]), + binascii.unhexlify(params["tag"])), + backend + ) + decryptor = cipher.decryptor() + decryptor.authenticate_additional_data(binascii.unhexlify(aad)) + actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) + actual_plaintext += decryptor.finalize() + assert actual_plaintext == binascii.unhexlify(plaintext) + return test_aead -def aead_test(backend, cipher_factory, mode_factory, params, only_if, - skip_message): - if not only_if(backend): - pytest.skip(skip_message) - if params.get("pt") is not None: - plaintext = params.pop("pt") - ciphertext = params.pop("ct") - aad = params.pop("aad") - if params.get("fail") is True: - cipher = Cipher( - cipher_factory(binascii.unhexlify(params["key"])), - mode_factory(binascii.unhexlify(params["iv"]), - binascii.unhexlify(params["tag"])), - backend - ) - decryptor = cipher.decryptor() - decryptor.authenticate_additional_data(binascii.unhexlify(aad)) - actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) - with pytest.raises(InvalidTag): - decryptor.finalize() - else: - cipher = Cipher( - cipher_factory(binascii.unhexlify(params["key"])), - mode_factory(binascii.unhexlify(params["iv"]), None), - backend - ) +def generate_stream_encryption_test(param_loader, path, file_names, + cipher_factory, only_if=None, + skip_message=None): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_stream_encryption(self, backend, params): + if not only_if(backend): + pytest.skip(skip_message) + plaintext = params.pop("plaintext") + ciphertext = params.pop("ciphertext") + offset = params.pop("offset") + cipher = Cipher(cipher_factory(**params), None, backend=backend) encryptor = cipher.encryptor() - encryptor.authenticate_additional_data(binascii.unhexlify(aad)) + # throw away offset bytes + encryptor.update(b"\x00" * int(offset)) actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) actual_ciphertext += encryptor.finalize() - tag_len = len(params["tag"]) - assert binascii.hexlify(encryptor.tag)[:tag_len] == params["tag"] - cipher = Cipher( - cipher_factory(binascii.unhexlify(params["key"])), - mode_factory(binascii.unhexlify(params["iv"]), - binascii.unhexlify(params["tag"])), - backend - ) + assert actual_ciphertext == binascii.unhexlify(ciphertext) decryptor = cipher.decryptor() - decryptor.authenticate_additional_data(binascii.unhexlify(aad)) + decryptor.update(b"\x00" * int(offset)) actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) actual_plaintext += decryptor.finalize() assert actual_plaintext == binascii.unhexlify(plaintext) - -def generate_stream_encryption_test(param_loader, path, file_names, - cipher_factory, only_if=None, - skip_message=None): - def test_stream_encryption(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - stream_encryption_test, - backend, - cipher_factory, - params, - only_if, - skip_message - ) return test_stream_encryption -def stream_encryption_test(backend, cipher_factory, params, only_if, - skip_message): - if not only_if(backend): - pytest.skip(skip_message) - plaintext = params.pop("plaintext") - ciphertext = params.pop("ciphertext") - offset = params.pop("offset") - cipher = Cipher(cipher_factory(**params), None, backend=backend) - encryptor = cipher.encryptor() - # throw away offset bytes - encryptor.update(b"\x00" * int(offset)) - actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext)) - actual_ciphertext += encryptor.finalize() - assert actual_ciphertext == binascii.unhexlify(ciphertext) - decryptor = cipher.decryptor() - decryptor.update(b"\x00" * int(offset)) - actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext)) - actual_plaintext += decryptor.finalize() - assert actual_plaintext == binascii.unhexlify(plaintext) - - -def generate_hash_test(param_loader, path, file_names, hash_cls, - only_if=None, skip_message=None): - def test_hash(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - hash_test, - backend, - hash_cls, - params, - only_if, - skip_message - ) - return test_hash +def generate_hash_test(param_loader, path, file_names, algorithm, only_if=None, + skip_message=None): + all_params = _load_all_params(path, file_names, param_loader) + @pytest.mark.parametrize("params", all_params) + def test_hash(self, backend, params): + if only_if is not None and not only_if(backend): + pytest.skip(skip_message) + msg, md = params + m = hashes.Hash(algorithm, backend=backend) + m.update(binascii.unhexlify(msg)) + expected_md = md.replace(" ", "").lower().encode("ascii") + assert m.finalize() == binascii.unhexlify(expected_md) -def hash_test(backend, algorithm, params, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) - msg = params[0] - md = params[1] - m = hashes.Hash(algorithm, backend=backend) - m.update(binascii.unhexlify(msg)) - expected_md = md.replace(" ", "").lower().encode("ascii") - assert m.finalize() == binascii.unhexlify(expected_md) - - -def generate_base_hash_test(algorithm, digest_size, block_size, - only_if=None, skip_message=None): - def test_base_hash(self): - for backend in _ALL_BACKENDS: - yield ( - base_hash_test, - backend, - algorithm, - digest_size, - block_size, - only_if, - skip_message, - ) - return test_base_hash + return test_hash -def base_hash_test(backend, algorithm, digest_size, block_size, only_if, - skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) +def generate_base_hash_test(algorithm, digest_size, block_size, only_if=None, + skip_message=None): + def test_base_hash(self, backend): + if only_if is not None and not only_if(backend): + pytest.skip(skip_message) - m = hashes.Hash(algorithm, backend=backend) - assert m.algorithm.digest_size == digest_size - assert m.algorithm.block_size == block_size - m_copy = m.copy() - assert m != m_copy - assert m._ctx != m_copy._ctx + m = hashes.Hash(algorithm, backend=backend) + assert m.algorithm.digest_size == digest_size + assert m.algorithm.block_size == block_size + m_copy = m.copy() + assert m != m_copy + assert m._ctx != m_copy._ctx - m.update(b"abc") - copy = m.copy() - copy.update(b"123") - m.update(b"123") - assert copy.finalize() == m.finalize() + m.update(b"abc") + copy = m.copy() + copy.update(b"123") + m.update(b"123") + assert copy.finalize() == m.finalize() + + return test_base_hash -def generate_long_string_hash_test(hash_factory, md, only_if=None, +def generate_long_string_hash_test(algorithm, md, only_if=None, skip_message=None): - def test_long_string_hash(self): - for backend in _ALL_BACKENDS: - yield( - long_string_hash_test, - backend, - hash_factory, - md, - only_if, - skip_message - ) + def test_long_string_hash(self, backend): + if only_if is not None and not only_if(backend): + pytest.skip(skip_message) + m = hashes.Hash(algorithm, backend=backend) + m.update(b"a" * 1000000) + assert m.finalize() == binascii.unhexlify(md.lower().encode("ascii")) + return test_long_string_hash -def long_string_hash_test(backend, algorithm, md, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) - m = hashes.Hash(algorithm, backend=backend) - m.update(b"a" * 1000000) - assert m.finalize() == binascii.unhexlify(md.lower().encode("ascii")) - - -def generate_hmac_test(param_loader, path, file_names, algorithm, - only_if=None, skip_message=None): - def test_hmac(self): - for backend in _ALL_BACKENDS: - for file_name in file_names: - for params in load_vectors_from_file( - os.path.join(path, file_name), - param_loader - ): - yield ( - hmac_test, - backend, - algorithm, - params, - only_if, - skip_message - ) +def generate_hmac_test(param_loader, path, file_names, algorithm, only_if=None, + skip_message=None): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_hmac(self, backend, params): + if only_if is not None and not only_if(backend): + pytest.skip(skip_message) + msg, md, key = params + h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) + h.update(binascii.unhexlify(msg)) + assert h.finalize() == binascii.unhexlify(md.encode("ascii")) + return test_hmac -def hmac_test(backend, algorithm, params, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) - msg = params[0] - md = params[1] - key = params[2] - h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) - h.update(binascii.unhexlify(msg)) - assert h.finalize() == binascii.unhexlify(md.encode("ascii")) - - -def generate_base_hmac_test(hash_cls, only_if=None, skip_message=None): - def test_base_hmac(self): - for backend in _ALL_BACKENDS: - yield ( - base_hmac_test, - backend, - hash_cls, - only_if, - skip_message, - ) +def generate_base_hmac_test(algorithm, only_if=None, skip_message=None): + def test_base_hmac(self, backend): + if only_if is not None and not only_if(backend): + pytest.skip(skip_message) + key = b"ab" + h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) + h_copy = h.copy() + assert h != h_copy + assert h._ctx != h_copy._ctx + return test_base_hmac -def base_hmac_test(backend, algorithm, only_if, skip_message): - if only_if is not None and not only_if(backend): - pytest.skip(skip_message) - key = b"ab" - h = hmac.HMAC(binascii.unhexlify(key), algorithm, backend=backend) - h_copy = h.copy() - assert h != h_copy - assert h._ctx != h_copy._ctx - - -def generate_aead_exception_test(cipher_factory, mode_factory, - only_if, skip_message): - def test_aead_exception(self): - for backend in _ALL_BACKENDS: - yield ( - aead_exception_test, - backend, - cipher_factory, - mode_factory, - only_if, - skip_message - ) - return test_aead_exception +def generate_aead_exception_test(cipher_factory, mode_factory, only_if, + skip_message): + def test_aead_exception(self, backend): + if not only_if(backend): + pytest.skip(skip_message) + cipher = Cipher( + cipher_factory(binascii.unhexlify(b"0" * 32)), + mode_factory(binascii.unhexlify(b"0" * 24)), + backend + ) + encryptor = cipher.encryptor() + encryptor.update(b"a" * 16) + with pytest.raises(NotYetFinalized): + encryptor.tag + with pytest.raises(AlreadyUpdated): + encryptor.authenticate_additional_data(b"b" * 16) + encryptor.finalize() + with pytest.raises(AlreadyFinalized): + encryptor.authenticate_additional_data(b"b" * 16) + with pytest.raises(AlreadyFinalized): + encryptor.update(b"b" * 16) + with pytest.raises(AlreadyFinalized): + encryptor.finalize() + cipher = Cipher( + cipher_factory(binascii.unhexlify(b"0" * 32)), + mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), + backend + ) + decryptor = cipher.decryptor() + decryptor.update(b"a" * 16) + with pytest.raises(AttributeError): + decryptor.tag + return test_aead_exception -def aead_exception_test(backend, cipher_factory, mode_factory, - only_if, skip_message): - if not only_if(backend): - pytest.skip(skip_message) - cipher = Cipher( - cipher_factory(binascii.unhexlify(b"0" * 32)), - mode_factory(binascii.unhexlify(b"0" * 24)), - backend - ) - encryptor = cipher.encryptor() - encryptor.update(b"a" * 16) - with pytest.raises(NotYetFinalized): - encryptor.tag - with pytest.raises(AlreadyUpdated): - encryptor.authenticate_additional_data(b"b" * 16) - encryptor.finalize() - with pytest.raises(AlreadyFinalized): - encryptor.authenticate_additional_data(b"b" * 16) - with pytest.raises(AlreadyFinalized): - encryptor.update(b"b" * 16) - with pytest.raises(AlreadyFinalized): - encryptor.finalize() - cipher = Cipher( - cipher_factory(binascii.unhexlify(b"0" * 32)), - mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), - backend - ) - decryptor = cipher.decryptor() - decryptor.update(b"a" * 16) - with pytest.raises(AttributeError): - decryptor.tag - - -def generate_aead_tag_exception_test(cipher_factory, mode_factory, - only_if, skip_message): - def test_aead_tag_exception(self): - for backend in _ALL_BACKENDS: - yield ( - aead_tag_exception_test, - backend, - cipher_factory, - mode_factory, - only_if, - skip_message - ) - return test_aead_tag_exception +def generate_aead_tag_exception_test(cipher_factory, mode_factory, only_if, + skip_message): + def test_aead_tag_exception(self, backend): + if not only_if(backend): + pytest.skip(skip_message) + cipher = Cipher( + cipher_factory(binascii.unhexlify(b"0" * 32)), + mode_factory(binascii.unhexlify(b"0" * 24)), + backend + ) + with pytest.raises(ValueError): + cipher.decryptor() + cipher = Cipher( + cipher_factory(binascii.unhexlify(b"0" * 32)), + mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), + backend + ) + with pytest.raises(ValueError): + cipher.encryptor() -def aead_tag_exception_test(backend, cipher_factory, mode_factory, - only_if, skip_message): - if not only_if(backend): - pytest.skip(skip_message) - cipher = Cipher( - cipher_factory(binascii.unhexlify(b"0" * 32)), - mode_factory(binascii.unhexlify(b"0" * 24)), - backend - ) - with pytest.raises(ValueError): - cipher.decryptor() - cipher = Cipher( - cipher_factory(binascii.unhexlify(b"0" * 32)), - mode_factory(binascii.unhexlify(b"0" * 24), b"0" * 16), - backend - ) - with pytest.raises(ValueError): - cipher.encryptor() + return test_aead_tag_exception -- cgit v1.2.3