diff options
Diffstat (limited to 'netlib/websockets')
-rw-r--r-- | netlib/websockets/implementations.py | 19 | ||||
-rw-r--r-- | netlib/websockets/websockets.py | 51 |
2 files changed, 49 insertions, 21 deletions
diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index ff42ff65..73a84690 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -32,7 +32,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.server_process_handshake(client_hs) + key = ws.process_handshake_from_client(client_hs) response = ws.create_server_handshake(key) self.wfile.write(response) self.wfile.flush() @@ -46,9 +46,9 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.key = ws.generate_client_nounce() - self.resource = "/" + self.version = "13" + self.client_nounce = ws.create_client_nounce() + self.resource = "/" def connect(self): super(WebSocketsClient, self).connect() @@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient): handshake = ws.create_client_handshake( self.address.host, self.address.port, - self.key, + self.client_nounce, self.version, self.resource ) @@ -64,9 +64,14 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(handshake) self.wfile.flush() - response = ws.read_handshake(self.rfile.read, 1) + server_handshake = ws.read_handshake(self.rfile.read, 1) - if not response: + if not server_handshake: + self.close() + + server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) + + if not server_nounce == ws.create_server_nounce(self.client_nounce): self.close() def read_next_message(self): diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 527d55d6..cf9a68aa 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -84,7 +84,7 @@ class WebSocketsFrame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_byte_stream() directly """ - self.from_byte_stream(io.BytesIO(bytestring).read) + return cls.from_byte_stream(io.BytesIO(bytestring).read) @classmethod @@ -115,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def frame_is_valid(self): + def is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -155,12 +155,11 @@ class WebSocketsFrame(object): ("masking_key - " + str(self.masking_key)), ("payload - " + str(self.payload)), ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length)), - ("use_validation - " + str(self.use_validation))]) + ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): try: - assert self.frame_is_valid() + assert self.is_valid() return self.to_bytes() except: raise WebSocketFrameValidationException() @@ -197,7 +196,7 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - + elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short @@ -267,6 +266,20 @@ class WebSocketsFrame(object): actual_payload_length = actual_payload_length ) + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length) + def apply_mask(message, masking_key): """ Data sent from the server must be masked to prevent malicious clients @@ -300,16 +313,14 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) - -def create_server_handshake(key, magic = websockets_magic): +def create_server_handshake(key): """ The server response is a valid HTTP 101 response. """ - digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', digest) + ('Sec-WebSocket-Accept', create_server_nounce(key)) ] request = "HTTP/1.1 101 Switching Protocols" return build_handshake(headers, request) @@ -322,7 +333,6 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) - def read_handshake(read_bytes, num_bytes_per_read): """ From provided function that reads bytes, read in a @@ -355,13 +365,26 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) -def server_process_handshake(handshake): - headers = Message(StringIO(handshake.split('\r\n', 1)[1])) +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": return key = headers['Sec-WebSocket-Key'] return key -def generate_client_nounce(): +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + +def headers_from_http_message(http_message): + return Message(StringIO(http_message.split('\r\n', 1)[1])) + +def create_server_nounce(client_nounce): + return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + +def create_client_nounce(): return b64encode(os.urandom(16)).decode('utf-8') |