From e41e5cbfdd7b778e6f68e86658e95f9e413133cb Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Thu, 9 Apr 2015 19:35:40 -0700 Subject: netlib websockets --- netlib/http.py | 14 ++ netlib/utils.py | 3 + netlib/websockets/__init__.py | 1 + netlib/websockets/implementations.py | 81 ++++++++ netlib/websockets/websockets.py | 368 +++++++++++++++++++++++++++++++++++ 5 files changed, 467 insertions(+) create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/implementations.py create mode 100644 netlib/websockets/websockets.py (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 26438863..2c72621d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,6 +29,20 @@ def _is_valid_host(host): return None return True +def is_successful_upgrade(request, response): + """ + determines if a client and server successfully agreed to an HTTP protocol upgrade + + https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism + """ + http_switching_protocols_code = 101 + + if request and response: + responseUpgrade = request.headers.get("Upgrade") + requestUpgrade = response.headers.get("Upgrade") + if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: + return requestUpgrade[0] if len(requestUpgrade) > 0 else None + return None def parse_url(url): """ diff --git a/netlib/utils.py b/netlib/utils.py index 79077ac6..03a70977 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -8,6 +8,9 @@ def isascii(s): return False return True +# best way to do it in python 2.x +def bytes_to_int(i): + return int(i.encode('hex'), 16) def cleanBin(s, fixspacing=False): """ diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..9b4faa33 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1 @@ +from __future__ import (absolute_import, print_function, division) diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py new file mode 100644 index 00000000..78ae5be6 --- /dev/null +++ b/netlib/websockets/implementations.py @@ -0,0 +1,81 @@ +from netlib import tcp +from base64 import b64encode +from StringIO import StringIO +from . import websockets as ws +import struct +import SocketServer +import os + +# Simple websocket client and servers that are used to exercise the functionality in websockets.py +# These are *not* fully RFC6455 compliant + +class WebSocketsEchoHandler(tcp.BaseHandler): + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__(connection, address, server) + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + self.on_message(decoded) + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() + + def handshake(self): + client_hs = ws.read_handshake(self.rfile.read, 1) + key = ws.server_process_handshake(client_hs) + response = ws.create_server_handshake(key) + self.wfile.write(response) + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.version = "13" + self.key = b64encode(os.urandom(16)).decode('utf-8') + self.resource = "/" + + def connect(self): + super(WebSocketsClient, self).connect() + + handshake = ws.create_client_handshake( + self.address.host, + self.address.port, + self.key, + self.version, + self.resource + ) + + self.wfile.write(handshake) + self.wfile.flush() + + response = ws.read_handshake(self.rfile.read, 1) + + if not response: + self.close() + + def read_next_message(self): + try: + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + except IndexError: + self.close() + + def send_message(self, message): + frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) + self.wfile.write(frame.to_bytes()) + self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py new file mode 100644 index 00000000..b796ce39 --- /dev/null +++ b/netlib/websockets/websockets.py @@ -0,0 +1,368 @@ +from __future__ import absolute_import + +from base64 import b64encode +from hashlib import sha1 +from mimetools import Message +from netlib import tcp +from netlib import utils +from StringIO import StringIO +import os +import SocketServer +import struct +import io + +# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol +# Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +# +# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + +class WebSocketFrameValidationException(Exception): + pass + +class WebSocketsFrame(object): + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + def __init__( + self, + fin, # decmial integer 1 or 0 + opcode, # decmial integer 1 - 4 + mask_bit, # decimal integer 1 or 0 + payload_length_code, # decimal integer 1 - 127 + decoded_payload, # bytestring + rsv1 = 0, # decimal integer 1 or 0 + rsv2 = 0, # decimal integer 1 or 0 + rsv3 = 0, # decimal integer 1 or 0 + payload = None, # bytestring + masking_key = None, # 32 bit byte string + actual_payload_length = None, # any decimal integer + use_validation = True # indicates whether or not you care if this frame adheres to the spec + ): + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + self.opcode = opcode + self.mask_bit = mask_bit + self.payload_length_code = payload_length_code + self.masking_key = masking_key + self.payload = payload + self.decoded_payload = decoded_payload + self.actual_payload_length = actual_payload_length + self.use_validation = use_validation + + if self.use_validation: + self.validate_frame() + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use read_frame() directly + """ + self.from_byte_stream(io.BytesIO(bytestring).read) + + @classmethod + def default_frame_from_message(cls, message, from_client = False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + length_code, actual_length = get_payload_length_pair(message) + + if from_client: + mask_bit = 1 + masking_key = random_masking_key() + payload = apply_mask(message, masking_key) + else: + mask_bit = 0 + masking_key = None + payload = message + + return cls( + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, + actual_payload_length = actual_length + ) + + def validate_frame(self): + """ + Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame + has not been corrupted. + """ + try: + assert 0 <= self.fin <= 1 + assert 0 <= self.rsv1 <= 1 + assert 0 <= self.rsv2 <= 1 + assert 0 <= self.rsv3 <= 1 + assert 1 <= self.opcode <= 4 + assert 0 <= self.mask_bit <= 1 + assert 1 <= self.payload_length_code <= 127 + + if self.mask_bit == 1: + assert 1 <= len(self.masking_key) <= 4 + else: + assert self.masking_key == None + + assert self.actual_payload_length == len(self.payload) + + if self.payload is not None and self.masking_key is not None: + apply_mask(self.payload, self.masking_key) == self.decoded_payload + + except AssertionError: + raise WebSocketFrameValidationException() + + def human_readable(self): + return "\n".join([ + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)), + ("use_validation - " + str(self.use_validation))]) + + def to_bytes(self): + """ + Serialize the frame back into the wire format, returns a bytestring + """ + # validate enforces all the assumptions made by this serializer + # in the spritit of mitmproxy, it's possible to create and serialize invalid frames + # by skipping validation. + if self.use_validation: + self.validate_frame() + + max_16_bit_int = (1 << 16) + max_64_bit_int = (1 << 63) + + # break down of the bit-math used to construct the first byte from the frame's integer values + # first shift the significant bit into the correct position + # 00000001 << 7 = 10000000 + # ... + # then combine: + # + # 10000000 fin + # 01000000 res1 + # 00100000 res2 + # 00010000 res3 + # 00000001 opcode + # -------- OR + # 11110001 = first_byte + + first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + + second_byte = (self.mask_bit << 7) | self.payload_length_code + + bytes = chr(first_byte) + chr(second_byte) + + if self.actual_payload_length < 126: + pass + + elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short + bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length + + elif self.actual_payload_length < max_64_bit_int: + # '!Q' = pack as 64 bit unsigned long long + bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + + if self.masking_key is not None: + bytes += self.masking_key + + bytes += self.payload # already will be encoded if neccessary + + return bytes + + + @classmethod + def from_byte_stream(cls, read_bytes): + """ + read a websockets frame sent by a server or client + + read_bytes is a function that can be backed + by sockets or by any byte reader. So this + function may be used to read frames from disk/wire/memory + """ + first_byte = utils.bytes_to_int(read_bytes(1)) + second_byte = utils.bytes_to_int(read_bytes(1)) + + fin = first_byte >> 7 # grab the left most bit + opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 + mask_bit = second_byte >> 7 # grab left most bit + payload_length = second_byte & 127 # grab the next 7 bits + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if payload_length <= 125: + actual_payload_length = payload_length + + elif payload_length == 126: + actual_payload_length = utils.bytes_to_int(read_bytes(2)) + + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = read_bytes(4) + else: + masking_key = None + + payload = read_bytes(actual_payload_length) + + if mask_bit == 1: + decoded_payload = apply_mask(payload, masking_key) + else: + decoded_payload = payload + + return cls( + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, + actual_payload_length = actual_payload_length + ) + +def apply_mask(message, masking_key): + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + This method both encodes and decodes strings with the provided mask + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + masks = [utils.bytes_to_int(byte) for byte in masking_key] + result = "" + for char in message: + result += chr(ord(char) ^ masks[len(result) % 4]) + return result + +def random_masking_key(): + return os.urandom(4) + +def masking_key_list(masking_key): + return [utils.bytes_to_int(byte) for byte in masking_key] + +def create_client_handshake(host, port, key, version, resource): + """ + WebSockets connections are intiated by the client with a valid HTTP upgrade request + """ + headers = [ + ('Host', '%s:%s' % (host, port)), + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Key', key), + ('Sec-WebSocket-Version', version) + ] + request = "GET %s HTTP/1.1" % resource + return build_handshake(headers, request) + + +def create_server_handshake(key, magic = websockets_magic): + """ + The server response is a valid HTTP 101 response. + """ + digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) + headers = [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + ('Sec-WebSocket-Accept', digest) + ] + request = "HTTP/1.1 101 Switching Protocols" + return build_handshake(headers, request) + + +def build_handshake(headers, request): + handshake = [request.encode('utf-8')] + for header, value in headers: + handshake.append(("%s: %s" % (header, value)).encode('utf-8')) + handshake.append(b'\r\n') + return b'\r\n'.join(handshake) + + +def read_handshake(read_bytes, num_bytes_per_read): + """ + From provided function that reads bytes, read in a + complete HTTP request, which terminates with a CLRF + """ + response = b'' + doubleCLRF = b'\r\n\r\n' + while True: + bytes = read_bytes(num_bytes_per_read) + if not bytes: + break + response += bytes + if doubleCLRF in response: + break + return response + +def get_payload_length_pair(payload_bytestring): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is larger + than 125 + """ + actual_length = len(payload_bytestring) + + if actual_length <= 125: + length_code = actual_length + elif actual_length >= 126 and actual_length <= 65535: + length_code = 126 + else: + length_code = 127 + return (length_code, actual_length) + +def server_process_handshake(handshake): + headers = Message(StringIO(handshake.split('\r\n', 1)[1])) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Key'] + return key + +def generate_client_nounce(): + return b64encode(os.urandom(16)).decode('utf-8') + -- cgit v1.2.3 From 0edc04814e3affa71025938ac354707b9b4c481c Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 11:35:15 -0700 Subject: small cleanups, working on tests --- netlib/websockets/implementations.py | 10 +++++----- netlib/websockets/websockets.py | 35 +++++++++++++++++------------------ 2 files changed, 22 insertions(+), 23 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 78ae5be6..ff42ff65 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -26,8 +26,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = False) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() def handshake(self): @@ -47,7 +47,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) self.version = "13" - self.key = b64encode(os.urandom(16)).decode('utf-8') + self.key = ws.generate_client_nounce() self.resource = "/" def connect(self): @@ -76,6 +76,6 @@ class WebSocketsClient(tcp.TCPClient): self.close() def send_message(self, message): - frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True) - self.wfile.write(frame.to_bytes()) + frame = ws.WebSocketsFrame.default(message, from_client = True) + self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index b796ce39..527d55d6 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -65,7 +65,6 @@ class WebSocketsFrame(object): payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer - use_validation = True # indicates whether or not you care if this frame adheres to the spec ): self.fin = fin self.rsv1 = rsv1 @@ -78,21 +77,18 @@ class WebSocketsFrame(object): self.payload = payload self.decoded_payload = decoded_payload self.actual_payload_length = actual_payload_length - self.use_validation = use_validation - - if self.use_validation: - self.validate_frame() @classmethod def from_bytes(cls, bytestring): """ Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use read_frame() directly + to construct a frame from a stream of bytes, use from_byte_stream() directly """ self.from_byte_stream(io.BytesIO(bytestring).read) + @classmethod - def default_frame_from_message(cls, message, from_client = False): + def default(cls, message, from_client = False): """ Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. @@ -119,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def validate_frame(self): + def frame_is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -141,10 +137,11 @@ class WebSocketsFrame(object): assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: - apply_mask(self.payload, self.masking_key) == self.decoded_payload + assert apply_mask(self.payload, self.masking_key) == self.decoded_payload + return True except AssertionError: - raise WebSocketFrameValidationException() + return False def human_readable(self): return "\n".join([ @@ -161,15 +158,19 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length)), ("use_validation - " + str(self.use_validation))]) + def safe_to_bytes(self): + try: + assert self.frame_is_valid() + return self.to_bytes() + except: + raise WebSocketFrameValidationException() + def to_bytes(self): """ Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees that the + serialized bytes will be correct. see safe_to_bytes() """ - # validate enforces all the assumptions made by this serializer - # in the spritit of mitmproxy, it's possible to create and serialize invalid frames - # by skipping validation. - if self.use_validation: - self.validate_frame() max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -198,6 +199,7 @@ class WebSocketsFrame(object): pass elif self.actual_payload_length < max_16_bit_int: + # '!H' pack as 16 bit unsigned short bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length @@ -284,9 +286,6 @@ def apply_mask(message, masking_key): def random_masking_key(): return os.urandom(4) -def masking_key_list(masking_key): - return [utils.bytes_to_int(byte) for byte in masking_key] - def create_client_handshake(host, port, key, version, resource): """ WebSockets connections are intiated by the client with a valid HTTP upgrade request -- cgit v1.2.3 From 73ce169e3d11eeabeb78143bd86edfdbc3e07fd9 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 10:26:09 +1200 Subject: Initial outline of a cookie parsing and serialization module. --- netlib/http_cookies.py | 133 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 133 insertions(+) create mode 100644 netlib/http_cookies.py (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py new file mode 100644 index 00000000..e11e0f90 --- /dev/null +++ b/netlib/http_cookies.py @@ -0,0 +1,133 @@ +""" +A flexible module for cookie parsing and manipulation. + +We try to be as permissive as possible. Parsing accepts formats from RFC6265 an +RFC2109. Serialization follows RFC6265 strictly. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 +""" + +import re + +import odict + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start+1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i+1], i+1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start+1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + pass + else: + ret.append(s[i]) + return "".join(ret), i+1 + + +def _read_value(s, start): + """ + Reads a value - the RHS of a token/value pair in a cookie. + """ + if s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, ";,") + + +def _read_pairs(s): + """ + Read pairs of lhs=rhs values. + """ + off = 0 + vals = [] + while 1: + lhs, off = _read_token(s, off) + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off+1) + vals.append([lhs.lstrip(), rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +ESCAPE = re.compile(r"([\"\\])") +SPECIAL = re.compile(r"^\w+$") + + +def _format_pairs(lst): + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + match = SPECIAL.search(v) + if match: + v = ESCAPE.sub(r"\1", v) + vals.append("%s=%s"%(k, v)) + return "; ".join(vals) + + +def parse_cookies(s): + """ + Parses a Cookie header value. + Returns an ODict object. + """ + pairs, off = _read_pairs(s) + return odict.ODict(pairs) + + +def unparse_cookies(od): + """ + Formats a Cookie header value. + """ + vals = [] + for i in od.lst: + vals.append("%s=%s"%(i[0], i[1])) + return "; ".join(vals) + + + +def parse_set_cookies(s): + start = 0 + + +def unparse_set_cookies(s): + pass -- cgit v1.2.3 From 2630da7263242411d413b5e4b2c520d29848c918 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Sun, 12 Apr 2015 11:26:02 +1200 Subject: cookies: Cater for special values, fix some bugs found in real-world testing --- netlib/http_cookies.py | 48 +++++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 15 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index e11e0f90..82675418 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -59,29 +59,39 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start): +def _read_value(s, start, special): """ Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. """ - if s[start] == '"': + if start >= len(s): + return "", start + elif s[start] == '"': return _read_quoted_string(s, start) + elif special: + return _read_until(s, start, ";") else: return _read_until(s, start, ";,") -def _read_pairs(s): +def _read_pairs(s, specials=()): """ Read pairs of lhs=rhs values. + + specials: A lower-cased list of keys that may contain commas. """ off = 0 vals = [] while 1: lhs, off = _read_token(s, off) + lhs = lhs.lstrip() rhs = None if off < len(s): if s[off] == "=": - rhs, off = _read_value(s, off+1) - vals.append([lhs.lstrip(), rhs]) + rhs, off = _read_value(s, off+1, lhs.lower() in specials) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break @@ -89,18 +99,30 @@ def _read_pairs(s): ESCAPE = re.compile(r"([\"\\])") -SPECIAL = re.compile(r"^\w+$") -def _format_pairs(lst): +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +def _format_pairs(lst, specials=()): + """ + specials: A lower-cased list of keys that will not be quoted. + """ vals = [] for k, v in lst: if v is None: vals.append(k) else: - match = SPECIAL.search(v) - if match: - v = ESCAPE.sub(r"\1", v) + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"'%v vals.append("%s=%s"%(k, v)) return "; ".join(vals) @@ -118,11 +140,7 @@ def unparse_cookies(od): """ Formats a Cookie header value. """ - vals = [] - for i in od.lst: - vals.append("%s=%s"%(i[0], i[1])) - return "; ".join(vals) - + return _format_pairs(od.lst) def parse_set_cookies(s): -- cgit v1.2.3 From f131f9b855e77554072415c925ed112ec74ee48a Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Sat, 11 Apr 2015 15:40:18 -0700 Subject: handshake tests, serialization test --- netlib/websockets/implementations.py | 19 +++++++++----- netlib/websockets/websockets.py | 51 ++++++++++++++++++++++++++---------- 2 files changed, 49 insertions(+), 21 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index ff42ff65..73a84690 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -32,7 +32,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) - key = ws.server_process_handshake(client_hs) + key = ws.process_handshake_from_client(client_hs) response = ws.create_server_handshake(key) self.wfile.write(response) self.wfile.flush() @@ -46,9 +46,9 @@ class WebSocketsEchoHandler(tcp.BaseHandler): class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) - self.version = "13" - self.key = ws.generate_client_nounce() - self.resource = "/" + self.version = "13" + self.client_nounce = ws.create_client_nounce() + self.resource = "/" def connect(self): super(WebSocketsClient, self).connect() @@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient): handshake = ws.create_client_handshake( self.address.host, self.address.port, - self.key, + self.client_nounce, self.version, self.resource ) @@ -64,9 +64,14 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.write(handshake) self.wfile.flush() - response = ws.read_handshake(self.rfile.read, 1) + server_handshake = ws.read_handshake(self.rfile.read, 1) - if not response: + if not server_handshake: + self.close() + + server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) + + if not server_nounce == ws.create_server_nounce(self.client_nounce): self.close() def read_next_message(self): diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 527d55d6..cf9a68aa 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -84,7 +84,7 @@ class WebSocketsFrame(object): Construct a websocket frame from an in-memory bytestring to construct a frame from a stream of bytes, use from_byte_stream() directly """ - self.from_byte_stream(io.BytesIO(bytestring).read) + return cls.from_byte_stream(io.BytesIO(bytestring).read) @classmethod @@ -115,7 +115,7 @@ class WebSocketsFrame(object): actual_payload_length = actual_length ) - def frame_is_valid(self): + def is_valid(self): """ Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame has not been corrupted. @@ -155,12 +155,11 @@ class WebSocketsFrame(object): ("masking_key - " + str(self.masking_key)), ("payload - " + str(self.payload)), ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length)), - ("use_validation - " + str(self.use_validation))]) + ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): try: - assert self.frame_is_valid() + assert self.is_valid() return self.to_bytes() except: raise WebSocketFrameValidationException() @@ -197,7 +196,7 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - + elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short @@ -267,6 +266,20 @@ class WebSocketsFrame(object): actual_payload_length = actual_payload_length ) + def __eq__(self, other): + return ( + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length) + def apply_mask(message, masking_key): """ Data sent from the server must be masked to prevent malicious clients @@ -300,16 +313,14 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) - -def create_server_handshake(key, magic = websockets_magic): +def create_server_handshake(key): """ The server response is a valid HTTP 101 response. """ - digest = b64encode(sha1(key + magic).hexdigest().decode('hex')) headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', digest) + ('Sec-WebSocket-Accept', create_server_nounce(key)) ] request = "HTTP/1.1 101 Switching Protocols" return build_handshake(headers, request) @@ -322,7 +333,6 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) - def read_handshake(read_bytes, num_bytes_per_read): """ From provided function that reads bytes, read in a @@ -355,13 +365,26 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) -def server_process_handshake(handshake): - headers = Message(StringIO(handshake.split('\r\n', 1)[1])) +def process_handshake_from_client(handshake): + headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": return key = headers['Sec-WebSocket-Key'] return key -def generate_client_nounce(): +def process_handshake_from_server(handshake, client_nounce): + headers = headers_from_http_message(handshake) + if headers.get("Upgrade", None) != "websocket": + return + key = headers['Sec-WebSocket-Accept'] + return key + +def headers_from_http_message(http_message): + return Message(StringIO(http_message.split('\r\n', 1)[1])) + +def create_server_nounce(client_nounce): + return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + +def create_client_nounce(): return b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 2d72a1b6b56f1643cd1d8be59eee55aa7ca2f17f Mon Sep 17 00:00:00 2001 From: Chandler Abraham Date: Mon, 13 Apr 2015 13:36:09 -0700 Subject: 100% test coverage, though still need plenty more --- netlib/http.py | 14 -------------- netlib/websockets/implementations.py | 10 ++-------- netlib/websockets/websockets.py | 9 ++++----- 3 files changed, 6 insertions(+), 27 deletions(-) (limited to 'netlib') diff --git a/netlib/http.py b/netlib/http.py index 2c72621d..26438863 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -29,20 +29,6 @@ def _is_valid_host(host): return None return True -def is_successful_upgrade(request, response): - """ - determines if a client and server successfully agreed to an HTTP protocol upgrade - - https://developer.mozilla.org/en-US/docs/Web/HTTP/Protocol_upgrade_mechanism - """ - http_switching_protocols_code = 101 - - if request and response: - responseUpgrade = request.headers.get("Upgrade") - requestUpgrade = response.headers.get("Upgrade") - if response.code == http_switching_protocols_code and responseUpgrade == requestUpgrade: - return requestUpgrade[0] if len(requestUpgrade) > 0 else None - return None def parse_url(url): """ diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 73a84690..1ded3b85 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -65,9 +65,6 @@ class WebSocketsClient(tcp.TCPClient): self.wfile.flush() server_handshake = ws.read_handshake(self.rfile.read, 1) - - if not server_handshake: - self.close() server_nounce = ws.process_handshake_from_server(server_handshake, self.client_nounce) @@ -75,11 +72,8 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - try: - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload - except IndexError: - self.close() - + return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + def send_message(self, message): frame = ws.WebSocketsFrame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index cf9a68aa..ea3db21d 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -158,11 +158,10 @@ class WebSocketsFrame(object): ("actual_payload_length - " + str(self.actual_payload_length))]) def safe_to_bytes(self): - try: - assert self.is_valid() - return self.to_bytes() - except: - raise WebSocketFrameValidationException() + if self.is_valid(): + return self.to_bytes() + else: + raise WebSocketFrameValidationException() def to_bytes(self): """ -- cgit v1.2.3 From de9e7411253c4f67ea4d0b96f6f9e952024c5fa3 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:02:10 +1200 Subject: Firm up cookie parsing and formatting API Make a tough call: we won't support old-style comma-separated set-cookie headers. Real world testing has shown that the latest rfc (6265) is often violated in ways that make the parsing problem indeterminate. Since this is much more common than the old style deprecated set-cookie variant, we focus on the most useful case. --- netlib/http_cookies.py | 112 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 83 insertions(+), 29 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 82675418..a1f240f5 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -1,13 +1,27 @@ """ A flexible module for cookie parsing and manipulation. -We try to be as permissive as possible. Parsing accepts formats from RFC6265 an -RFC2109. Serialization follows RFC6265 strictly. +This module differs from usual standards-compliant cookie modules in a number of +ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple cookies +to be set in a single header. Technically this should be feasible, but it turns +out that violations of RFC6265 that makes the parsing problem indeterminate are +much more common than genuine occurences of the multi-cookie variants. +Serialization follows RFC6265. http://tools.ietf.org/html/rfc6265 http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 """ +# TODO +# - Disallow LHS-only Cookie values + import re import odict @@ -59,7 +73,7 @@ def _read_quoted_string(s, start): return "".join(ret), i+1 -def _read_value(s, start, special): +def _read_value(s, start, delims): """ Reads a value - the RHS of a token/value pair in a cookie. @@ -70,37 +84,41 @@ def _read_value(s, start, special): return "", start elif s[start] == '"': return _read_quoted_string(s, start) - elif special: - return _read_until(s, start, ";") else: - return _read_until(s, start, ";,") + return _read_until(s, start, delims) -def _read_pairs(s, specials=()): +def _read_pairs(s, off=0, term=None, specials=()): """ Read pairs of lhs=rhs values. - specials: A lower-cased list of keys that may contain commas. + off: start offset + term: if True, treat a comma as a terminator for the pairs lists + specials: a lower-cased list of keys that may contain commas if term is + True """ - off = 0 vals = [] while 1: lhs, off = _read_token(s, off) lhs = lhs.lstrip() - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off+1, lhs.lower() in specials) - vals.append([lhs, rhs]) + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + if term and lhs.lower() not in specials: + delims = ";," + else: + delims = ";" + rhs, off = _read_value(s, off+1, delims) + vals.append([lhs, rhs]) off += 1 if not off < len(s): break + if term and s[off-1] == ",": + break return vals, off -ESCAPE = re.compile(r"([\"\\])") - - def _has_special(s): for i in s: if i in '",;\\': @@ -111,6 +129,9 @@ def _has_special(s): return False +ESCAPE = re.compile(r"([\"\\])") + + def _format_pairs(lst, specials=()): """ specials: A lower-cased list of keys that will not be quoted. @@ -127,25 +148,58 @@ def _format_pairs(lst, specials=()): return "; ".join(vals) -def parse_cookies(s): +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials = ("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): """ - Parses a Cookie header value. - Returns an ODict object. + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. """ - pairs, off = _read_pairs(s) - return odict.ODict(pairs) + pairs, off = _read_pairs( + s, + specials = ("expires", "path") + ) + return pairs -def unparse_cookies(od): +def parse_set_cookie_header(str): """ - Formats a Cookie header value. + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. """ - return _format_pairs(od.lst) + pairs = _parse_set_cookie_pairs(str) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) -def parse_set_cookies(s): - start = 0 +def parse_cookie_header(str): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off = _read_pairs(str) + return odict.ODict(pairs) -def unparse_set_cookies(s): - pass +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) -- cgit v1.2.3 From 6db5e0a4a133e6e6150f9cab87cd56b40d6db0b2 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 10:13:03 +1200 Subject: Remove old-style set-cookie cruft, unit tests to 100% --- netlib/http_cookies.py | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index a1f240f5..297efb80 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -88,14 +88,12 @@ def _read_value(s, start, delims): return _read_until(s, start, delims) -def _read_pairs(s, off=0, term=None, specials=()): +def _read_pairs(s, off=0, specials=()): """ Read pairs of lhs=rhs values. off: start offset - term: if True, treat a comma as a terminator for the pairs lists - specials: a lower-cased list of keys that may contain commas if term is - True + specials: a lower-cased list of keys that may contain commas """ vals = [] while 1: @@ -105,17 +103,11 @@ def _read_pairs(s, off=0, term=None, specials=()): rhs = None if off < len(s): if s[off] == "=": - if term and lhs.lower() not in specials: - delims = ";," - else: - delims = ";" - rhs, off = _read_value(s, off+1, delims) + rhs, off = _read_value(s, off+1, ";") vals.append([lhs, rhs]) off += 1 if not off < len(s): break - if term and s[off-1] == ",": - break return vals, off -- cgit v1.2.3 From d739882bf2dc65925c001c5bf848f5664640d299 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 13:50:57 +1200 Subject: Add an .extend method for ODicts --- netlib/odict.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 7a2f611b..7a54f282 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -108,6 +108,12 @@ class ODict(object): lst = copy.deepcopy(self.lst) return self.__class__(lst) + def extend(self, other): + """ + Add the contents of other, preserving any duplicates. + """ + self.lst.extend(other.lst) + def __repr__(self): elements = [] for itm in self.lst: -- cgit v1.2.3 From aeebf31927eb3ff74824525005c7b146024de6d5 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Tue, 14 Apr 2015 16:20:02 +1200 Subject: odict: don't convert values to strings when added --- netlib/odict.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index 7a54f282..a0ea9e53 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -84,7 +84,7 @@ class ODict(object): return False def add(self, key, value): - self.lst.append([key, str(value)]) + self.lst.append([key, value]) def get(self, k, d=None): if k in self: @@ -117,7 +117,7 @@ class ODict(object): def __repr__(self): elements = [] for itm in self.lst: - elements.append(itm[0] + ": " + itm[1]) + elements.append(itm[0] + ": " + str(itm[1])) elements.append("") return "\r\n".join(elements) -- cgit v1.2.3 From 0c85c72dc43d0d017e2bf5af9c2def46968d0499 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 15 Apr 2015 10:28:17 +1200 Subject: ODict improvements - Setting values now tries to preserve the existing order, rather than just appending to the end. - __repr__ now returns a repr of the tuple list. The old repr becomes a .format() method. This is clearer, makes troubleshooting easier, and doesn't assume all data in ODicts are header-like --- netlib/odict.py | 25 +++++++++++++++++++------ netlib/wsgi.py | 29 ++++++++++++++++++----------- 2 files changed, 37 insertions(+), 17 deletions(-) (limited to 'netlib') diff --git a/netlib/odict.py b/netlib/odict.py index a0ea9e53..dd738c55 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -13,7 +13,8 @@ def safe_subn(pattern, repl, target, *args, **kwargs): class ODict(object): """ - A dictionary-like object for managing ordered (key, value) data. + A dictionary-like object for managing ordered (key, value) data. Think + about it as a convenient interface to a list of (key, value) tuples. """ def __init__(self, lst=None): self.lst = lst or [] @@ -64,11 +65,20 @@ class ODict(object): key, they are cleared. """ if isinstance(valuelist, basestring): - raise ValueError("Expected list of values instead of string. Example: odict['Host'] = ['www.example.com']") - - new = self._filter_lst(k, self.lst) - for i in valuelist: - new.append([k, i]) + raise ValueError( + "Expected list of values instead of string. " + "Example: odict['Host'] = ['www.example.com']" + ) + kc = self._kconv(k) + new = [] + for i in self.lst: + if self._kconv(i[0]) == kc: + if valuelist: + new.append([k, valuelist.pop(0)]) + else: + new.append(i) + while valuelist: + new.append([k, valuelist.pop(0)]) self.lst = new def __delitem__(self, k): @@ -115,6 +125,9 @@ class ODict(object): self.lst.extend(other.lst) def __repr__(self): + return repr(self.lst) + + def format(self): elements = [] for itm in self.lst: elements.append(itm[0] + ": " + str(itm[1])) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index bac27d5a..1b979608 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -1,5 +1,8 @@ from __future__ import (absolute_import, print_function, division) -import cStringIO, urllib, time, traceback +import cStringIO +import urllib +import time +import traceback from . import odict, tcp @@ -23,15 +26,18 @@ class Request(object): def date_time_string(): """Return the current date and time formatted for a message header.""" WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] - MONTHS = [None, - 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', - 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] + MONTHS = [ + None, + 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', + 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec' + ] now = time.time() year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( - WEEKS[wd], - day, MONTHS[month], year, - hh, mm, ss) + WEEKS[wd], + day, MONTHS[month], year, + hh, mm, ss + ) return s @@ -100,6 +106,7 @@ class WSGIAdaptor(object): status = None, headers = None ) + def write(data): if not state["headers_sent"]: soc.write("HTTP/1.1 %s\r\n"%state["status"]) @@ -108,7 +115,7 @@ class WSGIAdaptor(object): h["Server"] = [self.sversion] if 'date' not in h: h["Date"] = [date_time_string()] - soc.write(str(h)) + soc.write(h.format()) soc.write("\r\n") state["headers_sent"] = True if data: @@ -130,7 +137,9 @@ class WSGIAdaptor(object): errs = cStringIO.StringIO() try: - dataiter = self.app(self.make_environ(request, errs, **env), start_response) + dataiter = self.app( + self.make_environ(request, errs, **env), start_response + ) for i in dataiter: write(i) if not state["headers_sent"]: @@ -143,5 +152,3 @@ class WSGIAdaptor(object): except Exception: # pragma: no cover pass return errs.getvalue() - - -- cgit v1.2.3 From c53d89fd7fad6c46458ab3d0140528e344de605f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Thu, 16 Apr 2015 08:30:54 +1200 Subject: Improve flexibility of http_cookies._format_pairs --- netlib/http_cookies.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'netlib') diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 297efb80..dab95ed0 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -124,7 +124,7 @@ def _has_special(s): ESCAPE = re.compile(r"([\"\\])") -def _format_pairs(lst, specials=()): +def _format_pairs(lst, specials=(), sep="; "): """ specials: A lower-cased list of keys that will not be quoted. """ @@ -137,7 +137,7 @@ def _format_pairs(lst, specials=()): v = ESCAPE.sub(r"\\\1", v) v = '"%s"'%v vals.append("%s=%s"%(k, v)) - return "; ".join(vals) + return sep.join(vals) def _format_set_cookie_pairs(lst): -- cgit v1.2.3 From 488c25d812a321f5a03253b62ab33b61ecc13de1 Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 13:57:39 +1200 Subject: websockets: whitespace, PEP8 --- netlib/websockets/websockets.py | 169 +++++++++++++++++++++++----------------- 1 file changed, 96 insertions(+), 73 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index ea3db21d..8782ea49 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -1,31 +1,34 @@ from __future__ import absolute_import -from base64 import b64encode -from hashlib import sha1 -from mimetools import Message -from netlib import tcp -from netlib import utils -from StringIO import StringIO +import base64 +import hashlib +import mimetools +import StringIO import os -import SocketServer import struct import io -# Colleciton of utility functions that implement small portions of the RFC6455 WebSockets Protocol -# Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or completeness +from .. import utils + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. # -# This is a work in progress and does not yet contain all the utilites need to create fully complient client/servers +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness # +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # # Spec: https://tools.ietf.org/html/rfc6455 -# The magic sha that websocket servers must know to prove they understand RFC6455 +# The magic sha that websocket servers must know to prove they understand +# RFC6455 websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' + class WebSocketFrameValidationException(Exception): pass + class WebSocketsFrame(object): """ Represents one websockets frame. @@ -33,7 +36,7 @@ class WebSocketsFrame(object): from_bytes() is also avaliable. WebSockets Frame as defined in RFC6455 - + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-------+-+-------------+-------------------------------+ |F|R|R|R| opcode|M| Payload len | Extended payload length | @@ -62,7 +65,7 @@ class WebSocketsFrame(object): rsv1 = 0, # decimal integer 1 or 0 rsv2 = 0, # decimal integer 1 or 0 rsv3 = 0, # decimal integer 1 or 0 - payload = None, # bytestring + payload = None, # bytestring masking_key = None, # 32 bit byte string actual_payload_length = None, # any decimal integer ): @@ -81,18 +84,17 @@ class WebSocketsFrame(object): @classmethod def from_bytes(cls, bytestring): """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_byte_stream() directly - """ + Construct a websocket frame from an in-memory bytestring to construct + a frame from a stream of bytes, use from_byte_stream() directly + """ return cls.from_byte_stream(io.BytesIO(bytestring).read) - @classmethod def default(cls, message, from_client = False): """ - Construct a basic websocket frame from some default values. + Construct a basic websocket frame from some default values. Creates a non-fragmented text frame. - """ + """ length_code, actual_length = get_payload_length_pair(message) if from_client: @@ -103,7 +105,7 @@ class WebSocketsFrame(object): mask_bit = 0 masking_key = None payload = message - + return cls( fin = 1, # final frame opcode = 1, # text @@ -117,10 +119,10 @@ class WebSocketsFrame(object): def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame - has not been corrupted. - """ - try: + Validate websocket frame invariants, call at anytime to ensure the + WebSocketsFrame has not been corrupted. + """ + try: assert 0 <= self.fin <= 1 assert 0 <= self.rsv1 <= 1 assert 0 <= self.rsv2 <= 1 @@ -128,18 +130,18 @@ class WebSocketsFrame(object): assert 1 <= self.opcode <= 4 assert 0 <= self.mask_bit <= 1 assert 1 <= self.payload_length_code <= 127 - + if self.mask_bit == 1: assert 1 <= len(self.masking_key) <= 4 else: - assert self.masking_key == None - + assert self.masking_key is None + assert self.actual_payload_length == len(self.payload) if self.payload is not None and self.masking_key is not None: assert apply_mask(self.payload, self.masking_key) == self.decoded_payload - return True + return True except AssertionError: return False @@ -165,30 +167,32 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring - If you haven't checked is_valid_frame() then there's no guarentees that the - serialized bytes will be correct. see safe_to_bytes() - """ + Serialize the frame back into the wire format, returns a bytestring If + you haven't checked is_valid_frame() then there's no guarentees that + the serialized bytes will be correct. see safe_to_bytes() + """ max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) - # break down of the bit-math used to construct the first byte from the frame's integer values - # first shift the significant bit into the correct position + # break down of the bit-math used to construct the first byte from the + # frame's integer values first shift the significant bit into the + # correct position # 00000001 << 7 = 10000000 # ... # then combine: - # + # # 10000000 fin # 01000000 res1 # 00100000 res2 # 00010000 res3 # 00000001 opcode - # -------- OR + # -------- OR # 11110001 = first_byte - first_byte = (self.fin << 7) | (self.rsv1 << 6) | (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode - + first_byte = (self.fin << 7) | (self.rsv1 << 6) |\ + (self.rsv2 << 4) | (self.rsv3 << 4) | self.opcode + second_byte = (self.mask_bit << 7) | self.payload_length_code bytes = chr(first_byte) + chr(second_byte) @@ -199,11 +203,13 @@ class WebSocketsFrame(object): elif self.actual_payload_length < max_16_bit_int: # '!H' pack as 16 bit unsigned short - bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length - + # add 2 byte extended payload length + bytes += struct.pack('!H', self.actual_payload_length) + elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long - bytes += struct.pack('!Q', self.actual_payload_length) # add 8 bytes extended payload length + # add 8 bytes extended payload length + bytes += struct.pack('!Q', self.actual_payload_length) if self.masking_key is not None: bytes += self.masking_key @@ -212,43 +218,46 @@ class WebSocketsFrame(object): return bytes - @classmethod def from_byte_stream(cls, read_bytes): """ read a websockets frame sent by a server or client - + read_bytes is a function that can be backed - by sockets or by any byte reader. So this + by sockets or by any byte reader. So this function may be used to read frames from disk/wire/memory - """ - first_byte = utils.bytes_to_int(read_bytes(1)) + """ + first_byte = utils.bytes_to_int(read_bytes(1)) second_byte = utils.bytes_to_int(read_bytes(1)) - - fin = first_byte >> 7 # grab the left most bit - opcode = first_byte & 15 # grab right most 4 bits by and-ing with 00001111 - mask_bit = second_byte >> 7 # grab left most bit - payload_length = second_byte & 127 # grab the next 7 bits + + # grab the left most bit + fin = first_byte >> 7 + # grab right most 4 bits by and-ing with 00001111 + opcode = first_byte & 15 + # grab left most bit + mask_bit = second_byte >> 7 + # grab the next 7 bits + payload_length = second_byte & 127 # payload_lengthy > 125 indicates you need to read more bytes # to get the actual payload length if payload_length <= 125: - actual_payload_length = payload_length + actual_payload_length = payload_length elif payload_length == 126: - actual_payload_length = utils.bytes_to_int(read_bytes(2)) + actual_payload_length = utils.bytes_to_int(read_bytes(2)) - elif payload_length == 127: - actual_payload_length = utils.bytes_to_int(read_bytes(8)) + elif payload_length == 127: + actual_payload_length = utils.bytes_to_int(read_bytes(8)) # masking key only present if mask bit set if mask_bit == 1: masking_key = read_bytes(4) else: masking_key = None - + payload = read_bytes(actual_payload_length) - + if mask_bit == 1: decoded_payload = apply_mask(payload, masking_key) else: @@ -295,12 +304,15 @@ def apply_mask(message, masking_key): result += chr(ord(char) ^ masks[len(result) % 4]) return result + def random_masking_key(): return os.urandom(4) + def create_client_handshake(host, port, key, version, resource): """ - WebSockets connections are intiated by the client with a valid HTTP upgrade request + WebSockets connections are intiated by the client with a valid HTTP + upgrade request """ headers = [ ('Host', '%s:%s' % (host, port)), @@ -312,10 +324,11 @@ def create_client_handshake(host, port, key, version, resource): request = "GET %s HTTP/1.1" % resource return build_handshake(headers, request) + def create_server_handshake(key): """ - The server response is a valid HTTP 101 response. - """ + The server response is a valid HTTP 101 response. + """ headers = [ ('Connection', 'Upgrade'), ('Upgrade', 'websocket'), @@ -332,12 +345,13 @@ def build_handshake(headers, request): handshake.append(b'\r\n') return b'\r\n'.join(handshake) + def read_handshake(read_bytes, num_bytes_per_read): """ - From provided function that reads bytes, read in a + From provided function that reads bytes, read in a complete HTTP request, which terminates with a CLRF - """ - response = b'' + """ + response = b'' doubleCLRF = b'\r\n\r\n' while True: bytes = read_bytes(num_bytes_per_read) @@ -348,14 +362,15 @@ def read_handshake(read_bytes, num_bytes_per_read): break return response + def get_payload_length_pair(payload_bytestring): """ A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is larger - than 125 - """ + extended length code to represent the actual length if length code is + larger than 125 + """ actual_length = len(payload_bytestring) - + if actual_length <= 125: length_code = actual_length elif actual_length >= 126 and actual_length <= 65535: @@ -364,6 +379,7 @@ def get_payload_length_pair(payload_bytestring): length_code = 127 return (length_code, actual_length) + def process_handshake_from_client(handshake): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -371,6 +387,7 @@ def process_handshake_from_client(handshake): key = headers['Sec-WebSocket-Key'] return key + def process_handshake_from_server(handshake, client_nounce): headers = headers_from_http_message(handshake) if headers.get("Upgrade", None) != "websocket": @@ -378,12 +395,18 @@ def process_handshake_from_server(handshake, client_nounce): key = headers['Sec-WebSocket-Accept'] return key + def headers_from_http_message(http_message): - return Message(StringIO(http_message.split('\r\n', 1)[1])) + return mimetools.Message( + StringIO.StringIO(http_message.split('\r\n', 1)[1]) + ) + def create_server_nounce(client_nounce): - return b64encode(sha1(client_nounce + websockets_magic).hexdigest().decode('hex')) + return base64.b64encode( + hashlib.sha1(client_nounce + websockets_magic).hexdigest().decode('hex') + ) -def create_client_nounce(): - return b64encode(os.urandom(16)).decode('utf-8') +def create_client_nounce(): + return base64.b64encode(os.urandom(16)).decode('utf-8') -- cgit v1.2.3 From 7defb5be862a4251da9d7c530593f7e9be3e739e Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Fri, 17 Apr 2015 14:29:20 +1200 Subject: websockets: more whitespace, WebSocketFrame -> Frame --- netlib/websockets/implementations.py | 12 ++--- netlib/websockets/websockets.py | 100 +++++++++++++++++------------------ 2 files changed, 55 insertions(+), 57 deletions(-) (limited to 'netlib') diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py index 1ded3b85..337c5496 100644 --- a/netlib/websockets/implementations.py +++ b/netlib/websockets/implementations.py @@ -9,7 +9,7 @@ import os # Simple websocket client and servers that are used to exercise the functionality in websockets.py # These are *not* fully RFC6455 compliant -class WebSocketsEchoHandler(tcp.BaseHandler): +class WebSocketsEchoHandler(tcp.BaseHandler): def __init__(self, connection, address, server): super(WebSocketsEchoHandler, self).__init__(connection, address, server) self.handshake_done = False @@ -22,14 +22,14 @@ class WebSocketsEchoHandler(tcp.BaseHandler): self.read_next_message() def read_next_message(self): - decoded = ws.WebSocketsFrame.from_byte_stream(self.rfile.read).decoded_payload + decoded = ws.Frame.from_byte_stream(self.rfile.read).decoded_payload self.on_message(decoded) def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = False) + frame = ws.Frame.default(message, from_client = False) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() - + def handshake(self): client_hs = ws.read_handshake(self.rfile.read, 1) key = ws.process_handshake_from_client(client_hs) @@ -72,9 +72,9 @@ class WebSocketsClient(tcp.TCPClient): self.close() def read_next_message(self): - return ws.WebSocketsFrame.from_byte_stream(self.rfile.read).payload + return ws.Frame.from_byte_stream(self.rfile.read).payload def send_message(self, message): - frame = ws.WebSocketsFrame.default(message, from_client = True) + frame = ws.Frame.default(message, from_client = True) self.wfile.write(frame.safe_to_bytes()) self.wfile.flush() diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py index 8782ea49..86d98caf 100644 --- a/netlib/websockets/websockets.py +++ b/netlib/websockets/websockets.py @@ -29,7 +29,7 @@ class WebSocketFrameValidationException(Exception): pass -class WebSocketsFrame(object): +class Frame(object): """ Represents one websockets frame. Constructor takes human readable forms of the frame components @@ -98,29 +98,29 @@ class WebSocketsFrame(object): length_code, actual_length = get_payload_length_pair(message) if from_client: - mask_bit = 1 + mask_bit = 1 masking_key = random_masking_key() - payload = apply_mask(message, masking_key) + payload = apply_mask(message, masking_key) else: - mask_bit = 0 + mask_bit = 0 masking_key = None - payload = message + payload = message return cls( - fin = 1, # final frame - opcode = 1, # text - mask_bit = mask_bit, - payload_length_code = length_code, - payload = payload, - masking_key = masking_key, - decoded_payload = message, + fin = 1, # final frame + opcode = 1, # text + mask_bit = mask_bit, + payload_length_code = length_code, + payload = payload, + masking_key = masking_key, + decoded_payload = message, actual_payload_length = actual_length ) def is_valid(self): """ - Validate websocket frame invariants, call at anytime to ensure the - WebSocketsFrame has not been corrupted. + Validate websocket frame invariants, call at anytime to ensure the + Frame has not been corrupted. """ try: assert 0 <= self.fin <= 1 @@ -147,17 +147,18 @@ class WebSocketsFrame(object): def human_readable(self): return "\n".join([ - ("fin - " + str(self.fin)), - ("rsv1 - " + str(self.rsv1)), - ("rsv2 - " + str(self.rsv2)), - ("rsv3 - " + str(self.rsv3)), - ("opcode - " + str(self.opcode)), - ("mask_bit - " + str(self.mask_bit)), - ("payload_length_code - " + str(self.payload_length_code)), - ("masking_key - " + str(self.masking_key)), - ("payload - " + str(self.payload)), - ("decoded_payload - " + str(self.decoded_payload)), - ("actual_payload_length - " + str(self.actual_payload_length))]) + ("fin - " + str(self.fin)), + ("rsv1 - " + str(self.rsv1)), + ("rsv2 - " + str(self.rsv2)), + ("rsv3 - " + str(self.rsv3)), + ("opcode - " + str(self.opcode)), + ("mask_bit - " + str(self.mask_bit)), + ("payload_length_code - " + str(self.payload_length_code)), + ("masking_key - " + str(self.masking_key)), + ("payload - " + str(self.payload)), + ("decoded_payload - " + str(self.decoded_payload)), + ("actual_payload_length - " + str(self.actual_payload_length)) + ]) def safe_to_bytes(self): if self.is_valid(): @@ -167,11 +168,10 @@ class WebSocketsFrame(object): def to_bytes(self): """ - Serialize the frame back into the wire format, returns a bytestring If - you haven't checked is_valid_frame() then there's no guarentees that - the serialized bytes will be correct. see safe_to_bytes() + Serialize the frame back into the wire format, returns a bytestring + If you haven't checked is_valid_frame() then there's no guarentees + that the serialized bytes will be correct. see safe_to_bytes() """ - max_16_bit_int = (1 << 16) max_64_bit_int = (1 << 63) @@ -199,13 +199,10 @@ class WebSocketsFrame(object): if self.actual_payload_length < 126: pass - elif self.actual_payload_length < max_16_bit_int: - # '!H' pack as 16 bit unsigned short # add 2 byte extended payload length bytes += struct.pack('!H', self.actual_payload_length) - elif self.actual_payload_length < max_64_bit_int: # '!Q' = pack as 64 bit unsigned long long # add 8 bytes extended payload length @@ -215,7 +212,6 @@ class WebSocketsFrame(object): bytes += self.masking_key bytes += self.payload # already will be encoded if neccessary - return bytes @classmethod @@ -264,29 +260,31 @@ class WebSocketsFrame(object): decoded_payload = payload return cls( - fin = fin, - opcode = opcode, - mask_bit = mask_bit, - payload_length_code = payload_length, - payload = payload, - masking_key = masking_key, - decoded_payload = decoded_payload, + fin = fin, + opcode = opcode, + mask_bit = mask_bit, + payload_length_code = payload_length, + payload = payload, + masking_key = masking_key, + decoded_payload = decoded_payload, actual_payload_length = actual_payload_length ) def __eq__(self, other): return ( - self.fin == other.fin and - self.rsv1 == other.rsv1 and - self.rsv2 == other.rsv2 and - self.rsv3 == other.rsv3 and - self.opcode == other.opcode and - self.mask_bit == other.mask_bit and - self.payload_length_code == other.payload_length_code and - self.masking_key == other.masking_key and - self.payload == other.payload and - self.decoded_payload == other.decoded_payload and - self.actual_payload_length == other.actual_payload_length) + self.fin == other.fin and + self.rsv1 == other.rsv1 and + self.rsv2 == other.rsv2 and + self.rsv3 == other.rsv3 and + self.opcode == other.opcode and + self.mask_bit == other.mask_bit and + self.payload_length_code == other.payload_length_code and + self.masking_key == other.masking_key and + self.payload == other.payload and + self.decoded_payload == other.decoded_payload and + self.actual_payload_length == other.actual_payload_length + ) + def apply_mask(message, masking_key): """ -- cgit v1.2.3