aboutsummaryrefslogtreecommitdiffstats
path: root/src/cryptography/hazmat/primitives/serialization/ssh.py
blob: a1d6c8c9fcc2690c4908c461321f17c11916ee36 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
# This file is dual licensed under the terms of the Apache License, Version
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
# for complete details.

from __future__ import absolute_import, division, print_function

import base64
import struct

import six

from cryptography import utils
from cryptography.exceptions import UnsupportedAlgorithm
from cryptography.hazmat.primitives.asymmetric import dsa, ec, ed25519, rsa


def load_ssh_public_key(data, backend):
    key_parts = data.split(b' ', 2)

    if len(key_parts) < 2:
        raise ValueError(
            'Key is not in the proper format or contains extra data.')

    key_type = key_parts[0]

    if key_type == b'ssh-rsa':
        loader = _load_ssh_rsa_public_key
    elif key_type == b'ssh-dss':
        loader = _load_ssh_dss_public_key
    elif key_type in [
        b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521',
    ]:
        loader = _load_ssh_ecdsa_public_key
    elif key_type == b'ssh-ed25519':
        loader = _load_ssh_ed25519_public_key
    else:
        raise UnsupportedAlgorithm('Key type is not supported.')

    key_body = key_parts[1]

    try:
        decoded_data = base64.b64decode(key_body)
    except TypeError:
        raise ValueError('Key is not in the proper format.')

    inner_key_type, rest = _ssh_read_next_string(decoded_data)

    if inner_key_type != key_type:
        raise ValueError(
            'Key header and key body contain different key type values.'
        )

    return loader(key_type, rest, backend)


def _load_ssh_rsa_public_key(key_type, decoded_data, backend):
    e, rest = _ssh_read_next_mpint(decoded_data)
    n, rest = _ssh_read_next_mpint(rest)

    if rest:
        raise ValueError('Key body contains extra bytes.')

    return rsa.RSAPublicNumbers(e, n).public_key(backend)


def _load_ssh_dss_public_key(key_type, decoded_data, backend):
    p, rest = _ssh_read_next_mpint(decoded_data)
    q, rest = _ssh_read_next_mpint(rest)
    g, rest = _ssh_read_next_mpint(rest)
    y, rest = _ssh_read_next_mpint(rest)

    if rest:
        raise ValueError('Key body contains extra bytes.')

    parameter_numbers = dsa.DSAParameterNumbers(p, q, g)
    public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers)

    return public_numbers.public_key(backend)


def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend):
    curve_name, rest = _ssh_read_next_string(decoded_data)
    data, rest = _ssh_read_next_string(rest)

    if expected_key_type != b"ecdsa-sha2-" + curve_name:
        raise ValueError(
            'Key header and key body contain different key type values.'
        )

    if rest:
        raise ValueError('Key body contains extra bytes.')

    curve = {
        b"nistp256": ec.SECP256R1,
        b"nistp384": ec.SECP384R1,
        b"nistp521": ec.SECP521R1,
    }[curve_name]()

    if six.indexbytes(data, 0) != 4:
        raise NotImplementedError(
            "Compressed elliptic curve points are not supported"
        )

    return ec.EllipticCurvePublicKey.from_encoded_point(curve, data)


def _load_ssh_ed25519_public_key(expected_key_type, decoded_data, backend):
    data, rest = _ssh_read_next_string(decoded_data)

    if rest:
        raise ValueError('Key body contains extra bytes.')

    return ed25519.Ed25519PublicKey.from_public_bytes(data)


def _ssh_read_next_string(data):
    """
    Retrieves the next RFC 4251 string value from the data.

    While the RFC calls these strings, in Python they are bytes objects.
    """
    if len(data) < 4:
        raise ValueError("Key is not in the proper format")

    str_len, = struct.unpack('>I', data[:4])
    if len(data) < str_len + 4:
        raise ValueError("Key is not in the proper format")

    return data[4:4 + str_len], data[4 + str_len:]


def _ssh_read_next_mpint(data):
    """
    Reads the next mpint from the data.

    Currently, all mpints are interpreted as unsigned.
    """
    mpint_data, rest = _ssh_read_next_string(data)

    return (
        utils.int_from_bytes(mpint_data, byteorder='big', signed=False), rest
    )


def _ssh_write_string(data):
    return struct.pack(">I", len(data)) + data


def _ssh_write_mpint(value):
    data = utils.int_to_bytes(value)
    if six.indexbytes(data, 0) & 0x80:
        data = b"\x00" + data
    return _ssh_write_string(data)