From 363f5b099ba6e9e551d20e39242bd657bd0703f1 Mon Sep 17 00:00:00 2001 From: Paul Kehrer Date: Tue, 12 Sep 2017 08:03:17 +0800 Subject: refactor AES keywrap into a wrap core and unwrap core (#3901) * refactor AES keywrap into a wrap core and unwrap core This refactor makes adding AES keywrap with padding much simpler. * remove an unneeded arg --- src/cryptography/hazmat/primitives/keywrap.py | 53 +++++++++++++++------------ 1 file changed, 30 insertions(+), 23 deletions(-) (limited to 'src') diff --git a/src/cryptography/hazmat/primitives/keywrap.py b/src/cryptography/hazmat/primitives/keywrap.py index 6e79ab6b..702a6932 100644 --- a/src/cryptography/hazmat/primitives/keywrap.py +++ b/src/cryptography/hazmat/primitives/keywrap.py @@ -12,20 +12,9 @@ from cryptography.hazmat.primitives.ciphers.modes import ECB from cryptography.hazmat.primitives.constant_time import bytes_eq -def aes_key_wrap(wrapping_key, key_to_wrap, backend): - if len(wrapping_key) not in [16, 24, 32]: - raise ValueError("The wrapping key must be a valid AES key length") - - if len(key_to_wrap) < 16: - raise ValueError("The key to wrap must be at least 16 bytes") - - if len(key_to_wrap) % 8 != 0: - raise ValueError("The key to wrap must be a multiple of 8 bytes") - +def _wrap_core(wrapping_key, a, r, backend): # RFC 3394 Key Wrap - 2.2.1 (index method) encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor() - a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" - r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)] n = len(r) for j in range(6): for i in range(n): @@ -44,22 +33,24 @@ def aes_key_wrap(wrapping_key, key_to_wrap, backend): return a + b"".join(r) -def aes_key_unwrap(wrapping_key, wrapped_key, backend): - if len(wrapped_key) < 24: - raise ValueError("Must be at least 24 bytes") - - if len(wrapped_key) % 8 != 0: - raise ValueError("The wrapped key must be a multiple of 8 bytes") - +def aes_key_wrap(wrapping_key, key_to_wrap, backend): if len(wrapping_key) not in [16, 24, 32]: raise ValueError("The wrapping key must be a valid AES key length") + if len(key_to_wrap) < 16: + raise ValueError("The key to wrap must be at least 16 bytes") + + if len(key_to_wrap) % 8 != 0: + raise ValueError("The key to wrap must be a multiple of 8 bytes") + + a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" + r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)] + return _wrap_core(wrapping_key, a, r, backend) + + +def _unwrap_core(wrapping_key, a, r, backend): # Implement RFC 3394 Key Unwrap - 2.2.2 (index method) decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor() - aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" - - r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)] - a = r.pop(0) n = len(r) for j in reversed(range(6)): for i in reversed(range(n)): @@ -74,7 +65,23 @@ def aes_key_unwrap(wrapping_key, wrapped_key, backend): r[i] = b[-8:] assert decryptor.finalize() == b"" + return a, r + + +def aes_key_unwrap(wrapping_key, wrapped_key, backend): + if len(wrapped_key) < 24: + raise ValueError("Must be at least 24 bytes") + + if len(wrapped_key) % 8 != 0: + raise ValueError("The wrapped key must be a multiple of 8 bytes") + if len(wrapping_key) not in [16, 24, 32]: + raise ValueError("The wrapping key must be a valid AES key length") + + aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6" + r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)] + a = r.pop(0) + a, r = _unwrap_core(wrapping_key, a, r, backend) if not bytes_eq(a, aiv): raise InvalidUnwrap() -- cgit v1.2.3