diff options
Diffstat (limited to 'tests/hazmat/primitives/utils.py')
-rw-r--r-- | tests/hazmat/primitives/utils.py | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index e148bc63..e45466d8 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -18,6 +18,9 @@ from cryptography.hazmat.primitives import hashes, hmac from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.kdf.hkdf import HKDF, HKDFExpand +from cryptography.hazmat.primitives.kdf.kbkdf import ( + CounterLocation, KBKDFHMAC, Mode +) from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC from ...utils import load_vectors_from_file @@ -370,6 +373,55 @@ def generate_hkdf_test(param_loader, path, file_names, algorithm): return test_hkdf +def generate_kbkdf_counter_mode_test(param_loader, path, file_names): + all_params = _load_all_params(path, file_names, param_loader) + + @pytest.mark.parametrize("params", all_params) + def test_kbkdf(self, backend, params): + kbkdf_counter_mode_test(backend, params) + return test_kbkdf + + +def kbkdf_counter_mode_test(backend, params): + supported_algorithms = { + 'hmac_sha1': hashes.SHA1, + 'hmac_sha224': hashes.SHA224, + 'hmac_sha256': hashes.SHA256, + 'hmac_sha384': hashes.SHA384, + 'hmac_sha512': hashes.SHA512, + } + + supportd_counter_locations = { + "before_fixed": CounterLocation.BeforeFixed, + "after_fixed": CounterLocation.AfterFixed, + } + + algorithm = supported_algorithms.get(params.get('prf')) + if algorithm is None or not backend.hmac_supported(algorithm()): + pytest.skip('Does not support algorithm') + + ctr_loc = supportd_counter_locations.get(params.get("ctrlocation")) + if ctr_loc is None or not isinstance(ctr_loc, CounterLocation): + pytest.skip("Does not support counter location".format( + location=params.get('ctrlocation') + )) + + ctrkdf = KBKDFHMAC( + algorithm(), + Mode.CounterMode, + params['l'] // 8, + params['rlen'] // 8, + None, + ctr_loc, + None, + None, + binascii.unhexlify(params['fixedinputdata']), + backend=backend) + + ko = ctrkdf.derive(binascii.unhexlify(params['ki'])) + assert binascii.hexlify(ko) == params["ko"] + + def generate_rsa_verification_test(param_loader, path, file_names, hash_alg, pad_factory): all_params = _load_all_params(path, file_names, param_loader) |