aboutsummaryrefslogtreecommitdiffstats
path: root/src/cryptography/hazmat/primitives/keywrap.py
blob: f55c519cff3305fde940e23848262823373ee608 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

from __future__ import absolute_import, division, print_function

import struct

from cryptography.hazmat.primitives.ciphers import Cipher
from cryptography.hazmat.primitives.ciphers.algorithms import AES
from cryptography.hazmat.primitives.ciphers.modes import ECB
from cryptography.hazmat.primitives.constant_time import bytes_eq


def _wrap_core(wrapping_key, a, r, backend):
    # RFC 3394 Key Wrap - 2.2.1 (index method)
    encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor()
    n = len(r)
    for j in range(6):
        for i in range(n):
            # every encryption operation is a discrete 16 byte chunk (because
            # AES has a 128-bit block size) and since we're using ECB it is
            # safe to reuse the encryptor for the entire operation
            b = encryptor.update(a + r[i])
            # pack/unpack are safe as these are always 64-bit chunks
            a = struct.pack(
                ">Q", struct.unpack(">Q", b[:8])[0] ^ ((n * j) + i + 1)
            )
            r[i] = b[-8:]

    assert encryptor.finalize() == b""

    return a + b"".join(r)


def aes_key_wrap(wrapping_key, key_to_wrap, backend):
    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")

    if len(key_to_wrap) < 16:
        raise ValueError("The key to wrap must be at least 16 bytes")

    if len(key_to_wrap) % 8 != 0:
        raise ValueError("The key to wrap must be a multiple of 8 bytes")

    a = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
    r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
    return _wrap_core(wrapping_key, a, r, backend)


def _unwrap_core(wrapping_key, a, r, backend):
    # Implement RFC 3394 Key Unwrap - 2.2.2 (index method)
    decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor()
    n = len(r)
    for j in reversed(range(6)):
        for i in reversed(range(n)):
            # pack/unpack are safe as these are always 64-bit chunks
            atr = struct.pack(
                ">Q", struct.unpack(">Q", a)[0] ^ ((n * j) + i + 1)
            ) + r[i]
            # every decryption operation is a discrete 16 byte chunk so
            # it is safe to reuse the decryptor for the entire operation
            b = decryptor.update(atr)
            a = b[:8]
            r[i] = b[-8:]

    assert decryptor.finalize() == b""
    return a, r


def aes_key_wrap_with_padding(wrapping_key, key_to_wrap, backend):
    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")

    aiv = b"\xA6\x59\x59\xA6" + struct.pack(">i", len(key_to_wrap))
    # pad the key to wrap if necessary
    pad = (8 - (len(key_to_wrap) % 8)) % 8
    key_to_wrap = key_to_wrap + b"\x00" * pad
    if len(key_to_wrap) == 8:
        # RFC 5649 - 4.1 - exactly 8 octets after padding
        encryptor = Cipher(AES(wrapping_key), ECB(), backend).encryptor()
        b = encryptor.update(aiv + key_to_wrap)
        assert encryptor.finalize() == b""
        return b
    else:
        r = [key_to_wrap[i:i + 8] for i in range(0, len(key_to_wrap), 8)]
        return _wrap_core(wrapping_key, aiv, r, backend)


def aes_key_unwrap_with_padding(wrapping_key, wrapped_key, backend):
    if len(wrapped_key) < 16:
        raise InvalidUnwrap("Must be at least 16 bytes")

    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")

    if len(wrapped_key) == 16:
        # RFC 5649 - 4.2 - exactly two 64-bit blocks
        decryptor = Cipher(AES(wrapping_key), ECB(), backend).decryptor()
        b = decryptor.update(wrapped_key)
        assert decryptor.finalize() == b""
        a = b[:8]
        data = b[8:]
        n = 1
    else:
        r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
        encrypted_aiv = r.pop(0)
        n = len(r)
        a, r = _unwrap_core(wrapping_key, encrypted_aiv, r, backend)
        data = b"".join(r)

    # 1) Check that MSB(32,A) = A65959A6.
    # 2) Check that 8*(n-1) < LSB(32,A) <= 8*n.  If so, let
    #    MLI = LSB(32,A).
    # 3) Let b = (8*n)-MLI, and then check that the rightmost b octets of
    #    the output data are zero.
    (mli,) = struct.unpack(">I", a[4:])
    b = (8 * n) - mli
    if (
        not bytes_eq(a[:4], b"\xa6\x59\x59\xa6") or not
        8 * (n - 1) < mli <= 8 * n or (
            b != 0 and not bytes_eq(data[-b:], b"\x00" * b)
        )
    ):
        raise InvalidUnwrap()

    if b == 0:
        return data
    else:
        return data[:-b]


def aes_key_unwrap(wrapping_key, wrapped_key, backend):
    if len(wrapped_key) < 24:
        raise InvalidUnwrap("Must be at least 24 bytes")

    if len(wrapped_key) % 8 != 0:
        raise InvalidUnwrap("The wrapped key must be a multiple of 8 bytes")

    if len(wrapping_key) not in [16, 24, 32]:
        raise ValueError("The wrapping key must be a valid AES key length")

    aiv = b"\xa6\xa6\xa6\xa6\xa6\xa6\xa6\xa6"
    r = [wrapped_key[i:i + 8] for i in range(0, len(wrapped_key), 8)]
    a = r.pop(0)
    a, r = _unwrap_core(wrapping_key, a, r, backend)
    if not bytes_eq(a, aiv):
        raise InvalidUnwrap()

    return b"".join(r)


class InvalidUnwrap(Exception):
    pass