aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMark Adams <mark@markadams.me>2014-12-14 00:16:03 -0600
committerMark Adams <mark@markadams.me>2014-12-14 00:30:39 -0600
commit4724d61be546f900298c7594d3bdb942b39a919f (patch)
treea19791b43dfaf648d310447605e90243cad59a70
parentdfa57bf7821a63c65ef0f83234c79f611fab46db (diff)
downloadcryptography-4724d61be546f900298c7594d3bdb942b39a919f.tar.gz
cryptography-4724d61be546f900298c7594d3bdb942b39a919f.tar.bz2
cryptography-4724d61be546f900298c7594d3bdb942b39a919f.zip
Added better parsing for RFC 4251 string and mpint values.
Also moved several of the SSH key splitting and validation checks up into the load_ssh_public_key method since they will apply to more than just RSA. Added additional checks to make sure the key doesn't contain extraneous data
-rw-r--r--src/cryptography/hazmat/primitives/serialization.py62
-rw-r--r--tests/hazmat/primitives/test_serialization.py45
2 files changed, 85 insertions, 22 deletions
diff --git a/src/cryptography/hazmat/primitives/serialization.py b/src/cryptography/hazmat/primitives/serialization.py
index 38c541cb..f20d9f56 100644
--- a/src/cryptography/hazmat/primitives/serialization.py
+++ b/src/cryptography/hazmat/primitives/serialization.py
@@ -46,42 +46,60 @@ def load_pem_public_key(data, backend):
def load_ssh_public_key(data, backend):
- if not data.startswith(b'ssh-'):
- raise ValueError('SSH-formatted keys must begin with ssh-')
+ key_parts = data.split(b' ')
- if not data.startswith(b'ssh-rsa'):
+ if len(key_parts) < 2 or len(key_parts) > 3:
+ raise ValueError(
+ 'Key is not in the proper format or contains extra data.')
+
+ key_type = key_parts[0]
+ key_body = key_parts[1]
+
+ if not key_type.startswith(b'ssh-'):
+ raise ValueError('SSH-formatted keys must begin with \'ssh-\'.')
+
+ if not key_type.startswith(b'ssh-rsa'):
raise UnsupportedAlgorithm('Only RSA keys are currently supported.')
- return _load_ssh_rsa_public_key(data, backend)
+ return _load_ssh_rsa_public_key(key_type, key_body, backend)
+
+def _load_ssh_rsa_public_key(key_type, key_body, backend):
+ assert key_type == b'ssh-rsa'
-def _load_ssh_rsa_public_key(data, backend):
- assert data.startswith(b'ssh-rsa ')
+ data = base64.b64decode(key_body)
- parts = data.split(b' ')
- data = base64.b64decode(parts[1])
+ key_body_type, rest = _read_next_string(data)
+ e, rest = _read_next_mpint(rest)
+ n, rest = _read_next_mpint(rest)
- cert_data = []
+ if key_type != key_body_type:
+ raise ValueError(
+ 'Key header and key body contain different key type values.')
- while len(data) > 0:
- str_len = struct.unpack('>I', data[0:4])[0]
- cert_data.append(data[4:4 + str_len])
- data = data[4 + str_len:]
+ if len(rest) != 0:
+ raise ValueError('Key body contains extra bytes.')
- e = _bytes_to_int(cert_data[1])
- n = _bytes_to_int(cert_data[2])
return backend.load_rsa_public_numbers(RSAPublicNumbers(e, n))
-def _bytes_to_int(data):
- if len(data) % 4 != 0:
+def _read_next_string(data):
+ """Retrieves the next RFC 4251 string value from the data."""
+ str_len = struct.unpack('>I', data[0:4])[0]
+ return data[4:4 + str_len], data[4 + str_len:]
+
+
+def _read_next_mpint(data):
+ mpint_data, rest = _read_next_string(data)
+
+ if len(mpint_data) % 4 != 0:
# Pad the bytes with 0x00 to a block size of 4
- data = (b'\x00' * (4 - (len(data) % 4))) + data
+ mpint_data = (b'\x00' * (4 - (len(mpint_data) % 4))) + mpint_data
result = 0
- while len(data) > 0:
- result = (result << 32) + struct.unpack('>I', data[0:4])[0]
- data = data[4:]
+ while len(mpint_data) > 0:
+ result = (result << 32) + struct.unpack('>I', mpint_data[0:4])[0]
+ mpint_data = mpint_data[4:]
- return result
+ return result, rest
diff --git a/tests/hazmat/primitives/test_serialization.py b/tests/hazmat/primitives/test_serialization.py
index ffe3d7df..9180b9aa 100644
--- a/tests/hazmat/primitives/test_serialization.py
+++ b/tests/hazmat/primitives/test_serialization.py
@@ -697,6 +697,51 @@ class TestSSHSerialization(object):
with pytest.raises(ValueError):
load_ssh_public_key(ssh_key, backend)
+ def test_load_ssh_public_key_rsa_too_short(self, backend):
+ ssh_key = b'ssh-rsa'
+
+ with pytest.raises(ValueError):
+ load_ssh_public_key(ssh_key, backend)
+
+ def test_load_ssh_public_key_rsa_key_types_dont_match(self, backend):
+ ssh_key = textwrap.dedent("""\
+ ssh-bad AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk
+ FiiuJLW6RU0j3KKVZSs1I7Z3UmU9/9aVh/rZV43WQG8jaR6kkcP4stOR0DEtll
+ PDA7ZRBnrfiHpSQYQ874AZaAoIjgkv7DBfsE6gcDQLub0PFjWyrYQUJhtOLQEK
+ vY/G0vt2iRL3juawWmCFdTK3W3XvwAdgGk71i6lHt+deOPNEPN2H58E4odrZ2f
+ sxn/adpDqfb2sM0kPwQs0aWvrrKGvUaustkivQE4XWiSFnB0oJB/lKK/CKVKuy
+ ///ImSCGHQRvhwariN2tvZ6CBNSLh3iQgeB0AkyJlng7MXB2qYq/Ci2FUOryCX
+ 2MzHvnbv testkey@localhost extra""").encode() # ssh-bad
+
+ with pytest.raises(ValueError):
+ load_ssh_public_key(ssh_key, backend)
+
+ def test_load_ssh_public_key_rsa_extra_string_after_comment(self, backend):
+ ssh_key = textwrap.dedent("""\
+ ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk
+ FiiuJLW6RU0j3KKVZSs1I7Z3UmU9/9aVh/rZV43WQG8jaR6kkcP4stOR0DEtll
+ PDA7ZRBnrfiHpSQYQ874AZaAoIjgkv7DBfsE6gcDQLub0PFjWyrYQUJhtOLQEK
+ vY/G0vt2iRL3juawWmCFdTK3W3XvwAdgGk71i6lHt+deOPNEPN2H58E4odrZ2f
+ sxn/adpDqfb2sM0kPwQs0aWvrrKGvUaustkivQE4XWiSFnB0oJB/lKK/CKVKuy
+ ///ImSCGHQRvhwariN2tvZ6CBNSLh3iQgeB0AkyJlng7MXB2qYq/Ci2FUOryCX
+ 2MzHvnbv testkey@localhost extra""").encode() # Extra appended
+
+ with pytest.raises(ValueError):
+ load_ssh_public_key(ssh_key, backend)
+
+ def test_load_ssh_public_key_rsa_extra_data_after_modulo(self, backend):
+ ssh_key = textwrap.dedent("""\
+ ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk
+ FiiuJLW6RU0j3KKVZSs1I7Z3UmU9/9aVh/rZV43WQG8jaR6kkcP4stOR0DEtll
+ PDA7ZRBnrfiHpSQYQ874AZaAoIjgkv7DBfsE6gcDQLub0PFjWyrYQUJhtOLQEK
+ vY/G0vt2iRL3juawWmCFdTK3W3XvwAdgGk71i6lHt+deOPNEPN2H58E4odrZ2f
+ sxn/adpDqfb2sM0kPwQs0aWvrrKGvUaustkivQE4XWiSFnB0oJB/lKK/CKVKuy
+ ///ImSCGHQRvhwariN2tvZ6CBNSLh3iQgeB0AkyJlng7MXB2qYq/Ci2FUOryCX
+ 2MzHvnbvAQ== testkey@localhost""").encode() # Extra 0x01 appended
+
+ with pytest.raises(ValueError):
+ load_ssh_public_key(ssh_key, backend)
+
def test_load_ssh_public_key_rsa(self, backend):
ssh_key = textwrap.dedent("""\
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDDu/XRP1kyK6Cgt36gts9XAk