diff options
Diffstat (limited to 'src')
4 files changed, 158 insertions, 148 deletions
| diff --git a/src/cryptography/hazmat/backends/openssl/backend.py b/src/cryptography/hazmat/backends/openssl/backend.py index 5d73a7e8..44c2e3cd 100644 --- a/src/cryptography/hazmat/backends/openssl/backend.py +++ b/src/cryptography/hazmat/backends/openssl/backend.py @@ -75,9 +75,7 @@ from cryptography.hazmat.primitives.ciphers.modes import (      CBC, CFB, CFB8, CTR, ECB, GCM, OFB, XTS  )  from cryptography.hazmat.primitives.kdf import scrypt -from cryptography.hazmat.primitives.serialization.base import ( -    _ssh_write_mpint, _ssh_write_string -) +from cryptography.hazmat.primitives.serialization import ssh  from cryptography.x509 import ocsp @@ -1798,19 +1796,19 @@ class Backend(object):          if isinstance(key, rsa.RSAPublicKey):              public_numbers = key.public_numbers()              return b"ssh-rsa " + base64.b64encode( -                _ssh_write_string(b"ssh-rsa") + -                _ssh_write_mpint(public_numbers.e) + -                _ssh_write_mpint(public_numbers.n) +                ssh._ssh_write_string(b"ssh-rsa") + +                ssh._ssh_write_mpint(public_numbers.e) + +                ssh._ssh_write_mpint(public_numbers.n)              )          elif isinstance(key, dsa.DSAPublicKey):              public_numbers = key.public_numbers()              parameter_numbers = public_numbers.parameter_numbers              return b"ssh-dss " + base64.b64encode( -                _ssh_write_string(b"ssh-dss") + -                _ssh_write_mpint(parameter_numbers.p) + -                _ssh_write_mpint(parameter_numbers.q) + -                _ssh_write_mpint(parameter_numbers.g) + -                _ssh_write_mpint(public_numbers.y) +                ssh._ssh_write_string(b"ssh-dss") + +                ssh._ssh_write_mpint(parameter_numbers.p) + +                ssh._ssh_write_mpint(parameter_numbers.q) + +                ssh._ssh_write_mpint(parameter_numbers.g) + +                ssh._ssh_write_mpint(public_numbers.y)              )          else:              assert isinstance(key, ec.EllipticCurvePublicKey) @@ -1827,9 +1825,9 @@ class Backend(object):                      "supported by the SSH public key format"                  )              return b"ecdsa-sha2-" + curve_name + b" " + base64.b64encode( -                _ssh_write_string(b"ecdsa-sha2-" + curve_name) + -                _ssh_write_string(curve_name) + -                _ssh_write_string(public_numbers.encode_point()) +                ssh._ssh_write_string(b"ecdsa-sha2-" + curve_name) + +                ssh._ssh_write_string(curve_name) + +                ssh._ssh_write_string(public_numbers.encode_point())              )      def _parameter_bytes(self, encoding, format, cdata): diff --git a/src/cryptography/hazmat/primitives/serialization/__init__.py b/src/cryptography/hazmat/primitives/serialization/__init__.py index cff775b8..3a34bf8f 100644 --- a/src/cryptography/hazmat/primitives/serialization/__init__.py +++ b/src/cryptography/hazmat/primitives/serialization/__init__.py @@ -9,7 +9,9 @@ from cryptography.hazmat.primitives.serialization.base import (      NoEncryption, ParameterFormat, PrivateFormat, PublicFormat,      load_der_parameters, load_der_private_key, load_der_public_key,      load_pem_parameters, load_pem_private_key, load_pem_public_key, -    load_ssh_public_key, +) +from cryptography.hazmat.primitives.serialization.ssh import ( +    load_ssh_public_key  )  __all__ = [ diff --git a/src/cryptography/hazmat/primitives/serialization/base.py b/src/cryptography/hazmat/primitives/serialization/base.py index bd09e6e3..5dd0c639 100644 --- a/src/cryptography/hazmat/primitives/serialization/base.py +++ b/src/cryptography/hazmat/primitives/serialization/base.py @@ -5,15 +5,11 @@  from __future__ import absolute_import, division, print_function  import abc -import base64 -import struct  from enum import Enum  import six  from cryptography import utils -from cryptography.exceptions import UnsupportedAlgorithm -from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa  def load_pem_private_key(data, password, backend): @@ -40,135 +36,6 @@ def load_der_parameters(data, backend):      return backend.load_der_parameters(data) -def load_ssh_public_key(data, backend): -    key_parts = data.split(b' ', 2) - -    if len(key_parts) < 2: -        raise ValueError( -            'Key is not in the proper format or contains extra data.') - -    key_type = key_parts[0] - -    if key_type == b'ssh-rsa': -        loader = _load_ssh_rsa_public_key -    elif key_type == b'ssh-dss': -        loader = _load_ssh_dss_public_key -    elif key_type in [ -        b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521', -    ]: -        loader = _load_ssh_ecdsa_public_key -    else: -        raise UnsupportedAlgorithm('Key type is not supported.') - -    key_body = key_parts[1] - -    try: -        decoded_data = base64.b64decode(key_body) -    except TypeError: -        raise ValueError('Key is not in the proper format.') - -    inner_key_type, rest = _ssh_read_next_string(decoded_data) - -    if inner_key_type != key_type: -        raise ValueError( -            'Key header and key body contain different key type values.' -        ) - -    return loader(key_type, rest, backend) - - -def _load_ssh_rsa_public_key(key_type, decoded_data, backend): -    e, rest = _ssh_read_next_mpint(decoded_data) -    n, rest = _ssh_read_next_mpint(rest) - -    if rest: -        raise ValueError('Key body contains extra bytes.') - -    return rsa.RSAPublicNumbers(e, n).public_key(backend) - - -def _load_ssh_dss_public_key(key_type, decoded_data, backend): -    p, rest = _ssh_read_next_mpint(decoded_data) -    q, rest = _ssh_read_next_mpint(rest) -    g, rest = _ssh_read_next_mpint(rest) -    y, rest = _ssh_read_next_mpint(rest) - -    if rest: -        raise ValueError('Key body contains extra bytes.') - -    parameter_numbers = dsa.DSAParameterNumbers(p, q, g) -    public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) - -    return public_numbers.public_key(backend) - - -def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend): -    curve_name, rest = _ssh_read_next_string(decoded_data) -    data, rest = _ssh_read_next_string(rest) - -    if expected_key_type != b"ecdsa-sha2-" + curve_name: -        raise ValueError( -            'Key header and key body contain different key type values.' -        ) - -    if rest: -        raise ValueError('Key body contains extra bytes.') - -    curve = { -        b"nistp256": ec.SECP256R1, -        b"nistp384": ec.SECP384R1, -        b"nistp521": ec.SECP521R1, -    }[curve_name]() - -    if six.indexbytes(data, 0) != 4: -        raise NotImplementedError( -            "Compressed elliptic curve points are not supported" -        ) - -    numbers = ec.EllipticCurvePublicNumbers.from_encoded_point(curve, data) -    return numbers.public_key(backend) - - -def _ssh_read_next_string(data): -    """ -    Retrieves the next RFC 4251 string value from the data. - -    While the RFC calls these strings, in Python they are bytes objects. -    """ -    if len(data) < 4: -        raise ValueError("Key is not in the proper format") - -    str_len, = struct.unpack('>I', data[:4]) -    if len(data) < str_len + 4: -        raise ValueError("Key is not in the proper format") - -    return data[4:4 + str_len], data[4 + str_len:] - - -def _ssh_read_next_mpint(data): -    """ -    Reads the next mpint from the data. - -    Currently, all mpints are interpreted as unsigned. -    """ -    mpint_data, rest = _ssh_read_next_string(data) - -    return ( -        utils.int_from_bytes(mpint_data, byteorder='big', signed=False), rest -    ) - - -def _ssh_write_string(data): -    return struct.pack(">I", len(data)) + data - - -def _ssh_write_mpint(value): -    data = utils.int_to_bytes(value) -    if six.indexbytes(data, 0) & 0x80: -        data = b"\x00" + data -    return _ssh_write_string(data) - -  class Encoding(Enum):      PEM = "PEM"      DER = "DER" diff --git a/src/cryptography/hazmat/primitives/serialization/ssh.py b/src/cryptography/hazmat/primitives/serialization/ssh.py new file mode 100644 index 00000000..f58ff074 --- /dev/null +++ b/src/cryptography/hazmat/primitives/serialization/ssh.py @@ -0,0 +1,143 @@ +# This file is dual licensed under the terms of the Apache License, Version +# 2.0, and the BSD License. See the LICENSE file in the root of this repository +# for complete details. + +from __future__ import absolute_import, division, print_function + +import base64 +import struct + +import six + +from cryptography import utils +from cryptography.exceptions import UnsupportedAlgorithm +from cryptography.hazmat.primitives.asymmetric import dsa, ec, rsa + + +def load_ssh_public_key(data, backend): +    key_parts = data.split(b' ', 2) + +    if len(key_parts) < 2: +        raise ValueError( +            'Key is not in the proper format or contains extra data.') + +    key_type = key_parts[0] + +    if key_type == b'ssh-rsa': +        loader = _load_ssh_rsa_public_key +    elif key_type == b'ssh-dss': +        loader = _load_ssh_dss_public_key +    elif key_type in [ +        b'ecdsa-sha2-nistp256', b'ecdsa-sha2-nistp384', b'ecdsa-sha2-nistp521', +    ]: +        loader = _load_ssh_ecdsa_public_key +    else: +        raise UnsupportedAlgorithm('Key type is not supported.') + +    key_body = key_parts[1] + +    try: +        decoded_data = base64.b64decode(key_body) +    except TypeError: +        raise ValueError('Key is not in the proper format.') + +    inner_key_type, rest = _ssh_read_next_string(decoded_data) + +    if inner_key_type != key_type: +        raise ValueError( +            'Key header and key body contain different key type values.' +        ) + +    return loader(key_type, rest, backend) + + +def _load_ssh_rsa_public_key(key_type, decoded_data, backend): +    e, rest = _ssh_read_next_mpint(decoded_data) +    n, rest = _ssh_read_next_mpint(rest) + +    if rest: +        raise ValueError('Key body contains extra bytes.') + +    return rsa.RSAPublicNumbers(e, n).public_key(backend) + + +def _load_ssh_dss_public_key(key_type, decoded_data, backend): +    p, rest = _ssh_read_next_mpint(decoded_data) +    q, rest = _ssh_read_next_mpint(rest) +    g, rest = _ssh_read_next_mpint(rest) +    y, rest = _ssh_read_next_mpint(rest) + +    if rest: +        raise ValueError('Key body contains extra bytes.') + +    parameter_numbers = dsa.DSAParameterNumbers(p, q, g) +    public_numbers = dsa.DSAPublicNumbers(y, parameter_numbers) + +    return public_numbers.public_key(backend) + + +def _load_ssh_ecdsa_public_key(expected_key_type, decoded_data, backend): +    curve_name, rest = _ssh_read_next_string(decoded_data) +    data, rest = _ssh_read_next_string(rest) + +    if expected_key_type != b"ecdsa-sha2-" + curve_name: +        raise ValueError( +            'Key header and key body contain different key type values.' +        ) + +    if rest: +        raise ValueError('Key body contains extra bytes.') + +    curve = { +        b"nistp256": ec.SECP256R1, +        b"nistp384": ec.SECP384R1, +        b"nistp521": ec.SECP521R1, +    }[curve_name]() + +    if six.indexbytes(data, 0) != 4: +        raise NotImplementedError( +            "Compressed elliptic curve points are not supported" +        ) + +    numbers = ec.EllipticCurvePublicNumbers.from_encoded_point(curve, data) +    return numbers.public_key(backend) + + +def _ssh_read_next_string(data): +    """ +    Retrieves the next RFC 4251 string value from the data. + +    While the RFC calls these strings, in Python they are bytes objects. +    """ +    if len(data) < 4: +        raise ValueError("Key is not in the proper format") + +    str_len, = struct.unpack('>I', data[:4]) +    if len(data) < str_len + 4: +        raise ValueError("Key is not in the proper format") + +    return data[4:4 + str_len], data[4 + str_len:] + + +def _ssh_read_next_mpint(data): +    """ +    Reads the next mpint from the data. + +    Currently, all mpints are interpreted as unsigned. +    """ +    mpint_data, rest = _ssh_read_next_string(data) + +    return ( +        utils.int_from_bytes(mpint_data, byteorder='big', signed=False), rest +    ) + + +def _ssh_write_string(data): +    return struct.pack(">I", len(data)) + data + + +def _ssh_write_mpint(value): +    data = utils.int_to_bytes(value) +    if six.indexbytes(data, 0) & 0x80: +        data = b"\x00" + data +    return _ssh_write_string(data) | 
