aboutsummaryrefslogtreecommitdiffstats
path: root/tests/hazmat/primitives/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/hazmat/primitives/utils.py')
-rw-r--r--tests/hazmat/primitives/utils.py105
1 files changed, 103 insertions, 2 deletions
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py
index 6c67ddb3..839ff822 100644
--- a/tests/hazmat/primitives/utils.py
+++ b/tests/hazmat/primitives/utils.py
@@ -4,9 +4,11 @@ import os
import pytest
from cryptography.hazmat.bindings import _ALL_BACKENDS
-from cryptography.hazmat.primitives import hashes
-from cryptography.hazmat.primitives import hmac
+from cryptography.hazmat.primitives import hashes, hmac
from cryptography.hazmat.primitives.ciphers import Cipher
+from cryptography.exceptions import (
+ AlreadyFinalized, NotFinalized,
+)
from ...utils import load_vectors_from_file
@@ -54,6 +56,72 @@ def encrypt_test(backend, cipher_factory, mode_factory, params, only_if,
assert actual_plaintext == binascii.unhexlify(plaintext)
+def generate_aead_test(param_loader, path, file_names, cipher_factory,
+ mode_factory, only_if, skip_message):
+ def test_aead(self):
+ for backend in _ALL_BACKENDS:
+ for file_name in file_names:
+ for params in load_vectors_from_file(
+ os.path.join(path, file_name),
+ param_loader
+ ):
+ yield (
+ aead_test,
+ backend,
+ cipher_factory,
+ mode_factory,
+ params,
+ only_if,
+ skip_message
+ )
+ return test_aead
+
+
+def aead_test(backend, cipher_factory, mode_factory, params, only_if,
+ skip_message):
+ if not only_if(backend):
+ pytest.skip(skip_message)
+ if params.get("pt") is not None:
+ plaintext = params.pop("pt")
+ ciphertext = params.pop("ct")
+ aad = params.pop("aad")
+ if params.get("fail") is True:
+ cipher = Cipher(
+ cipher_factory(binascii.unhexlify(params["key"])),
+ mode_factory(binascii.unhexlify(params["iv"]),
+ binascii.unhexlify(params["tag"])),
+ backend
+ )
+ decryptor = cipher.decryptor()
+ decryptor.add_data(binascii.unhexlify(aad))
+ actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext))
+ with pytest.raises(AssertionError):
+ decryptor.finalize()
+ else:
+ cipher = Cipher(
+ cipher_factory(binascii.unhexlify(params["key"])),
+ mode_factory(binascii.unhexlify(params["iv"]), None),
+ backend
+ )
+ encryptor = cipher.encryptor()
+ encryptor.add_data(binascii.unhexlify(aad))
+ actual_ciphertext = encryptor.update(binascii.unhexlify(plaintext))
+ actual_ciphertext += encryptor.finalize()
+ tag_len = len(params["tag"])
+ assert binascii.hexlify(encryptor.tag)[:tag_len] == params["tag"]
+ cipher = Cipher(
+ cipher_factory(binascii.unhexlify(params["key"])),
+ mode_factory(binascii.unhexlify(params["iv"]),
+ binascii.unhexlify(params["tag"])),
+ backend
+ )
+ decryptor = cipher.decryptor()
+ decryptor.add_data(binascii.unhexlify(aad))
+ actual_plaintext = decryptor.update(binascii.unhexlify(ciphertext))
+ actual_plaintext += decryptor.finalize()
+ assert actual_plaintext == binascii.unhexlify(plaintext)
+
+
def generate_stream_encryption_test(param_loader, path, file_names,
cipher_factory, only_if=None,
skip_message=None):
@@ -237,3 +305,36 @@ def base_hmac_test(backend, algorithm, only_if, skip_message):
h_copy = h.copy()
assert h != h_copy
assert h._ctx != h_copy._ctx
+
+
+def generate_aead_use_after_finalize_test(cipher_factory, mode_factory,
+ only_if, skip_message):
+ def test_aead_use_after_finalize(self):
+ for backend in _ALL_BACKENDS:
+ yield (
+ aead_use_after_finalize_test,
+ backend,
+ cipher_factory,
+ mode_factory,
+ only_if,
+ skip_message
+ )
+ return test_aead_use_after_finalize
+
+
+def aead_use_after_finalize_test(backend, cipher_factory, mode_factory,
+ only_if, skip_message):
+ if not only_if(backend):
+ pytest.skip(skip_message)
+ cipher = Cipher(
+ cipher_factory(binascii.unhexlify(b"0" * 32)),
+ mode_factory(binascii.unhexlify(b"0" * 24)),
+ backend
+ )
+ encryptor = cipher.encryptor()
+ encryptor.update(b"a" * 16)
+ with pytest.raises(NotFinalized):
+ encryptor.tag
+ encryptor.finalize()
+ with pytest.raises(AlreadyFinalized):
+ encryptor.add_data(b"b" * 16)