aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--src/cryptography/hazmat/primitives/keywrap.py53
1 files changed, 30 insertions, 23 deletions
diff --git a/src/cryptography/hazmat/primitives/keywrap.py b/src/cryptography/hazmat/primitives/keywrap.py
index 6e79ab6b..702a6932 100644
--- a/src/cryptography/hazmat/primitives/keywrap.py
+++ b/src/cryptography/hazmat/primitives/keywrap.py
@@ -12,20 +12,9 @@ from cryptography.hazmat.primitives.ciphers.modes import ECB
from cryptography.hazmat.primitives.constant_time import bytes_eq
-def aes_key_wrap(wrapping_key, key_to_wrap, backend):
- if len(wrapping_key) not in [16, 24, 32]:
- raise ValueError("The wrapping key must be a valid AES key length")
-
- if len(key_to_wrap) < 16:
- raise ValueError("The key to wrap must be at least 16 bytes")
-
- if len(key_to_wrap) % 8 != 0:
- raise ValueError("The key to wrap must be a multiple of 8 bytes")
-
+def _wrap_core(wrapping_key, a, r, backend):
# RFC 3394 Key Wrap - 2.2.1 (index method)
encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor()
- a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
- r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
n = len(r)
for j in range(6):
for i in range(n):
@@ -44,22 +33,24 @@ def aes_key_wrap(wrapping_key, key_to_wrap, backend):
return a + b"".join(r)
-def aes_key_unwrap(wrapping_key, wrapped_key, backend):
- if len(wrapped_key) < 24:
- raise ValueError("Must be at least 24 bytes")
-
- if len(wrapped_key) % 8 != 0:
- raise ValueError("The wrapped key must be a multiple of 8 bytes")
-
+def aes_key_wrap(wrapping_key, key_to_wrap, backend):
if len(wrapping_key) not in [16, 24, 32]:
raise ValueError("The wrapping key must be a valid AES key length")
+ if len(key_to_wrap) < 16:
+ raise ValueError("The key to wrap must be at least 16 bytes")
+
+ if len(key_to_wrap) % 8 != 0:
+ raise ValueError("The key to wrap must be a multiple of 8 bytes")
+
+ a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
+ r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
+ return _wrap_core(wrapping_key, a, r, backend)
+
+
+def _unwrap_core(wrapping_key, a, r, backend):
# Implement RFC 3394 Key Unwrap - 2.2.2 (index method)
decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor()
- aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
-
- r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
- a = r.pop(0)
n = len(r)
for j in reversed(range(6)):
for i in reversed(range(n)):
@@ -74,7 +65,23 @@ def aes_key_unwrap(wrapping_key, wrapped_key, backend):
r[i] = b[-8:]
assert decryptor.finalize() == b""
+ return a, r
+
+
+def aes_key_unwrap(wrapping_key, wrapped_key, backend):
+ if len(wrapped_key) < 24:
+ raise ValueError("Must be at least 24 bytes")
+
+ if len(wrapped_key) % 8 != 0:
+ raise ValueError("The wrapped key must be a multiple of 8 bytes")
+ if len(wrapping_key) not in [16, 24, 32]:
+ raise ValueError("The wrapping key must be a valid AES key length")
+
+ aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
+ r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
+ a = r.pop(0)
+ a, r = _unwrap_core(wrapping_key, a, r, backend)
if not bytes_eq(a, aiv):
raise InvalidUnwrap()