aboutsummaryrefslogtreecommitdiffstats
path: root/src/cryptography/hazmat/primitives/serialization.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/cryptography/hazmat/primitives/serialization.py')
-rw-r--r--src/cryptography/hazmat/primitives/serialization.py62
1 files changed, 40 insertions, 22 deletions
diff --git a/src/cryptography/hazmat/primitives/serialization.py b/src/cryptography/hazmat/primitives/serialization.py
index 38c541cb..f20d9f56 100644
--- a/src/cryptography/hazmat/primitives/serialization.py
+++ b/src/cryptography/hazmat/primitives/serialization.py
@@ -46,42 +46,60 @@ def load_pem_public_key(data, backend):
def load_ssh_public_key(data, backend):
- if not data.startswith(b'ssh-'):
- raise ValueError('SSH-formatted keys must begin with ssh-')
+ key_parts = data.split(b' ')
- if not data.startswith(b'ssh-rsa'):
+ if len(key_parts) < 2 or len(key_parts) > 3:
+ raise ValueError(
+ 'Key is not in the proper format or contains extra data.')
+
+ key_type = key_parts[0]
+ key_body = key_parts[1]
+
+ if not key_type.startswith(b'ssh-'):
+ raise ValueError('SSH-formatted keys must begin with \'ssh-\'.')
+
+ if not key_type.startswith(b'ssh-rsa'):
raise UnsupportedAlgorithm('Only RSA keys are currently supported.')
- return _load_ssh_rsa_public_key(data, backend)
+ return _load_ssh_rsa_public_key(key_type, key_body, backend)
+
+def _load_ssh_rsa_public_key(key_type, key_body, backend):
+ assert key_type == b'ssh-rsa'
-def _load_ssh_rsa_public_key(data, backend):
- assert data.startswith(b'ssh-rsa ')
+ data = base64.b64decode(key_body)
- parts = data.split(b' ')
- data = base64.b64decode(parts[1])
+ key_body_type, rest = _read_next_string(data)
+ e, rest = _read_next_mpint(rest)
+ n, rest = _read_next_mpint(rest)
- cert_data = []
+ if key_type != key_body_type:
+ raise ValueError(
+ 'Key header and key body contain different key type values.')
- while len(data) > 0:
- str_len = struct.unpack('>I', data[0:4])[0]
- cert_data.append(data[4:4 + str_len])
- data = data[4 + str_len:]
+ if len(rest) != 0:
+ raise ValueError('Key body contains extra bytes.')
- e = _bytes_to_int(cert_data[1])
- n = _bytes_to_int(cert_data[2])
return backend.load_rsa_public_numbers(RSAPublicNumbers(e, n))
-def _bytes_to_int(data):
- if len(data) % 4 != 0:
+def _read_next_string(data):
+ """Retrieves the next RFC 4251 string value from the data."""
+ str_len = struct.unpack('>I', data[0:4])[0]
+ return data[4:4 + str_len], data[4 + str_len:]
+
+
+def _read_next_mpint(data):
+ mpint_data, rest = _read_next_string(data)
+
+ if len(mpint_data) % 4 != 0:
# Pad the bytes with 0x00 to a block size of 4
- data = (b'\x00' * (4 - (len(data) % 4))) + data
+ mpint_data = (b'\x00' * (4 - (len(mpint_data) % 4))) + mpint_data
result = 0
- while len(data) > 0:
- result = (result << 32) + struct.unpack('>I', data[0:4])[0]
- data = data[4:]
+ while len(mpint_data) > 0:
+ result = (result << 32) + struct.unpack('>I', mpint_data[0:4])[0]
+ mpint_data = mpint_data[4:]
- return result
+ return result, rest