diff options
Diffstat (limited to 'tests/hazmat/primitives/utils.py')
| -rw-r--r-- | tests/hazmat/primitives/utils.py | 50 | 
1 files changed, 41 insertions, 9 deletions
diff --git a/tests/hazmat/primitives/utils.py b/tests/hazmat/primitives/utils.py index e546fa79..963838eb 100644 --- a/tests/hazmat/primitives/utils.py +++ b/tests/hazmat/primitives/utils.py @@ -1,12 +1,16 @@  import binascii  import os +import itertools +  import pytest  from cryptography.hazmat.primitives import hashes, hmac  from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC  from cryptography.hazmat.primitives.ciphers import Cipher -from cryptography.hazmat.primitives.kdf.hkdf import hkdf_derive +from cryptography.hazmat.primitives.kdf.hkdf import ( +    hkdf_derive, hkdf_extract, hkdf_expand +)  from cryptography.exceptions import (      AlreadyFinalized, NotYetFinalized, AlreadyUpdated, InvalidTag, @@ -301,12 +305,8 @@ def aead_tag_exception_test(backend, cipher_factory, mode_factory):          cipher.encryptor() -def hkdf_test(backend, algorithm, params): -    ikm = params[0] -    salt = params[1] -    info = params[2] -    length = params[3] -    expected_okm = params[4] +def hkdf_derive_test(backend, algorithm, params): +    ikm, salt, info, length, prk, expected_okm = params      okm = hkdf_derive(          binascii.unhexlify(ikm), @@ -320,11 +320,43 @@ def hkdf_test(backend, algorithm, params):      assert binascii.hexlify(okm) == expected_okm +def hkdf_extract_test(backend, algorithm, params): +    ikm, salt, info, length, expected_prk, okm = params + +    prk = hkdf_extract( +        algorithm, +        binascii.unhexlify(ikm), +        binascii.unhexlify(salt), +        backend=backend +    ) + +    assert prk == binascii.unhexlify(expected_prk) + + +def hkdf_expand_test(backend, algorithm, params): +    ikm, salt, info, length, prk, expected_okm = params + +    okm = hkdf_expand( +        algorithm, +        binascii.unhexlify(prk), +        binascii.unhexlify(info), +        length, +        backend=backend +    ) + +    assert okm == binascii.unhexlify(expected_okm) + +  def generate_hkdf_test(param_loader, path, file_names, algorithm):      all_params = _load_all_params(path, file_names, param_loader) -    @pytest.mark.parametrize("params", all_params) -    def test_hkdf(self, backend, params): +    all_tests = [hkdf_extract_test, hkdf_expand_test, hkdf_derive_test] + +    @pytest.mark.parametrize( +        ("params", "hkdf_test"), +        itertools.product(all_params, all_tests) +    ) +    def test_hkdf(self, backend, params, hkdf_test):          hkdf_test(backend, algorithm, params)      return test_hkdf  | 
