From 4724d61be546f900298c7594d3bdb942b39a919f Mon Sep 17 00:00:00 2001 From: Mark Adams Date: Sun, 14 Dec 2014 00:16:03 -0600 Subject: Added better parsing for RFC 4251 string and mpint values. Also moved several of the SSH key splitting and validation checks up into the load_ssh_public_key method since they will apply to more than just RSA. Added additional checks to make sure the key doesn't contain extraneous data --- .../hazmat/primitives/serialization.py | 62 ++++++++++++++-------- tests/hazmat/primitives/test_serialization.py | 45 ++++++++++++++++ 2 files changed, 85 insertions(+), 22 deletions(-) diff --git a/src/cryptography/hazmat/primitives/serialization.py b/src/cryptography/hazmat/primitives/serialization.py index 38c541cb..f20d9f56 100644 --- a/src/cryptography/hazmat/primitives/serialization.py +++ b/src/cryptography/hazmat/primitives/serialization.py @@ -46,42 +46,60 @@ def load_pem_public_key(data, backend): def load_ssh_public_key(data, backend): - if not data.startswith(b'ssh-'): - raise ValueError('SSH-formatted keys must begin with ssh-') + key_parts = data.split(b' ') - if not data.startswith(b'ssh-rsa'): + if len(key_parts) < 2 or len(key_parts) > 3: + raise ValueError( + 'Key is not in the proper format or contains extra data.') + + key_type = key_parts[0] + key_body = key_parts[1] + + if not key_type.startswith(b'ssh-'): + raise ValueError('SSH-formatted keys must begin with \'ssh-\'.') + + if not key_type.startswith(b'ssh-rsa'): raise UnsupportedAlgorithm('Only RSA keys are currently supported.') - return _load_ssh_rsa_public_key(data, backend) + return _load_ssh_rsa_public_key(key_type, key_body, backend) + +def _load_ssh_rsa_public_key(key_type, key_body, backend): + assert key_type == b'ssh-rsa' -def _load_ssh_rsa_public_key(data, backend): - assert data.startswith(b'ssh-rsa ') + data = base64.b64decode(key_body) - parts = data.split(b' ') - data = base64.b64decode(parts[1]) + key_body_type, rest = _read_next_string(data) + e, rest = _read_next_mpint(rest) + n, rest = _read_next_mpint(rest) - cert_data = [] + if key_type != key_body_type: + raise ValueError( + 'Key header and key body contain different key type values.') - while len(data) > 0: - str_len = struct.unpack('>I', data[0:4])[0] - cert_data.append(data[4:4 + str_len]) - data = data[4 + str_len:] + if len(rest) != 0: + raise ValueError('Key body contains extra bytes.') - e = _bytes_to_int(cert_data[1]) - n = _bytes_to_int(cert_data[2]) return backend.load_rsa_public_numbers(RSAPublicNumbers(e, n)) -def _bytes_to_int(data): - if len(data) % 4 != 0: +def _read_next_string(data): + """Retrieves the next RFC 4251 string value from the data.""" + str_len = struct.unpack('>I', data[0:4])[0] + return data[4:4 + str_len], data[4 + str_len:] + + +def _read_next_mpint(data): + mpint_data, rest = _read_next_string(data) + + if len(mpint_data) % 4 != 0: # Pad the bytes with 0x00 to a block size of 4 - data = (b'\x00' * (4 - (len(data) % 4))) + data + mpint_data = (b'\x00' * (4 - (len(mpint_data) % 4))) + mpint_data result = 0 - while len(data) > 0: - result = (result << 32) + struct.unpack('>I', data[0:4])[0] - data = data[4:] + while len(mpint_data) > 0: + result = (result << 32) + struct.unpack('>I', mpint_data[0:4])[0] + mpint_data = mpint_data[4:] - return result + return result, rest diff --git a/tests/hazmat/primitives/test_serialization.py b/tests/hazmat/primitives/test_serialization.py index ffe3d7df..9180b9aa 100644 --- a/tests/hazmat/primitives/test_serialization.py +++ b/tests/hazmat/primitives/test_serialization.py @@ -697,6 +697,51 @@ class TestSSHSerialization(object): with pytest.raises(ValueError): load_ssh_public_key(ssh_key, backend) + def test_load_ssh_public_key_rsa_too_short(self, backend): + ssh_key = b'ssh-rsa' + + with pytest.raises(ValueError): + load_ssh_public_key(ssh_key, backend) + + def test_load_ssh_public_key_rsa_key_types_dont_match(self, backend): + ssh_key = textwrap.dedent("""\ + ssh-bad AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk + FiiuJLW6RU0j3KKVZSs1I7Z3UmU9/9aVh/rZV43WQG8jaR6kkcP4stOR0DEtll + PDA7ZRBnrfiHpSQYQ874AZaAoIjgkv7DBfsE6gcDQLub0PFjWyrYQUJhtOLQEK + vY/G0vt2iRL3juawWmCFdTK3W3XvwAdgGk71i6lHt+deOPNEPN2H58E4odrZ2f + sxn/adpDqfb2sM0kPwQs0aWvrrKGvUaustkivQE4XWiSFnB0oJB/lKK/CKVKuy + ///ImSCGHQRvhwariN2tvZ6CBNSLh3iQgeB0AkyJlng7MXB2qYq/Ci2FUOryCX + 2MzHvnbv testkey@localhost extra""").encode() # ssh-bad + + with pytest.raises(ValueError): + load_ssh_public_key(ssh_key, backend) + + def test_load_ssh_public_key_rsa_extra_string_after_comment(self, backend): + ssh_key = textwrap.dedent("""\ + ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk + FiiuJLW6RU0j3KKVZSs1I7Z3UmU9/9aVh/rZV43WQG8jaR6kkcP4stOR0DEtll + PDA7ZRBnrfiHpSQYQ874AZaAoIjgkv7DBfsE6gcDQLub0PFjWyrYQUJhtOLQEK + vY/G0vt2iRL3juawWmCFdTK3W3XvwAdgGk71i6lHt+deOPNEPN2H58E4odrZ2f + sxn/adpDqfb2sM0kPwQs0aWvrrKGvUaustkivQE4XWiSFnB0oJB/lKK/CKVKuy + ///ImSCGHQRvhwariN2tvZ6CBNSLh3iQgeB0AkyJlng7MXB2qYq/Ci2FUOryCX + 2MzHvnbv testkey@localhost extra""").encode() # Extra appended + + with pytest.raises(ValueError): + load_ssh_public_key(ssh_key, backend) + + def test_load_ssh_public_key_rsa_extra_data_after_modulo(self, backend): + ssh_key = textwrap.dedent("""\ + ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk + FiiuJLW6RU0j3KKVZSs1I7Z3UmU9/9aVh/rZV43WQG8jaR6kkcP4stOR0DEtll + PDA7ZRBnrfiHpSQYQ874AZaAoIjgkv7DBfsE6gcDQLub0PFjWyrYQUJhtOLQEK + vY/G0vt2iRL3juawWmCFdTK3W3XvwAdgGk71i6lHt+deOPNEPN2H58E4odrZ2f + sxn/adpDqfb2sM0kPwQs0aWvrrKGvUaustkivQE4XWiSFnB0oJB/lKK/CKVKuy + ///ImSCGHQRvhwariN2tvZ6CBNSLh3iQgeB0AkyJlng7MXB2qYq/Ci2FUOryCX + 2MzHvnbvAQ== testkey@localhost""").encode() # Extra 0x01 appended + + with pytest.raises(ValueError): + load_ssh_public_key(ssh_key, backend) + def test_load_ssh_public_key_rsa(self, backend): ssh_key = textwrap.dedent("""\ ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk -- cgit v1.2.3