From 6dcfc35011208f4bfde7f37a63d7b980f6c41ce0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 8 Jul 2015 09:20:25 +0200 Subject: introduce http_semantics module used for generic HTTP representation everything should apply for HTTP/1 and HTTP/2 --- netlib/http.py | 16 ++-------------- netlib/http_semantics.py | 23 +++++++++++++++++++++++ test/test_http.py | 14 +++++++------- 3 files changed, 32 insertions(+), 21 deletions(-) create mode 100644 netlib/http_semantics.py diff --git a/netlib/http.py b/netlib/http.py index a2af9e49..073e9a3f 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -4,7 +4,7 @@ import string import urlparse import binascii import sys -from . import odict, utils, tcp, http_status +from . import odict, utils, tcp, http_semantics, http_status class HttpError(Exception): @@ -527,18 +527,6 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): ) -Response = collections.namedtuple( - "Response", - [ - "httpversion", - "code", - "msg", - "headers", - "content" - ] -) - - def read_response(rfile, request_method, body_size_limit, include_body=True): """ Return an (httpversion, code, msg, headers, content) tuple. @@ -580,7 +568,7 @@ def read_response(rfile, request_method, body_size_limit, include_body=True): # if include_body==False then a None content means the body should be # read separately content = None - return Response(httpversion, code, msg, headers, content) + return http_semantics.Response(httpversion, code, msg, headers, content) def request_preamble(method, resource, http_major="1", http_minor="1"): diff --git a/netlib/http_semantics.py b/netlib/http_semantics.py new file mode 100644 index 00000000..e8313e3c --- /dev/null +++ b/netlib/http_semantics.py @@ -0,0 +1,23 @@ +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + content, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.content = content + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) diff --git a/test/test_http.py b/test/test_http.py index 2ad81d24..bbc78847 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -1,7 +1,7 @@ import cStringIO import textwrap import binascii -from netlib import http, odict, tcp +from netlib import http, http_semantics, odict, tcp from . import tutils, tservers @@ -307,13 +307,13 @@ def test_read_response(): data = """ HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http_semantics.Response( (1, 1), 200, 'OK', odict.ODictCaseless(), '' ) data = """ HTTP/1.1 200 """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http_semantics.Response( (1, 1), 200, '', odict.ODictCaseless(), '' ) data = """ @@ -330,7 +330,7 @@ def test_read_response(): HTTP/1.1 200 OK """ - assert tst(data, "GET", None) == ( + assert tst(data, "GET", None) == http_semantics.Response( (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' ) @@ -340,8 +340,8 @@ def test_read_response(): foo """ - assert tst(data, "GET", None)[4] == 'foo' - assert tst(data, "HEAD", None)[4] == '' + assert tst(data, "GET", None).content == 'foo' + assert tst(data, "HEAD", None).content == '' data = """ HTTP/1.1 200 OK @@ -357,7 +357,7 @@ def test_read_response(): foo """ - assert tst(data, "GET", None, include_body=False)[4] is None + assert tst(data, "GET", None, include_body=False).content is None def test_parse_url(): -- cgit v1.2.3 From bd5ee212840e3be731ea93e14ef1375745383d88 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 8 Jul 2015 09:34:10 +0200 Subject: refactor websockets into protocol --- netlib/websockets.py | 381 ------------------------------------------ netlib/websockets/__init__.py | 2 + netlib/websockets/frame.py | 288 +++++++++++++++++++++++++++++++ netlib/websockets/protocol.py | 111 ++++++++++++ test/test_websockets.py | 31 ++-- 5 files changed, 419 insertions(+), 394 deletions(-) delete mode 100644 netlib/websockets.py create mode 100644 netlib/websockets/__init__.py create mode 100644 netlib/websockets/frame.py create mode 100644 netlib/websockets/protocol.py diff --git a/netlib/websockets.py b/netlib/websockets.py deleted file mode 100644 index c45db4df..00000000 --- a/netlib/websockets.py +++ /dev/null @@ -1,381 +0,0 @@ -from __future__ import absolute_import -import base64 -import hashlib -import os -import struct -import io - -from . import utils, odict, tcp - -# Colleciton of utility functions that implement small portions of the RFC6455 -# WebSockets Protocol Useful for building WebSocket clients and servers. -# -# Emphassis is on readabilty, simplicity and modularity, not performance or -# completeness -# -# This is a work in progress and does not yet contain all the utilites need to -# create fully complient client/servers # -# Spec: https://tools.ietf.org/html/rfc6455 - -# The magic sha that websocket servers must know to prove they understand -# RFC6455 -websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' -VERSION = "13" -MAX_16_BIT_INT = (1 << 16) -MAX_64_BIT_INT = (1 << 64) - - -OPCODE = utils.BiDi( - CONTINUE=0x00, - TEXT=0x01, - BINARY=0x02, - CLOSE=0x08, - PING=0x09, - PONG=0x0a -) - - -class Masker(object): - - """ - Data sent from the server must be masked to prevent malicious clients - from sending data over the wire in predictable patterns - - Servers do not have to mask data they send to the client. - https://tools.ietf.org/html/rfc6455#section-5.3 - """ - - def __init__(self, key): - self.key = key - self.masks = [utils.bytes_to_int(byte) for byte in key] - self.offset = 0 - - def mask(self, offset, data): - result = "" - for c in data: - result += chr(ord(c) ^ self.masks[offset % 4]) - offset += 1 - return result - - def __call__(self, data): - ret = self.mask(self.offset, data) - self.offset += len(ret) - return ret - - -def client_handshake_headers(key=None, version=VERSION): - """ - Create the headers for a valid HTTP upgrade request. If Key is not - specified, it is generated, and can be found in sec-websocket-key in - the returned header set. - - Returns an instance of ODictCaseless - """ - if not key: - key = base64.b64encode(os.urandom(16)).decode('utf-8') - return odict.ODictCaseless([ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Key', key), - ('Sec-WebSocket-Version', version) - ]) - - -def server_handshake_headers(key): - """ - The server response is a valid HTTP 101 response. - """ - return odict.ODictCaseless( - [ - ('Connection', 'Upgrade'), - ('Upgrade', 'websocket'), - ('Sec-WebSocket-Accept', create_server_nonce(key)) - ] - ) - - -def make_length_code(length): - """ - A websockets frame contains an initial length_code, and an optional - extended length code to represent the actual length if length code is - larger than 125 - """ - if length <= 125: - return length - elif length >= 126 and length <= 65535: - return 126 - else: - return 127 - - -def check_client_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-key') - - -def check_server_handshake(headers): - if headers.get_first("upgrade", None) != "websocket": - return - return headers.get_first('sec-websocket-accept') - - -def create_server_nonce(client_nonce): - return base64.b64encode( - hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') - ) - - -DEFAULT = object() - - -class FrameHeader(object): - - def __init__( - self, - opcode=OPCODE.TEXT, - payload_length=0, - fin=False, - rsv1=False, - rsv2=False, - rsv3=False, - masking_key=DEFAULT, - mask=DEFAULT, - length_code=DEFAULT - ): - if not 0 <= opcode < 2 ** 4: - raise ValueError("opcode must be 0-16") - self.opcode = opcode - self.payload_length = payload_length - self.fin = fin - self.rsv1 = rsv1 - self.rsv2 = rsv2 - self.rsv3 = rsv3 - - if length_code is DEFAULT: - self.length_code = make_length_code(self.payload_length) - else: - self.length_code = length_code - - if mask is DEFAULT and masking_key is DEFAULT: - self.mask = False - self.masking_key = "" - elif mask is DEFAULT: - self.mask = 1 - self.masking_key = masking_key - elif masking_key is DEFAULT: - self.mask = mask - self.masking_key = os.urandom(4) - else: - self.mask = mask - self.masking_key = masking_key - - if self.masking_key and len(self.masking_key) != 4: - raise ValueError("Masking key must be 4 bytes.") - - def human_readable(self): - vals = [ - "ws frame:", - OPCODE.get_name(self.opcode, hex(self.opcode)).lower() - ] - flags = [] - for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: - if getattr(self, i): - flags.append(i) - if flags: - vals.extend([":", "|".join(flags)]) - if self.masking_key: - vals.append(":key=%s" % repr(self.masking_key)) - if self.payload_length: - vals.append(" %s" % utils.pretty_size(self.payload_length)) - return "".join(vals) - - def to_bytes(self): - first_byte = utils.setbit(0, 7, self.fin) - first_byte = utils.setbit(first_byte, 6, self.rsv1) - first_byte = utils.setbit(first_byte, 5, self.rsv2) - first_byte = utils.setbit(first_byte, 4, self.rsv3) - first_byte = first_byte | self.opcode - - second_byte = utils.setbit(self.length_code, 7, self.mask) - - b = chr(first_byte) + chr(second_byte) - - if self.payload_length < 126: - pass - elif self.payload_length < MAX_16_BIT_INT: - # '!H' pack as 16 bit unsigned short - # add 2 byte extended payload length - b += struct.pack('!H', self.payload_length) - elif self.payload_length < MAX_64_BIT_INT: - # '!Q' = pack as 64 bit unsigned long long - # add 8 bytes extended payload length - b += struct.pack('!Q', self.payload_length) - if self.masking_key is not None: - b += self.masking_key - return b - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame header - """ - first_byte = utils.bytes_to_int(fp.safe_read(1)) - second_byte = utils.bytes_to_int(fp.safe_read(1)) - - fin = utils.getbit(first_byte, 7) - rsv1 = utils.getbit(first_byte, 6) - rsv2 = utils.getbit(first_byte, 5) - rsv3 = utils.getbit(first_byte, 4) - # grab right-most 4 bits - opcode = first_byte & 15 - mask_bit = utils.getbit(second_byte, 7) - # grab the next 7 bits - length_code = second_byte & 127 - - # payload_lengthy > 125 indicates you need to read more bytes - # to get the actual payload length - if length_code <= 125: - payload_length = length_code - elif length_code == 126: - payload_length = utils.bytes_to_int(fp.safe_read(2)) - elif length_code == 127: - payload_length = utils.bytes_to_int(fp.safe_read(8)) - - # masking key only present if mask bit set - if mask_bit == 1: - masking_key = fp.safe_read(4) - else: - masking_key = None - - return cls( - fin=fin, - rsv1=rsv1, - rsv2=rsv2, - rsv3=rsv3, - opcode=opcode, - mask=mask_bit, - length_code=length_code, - payload_length=payload_length, - masking_key=masking_key, - ) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class Frame(object): - - """ - Represents one websockets frame. - Constructor takes human readable forms of the frame components - from_bytes() is also avaliable. - - WebSockets Frame as defined in RFC6455 - - 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 - +-+-+-+-+-------+-+-------------+-------------------------------+ - |F|R|R|R| opcode|M| Payload len | Extended payload length | - |I|S|S|S| (4) |A| (7) | (16/64) | - |N|V|V|V| |S| | (if payload len==126/127) | - | |1|2|3| |K| | | - +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + - | Extended payload length continued, if payload len == 127 | - + - - - - - - - - - - - - - - - +-------------------------------+ - | |Masking-key, if MASK set to 1 | - +-------------------------------+-------------------------------+ - | Masking-key (continued) | Payload Data | - +-------------------------------- - - - - - - - - - - - - - - - + - : Payload Data continued ... : - + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + - | Payload Data continued ... | - +---------------------------------------------------------------+ - """ - - def __init__(self, payload="", **kwargs): - self.payload = payload - kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) - self.header = FrameHeader(**kwargs) - - @classmethod - def default(cls, message, from_client=False): - """ - Construct a basic websocket frame from some default values. - Creates a non-fragmented text frame. - """ - if from_client: - mask_bit = 1 - masking_key = os.urandom(4) - else: - mask_bit = 0 - masking_key = None - - return cls( - message, - fin=1, # final frame - opcode=OPCODE.TEXT, # text - mask=mask_bit, - masking_key=masking_key, - ) - - @classmethod - def from_bytes(cls, bytestring): - """ - Construct a websocket frame from an in-memory bytestring - to construct a frame from a stream of bytes, use from_file() directly - """ - return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) - - def human_readable(self): - ret = self.header.human_readable() - if self.payload: - ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) - return ret - - def __repr__(self): - return self.header.human_readable() - - def to_bytes(self): - """ - Serialize the frame to wire format. Returns a string. - """ - b = self.header.to_bytes() - if self.header.masking_key: - b += Masker(self.header.masking_key)(self.payload) - else: - b += self.payload - return b - - def to_file(self, writer): - writer.write(self.to_bytes()) - writer.flush() - - @classmethod - def from_file(cls, fp): - """ - read a websockets frame sent by a server or client - - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - header = FrameHeader.from_file(fp) - payload = fp.safe_read(header.payload_length) - - if header.mask == 1 and header.masking_key: - payload = Masker(header.masking_key)(payload) - - return cls( - payload, - fin=header.fin, - opcode=header.opcode, - mask=header.mask, - payload_length=header.payload_length, - masking_key=header.masking_key, - rsv1=header.rsv1, - rsv2=header.rsv2, - rsv3=header.rsv3, - length_code=header.length_code - ) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/websockets/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py new file mode 100644 index 00000000..d41059fa --- /dev/null +++ b/netlib/websockets/frame.py @@ -0,0 +1,288 @@ +from __future__ import absolute_import +import base64 +import hashlib +import os +import struct +import io + +from .protocol import Masker +from .. import utils, odict, tcp + +DEFAULT = object() + +MAX_16_BIT_INT = (1 << 16) +MAX_64_BIT_INT = (1 << 64) + +OPCODE = utils.BiDi( + CONTINUE=0x00, + TEXT=0x01, + BINARY=0x02, + CLOSE=0x08, + PING=0x09, + PONG=0x0a +) + +class FrameHeader(object): + + def __init__( + self, + opcode=OPCODE.TEXT, + payload_length=0, + fin=False, + rsv1=False, + rsv2=False, + rsv3=False, + masking_key=DEFAULT, + mask=DEFAULT, + length_code=DEFAULT + ): + if not 0 <= opcode < 2 ** 4: + raise ValueError("opcode must be 0-16") + self.opcode = opcode + self.payload_length = payload_length + self.fin = fin + self.rsv1 = rsv1 + self.rsv2 = rsv2 + self.rsv3 = rsv3 + + if length_code is DEFAULT: + self.length_code = self._make_length_code(self.payload_length) + else: + self.length_code = length_code + + if mask is DEFAULT and masking_key is DEFAULT: + self.mask = False + self.masking_key = "" + elif mask is DEFAULT: + self.mask = 1 + self.masking_key = masking_key + elif masking_key is DEFAULT: + self.mask = mask + self.masking_key = os.urandom(4) + else: + self.mask = mask + self.masking_key = masking_key + + if self.masking_key and len(self.masking_key) != 4: + raise ValueError("Masking key must be 4 bytes.") + + @classmethod + def _make_length_code(self, length): + """ + A websockets frame contains an initial length_code, and an optional + extended length code to represent the actual length if length code is + larger than 125 + """ + if length <= 125: + return length + elif length >= 126 and length <= 65535: + return 126 + else: + return 127 + + def human_readable(self): + vals = [ + "ws frame:", + OPCODE.get_name(self.opcode, hex(self.opcode)).lower() + ] + flags = [] + for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]: + if getattr(self, i): + flags.append(i) + if flags: + vals.extend([":", "|".join(flags)]) + if self.masking_key: + vals.append(":key=%s" % repr(self.masking_key)) + if self.payload_length: + vals.append(" %s" % utils.pretty_size(self.payload_length)) + return "".join(vals) + + def to_bytes(self): + first_byte = utils.setbit(0, 7, self.fin) + first_byte = utils.setbit(first_byte, 6, self.rsv1) + first_byte = utils.setbit(first_byte, 5, self.rsv2) + first_byte = utils.setbit(first_byte, 4, self.rsv3) + first_byte = first_byte | self.opcode + + second_byte = utils.setbit(self.length_code, 7, self.mask) + + b = chr(first_byte) + chr(second_byte) + + if self.payload_length < 126: + pass + elif self.payload_length < MAX_16_BIT_INT: + # '!H' pack as 16 bit unsigned short + # add 2 byte extended payload length + b += struct.pack('!H', self.payload_length) + elif self.payload_length < MAX_64_BIT_INT: + # '!Q' = pack as 64 bit unsigned long long + # add 8 bytes extended payload length + b += struct.pack('!Q', self.payload_length) + if self.masking_key is not None: + b += self.masking_key + return b + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame header + """ + first_byte = utils.bytes_to_int(fp.safe_read(1)) + second_byte = utils.bytes_to_int(fp.safe_read(1)) + + fin = utils.getbit(first_byte, 7) + rsv1 = utils.getbit(first_byte, 6) + rsv2 = utils.getbit(first_byte, 5) + rsv3 = utils.getbit(first_byte, 4) + # grab right-most 4 bits + opcode = first_byte & 15 + mask_bit = utils.getbit(second_byte, 7) + # grab the next 7 bits + length_code = second_byte & 127 + + # payload_lengthy > 125 indicates you need to read more bytes + # to get the actual payload length + if length_code <= 125: + payload_length = length_code + elif length_code == 126: + payload_length = utils.bytes_to_int(fp.safe_read(2)) + elif length_code == 127: + payload_length = utils.bytes_to_int(fp.safe_read(8)) + + # masking key only present if mask bit set + if mask_bit == 1: + masking_key = fp.safe_read(4) + else: + masking_key = None + + return cls( + fin=fin, + rsv1=rsv1, + rsv2=rsv2, + rsv3=rsv3, + opcode=opcode, + mask=mask_bit, + length_code=length_code, + payload_length=payload_length, + masking_key=masking_key, + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class Frame(object): + + """ + Represents one websockets frame. + Constructor takes human readable forms of the frame components + from_bytes() is also avaliable. + + WebSockets Frame as defined in RFC6455 + + 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + +-+-+-+-+-------+-+-------------+-------------------------------+ + |F|R|R|R| opcode|M| Payload len | Extended payload length | + |I|S|S|S| (4) |A| (7) | (16/64) | + |N|V|V|V| |S| | (if payload len==126/127) | + | |1|2|3| |K| | | + +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + | Extended payload length continued, if payload len == 127 | + + - - - - - - - - - - - - - - - +-------------------------------+ + | |Masking-key, if MASK set to 1 | + +-------------------------------+-------------------------------+ + | Masking-key (continued) | Payload Data | + +-------------------------------- - - - - - - - - - - - - - - - + + : Payload Data continued ... : + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + | Payload Data continued ... | + +---------------------------------------------------------------+ + """ + + def __init__(self, payload="", **kwargs): + self.payload = payload + kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) + self.header = FrameHeader(**kwargs) + + @classmethod + def default(cls, message, from_client=False): + """ + Construct a basic websocket frame from some default values. + Creates a non-fragmented text frame. + """ + if from_client: + mask_bit = 1 + masking_key = os.urandom(4) + else: + mask_bit = 0 + masking_key = None + + return cls( + message, + fin=1, # final frame + opcode=OPCODE.TEXT, # text + mask=mask_bit, + masking_key=masking_key, + ) + + @classmethod + def from_bytes(cls, bytestring): + """ + Construct a websocket frame from an in-memory bytestring + to construct a frame from a stream of bytes, use from_file() directly + """ + return cls.from_file(tcp.Reader(io.BytesIO(bytestring))) + + def human_readable(self): + ret = self.header.human_readable() + if self.payload: + ret = ret + "\nPayload:\n" + utils.cleanBin(self.payload) + return ret + + def __repr__(self): + return self.header.human_readable() + + def to_bytes(self): + """ + Serialize the frame to wire format. Returns a string. + """ + b = self.header.to_bytes() + if self.header.masking_key: + b += Masker(self.header.masking_key)(self.payload) + else: + b += self.payload + return b + + def to_file(self, writer): + writer.write(self.to_bytes()) + writer.flush() + + @classmethod + def from_file(cls, fp): + """ + read a websockets frame sent by a server or client + + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + header = FrameHeader.from_file(fp) + payload = fp.safe_read(header.payload_length) + + if header.mask == 1 and header.masking_key: + payload = Masker(header.masking_key)(payload) + + return cls( + payload, + fin=header.fin, + opcode=header.opcode, + mask=header.mask, + payload_length=header.payload_length, + masking_key=header.masking_key, + rsv1=header.rsv1, + rsv2=header.rsv2, + rsv3=header.rsv3, + length_code=header.length_code + ) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py new file mode 100644 index 00000000..dcab53fb --- /dev/null +++ b/netlib/websockets/protocol.py @@ -0,0 +1,111 @@ +from __future__ import absolute_import +import base64 +import hashlib +import os +import struct +import io + +from .. import utils, odict, tcp + +# Colleciton of utility functions that implement small portions of the RFC6455 +# WebSockets Protocol Useful for building WebSocket clients and servers. +# +# Emphassis is on readabilty, simplicity and modularity, not performance or +# completeness +# +# This is a work in progress and does not yet contain all the utilites need to +# create fully complient client/servers # +# Spec: https://tools.ietf.org/html/rfc6455 + +# The magic sha that websocket servers must know to prove they understand +# RFC6455 +websockets_magic = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11' +VERSION = "13" + +HEADER_WEBSOCKET_KEY = 'sec-websocket-key' +HEADER_WEBSOCKET_ACCEPT = 'sec-websocket-accept' +HEADER_WEBSOCKET_VERSION = 'sec-websocket-version' + +class Masker(object): + + """ + Data sent from the server must be masked to prevent malicious clients + from sending data over the wire in predictable patterns + + Servers do not have to mask data they send to the client. + https://tools.ietf.org/html/rfc6455#section-5.3 + """ + + def __init__(self, key): + self.key = key + self.masks = [utils.bytes_to_int(byte) for byte in key] + self.offset = 0 + + def mask(self, offset, data): + result = "" + for c in data: + result += chr(ord(c) ^ self.masks[offset % 4]) + offset += 1 + return result + + def __call__(self, data): + ret = self.mask(self.offset, data) + self.offset += len(ret) + return ret + +class WebsocketsProtocol(object): + + def __init__(self): + pass + + @classmethod + def client_handshake_headers(self, key=None, version=VERSION): + """ + Create the headers for a valid HTTP upgrade request. If Key is not + specified, it is generated, and can be found in sec-websocket-key in + the returned header set. + + Returns an instance of ODictCaseless + """ + if not key: + key = base64.b64encode(os.urandom(16)).decode('utf-8') + return odict.ODictCaseless([ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_KEY, key), + (HEADER_WEBSOCKET_VERSION, version) + ]) + + @classmethod + def server_handshake_headers(self, key): + """ + The server response is a valid HTTP 101 response. + """ + return odict.ODictCaseless( + [ + ('Connection', 'Upgrade'), + ('Upgrade', 'websocket'), + (HEADER_WEBSOCKET_ACCEPT, self.create_server_nonce(key)) + ] + ) + + + @classmethod + def check_client_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_KEY) + + + @classmethod + def check_server_handshake(self, headers): + if headers.get_first("upgrade", None) != "websocket": + return + return headers.get_first(HEADER_WEBSOCKET_ACCEPT) + + + @classmethod + def create_server_nonce(self, client_nonce): + return base64.b64encode( + hashlib.sha1(client_nonce + websockets_magic).hexdigest().decode('hex') + ) diff --git a/test/test_websockets.py b/test/test_websockets.py index 9956543b..ae0a5e33 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -12,6 +12,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): super(WebSocketsEchoHandler, self).__init__( connection, address, server ) + self.protocol = websockets.WebsocketsProtocol() self.handshake_done = False def handle(self): @@ -31,10 +32,10 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): req = http.read_request(self.rfile) - key = websockets.check_client_handshake(req.headers) + key = self.protocol.check_client_handshake(req.headers) self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers(key) + headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True @@ -48,6 +49,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) + self.protocol = websockets.WebsocketsProtocol() self.client_nonce = None def connect(self): @@ -55,15 +57,15 @@ class WebSocketsClient(tcp.TCPClient): preamble = http.request_preamble("GET", "/") self.wfile.write(preamble + "\r\n") - headers = websockets.client_handshake_headers() + headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() resp = http.read_response(self.rfile, "get", None) - server_nonce = websockets.check_server_handshake(resp.headers) + server_nonce = self.protocol.check_server_handshake(resp.headers) - if not server_nonce == websockets.create_server_nonce( + if not server_nonce == self.protocol.create_server_nonce( self.client_nonce): self.close() @@ -78,6 +80,9 @@ class WebSocketsClient(tcp.TCPClient): class TestWebSockets(tservers.ServerTestBase): handler = WebSocketsEchoHandler + def __init__(self): + self.protocol = websockets.WebsocketsProtocol() + def random_bytes(self, n=100): return os.urandom(n) @@ -130,26 +135,26 @@ class TestWebSockets(tservers.ServerTestBase): assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes def test_check_server_handshake(self): - headers = websockets.server_handshake_headers("key") - assert websockets.check_server_handshake(headers) + headers = self.protocol.server_handshake_headers("key") + assert self.protocol.check_server_handshake(headers) headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_server_handshake(headers) + assert not self.protocol.check_server_handshake(headers) def test_check_client_handshake(self): - headers = websockets.client_handshake_headers("key") - assert websockets.check_client_handshake(headers) == "key" + headers = self.protocol.client_handshake_headers("key") + assert self.protocol.check_client_handshake(headers) == "key" headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_client_handshake(headers) + assert not self.protocol.check_client_handshake(headers) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): client_hs = http.read_request(self.rfile) - websockets.check_client_handshake(client_hs.headers) + self.protocol.check_client_handshake(client_hs.headers) self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers("malformed key") + headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True -- cgit v1.2.3 From f50deb7b763d093a22a4d331e16465a2fb0329cf Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Tue, 14 Jul 2015 23:02:14 +0200 Subject: move bits around --- netlib/http.py | 583 ------------------------------ netlib/http/__init__.py | 2 + netlib/http/authentication.py | 149 ++++++++ netlib/http/cookies.py | 193 ++++++++++ netlib/http/exceptions.py | 9 + netlib/http/http1/__init__.py | 1 + netlib/http/http1/protocol.py | 518 +++++++++++++++++++++++++++ netlib/http/http2/__init__.py | 2 + netlib/http/http2/frame.py | 636 +++++++++++++++++++++++++++++++++ netlib/http/http2/protocol.py | 240 +++++++++++++ netlib/http/semantics.py | 94 +++++ netlib/http/status_codes.py | 104 ++++++ netlib/http/user_agents.py | 52 +++ netlib/http2/__init__.py | 2 - netlib/http2/frame.py | 636 --------------------------------- netlib/http2/protocol.py | 240 ------------- netlib/http_auth.py | 148 -------- netlib/http_cookies.py | 193 ---------- netlib/http_semantics.py | 23 -- netlib/http_status.py | 104 ------ netlib/http_uastrings.py | 52 --- netlib/websockets/frame.py | 2 +- netlib/websockets/protocol.py | 2 +- test/http/__init__.py | 0 test/http/http1/__init__.py | 0 test/http/http1/test_protocol.py | 445 +++++++++++++++++++++++ test/http/http2/__init__.py | 0 test/http/http2/test_frames.py | 704 +++++++++++++++++++++++++++++++++++++ test/http/http2/test_protocol.py | 325 +++++++++++++++++ test/http/test_authentication.py | 110 ++++++ test/http/test_cookies.py | 219 ++++++++++++ test/http/test_semantics.py | 54 +++ test/http/test_user_agents.py | 6 + test/http2/__init__.py | 0 test/http2/test_frames.py | 704 ------------------------------------- test/http2/test_protocol.py | 326 ----------------- test/test_http.py | 491 -------------------------- test/test_http_auth.py | 109 ------ test/test_http_cookies.py | 219 ------------ test/test_http_uastrings.py | 6 - test/test_websockets.py | 261 -------------- test/websockets/__init__.py | 0 test/websockets/test_websockets.py | 262 ++++++++++++++ 43 files changed, 4127 insertions(+), 4099 deletions(-) delete mode 100644 netlib/http.py create mode 100644 netlib/http/__init__.py create mode 100644 netlib/http/authentication.py create mode 100644 netlib/http/cookies.py create mode 100644 netlib/http/exceptions.py create mode 100644 netlib/http/http1/__init__.py create mode 100644 netlib/http/http1/protocol.py create mode 100644 netlib/http/http2/__init__.py create mode 100644 netlib/http/http2/frame.py create mode 100644 netlib/http/http2/protocol.py create mode 100644 netlib/http/semantics.py create mode 100644 netlib/http/status_codes.py create mode 100644 netlib/http/user_agents.py delete mode 100644 netlib/http2/__init__.py delete mode 100644 netlib/http2/frame.py delete mode 100644 netlib/http2/protocol.py delete mode 100644 netlib/http_auth.py delete mode 100644 netlib/http_cookies.py delete mode 100644 netlib/http_semantics.py delete mode 100644 netlib/http_status.py delete mode 100644 netlib/http_uastrings.py create mode 100644 test/http/__init__.py create mode 100644 test/http/http1/__init__.py create mode 100644 test/http/http1/test_protocol.py create mode 100644 test/http/http2/__init__.py create mode 100644 test/http/http2/test_frames.py create mode 100644 test/http/http2/test_protocol.py create mode 100644 test/http/test_authentication.py create mode 100644 test/http/test_cookies.py create mode 100644 test/http/test_semantics.py create mode 100644 test/http/test_user_agents.py delete mode 100644 test/http2/__init__.py delete mode 100644 test/http2/test_frames.py delete mode 100644 test/http2/test_protocol.py delete mode 100644 test/test_http.py delete mode 100644 test/test_http_auth.py delete mode 100644 test/test_http_cookies.py delete mode 100644 test/test_http_uastrings.py delete mode 100644 test/test_websockets.py create mode 100644 test/websockets/__init__.py create mode 100644 test/websockets/test_websockets.py diff --git a/netlib/http.py b/netlib/http.py deleted file mode 100644 index 073e9a3f..00000000 --- a/netlib/http.py +++ /dev/null @@ -1,583 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import collections -import string -import urlparse -import binascii -import sys -from . import odict, utils, tcp, http_semantics, http_status - - -class HttpError(Exception): - - def __init__(self, code, message): - super(HttpError, self).__init__(message) - self.code = code - - -class HttpErrorConnClosed(HttpError): - pass - - -def _is_valid_port(port): - if not 0 <= port <= 65535: - return False - return True - - -def _is_valid_host(host): - try: - host.decode("idna") - except ValueError: - return False - if "\0" in host: - return None - return True - - -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - - -def parse_url(url): - """ - Returns a (scheme, host, port, path) tuple, or None on error. - - Checks that: - port is an integer 0-65535 - host is a valid IDNA-encoded hostname with no null-bytes - path is valid ASCII - """ - try: - scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) - except ValueError: - return None - if not scheme: - return None - if '@' in netloc: - # FIXME: Consider what to do with the discarded credentials here Most - # probably we should extend the signature to return these as a separate - # value. - _, netloc = string.rsplit(netloc, '@', maxsplit=1) - if ':' in netloc: - host, port = string.rsplit(netloc, ':', maxsplit=1) - try: - port = int(port) - except ValueError: - return None - else: - host = netloc - if scheme == "https": - port = 443 - else: - port = 80 - path = urlparse.urlunparse(('', '', path, params, query, fragment)) - if not path.startswith("/"): - path = "/" + path - if not _is_valid_host(host): - return None - if not utils.isascii(path): - return None - if not _is_valid_port(port): - return None - return scheme, host, port, path - - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) - else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_http_basic_auth(s): - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not _is_valid_port(port): - return None - if not _is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) - - -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - wfile.flush() - del headers['expect'] - - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Return an (httpversion, code, msg, headers, content) tuple. - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http_semantics.Response(httpversion, code, msg, headers, content) - - -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = http_status.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py new file mode 100644 index 00000000..9b4b0e6b --- /dev/null +++ b/netlib/http/__init__.py @@ -0,0 +1,2 @@ +from exceptions import * +from semantics import * diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py new file mode 100644 index 00000000..26e3c2c4 --- /dev/null +++ b/netlib/http/authentication.py @@ -0,0 +1,149 @@ +from __future__ import (absolute_import, print_function, division) +from argparse import Action, ArgumentTypeError + +from .. import http + + +class NullProxyAuth(object): + + """ + No proxy auth at all (returns empty challange headers) + """ + + def __init__(self, password_manager): + self.password_manager = password_manager + + def clean(self, headers_): + """ + Clean up authentication headers, so they're not passed upstream. + """ + pass + + def authenticate(self, headers_): + """ + Tests that the user is allowed to use the proxy + """ + return True + + def auth_challenge_headers(self): + """ + Returns a dictionary containing the headers require to challenge the user + """ + return {} + + +class BasicProxyAuth(NullProxyAuth): + CHALLENGE_HEADER = 'Proxy-Authenticate' + AUTH_HEADER = 'Proxy-Authorization' + + def __init__(self, password_manager, realm): + NullProxyAuth.__init__(self, password_manager) + self.realm = realm + + def clean(self, headers): + del headers[self.AUTH_HEADER] + + def authenticate(self, headers): + auth_value = headers.get(self.AUTH_HEADER, []) + if not auth_value: + return False + parts = http.http1.parse_http_basic_auth(auth_value[0]) + if not parts: + return False + scheme, username, password = parts + if scheme.lower() != 'basic': + return False + if not self.password_manager.test(username, password): + return False + self.username = username + return True + + def auth_challenge_headers(self): + return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} + + +class PassMan(object): + + def test(self, username_, password_token_): + return False + + +class PassManNonAnon(PassMan): + + """ + Ensure the user specifies a username, accept any password. + """ + + def test(self, username, password_token_): + if username: + return True + return False + + +class PassManHtpasswd(PassMan): + + """ + Read usernames and passwords from an htpasswd file + """ + + def __init__(self, path): + """ + Raises ValueError if htpasswd file is invalid. + """ + import passlib.apache + self.htpasswd = passlib.apache.HtpasswdFile(path) + + def test(self, username, password_token): + return bool(self.htpasswd.check_password(username, password_token)) + + +class PassManSingleUser(PassMan): + + def __init__(self, username, password): + self.username, self.password = username, password + + def test(self, username, password_token): + return self.username == username and self.password == password_token + + +class AuthAction(Action): + + """ + Helper class to allow seamless integration int argparse. Example usage: + parser.add_argument( + "--nonanonymous", + action=NonanonymousAuthAction, nargs=0, + help="Allow access to any user long as a credentials are specified." + ) + """ + + def __call__(self, parser, namespace, values, option_string=None): + passman = self.getPasswordManager(values) + authenticator = BasicProxyAuth(passman, "mitmproxy") + setattr(namespace, self.dest, authenticator) + + def getPasswordManager(self, s): # pragma: nocover + raise NotImplementedError() + + +class SingleuserAuthAction(AuthAction): + + def getPasswordManager(self, s): + if len(s.split(':')) != 2: + raise ArgumentTypeError( + "Invalid single-user specification. Please use the format username:password" + ) + username, password = s.split(':') + return PassManSingleUser(username, password) + + +class NonanonymousAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManNonAnon() + + +class HtpasswdAuthAction(AuthAction): + + def getPasswordManager(self, s): + return PassManHtpasswd(s) diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py new file mode 100644 index 00000000..b77e3503 --- /dev/null +++ b/netlib/http/cookies.py @@ -0,0 +1,193 @@ +import re + +from .. import odict + +""" +A flexible module for cookie parsing and manipulation. + +This module differs from usual standards-compliant cookie modules in a number +of ways. We try to be as permissive as possible, and to retain even mal-formed +information. Duplicate cookies are preserved in parsing, and can be set in +formatting. We do attempt to escape and quote values where needed, but will not +reject data that violate the specs. + +Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do +not parse the comma-separated variant of Set-Cookie that allows multiple +cookies to be set in a single header. Technically this should be feasible, but +it turns out that violations of RFC6265 that makes the parsing problem +indeterminate are much more common than genuine occurences of the multi-cookie +variants. Serialization follows RFC6265. + + http://tools.ietf.org/html/rfc6265 + http://tools.ietf.org/html/rfc2109 + http://tools.ietf.org/html/rfc2965 +""" + +# TODO +# - Disallow LHS-only Cookie values + + +def _read_until(s, start, term): + """ + Read until one of the characters in term is reached. + """ + if start == len(s): + return "", start + 1 + for i in range(start, len(s)): + if s[i] in term: + return s[start:i], i + return s[start:i + 1], i + 1 + + +def _read_token(s, start): + """ + Read a token - the LHS of a token/value pair in a cookie. + """ + return _read_until(s, start, ";=") + + +def _read_quoted_string(s, start): + """ + start: offset to the first quote of the string to be read + + A sort of loose super-set of the various quoted string specifications. + + RFC6265 disallows backslashes or double quotes within quoted strings. + Prior RFCs use backslashes to escape. This leaves us free to apply + backslash escaping by default and be compatible with everything. + """ + escaping = False + ret = [] + # Skip the first quote + for i in range(start + 1, len(s)): + if escaping: + ret.append(s[i]) + escaping = False + elif s[i] == '"': + break + elif s[i] == "\\": + escaping = True + else: + ret.append(s[i]) + return "".join(ret), i + 1 + + +def _read_value(s, start, delims): + """ + Reads a value - the RHS of a token/value pair in a cookie. + + special: If the value is special, commas are premitted. Else comma + terminates. This helps us support old and new style values. + """ + if start >= len(s): + return "", start + elif s[start] == '"': + return _read_quoted_string(s, start) + else: + return _read_until(s, start, delims) + + +def _read_pairs(s, off=0): + """ + Read pairs of lhs=rhs values. + + off: start offset + specials: a lower-cased list of keys that may contain commas + """ + vals = [] + while True: + lhs, off = _read_token(s, off) + lhs = lhs.lstrip() + if lhs: + rhs = None + if off < len(s): + if s[off] == "=": + rhs, off = _read_value(s, off + 1, ";") + vals.append([lhs, rhs]) + off += 1 + if not off < len(s): + break + return vals, off + + +def _has_special(s): + for i in s: + if i in '",;\\': + return True + o = ord(i) + if o < 0x21 or o > 0x7e: + return True + return False + + +ESCAPE = re.compile(r"([\"\\])") + + +def _format_pairs(lst, specials=(), sep="; "): + """ + specials: A lower-cased list of keys that will not be quoted. + """ + vals = [] + for k, v in lst: + if v is None: + vals.append(k) + else: + if k.lower() not in specials and _has_special(v): + v = ESCAPE.sub(r"\\\1", v) + v = '"%s"' % v + vals.append("%s=%s" % (k, v)) + return sep.join(vals) + + +def _format_set_cookie_pairs(lst): + return _format_pairs( + lst, + specials=("expires", "path") + ) + + +def _parse_set_cookie_pairs(s): + """ + For Set-Cookie, we support multiple cookies as described in RFC2109. + This function therefore returns a list of lists. + """ + pairs, off_ = _read_pairs(s) + return pairs + + +def parse_set_cookie_header(line): + """ + Parse a Set-Cookie header value + + Returns a (name, value, attrs) tuple, or None, where attrs is an + ODictCaseless set of attributes. No attempt is made to parse attribute + values - they are treated purely as strings. + """ + pairs = _parse_set_cookie_pairs(line) + if pairs: + return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) + + +def format_set_cookie_header(name, value, attrs): + """ + Formats a Set-Cookie header value. + """ + pairs = [[name, value]] + pairs.extend(attrs.lst) + return _format_set_cookie_pairs(pairs) + + +def parse_cookie_header(line): + """ + Parse a Cookie header value. + Returns a (possibly empty) ODict object. + """ + pairs, off_ = _read_pairs(line) + return odict.ODict(pairs) + + +def format_cookie_header(od): + """ + Formats a Cookie header value. + """ + return _format_pairs(od.lst) diff --git a/netlib/http/exceptions.py b/netlib/http/exceptions.py new file mode 100644 index 00000000..8a2bbebc --- /dev/null +++ b/netlib/http/exceptions.py @@ -0,0 +1,9 @@ +class HttpError(Exception): + + def __init__(self, code, message): + super(HttpError, self).__init__(message) + self.code = code + + +class HttpErrorConnClosed(HttpError): + pass diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py new file mode 100644 index 00000000..6b5043af --- /dev/null +++ b/netlib/http/http1/__init__.py @@ -0,0 +1 @@ +from protocol import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py new file mode 100644 index 00000000..0f7a0bd3 --- /dev/null +++ b/netlib/http/http1/protocol.py @@ -0,0 +1,518 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from netlib import odict, utils, tcp, http +from .. import status_codes +from ..exceptions import * + + +def get_request_line(fp): + """ + Get a line, possibly preceded by a blank. + """ + line = fp.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = fp.readline() + return line + +def read_headers(fp): + """ + Read a set of headers from a file pointer. Stop once a blank line is + reached. Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = fp.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() + else: + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + +def read_chunked(fp, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = fp.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = fp.read(length) + suffix = fp.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks + + +def has_chunked_encoding(headers): + return "chunked" in [ + i.lower() for i in get_header_tokens(headers, "transfer-encoding") + ] + + +def parse_http_protocol(s): + """ + Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or + None. + """ + if not s.startswith("HTTP/"): + return None + _, version = s.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor + + +def parse_http_basic_auth(s): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + + +def parse_init(line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion + + +def parse_init_connect(line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + if method.upper() != 'CONNECT': + return None + try: + host, port = url.split(":") + except ValueError: + return None + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + +def parse_init_proxy(line): + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + +def parse_init_http(line): + """ + Returns (method, url, httpversion) + """ + v = parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + +def connection_close(httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + if httpversion == (1, 1): + return False + return True + + +def parse_response_line(line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + +def read_http_body(*args, **kwargs): + return "".join( + content for _, content, _ in read_http_body_chunked(*args, **kwargs) + ) + + +def read_http_body_chunked( + rfile, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None +): + """ + Read an HTTP message body: + + rfile: A file descriptor to read from + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if has_chunked_encoding(headers): + # Python 3: yield from + for x in read_chunked(rfile, limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + +def expected_http_body_size(headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + +# TODO: make this a regular class - just like Response +Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] +) + + +def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): + """ + Parse an HTTP request from a file stream + + Args: + rfile (file): Input file to read from + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = get_request_line(rfile) + if not request_line: + raise tcp.NetLibDisconnect() + + request_line_parts = parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = read_headers(rfile) + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + wfile.flush() + del headers['expect'] + + if include_body: + content = read_http_body( + rfile, headers, body_size_limit, method, None, True + ) + + return Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) + + +def read_response(rfile, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + + line = rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = read_headers(rfile) + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = read_http_body( + rfile, + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) + + +def request_preamble(method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + +def response_preamble(code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py new file mode 100644 index 00000000..5acf7696 --- /dev/null +++ b/netlib/http/http2/__init__.py @@ -0,0 +1,2 @@ +from frame import * +from protocol import * diff --git a/netlib/http/http2/frame.py b/netlib/http/http2/frame.py new file mode 100644 index 00000000..f7e60471 --- /dev/null +++ b/netlib/http/http2/frame.py @@ -0,0 +1,636 @@ +import sys +import struct +from hpack.hpack import Encoder, Decoder + +from .. import utils + + +class FrameSizeError(Exception): + pass + + +class Frame(object): + + """ + Baseclass Frame + contains header + payload is defined in subclasses + """ + + FLAG_NO_FLAGS = 0x0 + FLAG_ACK = 0x1 + FLAG_END_STREAM = 0x1 + FLAG_END_HEADERS = 0x4 + FLAG_PADDED = 0x8 + FLAG_PRIORITY = 0x20 + + def __init__( + self, + state=None, + length=0, + flags=FLAG_NO_FLAGS, + stream_id=0x0): + valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) + if flags | valid_flags != valid_flags: + raise ValueError('invalid flags detected.') + + if state is None: + class State(object): + pass + + state = State() + state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() + state.encoder = Encoder() + state.decoder = Decoder() + + self.state = state + + self.length = length + self.type = self.TYPE + self.flags = flags + self.stream_id = stream_id + + @classmethod + def _check_frame_size(cls, length, state): + if state: + settings = state.http2_settings + else: + settings = HTTP2_DEFAULT_SETTINGS.copy() + + max_frame_size = settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] + + if length > max_frame_size: + raise FrameSizeError( + "Frame size exceeded: %d, but only %d allowed." % ( + length, max_frame_size)) + + @classmethod + def from_file(cls, fp, state=None): + """ + read a HTTP/2 frame sent by a server or client + fp is a "file like" object that could be backed by a network + stream or a disk or an in memory stream reader + """ + raw_header = fp.safe_read(9) + + fields = struct.unpack("!HBBBL", raw_header) + length = (fields[0] << 8) + fields[1] + flags = fields[3] + stream_id = fields[4] + + if raw_header[:4] == b'HTTP': # pragma no cover + print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" + + cls._check_frame_size(length, state) + + payload = fp.safe_read(length) + return FRAMES[fields[2]].from_bytes( + state, + length, + flags, + stream_id, + payload) + + def to_bytes(self): + payload = self.payload_bytes() + self.length = len(payload) + + self._check_frame_size(self.length, self.state) + + b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) + b += struct.pack('!B', self.TYPE) + b += struct.pack('!B', self.flags) + b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) + b += payload + + return b + + def payload_bytes(self): # pragma: no cover + raise NotImplementedError() + + def payload_human_readable(self): # pragma: no cover + raise NotImplementedError() + + def human_readable(self, direction="-"): + self.length = len(self.payload_bytes()) + + return "\n".join([ + "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( + direction, self.__class__.__name__, self.length, self.flags, self.stream_id), + self.payload_human_readable(), + "===============================================================", + ]) + + def __eq__(self, other): + return self.to_bytes() == other.to_bytes() + + +class DataFrame(Frame): + TYPE = 0x0 + VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'', + pad_length=0): + super(DataFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + self.pad_length = pad_length + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.payload = payload[1:-f.pad_length] + else: + f.payload = payload + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('DATA frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += bytes(self.payload) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + return "payload: %s" % str(self.payload) + + +class HeadersFrame(Frame): + TYPE = 0x1 + VALID_FLAGS = [ + Frame.FLAG_END_STREAM, + Frame.FLAG_END_HEADERS, + Frame.FLAG_PADDED, + Frame.FLAG_PRIORITY] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b'', + pad_length=0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(HeadersFrame, self).__init__(state, length, flags, stream_id) + + self.header_block_fragment = header_block_fragment + self.pad_length = pad_length + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length = struct.unpack('!B', payload[0])[0] + f.header_block_fragment = payload[1:-f.pad_length] + else: + f.header_block_fragment = payload[0:] + + if f.flags & Frame.FLAG_PRIORITY: + f.stream_dependency, f.weight = struct.unpack( + '!LB', f.header_block_fragment[:5]) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + f.header_block_fragment = f.header_block_fragment[5:] + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError('HEADERS frames MUST be associated with a stream.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + if self.flags & self.FLAG_PRIORITY: + b += struct.pack('!LB', + (int(self.exclusive) << 31) | self.stream_dependency, + self.weight) + + b += self.header_block_fragment + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PRIORITY: + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PriorityFrame(Frame): + TYPE = 0x2 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + exclusive=False, + stream_dependency=0x0, + weight=0): + super(PriorityFrame, self).__init__(state, length, flags, stream_id) + self.exclusive = exclusive + self.stream_dependency = stream_dependency + self.weight = weight + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.stream_dependency, f.weight = struct.unpack('!LB', payload) + f.exclusive = bool(f.stream_dependency >> 31) + f.stream_dependency &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PRIORITY frames MUST be associated with a stream.') + + if self.stream_dependency == 0x0: + raise ValueError('stream dependency is invalid.') + + return struct.pack( + '!LB', + (int( + self.exclusive) << 31) | self.stream_dependency, + self.weight) + + def payload_human_readable(self): + s = [] + s.append("exclusive: %d" % self.exclusive) + s.append("stream dependency: %#x" % self.stream_dependency) + s.append("weight: %d" % self.weight) + return "\n".join(s) + + +class RstStreamFrame(Frame): + TYPE = 0x3 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + error_code=0x0): + super(RstStreamFrame, self).__init__(state, length, flags, stream_id) + self.error_code = error_code + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.error_code = struct.unpack('!L', payload)[0] + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'RST_STREAM frames MUST be associated with a stream.') + + return struct.pack('!L', self.error_code) + + def payload_human_readable(self): + return "error code: %#x" % self.error_code + + +class SettingsFrame(Frame): + TYPE = 0x4 + VALID_FLAGS = [Frame.FLAG_ACK] + + SETTINGS = utils.BiDi( + SETTINGS_HEADER_TABLE_SIZE=0x1, + SETTINGS_ENABLE_PUSH=0x2, + SETTINGS_MAX_CONCURRENT_STREAMS=0x3, + SETTINGS_INITIAL_WINDOW_SIZE=0x4, + SETTINGS_MAX_FRAME_SIZE=0x5, + SETTINGS_MAX_HEADER_LIST_SIZE=0x6, + ) + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings=None): + super(SettingsFrame, self).__init__(state, length, flags, stream_id) + + if settings is None: + settings = {} + + self.settings = settings + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + for i in xrange(0, len(payload), 6): + identifier, value = struct.unpack("!HL", payload[i:i + 6]) + f.settings[identifier] = value + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'SETTINGS frames MUST NOT be associated with a stream.') + + b = b'' + for identifier, value in self.settings.items(): + b += struct.pack("!HL", identifier & 0xFF, value) + + return b + + def payload_human_readable(self): + s = [] + + for identifier, value in self.settings.items(): + s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) + + if not s: + return "settings: None" + else: + return "\n".join(s) + + +class PushPromiseFrame(Frame): + TYPE = 0x5 + VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x0, + header_block_fragment=b'', + pad_length=0): + super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) + self.pad_length = pad_length + self.promised_stream = promised_stream + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + if f.flags & Frame.FLAG_PADDED: + f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) + f.header_block_fragment = payload[5:-f.pad_length] + else: + f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) + f.header_block_fragment = payload[4:] + + f.promised_stream &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'PUSH_PROMISE frames MUST be associated with a stream.') + + if self.promised_stream == 0x0: + raise ValueError('Promised stream id not valid.') + + b = b'' + if self.flags & self.FLAG_PADDED: + b += struct.pack('!B', self.pad_length) + + b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) + b += bytes(self.header_block_fragment) + + if self.flags & self.FLAG_PADDED: + b += b'\0' * self.pad_length + + return b + + def payload_human_readable(self): + s = [] + + if self.flags & self.FLAG_PADDED: + s.append("padding: %d" % self.pad_length) + + s.append("promised stream: %#x" % self.promised_stream) + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + + return "\n".join(s) + + +class PingFrame(Frame): + TYPE = 0x6 + VALID_FLAGS = [Frame.FLAG_ACK] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b''): + super(PingFrame, self).__init__(state, length, flags, stream_id) + self.payload = payload + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.payload = payload + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'PING frames MUST NOT be associated with a stream.') + + b = self.payload[0:8] + b += b'\0' * (8 - len(b)) + return b + + def payload_human_readable(self): + return "opaque data: %s" % str(self.payload) + + +class GoAwayFrame(Frame): + TYPE = 0x7 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x0, + error_code=0x0, + data=b''): + super(GoAwayFrame, self).__init__(state, length, flags, stream_id) + self.last_stream = last_stream + self.error_code = error_code + self.data = data + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) + f.last_stream &= 0x7FFFFFFF + f.data = payload[8:] + + return f + + def payload_bytes(self): + if self.stream_id != 0x0: + raise ValueError( + 'GOAWAY frames MUST NOT be associated with a stream.') + + b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) + b += bytes(self.data) + return b + + def payload_human_readable(self): + s = [] + s.append("last stream: %#x" % self.last_stream) + s.append("error code: %d" % self.error_code) + s.append("debug data: %s" % str(self.data)) + return "\n".join(s) + + +class WindowUpdateFrame(Frame): + TYPE = 0x8 + VALID_FLAGS = [] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x0): + super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) + self.window_size_increment = window_size_increment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + + f.window_size_increment = struct.unpack("!L", payload)[0] + f.window_size_increment &= 0x7FFFFFFF + + return f + + def payload_bytes(self): + if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: + raise ValueError( + 'Window Szie Increment MUST be greater than 0 and less than 2^31.') + + return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) + + def payload_human_readable(self): + return "window size increment: %#x" % self.window_size_increment + + +class ContinuationFrame(Frame): + TYPE = 0x9 + VALID_FLAGS = [Frame.FLAG_END_HEADERS] + + def __init__( + self, + state=None, + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment=b''): + super(ContinuationFrame, self).__init__(state, length, flags, stream_id) + self.header_block_fragment = header_block_fragment + + @classmethod + def from_bytes(cls, state, length, flags, stream_id, payload): + f = cls(state=state, length=length, flags=flags, stream_id=stream_id) + f.header_block_fragment = payload + return f + + def payload_bytes(self): + if self.stream_id == 0x0: + raise ValueError( + 'CONTINUATION frames MUST be associated with a stream.') + + return self.header_block_fragment + + def payload_human_readable(self): + s = [] + s.append( + "header_block_fragment: %s" % + self.header_block_fragment.encode('hex')) + return "\n".join(s) + +_FRAME_CLASSES = [ + DataFrame, + HeadersFrame, + PriorityFrame, + RstStreamFrame, + SettingsFrame, + PushPromiseFrame, + PingFrame, + GoAwayFrame, + WindowUpdateFrame, + ContinuationFrame +] +FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} + + +HTTP2_DEFAULT_SETTINGS = { + SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, + SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, +} diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py new file mode 100644 index 00000000..8e5f5429 --- /dev/null +++ b/netlib/http/http2/protocol.py @@ -0,0 +1,240 @@ +from __future__ import (absolute_import, print_function, division) +import itertools + +from hpack.hpack import Encoder, Decoder +from .. import utils +from . import frame + + +class HTTP2Protocol(object): + + ERROR_CODES = utils.BiDi( + NO_ERROR=0x0, + PROTOCOL_ERROR=0x1, + INTERNAL_ERROR=0x2, + FLOW_CONTROL_ERROR=0x3, + SETTINGS_TIMEOUT=0x4, + STREAM_CLOSED=0x5, + FRAME_SIZE_ERROR=0x6, + REFUSED_STREAM=0x7, + CANCEL=0x8, + COMPRESSION_ERROR=0x9, + CONNECT_ERROR=0xa, + ENHANCE_YOUR_CALM=0xb, + INADEQUATE_SECURITY=0xc, + HTTP_1_1_REQUIRED=0xd + ) + + # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" + CLIENT_CONNECTION_PREFACE =\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + ALPN_PROTO_H2 = 'h2' + + def __init__(self, tcp_handler, is_server=False, dump_frames=False): + self.tcp_handler = tcp_handler + self.is_server = is_server + + self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() + self.current_stream_id = None + self.encoder = Encoder() + self.decoder = Decoder() + self.connection_preface_performed = False + self.dump_frames = dump_frames + + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def perform_server_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + magic_length = len(self.CLIENT_CONNECTION_PREFACE) + magic = self.tcp_handler.rfile.safe_read(magic_length) + assert magic == self.CLIENT_CONNECTION_PREFACE + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def perform_client_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + self.connection_preface_performed = True + + self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) + + self.send_frame(frame.SettingsFrame(state=self), hide=True) + self._receive_settings(hide=True) + + def next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + + def send_frame(self, frm, hide=False): + raw_bytes = frm.to_bytes() + self.tcp_handler.wfile.write(raw_bytes) + self.tcp_handler.wfile.flush() + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + def read_frame(self, hide=False): + frm = frame.Frame.from_file(self.tcp_handler.rfile, self) + if not hide and self.dump_frames: # pragma no cover + print(frm.human_readable("<<")) + if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: + self._apply_settings(frm.settings, hide) + + return frm + + def _apply_settings(self, settings, hide=False): + for setting, value in settings.items(): + old_value = self.http2_settings[setting] + if not old_value: + old_value = '-' + self.http2_settings[setting] = value + + frm = frame.SettingsFrame( + state=self, + flags=frame.Frame.FLAG_ACK) + self.send_frame(frm, hide) + + # be liberal in what we expect from the other end + # to be more strict use: self._read_settings_ack(hide) + + def _create_headers(self, headers, stream_id, end_stream=True): + # TODO: implement max frame size checks and sending in chunks + + flags = frame.Frame.FLAG_END_HEADERS + if end_stream: + flags |= frame.Frame.FLAG_END_STREAM + + header_block_fragment = self.encoder.encode(headers) + + frm = frame.HeadersFrame( + state=self, + flags=flags, + stream_id=stream_id, + header_block_fragment=header_block_fragment) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + def _create_body(self, body, stream_id): + if body is None or len(body) == 0: + return b'' + + # TODO: implement max frame size checks and sending in chunks + # TODO: implement flow-control window + + frm = frame.DataFrame( + state=self, + flags=frame.Frame.FLAG_END_STREAM, + stream_id=stream_id, + payload=body) + + if self.dump_frames: # pragma no cover + print(frm.human_readable(">>")) + + return [frm.to_bytes()] + + + def create_request(self, method, path, headers=None, body=None): + if headers is None: + headers = [] + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = [ + (b':method', bytes(method)), + (b':path', bytes(path)), + (b':scheme', b'https'), + (b':authority', authority), + ] + headers + + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id))) + + def read_response(self): + stream_id_, headers, body = self._receive_transmission() + return headers[':status'], headers, body + + def read_request(self): + return self._receive_transmission() + + def _receive_transmission(self): + body_expected = True + + stream_id = 0 + header_block_fragment = b'' + body = b'' + + while True: + frm = self.read_frame() + if isinstance(frm, frame.HeadersFrame)\ + or isinstance(frm, frame.ContinuationFrame): + stream_id = frm.stream_id + header_block_fragment += frm.header_block_fragment + if frm.flags & frame.Frame.FLAG_END_STREAM: + body_expected = False + if frm.flags & frame.Frame.FLAG_END_HEADERS: + break + + while body_expected: + frm = self.read_frame() + if isinstance(frm, frame.DataFrame): + body += frm.payload + if frm.flags & frame.Frame.FLAG_END_STREAM: + break + # TODO: implement window update & flow + + headers = {} + for header, value in self.decoder.decode(header_block_fragment): + headers[header] = value + + return stream_id, headers, body + + def create_response(self, code, stream_id=None, headers=None, body=None): + if headers is None: + headers = [] + + headers = [(b':status', bytes(str(code)))] + headers + + if not stream_id: + stream_id = self.next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(body is None)), + self._create_body(body, stream_id), + )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py new file mode 100644 index 00000000..e7e84fe3 --- /dev/null +++ b/netlib/http/semantics.py @@ -0,0 +1,94 @@ +from __future__ import (absolute_import, print_function, division) +import binascii +import collections +import string +import sys +import urlparse + +from .. import utils + +class Response(object): + + def __init__( + self, + httpversion, + status_code, + msg, + headers, + content, + sslinfo=None, + ): + self.httpversion = httpversion + self.status_code = status_code + self.msg = msg + self.headers = headers + self.content = content + self.sslinfo = sslinfo + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Response(%s - %s)" % (self.status_code, self.msg) + + + +def is_valid_port(port): + if not 0 <= port <= 65535: + return False + return True + + +def is_valid_host(host): + try: + host.decode("idna") + except ValueError: + return False + if "\0" in host: + return None + return True + + + +def parse_url(url): + """ + Returns a (scheme, host, port, path) tuple, or None on error. + + Checks that: + port is an integer 0-65535 + host is a valid IDNA-encoded hostname with no null-bytes + path is valid ASCII + """ + try: + scheme, netloc, path, params, query, fragment = urlparse.urlparse(url) + except ValueError: + return None + if not scheme: + return None + if '@' in netloc: + # FIXME: Consider what to do with the discarded credentials here Most + # probably we should extend the signature to return these as a separate + # value. + _, netloc = string.rsplit(netloc, '@', maxsplit=1) + if ':' in netloc: + host, port = string.rsplit(netloc, ':', maxsplit=1) + try: + port = int(port) + except ValueError: + return None + else: + host = netloc + if scheme == "https": + port = 443 + else: + port = 80 + path = urlparse.urlunparse(('', '', path, params, query, fragment)) + if not path.startswith("/"): + path = "/" + path + if not is_valid_host(host): + return None + if not utils.isascii(path): + return None + if not is_valid_port(port): + return None + return scheme, host, port, path diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py new file mode 100644 index 00000000..dc09f465 --- /dev/null +++ b/netlib/http/status_codes.py @@ -0,0 +1,104 @@ +from __future__ import (absolute_import, print_function, division) + +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 + +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 + +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415 +REQUESTED_RANGE_NOT_SATISFIABLE = 416 +EXPECTATION_FAILED = 417 + +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510 + +RESPONSES = { + # 100 + CONTINUE: "Continue", + SWITCHING: "Switching Protocols", + + # 200 + OK: "OK", + CREATED: "Created", + ACCEPTED: "Accepted", + NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", + NO_CONTENT: "No Content", + RESET_CONTENT: "Reset Content.", + PARTIAL_CONTENT: "Partial Content", + MULTI_STATUS: "Multi-Status", + + # 300 + MULTIPLE_CHOICE: "Multiple Choices", + MOVED_PERMANENTLY: "Moved Permanently", + FOUND: "Found", + SEE_OTHER: "See Other", + NOT_MODIFIED: "Not Modified", + USE_PROXY: "Use Proxy", + # 306 not defined?? + TEMPORARY_REDIRECT: "Temporary Redirect", + + # 400 + BAD_REQUEST: "Bad Request", + UNAUTHORIZED: "Unauthorized", + PAYMENT_REQUIRED: "Payment Required", + FORBIDDEN: "Forbidden", + NOT_FOUND: "Not Found", + NOT_ALLOWED: "Method Not Allowed", + NOT_ACCEPTABLE: "Not Acceptable", + PROXY_AUTH_REQUIRED: "Proxy Authentication Required", + REQUEST_TIMEOUT: "Request Time-out", + CONFLICT: "Conflict", + GONE: "Gone", + LENGTH_REQUIRED: "Length Required", + PRECONDITION_FAILED: "Precondition Failed", + REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", + REQUEST_URI_TOO_LONG: "Request-URI Too Long", + UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", + REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", + EXPECTATION_FAILED: "Expectation Failed", + + # 500 + INTERNAL_SERVER_ERROR: "Internal Server Error", + NOT_IMPLEMENTED: "Not Implemented", + BAD_GATEWAY: "Bad Gateway", + SERVICE_UNAVAILABLE: "Service Unavailable", + GATEWAY_TIMEOUT: "Gateway Time-out", + HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", + INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", + NOT_EXTENDED: "Not Extended" +} diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py new file mode 100644 index 00000000..e8681908 --- /dev/null +++ b/netlib/http/user_agents.py @@ -0,0 +1,52 @@ +from __future__ import (absolute_import, print_function, division) + +""" + A small collection of useful user-agent header strings. These should be + kept reasonably current to reflect common usage. +""" + +# pylint: line-too-long + +# A collection of (name, shortcut, string) tuples. + +UASTRINGS = [ + ("android", + "a", + "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa + ("blackberry", + "l", + "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa + ("bingbot", + "b", + "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa + ("chrome", + "c", + "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa + ("firefox", + "f", + "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa + ("googlebot", + "g", + "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa + ("ie9", + "i", + "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa + ("ipad", + "p", + "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa + ("iphone", + "h", + "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa + ("safari", + "s", + "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa +] + + +def get_by_shortcut(s): + """ + Retrieve a user agent entry by shortcut. + """ + for i in UASTRINGS: + if s == i[1]: + return i diff --git a/netlib/http2/__init__.py b/netlib/http2/__init__.py deleted file mode 100644 index 5acf7696..00000000 --- a/netlib/http2/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from frame import * -from protocol import * diff --git a/netlib/http2/frame.py b/netlib/http2/frame.py deleted file mode 100644 index f7e60471..00000000 --- a/netlib/http2/frame.py +++ /dev/null @@ -1,636 +0,0 @@ -import sys -import struct -from hpack.hpack import Encoder, Decoder - -from .. import utils - - -class FrameSizeError(Exception): - pass - - -class Frame(object): - - """ - Baseclass Frame - contains header - payload is defined in subclasses - """ - - FLAG_NO_FLAGS = 0x0 - FLAG_ACK = 0x1 - FLAG_END_STREAM = 0x1 - FLAG_END_HEADERS = 0x4 - FLAG_PADDED = 0x8 - FLAG_PRIORITY = 0x20 - - def __init__( - self, - state=None, - length=0, - flags=FLAG_NO_FLAGS, - stream_id=0x0): - valid_flags = reduce(lambda x, y: x | y, self.VALID_FLAGS, 0x0) - if flags | valid_flags != valid_flags: - raise ValueError('invalid flags detected.') - - if state is None: - class State(object): - pass - - state = State() - state.http2_settings = HTTP2_DEFAULT_SETTINGS.copy() - state.encoder = Encoder() - state.decoder = Decoder() - - self.state = state - - self.length = length - self.type = self.TYPE - self.flags = flags - self.stream_id = stream_id - - @classmethod - def _check_frame_size(cls, length, state): - if state: - settings = state.http2_settings - else: - settings = HTTP2_DEFAULT_SETTINGS.copy() - - max_frame_size = settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE] - - if length > max_frame_size: - raise FrameSizeError( - "Frame size exceeded: %d, but only %d allowed." % ( - length, max_frame_size)) - - @classmethod - def from_file(cls, fp, state=None): - """ - read a HTTP/2 frame sent by a server or client - fp is a "file like" object that could be backed by a network - stream or a disk or an in memory stream reader - """ - raw_header = fp.safe_read(9) - - fields = struct.unpack("!HBBBL", raw_header) - length = (fields[0] << 8) + fields[1] - flags = fields[3] - stream_id = fields[4] - - if raw_header[:4] == b'HTTP': # pragma no cover - print >> sys.stderr, "WARNING: This looks like an HTTP/1 connection!" - - cls._check_frame_size(length, state) - - payload = fp.safe_read(length) - return FRAMES[fields[2]].from_bytes( - state, - length, - flags, - stream_id, - payload) - - def to_bytes(self): - payload = self.payload_bytes() - self.length = len(payload) - - self._check_frame_size(self.length, self.state) - - b = struct.pack('!HB', self.length & 0xFFFF00, self.length & 0x0000FF) - b += struct.pack('!B', self.TYPE) - b += struct.pack('!B', self.flags) - b += struct.pack('!L', self.stream_id & 0x7FFFFFFF) - b += payload - - return b - - def payload_bytes(self): # pragma: no cover - raise NotImplementedError() - - def payload_human_readable(self): # pragma: no cover - raise NotImplementedError() - - def human_readable(self, direction="-"): - self.length = len(self.payload_bytes()) - - return "\n".join([ - "%s: %s | length: %d | flags: %#x | stream_id: %d" % ( - direction, self.__class__.__name__, self.length, self.flags, self.stream_id), - self.payload_human_readable(), - "===============================================================", - ]) - - def __eq__(self, other): - return self.to_bytes() == other.to_bytes() - - -class DataFrame(Frame): - TYPE = 0x0 - VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'', - pad_length=0): - super(DataFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - self.pad_length = pad_length - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.payload = payload[1:-f.pad_length] - else: - f.payload = payload - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('DATA frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += bytes(self.payload) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - return "payload: %s" % str(self.payload) - - -class HeadersFrame(Frame): - TYPE = 0x1 - VALID_FLAGS = [ - Frame.FLAG_END_STREAM, - Frame.FLAG_END_HEADERS, - Frame.FLAG_PADDED, - Frame.FLAG_PRIORITY] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b'', - pad_length=0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(HeadersFrame, self).__init__(state, length, flags, stream_id) - - self.header_block_fragment = header_block_fragment - self.pad_length = pad_length - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length = struct.unpack('!B', payload[0])[0] - f.header_block_fragment = payload[1:-f.pad_length] - else: - f.header_block_fragment = payload[0:] - - if f.flags & Frame.FLAG_PRIORITY: - f.stream_dependency, f.weight = struct.unpack( - '!LB', f.header_block_fragment[:5]) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - f.header_block_fragment = f.header_block_fragment[5:] - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError('HEADERS frames MUST be associated with a stream.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - if self.flags & self.FLAG_PRIORITY: - b += struct.pack('!LB', - (int(self.exclusive) << 31) | self.stream_dependency, - self.weight) - - b += self.header_block_fragment - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PRIORITY: - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PriorityFrame(Frame): - TYPE = 0x2 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - exclusive=False, - stream_dependency=0x0, - weight=0): - super(PriorityFrame, self).__init__(state, length, flags, stream_id) - self.exclusive = exclusive - self.stream_dependency = stream_dependency - self.weight = weight - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.stream_dependency, f.weight = struct.unpack('!LB', payload) - f.exclusive = bool(f.stream_dependency >> 31) - f.stream_dependency &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PRIORITY frames MUST be associated with a stream.') - - if self.stream_dependency == 0x0: - raise ValueError('stream dependency is invalid.') - - return struct.pack( - '!LB', - (int( - self.exclusive) << 31) | self.stream_dependency, - self.weight) - - def payload_human_readable(self): - s = [] - s.append("exclusive: %d" % self.exclusive) - s.append("stream dependency: %#x" % self.stream_dependency) - s.append("weight: %d" % self.weight) - return "\n".join(s) - - -class RstStreamFrame(Frame): - TYPE = 0x3 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - error_code=0x0): - super(RstStreamFrame, self).__init__(state, length, flags, stream_id) - self.error_code = error_code - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.error_code = struct.unpack('!L', payload)[0] - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'RST_STREAM frames MUST be associated with a stream.') - - return struct.pack('!L', self.error_code) - - def payload_human_readable(self): - return "error code: %#x" % self.error_code - - -class SettingsFrame(Frame): - TYPE = 0x4 - VALID_FLAGS = [Frame.FLAG_ACK] - - SETTINGS = utils.BiDi( - SETTINGS_HEADER_TABLE_SIZE=0x1, - SETTINGS_ENABLE_PUSH=0x2, - SETTINGS_MAX_CONCURRENT_STREAMS=0x3, - SETTINGS_INITIAL_WINDOW_SIZE=0x4, - SETTINGS_MAX_FRAME_SIZE=0x5, - SETTINGS_MAX_HEADER_LIST_SIZE=0x6, - ) - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings=None): - super(SettingsFrame, self).__init__(state, length, flags, stream_id) - - if settings is None: - settings = {} - - self.settings = settings - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - for i in xrange(0, len(payload), 6): - identifier, value = struct.unpack("!HL", payload[i:i + 6]) - f.settings[identifier] = value - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'SETTINGS frames MUST NOT be associated with a stream.') - - b = b'' - for identifier, value in self.settings.items(): - b += struct.pack("!HL", identifier & 0xFF, value) - - return b - - def payload_human_readable(self): - s = [] - - for identifier, value in self.settings.items(): - s.append("%s: %#x" % (self.SETTINGS.get_name(identifier), value)) - - if not s: - return "settings: None" - else: - return "\n".join(s) - - -class PushPromiseFrame(Frame): - TYPE = 0x5 - VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x0, - header_block_fragment=b'', - pad_length=0): - super(PushPromiseFrame, self).__init__(state, length, flags, stream_id) - self.pad_length = pad_length - self.promised_stream = promised_stream - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - if f.flags & Frame.FLAG_PADDED: - f.pad_length, f.promised_stream = struct.unpack('!BL', payload[:5]) - f.header_block_fragment = payload[5:-f.pad_length] - else: - f.promised_stream = int(struct.unpack("!L", payload[:4])[0]) - f.header_block_fragment = payload[4:] - - f.promised_stream &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'PUSH_PROMISE frames MUST be associated with a stream.') - - if self.promised_stream == 0x0: - raise ValueError('Promised stream id not valid.') - - b = b'' - if self.flags & self.FLAG_PADDED: - b += struct.pack('!B', self.pad_length) - - b += struct.pack('!L', self.promised_stream & 0x7FFFFFFF) - b += bytes(self.header_block_fragment) - - if self.flags & self.FLAG_PADDED: - b += b'\0' * self.pad_length - - return b - - def payload_human_readable(self): - s = [] - - if self.flags & self.FLAG_PADDED: - s.append("padding: %d" % self.pad_length) - - s.append("promised stream: %#x" % self.promised_stream) - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - - return "\n".join(s) - - -class PingFrame(Frame): - TYPE = 0x6 - VALID_FLAGS = [Frame.FLAG_ACK] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b''): - super(PingFrame, self).__init__(state, length, flags, stream_id) - self.payload = payload - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.payload = payload - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'PING frames MUST NOT be associated with a stream.') - - b = self.payload[0:8] - b += b'\0' * (8 - len(b)) - return b - - def payload_human_readable(self): - return "opaque data: %s" % str(self.payload) - - -class GoAwayFrame(Frame): - TYPE = 0x7 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x0, - error_code=0x0, - data=b''): - super(GoAwayFrame, self).__init__(state, length, flags, stream_id) - self.last_stream = last_stream - self.error_code = error_code - self.data = data - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.last_stream, f.error_code = struct.unpack("!LL", payload[:8]) - f.last_stream &= 0x7FFFFFFF - f.data = payload[8:] - - return f - - def payload_bytes(self): - if self.stream_id != 0x0: - raise ValueError( - 'GOAWAY frames MUST NOT be associated with a stream.') - - b = struct.pack('!LL', self.last_stream & 0x7FFFFFFF, self.error_code) - b += bytes(self.data) - return b - - def payload_human_readable(self): - s = [] - s.append("last stream: %#x" % self.last_stream) - s.append("error code: %d" % self.error_code) - s.append("debug data: %s" % str(self.data)) - return "\n".join(s) - - -class WindowUpdateFrame(Frame): - TYPE = 0x8 - VALID_FLAGS = [] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x0): - super(WindowUpdateFrame, self).__init__(state, length, flags, stream_id) - self.window_size_increment = window_size_increment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - - f.window_size_increment = struct.unpack("!L", payload)[0] - f.window_size_increment &= 0x7FFFFFFF - - return f - - def payload_bytes(self): - if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31: - raise ValueError( - 'Window Szie Increment MUST be greater than 0 and less than 2^31.') - - return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) - - def payload_human_readable(self): - return "window size increment: %#x" % self.window_size_increment - - -class ContinuationFrame(Frame): - TYPE = 0x9 - VALID_FLAGS = [Frame.FLAG_END_HEADERS] - - def __init__( - self, - state=None, - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment=b''): - super(ContinuationFrame, self).__init__(state, length, flags, stream_id) - self.header_block_fragment = header_block_fragment - - @classmethod - def from_bytes(cls, state, length, flags, stream_id, payload): - f = cls(state=state, length=length, flags=flags, stream_id=stream_id) - f.header_block_fragment = payload - return f - - def payload_bytes(self): - if self.stream_id == 0x0: - raise ValueError( - 'CONTINUATION frames MUST be associated with a stream.') - - return self.header_block_fragment - - def payload_human_readable(self): - s = [] - s.append( - "header_block_fragment: %s" % - self.header_block_fragment.encode('hex')) - return "\n".join(s) - -_FRAME_CLASSES = [ - DataFrame, - HeadersFrame, - PriorityFrame, - RstStreamFrame, - SettingsFrame, - PushPromiseFrame, - PingFrame, - GoAwayFrame, - WindowUpdateFrame, - ContinuationFrame -] -FRAMES = {cls.TYPE: cls for cls in _FRAME_CLASSES} - - -HTTP2_DEFAULT_SETTINGS = { - SettingsFrame.SETTINGS.SETTINGS_HEADER_TABLE_SIZE: 4096, - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: None, - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 2 ** 16 - 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_FRAME_SIZE: 2 ** 14, - SettingsFrame.SETTINGS.SETTINGS_MAX_HEADER_LIST_SIZE: None, -} diff --git a/netlib/http2/protocol.py b/netlib/http2/protocol.py deleted file mode 100644 index 8e5f5429..00000000 --- a/netlib/http2/protocol.py +++ /dev/null @@ -1,240 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -import itertools - -from hpack.hpack import Encoder, Decoder -from .. import utils -from . import frame - - -class HTTP2Protocol(object): - - ERROR_CODES = utils.BiDi( - NO_ERROR=0x0, - PROTOCOL_ERROR=0x1, - INTERNAL_ERROR=0x2, - FLOW_CONTROL_ERROR=0x3, - SETTINGS_TIMEOUT=0x4, - STREAM_CLOSED=0x5, - FRAME_SIZE_ERROR=0x6, - REFUSED_STREAM=0x7, - CANCEL=0x8, - COMPRESSION_ERROR=0x9, - CONNECT_ERROR=0xa, - ENHANCE_YOUR_CALM=0xb, - INADEQUATE_SECURITY=0xc, - HTTP_1_1_REQUIRED=0xd - ) - - # "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" - CLIENT_CONNECTION_PREFACE =\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - ALPN_PROTO_H2 = 'h2' - - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler - self.is_server = is_server - - self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() - self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() - self.connection_preface_performed = False - self.dump_frames = dump_frames - - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True - - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break - - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break - - def perform_server_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - magic_length = len(self.CLIENT_CONNECTION_PREFACE) - magic = self.tcp_handler.rfile.safe_read(magic_length) - assert magic == self.CLIENT_CONNECTION_PREFACE - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def perform_client_connection_preface(self, force=False): - if force or not self.connection_preface_performed: - self.connection_preface_performed = True - - self.tcp_handler.wfile.write(self.CLIENT_CONNECTION_PREFACE) - - self.send_frame(frame.SettingsFrame(state=self), hide=True) - self._receive_settings(hide=True) - - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - - def send_frame(self, frm, hide=False): - raw_bytes = frm.to_bytes() - self.tcp_handler.wfile.write(raw_bytes) - self.tcp_handler.wfile.flush() - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - def read_frame(self, hide=False): - frm = frame.Frame.from_file(self.tcp_handler.rfile, self) - if not hide and self.dump_frames: # pragma no cover - print(frm.human_readable("<<")) - if isinstance(frm, frame.SettingsFrame) and not frm.flags & frame.Frame.FLAG_ACK: - self._apply_settings(frm.settings, hide) - - return frm - - def _apply_settings(self, settings, hide=False): - for setting, value in settings.items(): - old_value = self.http2_settings[setting] - if not old_value: - old_value = '-' - self.http2_settings[setting] = value - - frm = frame.SettingsFrame( - state=self, - flags=frame.Frame.FLAG_ACK) - self.send_frame(frm, hide) - - # be liberal in what we expect from the other end - # to be more strict use: self._read_settings_ack(hide) - - def _create_headers(self, headers, stream_id, end_stream=True): - # TODO: implement max frame size checks and sending in chunks - - flags = frame.Frame.FLAG_END_HEADERS - if end_stream: - flags |= frame.Frame.FLAG_END_STREAM - - header_block_fragment = self.encoder.encode(headers) - - frm = frame.HeadersFrame( - state=self, - flags=flags, - stream_id=stream_id, - header_block_fragment=header_block_fragment) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - def _create_body(self, body, stream_id): - if body is None or len(body) == 0: - return b'' - - # TODO: implement max frame size checks and sending in chunks - # TODO: implement flow-control window - - frm = frame.DataFrame( - state=self, - flags=frame.Frame.FLAG_END_STREAM, - stream_id=stream_id, - payload=body) - - if self.dump_frames: # pragma no cover - print(frm.human_readable(">>")) - - return [frm.to_bytes()] - - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self): - stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body - - def read_request(self): - return self._receive_transmission() - - def _receive_transmission(self): - body_expected = True - - stream_id = 0 - header_block_fragment = b'' - body = b'' - - while True: - frm = self.read_frame() - if isinstance(frm, frame.HeadersFrame)\ - or isinstance(frm, frame.ContinuationFrame): - stream_id = frm.stream_id - header_block_fragment += frm.header_block_fragment - if frm.flags & frame.Frame.FLAG_END_STREAM: - body_expected = False - if frm.flags & frame.Frame.FLAG_END_HEADERS: - break - - while body_expected: - frm = self.read_frame() - if isinstance(frm, frame.DataFrame): - body += frm.payload - if frm.flags & frame.Frame.FLAG_END_STREAM: - break - # TODO: implement window update & flow - - headers = {} - for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value - - return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http_auth.py b/netlib/http_auth.py deleted file mode 100644 index adab4aed..00000000 --- a/netlib/http_auth.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import (absolute_import, print_function, division) -from argparse import Action, ArgumentTypeError -from . import http - - -class NullProxyAuth(object): - - """ - No proxy auth at all (returns empty challange headers) - """ - - def __init__(self, password_manager): - self.password_manager = password_manager - - def clean(self, headers_): - """ - Clean up authentication headers, so they're not passed upstream. - """ - pass - - def authenticate(self, headers_): - """ - Tests that the user is allowed to use the proxy - """ - return True - - def auth_challenge_headers(self): - """ - Returns a dictionary containing the headers require to challenge the user - """ - return {} - - -class BasicProxyAuth(NullProxyAuth): - CHALLENGE_HEADER = 'Proxy-Authenticate' - AUTH_HEADER = 'Proxy-Authorization' - - def __init__(self, password_manager, realm): - NullProxyAuth.__init__(self, password_manager) - self.realm = realm - - def clean(self, headers): - del headers[self.AUTH_HEADER] - - def authenticate(self, headers): - auth_value = headers.get(self.AUTH_HEADER, []) - if not auth_value: - return False - parts = http.parse_http_basic_auth(auth_value[0]) - if not parts: - return False - scheme, username, password = parts - if scheme.lower() != 'basic': - return False - if not self.password_manager.test(username, password): - return False - self.username = username - return True - - def auth_challenge_headers(self): - return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm} - - -class PassMan(object): - - def test(self, username_, password_token_): - return False - - -class PassManNonAnon(PassMan): - - """ - Ensure the user specifies a username, accept any password. - """ - - def test(self, username, password_token_): - if username: - return True - return False - - -class PassManHtpasswd(PassMan): - - """ - Read usernames and passwords from an htpasswd file - """ - - def __init__(self, path): - """ - Raises ValueError if htpasswd file is invalid. - """ - import passlib.apache - self.htpasswd = passlib.apache.HtpasswdFile(path) - - def test(self, username, password_token): - return bool(self.htpasswd.check_password(username, password_token)) - - -class PassManSingleUser(PassMan): - - def __init__(self, username, password): - self.username, self.password = username, password - - def test(self, username, password_token): - return self.username == username and self.password == password_token - - -class AuthAction(Action): - - """ - Helper class to allow seamless integration int argparse. Example usage: - parser.add_argument( - "--nonanonymous", - action=NonanonymousAuthAction, nargs=0, - help="Allow access to any user long as a credentials are specified." - ) - """ - - def __call__(self, parser, namespace, values, option_string=None): - passman = self.getPasswordManager(values) - authenticator = BasicProxyAuth(passman, "mitmproxy") - setattr(namespace, self.dest, authenticator) - - def getPasswordManager(self, s): # pragma: nocover - raise NotImplementedError() - - -class SingleuserAuthAction(AuthAction): - - def getPasswordManager(self, s): - if len(s.split(':')) != 2: - raise ArgumentTypeError( - "Invalid single-user specification. Please use the format username:password" - ) - username, password = s.split(':') - return PassManSingleUser(username, password) - - -class NonanonymousAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManNonAnon() - - -class HtpasswdAuthAction(AuthAction): - - def getPasswordManager(self, s): - return PassManHtpasswd(s) diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py deleted file mode 100644 index e91ee5c0..00000000 --- a/netlib/http_cookies.py +++ /dev/null @@ -1,193 +0,0 @@ -""" -A flexible module for cookie parsing and manipulation. - -This module differs from usual standards-compliant cookie modules in a number -of ways. We try to be as permissive as possible, and to retain even mal-formed -information. Duplicate cookies are preserved in parsing, and can be set in -formatting. We do attempt to escape and quote values where needed, but will not -reject data that violate the specs. - -Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We do -not parse the comma-separated variant of Set-Cookie that allows multiple -cookies to be set in a single header. Technically this should be feasible, but -it turns out that violations of RFC6265 that makes the parsing problem -indeterminate are much more common than genuine occurences of the multi-cookie -variants. Serialization follows RFC6265. - - http://tools.ietf.org/html/rfc6265 - http://tools.ietf.org/html/rfc2109 - http://tools.ietf.org/html/rfc2965 -""" - -# TODO -# - Disallow LHS-only Cookie values - -import re - -import odict - - -def _read_until(s, start, term): - """ - Read until one of the characters in term is reached. - """ - if start == len(s): - return "", start + 1 - for i in range(start, len(s)): - if s[i] in term: - return s[start:i], i - return s[start:i + 1], i + 1 - - -def _read_token(s, start): - """ - Read a token - the LHS of a token/value pair in a cookie. - """ - return _read_until(s, start, ";=") - - -def _read_quoted_string(s, start): - """ - start: offset to the first quote of the string to be read - - A sort of loose super-set of the various quoted string specifications. - - RFC6265 disallows backslashes or double quotes within quoted strings. - Prior RFCs use backslashes to escape. This leaves us free to apply - backslash escaping by default and be compatible with everything. - """ - escaping = False - ret = [] - # Skip the first quote - for i in range(start + 1, len(s)): - if escaping: - ret.append(s[i]) - escaping = False - elif s[i] == '"': - break - elif s[i] == "\\": - escaping = True - else: - ret.append(s[i]) - return "".join(ret), i + 1 - - -def _read_value(s, start, delims): - """ - Reads a value - the RHS of a token/value pair in a cookie. - - special: If the value is special, commas are premitted. Else comma - terminates. This helps us support old and new style values. - """ - if start >= len(s): - return "", start - elif s[start] == '"': - return _read_quoted_string(s, start) - else: - return _read_until(s, start, delims) - - -def _read_pairs(s, off=0): - """ - Read pairs of lhs=rhs values. - - off: start offset - specials: a lower-cased list of keys that may contain commas - """ - vals = [] - while True: - lhs, off = _read_token(s, off) - lhs = lhs.lstrip() - if lhs: - rhs = None - if off < len(s): - if s[off] == "=": - rhs, off = _read_value(s, off + 1, ";") - vals.append([lhs, rhs]) - off += 1 - if not off < len(s): - break - return vals, off - - -def _has_special(s): - for i in s: - if i in '",;\\': - return True - o = ord(i) - if o < 0x21 or o > 0x7e: - return True - return False - - -ESCAPE = re.compile(r"([\"\\])") - - -def _format_pairs(lst, specials=(), sep="; "): - """ - specials: A lower-cased list of keys that will not be quoted. - """ - vals = [] - for k, v in lst: - if v is None: - vals.append(k) - else: - if k.lower() not in specials and _has_special(v): - v = ESCAPE.sub(r"\\\1", v) - v = '"%s"' % v - vals.append("%s=%s" % (k, v)) - return sep.join(vals) - - -def _format_set_cookie_pairs(lst): - return _format_pairs( - lst, - specials=("expires", "path") - ) - - -def _parse_set_cookie_pairs(s): - """ - For Set-Cookie, we support multiple cookies as described in RFC2109. - This function therefore returns a list of lists. - """ - pairs, off_ = _read_pairs(s) - return pairs - - -def parse_set_cookie_header(line): - """ - Parse a Set-Cookie header value - - Returns a (name, value, attrs) tuple, or None, where attrs is an - ODictCaseless set of attributes. No attempt is made to parse attribute - values - they are treated purely as strings. - """ - pairs = _parse_set_cookie_pairs(line) - if pairs: - return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) - - -def format_set_cookie_header(name, value, attrs): - """ - Formats a Set-Cookie header value. - """ - pairs = [[name, value]] - pairs.extend(attrs.lst) - return _format_set_cookie_pairs(pairs) - - -def parse_cookie_header(line): - """ - Parse a Cookie header value. - Returns a (possibly empty) ODict object. - """ - pairs, off_ = _read_pairs(line) - return odict.ODict(pairs) - - -def format_cookie_header(od): - """ - Formats a Cookie header value. - """ - return _format_pairs(od.lst) diff --git a/netlib/http_semantics.py b/netlib/http_semantics.py deleted file mode 100644 index e8313e3c..00000000 --- a/netlib/http_semantics.py +++ /dev/null @@ -1,23 +0,0 @@ -class Response(object): - - def __init__( - self, - httpversion, - status_code, - msg, - headers, - content, - sslinfo=None, - ): - self.httpversion = httpversion - self.status_code = status_code - self.msg = msg - self.headers = headers - self.content = content - self.sslinfo = sslinfo - - def __eq__(self, other): - return self.__dict__ == other.__dict__ - - def __repr__(self): - return "Response(%s - %s)" % (self.status_code, self.msg) diff --git a/netlib/http_status.py b/netlib/http_status.py deleted file mode 100644 index dc09f465..00000000 --- a/netlib/http_status.py +++ /dev/null @@ -1,104 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -CONTINUE = 100 -SWITCHING = 101 -OK = 200 -CREATED = 201 -ACCEPTED = 202 -NON_AUTHORITATIVE_INFORMATION = 203 -NO_CONTENT = 204 -RESET_CONTENT = 205 -PARTIAL_CONTENT = 206 -MULTI_STATUS = 207 - -MULTIPLE_CHOICE = 300 -MOVED_PERMANENTLY = 301 -FOUND = 302 -SEE_OTHER = 303 -NOT_MODIFIED = 304 -USE_PROXY = 305 -TEMPORARY_REDIRECT = 307 - -BAD_REQUEST = 400 -UNAUTHORIZED = 401 -PAYMENT_REQUIRED = 402 -FORBIDDEN = 403 -NOT_FOUND = 404 -NOT_ALLOWED = 405 -NOT_ACCEPTABLE = 406 -PROXY_AUTH_REQUIRED = 407 -REQUEST_TIMEOUT = 408 -CONFLICT = 409 -GONE = 410 -LENGTH_REQUIRED = 411 -PRECONDITION_FAILED = 412 -REQUEST_ENTITY_TOO_LARGE = 413 -REQUEST_URI_TOO_LONG = 414 -UNSUPPORTED_MEDIA_TYPE = 415 -REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED = 417 - -INTERNAL_SERVER_ERROR = 500 -NOT_IMPLEMENTED = 501 -BAD_GATEWAY = 502 -SERVICE_UNAVAILABLE = 503 -GATEWAY_TIMEOUT = 504 -HTTP_VERSION_NOT_SUPPORTED = 505 -INSUFFICIENT_STORAGE_SPACE = 507 -NOT_EXTENDED = 510 - -RESPONSES = { - # 100 - CONTINUE: "Continue", - SWITCHING: "Switching Protocols", - - # 200 - OK: "OK", - CREATED: "Created", - ACCEPTED: "Accepted", - NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information", - NO_CONTENT: "No Content", - RESET_CONTENT: "Reset Content.", - PARTIAL_CONTENT: "Partial Content", - MULTI_STATUS: "Multi-Status", - - # 300 - MULTIPLE_CHOICE: "Multiple Choices", - MOVED_PERMANENTLY: "Moved Permanently", - FOUND: "Found", - SEE_OTHER: "See Other", - NOT_MODIFIED: "Not Modified", - USE_PROXY: "Use Proxy", - # 306 not defined?? - TEMPORARY_REDIRECT: "Temporary Redirect", - - # 400 - BAD_REQUEST: "Bad Request", - UNAUTHORIZED: "Unauthorized", - PAYMENT_REQUIRED: "Payment Required", - FORBIDDEN: "Forbidden", - NOT_FOUND: "Not Found", - NOT_ALLOWED: "Method Not Allowed", - NOT_ACCEPTABLE: "Not Acceptable", - PROXY_AUTH_REQUIRED: "Proxy Authentication Required", - REQUEST_TIMEOUT: "Request Time-out", - CONFLICT: "Conflict", - GONE: "Gone", - LENGTH_REQUIRED: "Length Required", - PRECONDITION_FAILED: "Precondition Failed", - REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large", - REQUEST_URI_TOO_LONG: "Request-URI Too Long", - UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type", - REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable", - EXPECTATION_FAILED: "Expectation Failed", - - # 500 - INTERNAL_SERVER_ERROR: "Internal Server Error", - NOT_IMPLEMENTED: "Not Implemented", - BAD_GATEWAY: "Bad Gateway", - SERVICE_UNAVAILABLE: "Service Unavailable", - GATEWAY_TIMEOUT: "Gateway Time-out", - HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported", - INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space", - NOT_EXTENDED: "Not Extended" -} diff --git a/netlib/http_uastrings.py b/netlib/http_uastrings.py deleted file mode 100644 index e8681908..00000000 --- a/netlib/http_uastrings.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import (absolute_import, print_function, division) - -""" - A small collection of useful user-agent header strings. These should be - kept reasonably current to reflect common usage. -""" - -# pylint: line-too-long - -# A collection of (name, shortcut, string) tuples. - -UASTRINGS = [ - ("android", - "a", - "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa - ("blackberry", - "l", - "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa - ("bingbot", - "b", - "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa - ("chrome", - "c", - "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa - ("firefox", - "f", - "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa - ("googlebot", - "g", - "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa - ("ie9", - "i", - "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa - ("ipad", - "p", - "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa - ("iphone", - "h", - "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa - ("safari", - "s", - "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa -] - - -def get_by_shortcut(s): - """ - Retrieve a user agent entry by shortcut. - """ - for i in UASTRINGS: - if s == i[1]: - return i diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index d41059fa..49d8ee10 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -6,7 +6,7 @@ import struct import io from .protocol import Masker -from .. import utils, odict, tcp +from netlib import utils, odict, tcp DEFAULT = object() diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index dcab53fb..29b4db3d 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -5,7 +5,7 @@ import os import struct import io -from .. import utils, odict, tcp +from netlib import utils, odict, tcp # Colleciton of utility functions that implement small portions of the RFC6455 # WebSockets Protocol Useful for building WebSocket clients and servers. diff --git a/test/http/__init__.py b/test/http/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http1/__init__.py b/test/http/http1/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py new file mode 100644 index 00000000..05e82831 --- /dev/null +++ b/test/http/http1/test_protocol.py @@ -0,0 +1,445 @@ +import cStringIO +import textwrap +import binascii + +from netlib import http, odict, tcp +from netlib.http.http1 import protocol +from ... import tutils, tservers + + +def test_has_chunked_encoding(): + h = odict.ODictCaseless() + assert not protocol.has_chunked_encoding(h) + h["transfer-encoding"] = ["chunked"] + assert protocol.has_chunked_encoding(h) + + +def test_read_chunked(): + + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] + s = cStringIO.StringIO("1\r\na\r\n0\r\n") + + tutils.raises( + "malformed chunked body", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") + assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + + s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") + assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + + s = cStringIO.StringIO("\r\n") + tutils.raises( + "closed prematurely", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("1\r\nfoo") + tutils.raises( + "malformed chunked body", + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("foo\r\nfoo") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", None, True + ) + + s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") + tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True) + + +def test_connection_close(): + h = odict.ODictCaseless() + assert protocol.connection_close((1, 0), h) + assert not protocol.connection_close((1, 1), h) + + h["connection"] = ["keep-alive"] + assert not protocol.connection_close((1, 1), h) + + h["connection"] = ["close"] + assert protocol.connection_close((1, 1), h) + + +def test_get_header_tokens(): + h = odict.ODictCaseless() + assert protocol.get_header_tokens(h, "foo") == [] + h["foo"] = ["bar"] + assert protocol.get_header_tokens(h, "foo") == ["bar"] + h["foo"] = ["bar, voing"] + assert protocol.get_header_tokens(h, "foo") == ["bar", "voing"] + h["foo"] = ["bar, voing", "oink"] + assert protocol.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + + +def test_read_http_body_request(): + h = odict.ODictCaseless() + r = cStringIO.StringIO("testing") + assert protocol.read_http_body(r, h, None, "GET", None, True) == "" + + +def test_read_http_body_response(): + h = odict.ODictCaseless() + s = tcp.Reader(cStringIO.StringIO("testing")) + assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + + +def test_read_http_body(): + # test default case + h = odict.ODictCaseless() + h["content-length"] = [7] + s = cStringIO.StringIO("testing") + assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + + # test content length: invalid header + h["content-length"] = ["foo"] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", 200, False + ) + + # test content length: invalid header #2 + h["content-length"] = [-1] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, None, "GET", 200, False + ) + + # test content length: content length > actual content + h["content-length"] = [5] + s = cStringIO.StringIO("testing") + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, 4, "GET", 200, False + ) + + # test content length: content length < actual content + s = cStringIO.StringIO("testing") + assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5 + + # test no content length: limit > actual content + h = odict.ODictCaseless() + s = tcp.Reader(cStringIO.StringIO("testing")) + assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7 + + # test no content length: limit < actual content + s = tcp.Reader(cStringIO.StringIO("testing")) + tutils.raises( + protocol.HttpError, + protocol.read_http_body, + s, h, 4, "GET", 200, False + ) + + # test chunked + h = odict.ODictCaseless() + h["transfer-encoding"] = ["chunked"] + s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) + assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + + +def test_expected_http_body_size(): + # gibber in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["foo"] + assert protocol.expected_http_body_size(h, False, "GET", 200) is None + # negative number in the content-length field + h = odict.ODictCaseless() + h["content-length"] = ["-7"] + assert protocol.expected_http_body_size(h, False, "GET", 200) is None + # explicit length + h = odict.ODictCaseless() + h["content-length"] = ["5"] + assert protocol.expected_http_body_size(h, False, "GET", 200) == 5 + # no length + h = odict.ODictCaseless() + assert protocol.expected_http_body_size(h, False, "GET", 200) == -1 + # no length request + h = odict.ODictCaseless() + assert protocol.expected_http_body_size(h, True, "GET", None) == 0 + + +def test_parse_http_protocol(): + assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not protocol.parse_http_protocol("HTTP/a.1") + assert not protocol.parse_http_protocol("HTTP/1.a") + assert not protocol.parse_http_protocol("foo/0.0") + assert not protocol.parse_http_protocol("HTTP/x") + + +def test_parse_init_connect(): + assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not protocol.parse_init_connect("bogus") + assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + + +def test_parse_init_proxy(): + u = "GET http://foo.com:8888/test HTTP/1.1" + m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + assert m == "GET" + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + assert httpversion == (1, 1) + + u = "G\xfeET http://foo.com:8888/test HTTP/1.1" + assert not protocol.parse_init_proxy(u) + + assert not protocol.parse_init_proxy("invalid") + assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + + +def test_parse_init_http(): + u = "GET /test HTTP/1.1" + m, u, httpversion = protocol.parse_init_http(u) + assert m == "GET" + assert u == "/test" + assert httpversion == (1, 1) + + u = "G\xfeET /test HTTP/1.1" + assert not protocol.parse_init_http(u) + + assert not protocol.parse_init_http("invalid") + assert not protocol.parse_init_http("GET invalid HTTP/1.1") + assert not protocol.parse_init_http("GET /test foo/1.1") + assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + + +class TestReadHeaders: + + def _read(self, data, verbatim=False): + if not verbatim: + data = textwrap.dedent(data) + data = data.strip() + s = cStringIO.StringIO(data) + return protocol.read_headers(s) + + def test_read_simple(self): + data = """ + Header: one + Header2: two + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one"], ["Header2", "two"]] + + def test_read_multi(self): + data = """ + Header: one + Header: two + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one"], ["Header", "two"]] + + def test_read_continued(self): + data = """ + Header: one + \ttwo + Header2: three + \r\n + """ + h = self._read(data) + assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] + + def test_read_continued_err(self): + data = "\tfoo: bar\r\n" + assert self._read(data, True) is None + + def test_read_err(self): + data = """ + foo + """ + assert self._read(data) is None + + +class NoContentLengthHTTPHandler(tcp.BaseHandler): + + def handle(self): + self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") + self.wfile.flush() + + +class TestReadResponseNoContentLength(tservers.ServerTestBase): + handler = NoContentLengthHTTPHandler + + def test_no_content_length(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + resp = protocol.read_response(c.rfile, "GET", None) + assert resp.content == "bar\r\n\r\n" + + +def test_read_response(): + def tst(data, method, limit, include_body=True): + data = textwrap.dedent(data) + r = cStringIO.StringIO(data) + return protocol.read_response( + r, method, limit, include_body=include_body + ) + + tutils.raises("server disconnect", tst, "", "GET", None) + tutils.raises("invalid server response", tst, "foo", "GET", None) + data = """ + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 200, 'OK', odict.ODictCaseless(), '' + ) + data = """ + HTTP/1.1 200 + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 200, '', odict.ODictCaseless(), '' + ) + data = """ + HTTP/x 200 OK + """ + tutils.raises("invalid http version", tst, data, "GET", None) + data = """ + HTTP/1.1 xx OK + """ + tutils.raises("invalid server response", tst, data, "GET", None) + + data = """ + HTTP/1.1 100 CONTINUE + + HTTP/1.1 200 OK + """ + assert tst(data, "GET", None) == http.Response( + (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' + ) + + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None).content == 'foo' + assert tst(data, "HEAD", None).content == '' + + data = """ + HTTP/1.1 200 OK + \tContent-Length: 3 + + foo + """ + tutils.raises("invalid headers", tst, data, "GET", None) + + data = """ + HTTP/1.1 200 OK + Content-Length: 3 + + foo + """ + assert tst(data, "GET", None, include_body=False).content is None + + +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert protocol.parse_http_basic_auth( + protocol.assemble_http_basic_auth(*vals) + ) == vals + assert not protocol.parse_http_basic_auth("") + assert not protocol.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not protocol.parse_http_basic_auth(v) + + +def test_get_request_line(): + r = cStringIO.StringIO("\nfoo") + assert protocol.get_request_line(r) == "foo" + assert not protocol.get_request_line(r) + + +class TestReadRequest(): + + def tst(self, data, **kwargs): + r = cStringIO.StringIO(data) + return protocol.read_request(r, **kwargs) + + def test_invalid(self): + tutils.raises( + "bad http request", + self.tst, + "xxx" + ) + tutils.raises( + "bad http request line", + self.tst, + "get /\xff HTTP/1.1" + ) + tutils.raises( + "invalid headers", + self.tst, + "get / HTTP/1.1\r\nfoo" + ) + tutils.raises( + tcp.NetLibDisconnect, + self.tst, + "\r\n" + ) + + def test_asterisk_form_in(self): + v = self.tst("OPTIONS * HTTP/1.1") + assert v.form_in == "relative" + assert v.method == "OPTIONS" + + def test_absolute_form_in(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "GET oops-no-protocol.com HTTP/1.1" + ) + v = self.tst("GET http://address:22/ HTTP/1.1") + assert v.form_in == "absolute" + assert v.port == 22 + assert v.host == "address" + assert v.scheme == "http" + + def test_connect(self): + tutils.raises( + "Bad HTTP request line", + self.tst, + "CONNECT oops-no-port.com HTTP/1.1" + ) + v = self.tst("CONNECT foo.com:443 HTTP/1.1") + assert v.form_in == "authority" + assert v.method == "CONNECT" + assert v.port == 443 + assert v.host == "foo.com" + + def test_expect(self): + w = cStringIO.StringIO() + r = cStringIO.StringIO( + "GET / HTTP/1.1\r\n" + "Content-Length: 3\r\n" + "Expect: 100-continue\r\n\r\n" + "foobar", + ) + v = protocol.read_request(r, wfile=w) + assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + assert v.content == "foo" + assert r.read(3) == "bar" diff --git a/test/http/http2/__init__.py b/test/http/http2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/http/http2/test_frames.py b/test/http/http2/test_frames.py new file mode 100644 index 00000000..ee2edc39 --- /dev/null +++ b/test/http/http2/test_frames.py @@ -0,0 +1,704 @@ +import cStringIO +from test import tutils +from nose.tools import assert_equal +from netlib import tcp +from netlib.http.http2.frame import * + + +def hex_to_file(data): + data = data.decode('hex') + return tcp.Reader(cStringIO.StringIO(data)) + + +def test_invalid_flags(): + tutils.raises( + ValueError, + DataFrame, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + payload='foobar') + + +def test_frame_equality(): + a = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + b = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(a, b) + + +def test_too_large_frames(): + f = DataFrame( + length=9000, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar' * 3000) + tutils.raises(FrameSizeError, f.to_bytes) + + +def test_data_frame_to_bytes(): + f = DataFrame( + length=6, + flags=Frame.FLAG_END_STREAM, + stream_id=0x1234567, + payload='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') + + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000a00090123456703666f6f626172000000') + + f = DataFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_data_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) + assert isinstance(f, DataFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, DataFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.payload, 'foobar') + + +def test_data_frame_human_readable(): + f = DataFrame( + length=11, + flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), + stream_id=0x1234567, + payload='foobar', + pad_length=3) + assert f.human_readable() + + +def test_headers_frame_to_bytes(): + f = HeadersFrame( + length=6, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex')) + assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PADDED), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000b01080123456703668594e75e31d9000000') + + f = HeadersFrame( + length=10, + flags=(HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00000c012001234567876543212a668594e75e31d9') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703876543212a668594e75e31d9000000') + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert_equal( + f.to_bytes().encode('hex'), + '00001001280123456703076543212a668594e75e31d9000000') + + f = HeadersFrame( + length=6, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + header_block_fragment='668594e75e31d9'.decode('hex')) + tutils.raises(ValueError, f.to_bytes) + + +def test_headers_frame_from_bytes(): + f = Frame.from_file(hex_to_file( + '000007010001234567668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 7) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(hex_to_file( + '00000b01080123456703668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 11) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + + f = Frame.from_file(hex_to_file( + '00000c012001234567876543212a668594e75e31d9')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file( + '00001001280123456703876543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file( + '00001001280123456703076543212a668594e75e31d9000000')) + assert isinstance(f, HeadersFrame) + assert_equal(f.length, 16) + assert_equal(f.TYPE, HeadersFrame.TYPE) + assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + +def test_headers_frame_human_readable(): + f = HeadersFrame( + length=7, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment=b'', + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + f = HeadersFrame( + length=14, + flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), + stream_id=0x1234567, + header_block_fragment='668594e75e31d9'.decode('hex'), + pad_length=3, + exclusive=False, + stream_dependency=0x7654321, + weight=42) + assert f.human_readable() + + +def test_priority_frame_to_bytes(): + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=True, + stream_dependency=0x7654321, + weight=42) + assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') + + f = PriorityFrame( + length=5, + flags=(Frame.FLAG_NO_FLAGS), + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + stream_dependency=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + stream_dependency=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_priority_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000005020001234567876543212a')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, True) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 42) + + f = Frame.from_file(hex_to_file('0000050200012345670765432115')) + assert isinstance(f, PriorityFrame) + assert_equal(f.length, 5) + assert_equal(f.TYPE, PriorityFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.exclusive, False) + assert_equal(f.stream_dependency, 0x7654321) + assert_equal(f.weight, 21) + + +def test_priority_frame_human_readable(): + f = PriorityFrame( + length=5, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + exclusive=False, + stream_dependency=0x7654321, + weight=21) + assert f.human_readable() + + +def test_rst_stream_frame_to_bytes(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') + + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_rst_stream_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000403000123456707654321')) + assert isinstance(f, RstStreamFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, RstStreamFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.error_code, 0x07654321) + + +def test_rst_stream_frame_human_readable(): + f = RstStreamFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + error_code=0x7654321) + assert f.human_readable() + + +def test_settings_frame_to_bytes(): + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040000000000') + + f = SettingsFrame( + length=0, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0) + assert_equal(f.to_bytes().encode('hex'), '000000040100000000') + + f = SettingsFrame( + length=6, + flags=SettingsFrame.FLAG_ACK, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) + assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert_equal( + f.to_bytes().encode('hex'), + '00000c040000000000000200000001000312345678') + + f = SettingsFrame( + length=0, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_settings_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000000040000000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(hex_to_file('000000040100000000')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 0) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + + f = Frame.from_file(hex_to_file('000006040100000000000200000001')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 1) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + + f = Frame.from_file(hex_to_file( + '00000c040000000000000200000001000312345678')) + assert isinstance(f, SettingsFrame) + assert_equal(f.length, 12) + assert_equal(f.TYPE, SettingsFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(len(f.settings), 2) + assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) + assert_equal( + f.settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], + 0x12345678) + + +def test_settings_frame_human_readable(): + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={}) + assert f.human_readable() + + f = SettingsFrame( + length=12, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + settings={ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) + assert f.human_readable() + + +def test_push_promise_frame_to_bytes(): + f = PushPromiseFrame( + length=10, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000a05000123456707654321666f6f626172') + + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert_equal( + f.to_bytes().encode('hex'), + '00000e0508012345670307654321666f6f626172000000') + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + promised_stream=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + f = PushPromiseFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + promised_stream=0x0) + tutils.raises(ValueError, f.to_bytes) + + +def test_push_promise_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 10) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + f = Frame.from_file(hex_to_file( + '00000e0508012345670307654321666f6f626172000000')) + assert isinstance(f, PushPromiseFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, PushPromiseFrame.TYPE) + assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_push_promise_frame_human_readable(): + f = PushPromiseFrame( + length=14, + flags=HeadersFrame.FLAG_PADDED, + stream_id=0x1234567, + promised_stream=0x7654321, + header_block_fragment='foobar', + pad_length=3) + assert f.human_readable() + + +def test_ping_frame_to_bytes(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '000008060100000000666f6f6261720000') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + payload=b'foobardeadbeef') + assert_equal( + f.to_bytes().encode('hex'), + '000008060000000000666f6f6261726465') + + f = PingFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567) + tutils.raises(ValueError, f.to_bytes) + + +def test_ping_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, PingFrame.FLAG_ACK) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobar\0\0') + + f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) + assert isinstance(f, PingFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, PingFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.payload, b'foobarde') + + +def test_ping_frame_human_readable(): + f = PingFrame( + length=8, + flags=PingFrame.FLAG_ACK, + stream_id=0x0, + payload=b'foobar') + assert f.human_readable() + + +def test_goaway_frame_to_bytes(): + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'') + assert_equal( + f.to_bytes().encode('hex'), + '0000080700000000000123456787654321') + + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert_equal( + f.to_bytes().encode('hex'), + '00000e0700000000000123456787654321666f6f626172') + + f = GoAwayFrame( + length=8, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + last_stream=0x1234567, + error_code=0x87654321) + tutils.raises(ValueError, f.to_bytes) + + +def test_goaway_frame_from_bytes(): + f = Frame.from_file(hex_to_file( + '0000080700000000000123456787654321')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 8) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'') + + f = Frame.from_file(hex_to_file( + '00000e0700000000000123456787654321666f6f626172')) + assert isinstance(f, GoAwayFrame) + assert_equal(f.length, 14) + assert_equal(f.TYPE, GoAwayFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.last_stream, 0x1234567) + assert_equal(f.error_code, 0x87654321) + assert_equal(f.data, b'foobar') + + +def test_go_away_frame_human_readable(): + f = GoAwayFrame( + length=14, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + last_stream=0x1234567, + error_code=0x87654321, + data=b'foobar') + assert f.human_readable() + + +def test_window_update_frame_to_bytes(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0x1234567) + assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') + + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x0, + window_size_increment=0xdeadbeef) + tutils.raises(ValueError, f.to_bytes) + + f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) + tutils.raises(ValueError, f.to_bytes) + + +def test_window_update_frame_from_bytes(): + f = Frame.from_file(hex_to_file('00000408000000000001234567')) + assert isinstance(f, WindowUpdateFrame) + assert_equal(f.length, 4) + assert_equal(f.TYPE, WindowUpdateFrame.TYPE) + assert_equal(f.flags, Frame.FLAG_NO_FLAGS) + assert_equal(f.stream_id, 0x0) + assert_equal(f.window_size_increment, 0x1234567) + + +def test_window_update_frame_human_readable(): + f = WindowUpdateFrame( + length=4, + flags=Frame.FLAG_NO_FLAGS, + stream_id=0x1234567, + window_size_increment=0x7654321) + assert f.human_readable() + + +def test_continuation_frame_to_bytes(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') + + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x0, + header_block_fragment='foobar') + tutils.raises(ValueError, f.to_bytes) + + +def test_continuation_frame_from_bytes(): + f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) + assert isinstance(f, ContinuationFrame) + assert_equal(f.length, 6) + assert_equal(f.TYPE, ContinuationFrame.TYPE) + assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) + assert_equal(f.stream_id, 0x1234567) + assert_equal(f.header_block_fragment, 'foobar') + + +def test_continuation_frame_human_readable(): + f = ContinuationFrame( + length=6, + flags=ContinuationFrame.FLAG_END_HEADERS, + stream_id=0x1234567, + header_block_fragment='foobar') + assert f.human_readable() diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py new file mode 100644 index 00000000..f607860e --- /dev/null +++ b/test/http/http2/test_protocol.py @@ -0,0 +1,325 @@ +import OpenSSL + +from netlib import tcp +from netlib.http import http2 +from netlib.http.http2.frame import * +from ... import tutils, tservers + + +class EchoHandler(tcp.BaseHandler): + sni = None + + def handle(self): + while True: + v = self.rfile.safe_read(1) + self.wfile.write(v) + self.wfile.flush() + + +class TestCheckALPNMatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + assert protocol.check_alpn() + + +class TestCheckALPNMismatch(tservers.ServerTestBase): + handler = EchoHandler + ssl = dict( + alpn_select=None, + ) + + if OpenSSL._util.lib.Cryptography_HAS_ALPN: + + def test_check_alpn(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) + protocol = http2.HTTP2Protocol(c) + tutils.raises(NotImplementedError, protocol.check_alpn) + + +class TestPerformServerConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # send magic + self.wfile.write( + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) + self.wfile.flush() + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_server_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_server_connection_preface() + + +class TestPerformClientConnectionPreface(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check magic + assert self.rfile.read(24) ==\ + '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') + + # check empty settings frame + assert self.rfile.read(9) ==\ + '000000040000000000'.decode('hex') + + # send empty settings frame + self.wfile.write('000000040000000000'.decode('hex')) + self.wfile.flush() + + # check settings acknowledgement + assert self.rfile.read(9) == \ + '000000040100000000'.decode('hex') + + # send settings acknowledgement + self.wfile.write('000000040100000000'.decode('hex')) + self.wfile.flush() + + def test_perform_client_connection_preface(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + protocol = http2.HTTP2Protocol(c) + protocol.perform_client_connection_preface() + + +class TestClientStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_client_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 1 + assert self.protocol.current_stream_id == 1 + assert self.protocol.next_stream_id() == 3 + assert self.protocol.current_stream_id == 3 + assert self.protocol.next_stream_id() == 5 + assert self.protocol.current_stream_id == 5 + + +class TestServerStreamIds(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c, is_server=True) + + def test_server_stream_ids(self): + assert self.protocol.current_stream_id is None + assert self.protocol.next_stream_id() == 2 + assert self.protocol.current_stream_id == 2 + assert self.protocol.next_stream_id() == 4 + assert self.protocol.current_stream_id == 4 + assert self.protocol.next_stream_id() == 6 + assert self.protocol.current_stream_id == 6 + + +class TestApplySettings(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + # check settings acknowledgement + assert self.rfile.read(9) == '000000040100000000'.decode('hex') + self.wfile.write("OK") + self.wfile.flush() + + ssl = True + + def test_apply_settings(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + protocol._apply_settings({ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', + }) + + assert c.rfile.safe_read(2) == "OK" + + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' + assert protocol.http2_settings[ + SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' + + +class TestCreateHeaders(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_headers(self): + headers = [ + (b':method', b'GET'), + (b':path', b'index.html'), + (b':scheme', b'https'), + (b'foo', b'bar')] + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=True) + assert b''.join(bytes) ==\ + '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + bytes = http2.HTTP2Protocol(self.c)._create_headers( + headers, 1, end_stream=False) + assert b''.join(bytes) ==\ + '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ + .decode('hex') + + # TODO: add test for too large header_block_fragments + + +class TestCreateBody(): + c = tcp.TCPClient(("127.0.0.1", 0)) + protocol = http2.HTTP2Protocol(c) + + def test_create_body_empty(self): + bytes = self.protocol._create_body(b'', 1) + assert b''.join(bytes) == ''.decode('hex') + + def test_create_body_single_frame(self): + bytes = self.protocol._create_body('foobar', 1) + assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') + + def test_create_body_multiple_frames(self): + pass + # bytes = self.protocol._create_body('foobar' * 3000, 1) + # TODO: add test for too large frames + + +class TestCreateRequest(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_request_simple(self): + bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') + assert len(bytes) == 1 + assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') + + def test_create_request_with_body(self): + bytes = http2.HTTP2Protocol(self.c).create_request( + 'GET', '/', [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') + + +class TestReadResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801040000000188628594e78c767f'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'foobar' + + +class TestReadEmptyResponse(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'00000801050000000188628594e78c767f'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_empty_response(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c) + + status, headers, body = protocol.read_response() + + assert headers == {':status': '200', 'etag': 'foobar'} + assert status == "200" + assert body == b'' + + +class TestReadRequest(tservers.ServerTestBase): + class handler(tcp.BaseHandler): + + def handle(self): + self.wfile.write( + b'000003010400000001828487'.decode('hex')) + self.wfile.write( + b'000006000100000001666f6f626172'.decode('hex')) + self.wfile.flush() + + ssl = True + + def test_read_request(self): + c = tcp.TCPClient(("127.0.0.1", self.port)) + c.connect() + c.convert_to_ssl() + protocol = http2.HTTP2Protocol(c, is_server=True) + + stream_id, headers, body = protocol.read_request() + + assert stream_id + assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert body == b'foobar' + + +class TestCreateResponse(): + c = tcp.TCPClient(("127.0.0.1", 0)) + + def test_create_response_simple(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) + assert len(bytes) == 1 + assert bytes[0] ==\ + '00000101050000000288'.decode('hex') + + def test_create_response_with_body(self): + bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( + 200, 1, [(b'foo', b'bar')], 'foobar') + assert len(bytes) == 2 + assert bytes[0] ==\ + '00000901040000000188408294e7838c767f'.decode('hex') + assert bytes[1] ==\ + '000006000100000001666f6f626172'.decode('hex') diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py new file mode 100644 index 00000000..c0dae1a2 --- /dev/null +++ b/test/http/test_authentication.py @@ -0,0 +1,110 @@ +from netlib import odict, http +from netlib.http import authentication +from .. import tutils + + +class TestPassManNonAnon: + + def test_simple(self): + p = authentication.PassManNonAnon() + assert not p.test("", "") + assert p.test("user", "") + + +class TestPassManHtpasswd: + + def test_file_errors(self): + tutils.raises( + "malformed htpasswd file", + authentication.PassManHtpasswd, + tutils.test_data.path("data/server.crt")) + + def test_simple(self): + pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) + + vals = ("basic", "test", "test") + http.http1.assemble_http_basic_auth(*vals) + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + assert not pm.test("test", "") + assert not pm.test("", "") + + +class TestPassManSingleUser: + + def test_simple(self): + pm = authentication.PassManSingleUser("test", "test") + assert pm.test("test", "test") + assert not pm.test("test", "foo") + assert not pm.test("foo", "test") + + +class TestNullProxyAuth: + + def test_simple(self): + na = authentication.NullProxyAuth(authentication.PassManNonAnon()) + assert not na.auth_challenge_headers() + assert na.authenticate("foo") + na.clean({}) + + +class TestBasicProxyAuth: + + def test_simple(self): + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") + h = odict.ODictCaseless() + assert ba.auth_challenge_headers() + assert not ba.authenticate(h) + + def test_authenticate_clean(self): + ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test") + + hdrs = odict.ODictCaseless() + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert ba.authenticate(hdrs) + + ba.clean(hdrs) + assert not ba.AUTH_HEADER in hdrs + + hdrs[ba.AUTH_HEADER] = [""] + assert not ba.authenticate(hdrs) + + hdrs[ba.AUTH_HEADER] = ["foo"] + assert not ba.authenticate(hdrs) + + vals = ("foo", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + + ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") + vals = ("basic", "foo", "bar") + hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + assert not ba.authenticate(hdrs) + + +class Bunch: + pass + + +class TestAuthAction: + + def test_nonanonymous(self): + m = Bunch() + aa = authentication.NonanonymousAuthAction(None, "authenticator") + aa(None, m, None, None) + assert m.authenticator + + def test_singleuser(self): + m = Bunch() + aa = authentication.SingleuserAuthAction(None, "authenticator") + aa(None, m, "foo:bar", None) + assert m.authenticator + tutils.raises("invalid", aa, None, m, "foo", None) + + def test_httppasswd(self): + m = Bunch() + aa = authentication.HtpasswdAuthAction(None, "authenticator") + aa(None, m, tutils.test_data.path("data/htpasswd"), None) + assert m.authenticator diff --git a/test/http/test_cookies.py b/test/http/test_cookies.py new file mode 100644 index 00000000..4f99593a --- /dev/null +++ b/test/http/test_cookies.py @@ -0,0 +1,219 @@ +import nose.tools + +from netlib.http import cookies + + +def test_read_token(): + tokens = [ + [("foo", 0), ("foo", 3)], + [("foo", 1), ("oo", 3)], + [(" foo", 1), ("foo", 4)], + [(" foo;", 1), ("foo", 4)], + [(" foo=", 1), ("foo", 4)], + [(" foo=bar", 1), ("foo", 4)], + ] + for q, a in tokens: + nose.tools.eq_(cookies._read_token(*q), a) + + +def test_read_quoted_string(): + tokens = [ + [('"foo" x', 0), ("foo", 5)], + [('"f\oo" x', 0), ("foo", 6)], + [(r'"f\\o" x', 0), (r"f\o", 6)], + [(r'"f\\" x', 0), (r"f" + '\\', 5)], + [('"fo\\\"" x', 0), ("fo\"", 6)], + ] + for q, a in tokens: + nose.tools.eq_(cookies._read_quoted_string(*q), a) + + +def test_read_pairs(): + vals = [ + [ + "one", + [["one", None]] + ], + [ + "one=two", + [["one", "two"]] + ], + [ + "one=", + [["one", ""]] + ], + [ + 'one="two"', + [["one", "two"]] + ], + [ + 'one="two"; three=four', + [["one", "two"], ["three", "four"]] + ], + [ + 'one="two"; three=four; five', + [["one", "two"], ["three", "four"], ["five", None]] + ], + [ + 'one="\\"two"; three=four', + [["one", '"two'], ["three", "four"]] + ], + ] + for s, lst in vals: + ret, off = cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + + +def test_pairs_roundtrips(): + pairs = [ + [ + "", + [] + ], + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one", + [["one", None]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="uno"; two="\due"', + [["one", "uno"], ["two", "due"]] + ], + [ + 'one="un\\"o"', + [["one", 'un"o']] + ], + [ + 'one="uno,due"', + [["one", 'uno,due']] + ], + [ + "one=uno; two; three=tre", + [["one", "uno"], ["two", None], ["three", "tre"]] + ], + [ + "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " + "_rcc2=53VdltWl+Ov6ordflA==;", + [ + ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], + ["_rcc2", "53VdltWl+Ov6ordflA=="] + ] + ] + ] + for s, lst in pairs: + ret, off = cookies._read_pairs(s) + nose.tools.eq_(ret, lst) + s2 = cookies._format_pairs(lst) + ret, off = cookies._read_pairs(s2) + nose.tools.eq_(ret, lst) + + +def test_cookie_roundtrips(): + pairs = [ + [ + "one=uno", + [["one", "uno"]] + ], + [ + "one=uno; two=due", + [["one", "uno"], ["two", "due"]] + ], + ] + for s, lst in pairs: + ret = cookies.parse_cookie_header(s) + nose.tools.eq_(ret.lst, lst) + s2 = cookies.format_cookie_header(ret) + ret = cookies.parse_cookie_header(s2) + nose.tools.eq_(ret.lst, lst) + + +def test_parse_set_cookie_pairs(): + pairs = [ + [ + "one=uno", + [ + ["one", "uno"] + ] + ], + [ + "one=un\x20", + [ + ["one", "un\x20"] + ] + ], + [ + "one=uno; foo", + [ + ["one", "uno"], + ["foo", None] + ] + ], + [ + "mun=1.390.f60; " + "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " + "domain=b.aol.com", + [ + ["mun", "1.390.f60"], + ["expires", "sun, 11-oct-2015 12:38:31 gmt"], + ["path", "/"], + ["domain", "b.aol.com"] + ] + ], + [ + r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' + 'domain=.rubiconproject.com; ' + 'expires=mon, 11-may-2015 21:54:57 gmt; ' + 'path=/', + [ + ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], + ['domain', '.rubiconproject.com'], + ['expires', 'mon, 11-may-2015 21:54:57 gmt'], + ['path', '/'] + ] + ], + ] + for s, lst in pairs: + ret = cookies._parse_set_cookie_pairs(s) + nose.tools.eq_(ret, lst) + s2 = cookies._format_set_cookie_pairs(ret) + ret2 = cookies._parse_set_cookie_pairs(s2) + nose.tools.eq_(ret2, lst) + + +def test_parse_set_cookie_header(): + vals = [ + [ + "", None + ], + [ + ";", None + ], + [ + "one=uno", + ("one", "uno", []) + ], + [ + "one=uno; foo=bar", + ("one", "uno", [["foo", "bar"]]) + ] + ] + for s, expected in vals: + ret = cookies.parse_set_cookie_header(s) + if expected: + assert ret[0] == expected[0] + assert ret[1] == expected[1] + nose.tools.eq_(ret[2].lst, expected[2]) + s2 = cookies.format_set_cookie_header(*ret) + ret2 = cookies.parse_set_cookie_header(s2) + assert ret2[0] == expected[0] + assert ret2[1] == expected[1] + nose.tools.eq_(ret2[2].lst, expected[2]) + else: + assert ret is None diff --git a/test/http/test_semantics.py b/test/http/test_semantics.py new file mode 100644 index 00000000..c4605302 --- /dev/null +++ b/test/http/test_semantics.py @@ -0,0 +1,54 @@ +import cStringIO +import textwrap +import binascii + +from netlib import http, odict, tcp +from netlib.http import http1 +from .. import tutils, tservers + +def test_httperror(): + e = http.exceptions.HttpError(404, "Not found") + assert str(e) + + +def test_parse_url(): + assert not http.parse_url("") + + u = "http://foo.com:8888/test" + s, h, po, pa = http.parse_url(u) + assert s == "http" + assert h == "foo.com" + assert po == 8888 + assert pa == "/test" + + s, h, po, pa = http.parse_url("http://foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://user:pass@foo/bar") + assert s == "http" + assert h == "foo" + assert po == 80 + assert pa == "/bar" + + s, h, po, pa = http.parse_url("http://foo") + assert pa == "/" + + s, h, po, pa = http.parse_url("https://foo") + assert po == 443 + + assert not http.parse_url("https://foo:bar") + assert not http.parse_url("https://foo:") + + # Invalid IDNA + assert not http.parse_url("http://\xfafoo") + # Invalid PATH + assert not http.parse_url("http:/\xc6/localhost:56121") + # Null byte in host + assert not http.parse_url("http://foo\0") + # Port out of range + assert not http.parse_url("http://foo:999999") + # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt + assert not http.parse_url('http://lo[calhost') diff --git a/test/http/test_user_agents.py b/test/http/test_user_agents.py new file mode 100644 index 00000000..0bf1bba7 --- /dev/null +++ b/test/http/test_user_agents.py @@ -0,0 +1,6 @@ +from netlib.http import user_agents + + +def test_get_shortcut(): + assert user_agents.get_by_shortcut("c")[0] == "chrome" + assert not user_agents.get_by_shortcut("_") diff --git a/test/http2/__init__.py b/test/http2/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/test/http2/test_frames.py b/test/http2/test_frames.py deleted file mode 100644 index 76a4b712..00000000 --- a/test/http2/test_frames.py +++ /dev/null @@ -1,704 +0,0 @@ -import cStringIO -from test import tutils -from nose.tools import assert_equal -from netlib import tcp -from netlib.http2.frame import * - - -def hex_to_file(data): - data = data.decode('hex') - return tcp.Reader(cStringIO.StringIO(data)) - - -def test_invalid_flags(): - tutils.raises( - ValueError, - DataFrame, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - payload='foobar') - - -def test_frame_equality(): - a = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - b = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(a, b) - - -def test_too_large_frames(): - f = DataFrame( - length=9000, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar' * 3000) - tutils.raises(FrameSizeError, f.to_bytes) - - -def test_data_frame_to_bytes(): - f = DataFrame( - length=6, - flags=Frame.FLAG_END_STREAM, - stream_id=0x1234567, - payload='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006000101234567666f6f626172') - - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000a00090123456703666f6f626172000000') - - f = DataFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_data_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006000101234567666f6f626172')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - f = Frame.from_file(hex_to_file('00000a00090123456703666f6f626172000000')) - assert isinstance(f, DataFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, DataFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_END_STREAM | Frame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.payload, 'foobar') - - -def test_data_frame_human_readable(): - f = DataFrame( - length=11, - flags=(Frame.FLAG_END_STREAM | Frame.FLAG_PADDED), - stream_id=0x1234567, - payload='foobar', - pad_length=3) - assert f.human_readable() - - -def test_headers_frame_to_bytes(): - f = HeadersFrame( - length=6, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex')) - assert_equal(f.to_bytes().encode('hex'), '000007010001234567668594e75e31d9') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PADDED), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000b01080123456703668594e75e31d9000000') - - f = HeadersFrame( - length=10, - flags=(HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00000c012001234567876543212a668594e75e31d9') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703876543212a668594e75e31d9000000') - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert_equal( - f.to_bytes().encode('hex'), - '00001001280123456703076543212a668594e75e31d9000000') - - f = HeadersFrame( - length=6, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - header_block_fragment='668594e75e31d9'.decode('hex')) - tutils.raises(ValueError, f.to_bytes) - - -def test_headers_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '000007010001234567668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 7) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(hex_to_file( - '00000b01080123456703668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 11) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - - f = Frame.from_file(hex_to_file( - '00000c012001234567876543212a668594e75e31d9')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file( - '00001001280123456703876543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file( - '00001001280123456703076543212a668594e75e31d9000000')) - assert isinstance(f, HeadersFrame) - assert_equal(f.length, 16) - assert_equal(f.TYPE, HeadersFrame.TYPE) - assert_equal(f.flags, HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, '668594e75e31d9'.decode('hex')) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - -def test_headers_frame_human_readable(): - f = HeadersFrame( - length=7, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment=b'', - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - f = HeadersFrame( - length=14, - flags=(HeadersFrame.FLAG_PADDED | HeadersFrame.FLAG_PRIORITY), - stream_id=0x1234567, - header_block_fragment='668594e75e31d9'.decode('hex'), - pad_length=3, - exclusive=False, - stream_dependency=0x7654321, - weight=42) - assert f.human_readable() - - -def test_priority_frame_to_bytes(): - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=True, - stream_dependency=0x7654321, - weight=42) - assert_equal(f.to_bytes().encode('hex'), '000005020001234567876543212a') - - f = PriorityFrame( - length=5, - flags=(Frame.FLAG_NO_FLAGS), - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert_equal(f.to_bytes().encode('hex'), '0000050200012345670765432115') - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - stream_dependency=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - stream_dependency=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_priority_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000005020001234567876543212a')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, True) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 42) - - f = Frame.from_file(hex_to_file('0000050200012345670765432115')) - assert isinstance(f, PriorityFrame) - assert_equal(f.length, 5) - assert_equal(f.TYPE, PriorityFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.exclusive, False) - assert_equal(f.stream_dependency, 0x7654321) - assert_equal(f.weight, 21) - - -def test_priority_frame_human_readable(): - f = PriorityFrame( - length=5, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - exclusive=False, - stream_dependency=0x7654321, - weight=21) - assert f.human_readable() - - -def test_rst_stream_frame_to_bytes(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000403000123456707654321') - - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_rst_stream_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000403000123456707654321')) - assert isinstance(f, RstStreamFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, RstStreamFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.error_code, 0x07654321) - - -def test_rst_stream_frame_human_readable(): - f = RstStreamFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - error_code=0x7654321) - assert f.human_readable() - - -def test_settings_frame_to_bytes(): - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040000000000') - - f = SettingsFrame( - length=0, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0) - assert_equal(f.to_bytes().encode('hex'), '000000040100000000') - - f = SettingsFrame( - length=6, - flags=SettingsFrame.FLAG_ACK, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1}) - assert_equal(f.to_bytes().encode('hex'), '000006040100000000000200000001') - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert_equal( - f.to_bytes().encode('hex'), - '00000c040000000000000200000001000312345678') - - f = SettingsFrame( - length=0, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_settings_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000000040000000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(hex_to_file('000000040100000000')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 0) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - - f = Frame.from_file(hex_to_file('000006040100000000000200000001')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, SettingsFrame.FLAG_ACK, 0x0) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 1) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - - f = Frame.from_file(hex_to_file( - '00000c040000000000000200000001000312345678')) - assert isinstance(f, SettingsFrame) - assert_equal(f.length, 12) - assert_equal(f.TYPE, SettingsFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(len(f.settings), 2) - assert_equal(f.settings[SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH], 1) - assert_equal( - f.settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS], - 0x12345678) - - -def test_settings_frame_human_readable(): - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={}) - assert f.human_readable() - - f = SettingsFrame( - length=12, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - settings={ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 1, - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 0x12345678}) - assert f.human_readable() - - -def test_push_promise_frame_to_bytes(): - f = PushPromiseFrame( - length=10, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000a05000123456707654321666f6f626172') - - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert_equal( - f.to_bytes().encode('hex'), - '00000e0508012345670307654321666f6f626172000000') - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - promised_stream=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - f = PushPromiseFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - promised_stream=0x0) - tutils.raises(ValueError, f.to_bytes) - - -def test_push_promise_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000a05000123456707654321666f6f626172')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 10) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - f = Frame.from_file(hex_to_file( - '00000e0508012345670307654321666f6f626172000000')) - assert isinstance(f, PushPromiseFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, PushPromiseFrame.TYPE) - assert_equal(f.flags, PushPromiseFrame.FLAG_PADDED) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_push_promise_frame_human_readable(): - f = PushPromiseFrame( - length=14, - flags=HeadersFrame.FLAG_PADDED, - stream_id=0x1234567, - promised_stream=0x7654321, - header_block_fragment='foobar', - pad_length=3) - assert f.human_readable() - - -def test_ping_frame_to_bytes(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '000008060100000000666f6f6261720000') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - payload=b'foobardeadbeef') - assert_equal( - f.to_bytes().encode('hex'), - '000008060000000000666f6f6261726465') - - f = PingFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567) - tutils.raises(ValueError, f.to_bytes) - - -def test_ping_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000008060100000000666f6f6261720000')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, PingFrame.FLAG_ACK) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobar\0\0') - - f = Frame.from_file(hex_to_file('000008060000000000666f6f6261726465')) - assert isinstance(f, PingFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, PingFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.payload, b'foobarde') - - -def test_ping_frame_human_readable(): - f = PingFrame( - length=8, - flags=PingFrame.FLAG_ACK, - stream_id=0x0, - payload=b'foobar') - assert f.human_readable() - - -def test_goaway_frame_to_bytes(): - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'') - assert_equal( - f.to_bytes().encode('hex'), - '0000080700000000000123456787654321') - - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert_equal( - f.to_bytes().encode('hex'), - '00000e0700000000000123456787654321666f6f626172') - - f = GoAwayFrame( - length=8, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - last_stream=0x1234567, - error_code=0x87654321) - tutils.raises(ValueError, f.to_bytes) - - -def test_goaway_frame_from_bytes(): - f = Frame.from_file(hex_to_file( - '0000080700000000000123456787654321')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 8) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'') - - f = Frame.from_file(hex_to_file( - '00000e0700000000000123456787654321666f6f626172')) - assert isinstance(f, GoAwayFrame) - assert_equal(f.length, 14) - assert_equal(f.TYPE, GoAwayFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.last_stream, 0x1234567) - assert_equal(f.error_code, 0x87654321) - assert_equal(f.data, b'foobar') - - -def test_go_away_frame_human_readable(): - f = GoAwayFrame( - length=14, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - last_stream=0x1234567, - error_code=0x87654321, - data=b'foobar') - assert f.human_readable() - - -def test_window_update_frame_to_bytes(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0x1234567) - assert_equal(f.to_bytes().encode('hex'), '00000408000000000001234567') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert_equal(f.to_bytes().encode('hex'), '00000408000123456707654321') - - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x0, - window_size_increment=0xdeadbeef) - tutils.raises(ValueError, f.to_bytes) - - f = WindowUpdateFrame(4, Frame.FLAG_NO_FLAGS, 0x0, window_size_increment=0) - tutils.raises(ValueError, f.to_bytes) - - -def test_window_update_frame_from_bytes(): - f = Frame.from_file(hex_to_file('00000408000000000001234567')) - assert isinstance(f, WindowUpdateFrame) - assert_equal(f.length, 4) - assert_equal(f.TYPE, WindowUpdateFrame.TYPE) - assert_equal(f.flags, Frame.FLAG_NO_FLAGS) - assert_equal(f.stream_id, 0x0) - assert_equal(f.window_size_increment, 0x1234567) - - -def test_window_update_frame_human_readable(): - f = WindowUpdateFrame( - length=4, - flags=Frame.FLAG_NO_FLAGS, - stream_id=0x1234567, - window_size_increment=0x7654321) - assert f.human_readable() - - -def test_continuation_frame_to_bytes(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert_equal(f.to_bytes().encode('hex'), '000006090401234567666f6f626172') - - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x0, - header_block_fragment='foobar') - tutils.raises(ValueError, f.to_bytes) - - -def test_continuation_frame_from_bytes(): - f = Frame.from_file(hex_to_file('000006090401234567666f6f626172')) - assert isinstance(f, ContinuationFrame) - assert_equal(f.length, 6) - assert_equal(f.TYPE, ContinuationFrame.TYPE) - assert_equal(f.flags, ContinuationFrame.FLAG_END_HEADERS) - assert_equal(f.stream_id, 0x1234567) - assert_equal(f.header_block_fragment, 'foobar') - - -def test_continuation_frame_human_readable(): - f = ContinuationFrame( - length=6, - flags=ContinuationFrame.FLAG_END_HEADERS, - stream_id=0x1234567, - header_block_fragment='foobar') - assert f.human_readable() diff --git a/test/http2/test_protocol.py b/test/http2/test_protocol.py deleted file mode 100644 index 5e2af34e..00000000 --- a/test/http2/test_protocol.py +++ /dev/null @@ -1,326 +0,0 @@ -import OpenSSL - -from netlib import http2 -from netlib import tcp -from netlib.http2.frame import * -from test import tutils -from .. import tservers - - -class EchoHandler(tcp.BaseHandler): - sni = None - - def handle(self): - while True: - v = self.rfile.safe_read(1) - self.wfile.write(v) - self.wfile.flush() - - -class TestCheckALPNMatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=http2.HTTP2Protocol.ALPN_PROTO_H2, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - assert protocol.check_alpn() - - -class TestCheckALPNMismatch(tservers.ServerTestBase): - handler = EchoHandler - ssl = dict( - alpn_select=None, - ) - - if OpenSSL._util.lib.Cryptography_HAS_ALPN: - - def test_check_alpn(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl(alpn_protos=[http2.HTTP2Protocol.ALPN_PROTO_H2]) - protocol = http2.HTTP2Protocol(c) - tutils.raises(NotImplementedError, protocol.check_alpn) - - -class TestPerformServerConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # send magic - self.wfile.write( - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex')) - self.wfile.flush() - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_server_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_server_connection_preface() - - -class TestPerformClientConnectionPreface(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check magic - assert self.rfile.read(24) ==\ - '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'.decode('hex') - - # check empty settings frame - assert self.rfile.read(9) ==\ - '000000040000000000'.decode('hex') - - # send empty settings frame - self.wfile.write('000000040000000000'.decode('hex')) - self.wfile.flush() - - # check settings acknowledgement - assert self.rfile.read(9) == \ - '000000040100000000'.decode('hex') - - # send settings acknowledgement - self.wfile.write('000000040100000000'.decode('hex')) - self.wfile.flush() - - def test_perform_client_connection_preface(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - protocol = http2.HTTP2Protocol(c) - protocol.perform_client_connection_preface() - - -class TestClientStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_client_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 1 - assert self.protocol.current_stream_id == 1 - assert self.protocol.next_stream_id() == 3 - assert self.protocol.current_stream_id == 3 - assert self.protocol.next_stream_id() == 5 - assert self.protocol.current_stream_id == 5 - - -class TestServerStreamIds(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c, is_server=True) - - def test_server_stream_ids(self): - assert self.protocol.current_stream_id is None - assert self.protocol.next_stream_id() == 2 - assert self.protocol.current_stream_id == 2 - assert self.protocol.next_stream_id() == 4 - assert self.protocol.current_stream_id == 4 - assert self.protocol.next_stream_id() == 6 - assert self.protocol.current_stream_id == 6 - - -class TestApplySettings(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - # check settings acknowledgement - assert self.rfile.read(9) == '000000040100000000'.decode('hex') - self.wfile.write("OK") - self.wfile.flush() - - ssl = True - - def test_apply_settings(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - protocol._apply_settings({ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH: 'foo', - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS: 'bar', - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE: 'deadbeef', - }) - - assert c.rfile.safe_read(2) == "OK" - - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_ENABLE_PUSH] == 'foo' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_MAX_CONCURRENT_STREAMS] == 'bar' - assert protocol.http2_settings[ - SettingsFrame.SETTINGS.SETTINGS_INITIAL_WINDOW_SIZE] == 'deadbeef' - - -class TestCreateHeaders(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_headers(self): - headers = [ - (b':method', b'GET'), - (b':path', b'index.html'), - (b':scheme', b'https'), - (b'foo', b'bar')] - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=True) - assert b''.join(bytes) ==\ - '000014010500000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - bytes = http2.HTTP2Protocol(self.c)._create_headers( - headers, 1, end_stream=False) - assert b''.join(bytes) ==\ - '000014010400000001824488355217caf3a69a3f87408294e7838c767f'\ - .decode('hex') - - # TODO: add test for too large header_block_fragments - - -class TestCreateBody(): - c = tcp.TCPClient(("127.0.0.1", 0)) - protocol = http2.HTTP2Protocol(c) - - def test_create_body_empty(self): - bytes = self.protocol._create_body(b'', 1) - assert b''.join(bytes) == ''.decode('hex') - - def test_create_body_single_frame(self): - bytes = self.protocol._create_body('foobar', 1) - assert b''.join(bytes) == '000006000100000001666f6f626172'.decode('hex') - - def test_create_body_multiple_frames(self): - pass - # bytes = self.protocol._create_body('foobar' * 3000, 1) - # TODO: add test for too large frames - - -class TestCreateRequest(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_request_simple(self): - bytes = http2.HTTP2Protocol(self.c).create_request('GET', '/') - assert len(bytes) == 1 - assert bytes[0] == '00000d0105000000018284874188089d5c0b8170dc07'.decode('hex') - - def test_create_request_with_body(self): - bytes = http2.HTTP2Protocol(self.c).create_request( - 'GET', '/', [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '0000150104000000018284874188089d5c0b8170dc07408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') - - -class TestReadResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801040000000188628594e78c767f'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'foobar' - - -class TestReadEmptyResponse(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'00000801050000000188628594e78c767f'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_empty_response(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c) - - status, headers, body = protocol.read_response() - - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'' - - -class TestReadRequest(tservers.ServerTestBase): - class handler(tcp.BaseHandler): - - def handle(self): - self.wfile.write( - b'000003010400000001828487'.decode('hex')) - self.wfile.write( - b'000006000100000001666f6f626172'.decode('hex')) - self.wfile.flush() - - ssl = True - - def test_read_request(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - c.convert_to_ssl() - protocol = http2.HTTP2Protocol(c, is_server=True) - - stream_id, headers, body = protocol.read_request() - - assert stream_id - assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert body == b'foobar' - - -class TestCreateResponse(): - c = tcp.TCPClient(("127.0.0.1", 0)) - - def test_create_response_simple(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response(200) - assert len(bytes) == 1 - assert bytes[0] ==\ - '00000101050000000288'.decode('hex') - - def test_create_response_with_body(self): - bytes = http2.HTTP2Protocol(self.c, is_server=True).create_response( - 200, 1, [(b'foo', b'bar')], 'foobar') - assert len(bytes) == 2 - assert bytes[0] ==\ - '00000901040000000188408294e7838c767f'.decode('hex') - assert bytes[1] ==\ - '000006000100000001666f6f626172'.decode('hex') diff --git a/test/test_http.py b/test/test_http.py deleted file mode 100644 index bbc78847..00000000 --- a/test/test_http.py +++ /dev/null @@ -1,491 +0,0 @@ -import cStringIO -import textwrap -import binascii -from netlib import http, http_semantics, odict, tcp -from . import tutils, tservers - - -def test_httperror(): - e = http.HttpError(404, "Not found") - assert str(e) - - -def test_has_chunked_encoding(): - h = odict.ODictCaseless() - assert not http.has_chunked_encoding(h) - h["transfer-encoding"] = ["chunked"] - assert http.has_chunked_encoding(h) - - -def test_read_chunked(): - - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("1\r\na\r\n0\r\n") - - tutils.raises( - "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" - - s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert http.read_http_body(s, h, None, "GET", None, True) == "a" - - s = cStringIO.StringIO("\r\n") - tutils.raises( - "closed prematurely", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("1\r\nfoo") - tutils.raises( - "malformed chunked body", - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("foo\r\nfoo") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", None, True - ) - - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True) - - -def test_connection_close(): - h = odict.ODictCaseless() - assert http.connection_close((1, 0), h) - assert not http.connection_close((1, 1), h) - - h["connection"] = ["keep-alive"] - assert not http.connection_close((1, 1), h) - - h["connection"] = ["close"] - assert http.connection_close((1, 1), h) - - -def test_get_header_tokens(): - h = odict.ODictCaseless() - assert http.get_header_tokens(h, "foo") == [] - h["foo"] = ["bar"] - assert http.get_header_tokens(h, "foo") == ["bar"] - h["foo"] = ["bar, voing"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing"] - h["foo"] = ["bar, voing", "oink"] - assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] - - -def test_read_http_body_request(): - h = odict.ODictCaseless() - r = cStringIO.StringIO("testing") - assert http.read_http_body(r, h, None, "GET", None, True) == "" - - -def test_read_http_body_response(): - h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" - - -def test_read_http_body(): - # test default case - h = odict.ODictCaseless() - h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert http.read_http_body(s, h, None, "GET", 200, False) == "testing" - - # test content length: invalid header - h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False - ) - - # test content length: invalid header #2 - h["content-length"] = [-1] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, None, "GET", 200, False - ) - - # test content length: content length > actual content - h["content-length"] = [5] - s = cStringIO.StringIO("testing") - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False - ) - - # test content length: content length < actual content - s = cStringIO.StringIO("testing") - assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5 - - # test no content length: limit > actual content - h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7 - - # test no content length: limit < actual content - s = tcp.Reader(cStringIO.StringIO("testing")) - tutils.raises( - http.HttpError, - http.read_http_body, - s, h, 4, "GET", 200, False - ) - - # test chunked - h = odict.ODictCaseless() - h["transfer-encoding"] = ["chunked"] - s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) - assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" - - -def test_expected_http_body_size(): - # gibber in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["foo"] - assert http.expected_http_body_size(h, False, "GET", 200) is None - # negative number in the content-length field - h = odict.ODictCaseless() - h["content-length"] = ["-7"] - assert http.expected_http_body_size(h, False, "GET", 200) is None - # explicit length - h = odict.ODictCaseless() - h["content-length"] = ["5"] - assert http.expected_http_body_size(h, False, "GET", 200) == 5 - # no length - h = odict.ODictCaseless() - assert http.expected_http_body_size(h, False, "GET", 200) == -1 - # no length request - h = odict.ODictCaseless() - assert http.expected_http_body_size(h, True, "GET", None) == 0 - - -def test_parse_http_protocol(): - assert http.parse_http_protocol("HTTP/1.1") == (1, 1) - assert http.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not http.parse_http_protocol("HTTP/a.1") - assert not http.parse_http_protocol("HTTP/1.a") - assert not http.parse_http_protocol("foo/0.0") - assert not http.parse_http_protocol("HTTP/x") - - -def test_parse_init_connect(): - assert http.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not http.parse_init_connect("bogus") - assert not http.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not http.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not http.parse_init_connect("CONNECT host.com:foo HTTP/1.0") - - -def test_parse_init_proxy(): - u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = http.parse_init_proxy(u) - assert m == "GET" - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not http.parse_init_proxy(u) - - assert not http.parse_init_proxy("invalid") - assert not http.parse_init_proxy("GET invalid HTTP/1.1") - assert not http.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") - - -def test_parse_init_http(): - u = "GET /test HTTP/1.1" - m, u, httpversion = http.parse_init_http(u) - assert m == "GET" - assert u == "/test" - assert httpversion == (1, 1) - - u = "G\xfeET /test HTTP/1.1" - assert not http.parse_init_http(u) - - assert not http.parse_init_http("invalid") - assert not http.parse_init_http("GET invalid HTTP/1.1") - assert not http.parse_init_http("GET /test foo/1.1") - assert not http.parse_init_http("GET /test\xc0 HTTP/1.1") - - -class TestReadHeaders: - - def _read(self, data, verbatim=False): - if not verbatim: - data = textwrap.dedent(data) - data = data.strip() - s = cStringIO.StringIO(data) - return http.read_headers(s) - - def test_read_simple(self): - data = """ - Header: one - Header2: two - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header2", "two"]] - - def test_read_multi(self): - data = """ - Header: one - Header: two - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one"], ["Header", "two"]] - - def test_read_continued(self): - data = """ - Header: one - \ttwo - Header2: three - \r\n - """ - h = self._read(data) - assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] - - def test_read_continued_err(self): - data = "\tfoo: bar\r\n" - assert self._read(data, True) is None - - def test_read_err(self): - data = """ - foo - """ - assert self._read(data) is None - - -class NoContentLengthHTTPHandler(tcp.BaseHandler): - - def handle(self): - self.wfile.write("HTTP/1.1 200 OK\r\n\r\nbar\r\n\r\n") - self.wfile.flush() - - -class TestReadResponseNoContentLength(tservers.ServerTestBase): - handler = NoContentLengthHTTPHandler - - def test_no_content_length(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - c.connect() - resp = http.read_response(c.rfile, "GET", None) - assert resp.content == "bar\r\n\r\n" - - -def test_read_response(): - def tst(data, method, limit, include_body=True): - data = textwrap.dedent(data) - r = cStringIO.StringIO(data) - return http.read_response( - r, method, limit, include_body=include_body - ) - - tutils.raises("server disconnect", tst, "", "GET", None) - tutils.raises("invalid server response", tst, "foo", "GET", None) - data = """ - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 200, 'OK', odict.ODictCaseless(), '' - ) - data = """ - HTTP/1.1 200 - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 200, '', odict.ODictCaseless(), '' - ) - data = """ - HTTP/x 200 OK - """ - tutils.raises("invalid http version", tst, data, "GET", None) - data = """ - HTTP/1.1 xx OK - """ - tutils.raises("invalid server response", tst, data, "GET", None) - - data = """ - HTTP/1.1 100 CONTINUE - - HTTP/1.1 200 OK - """ - assert tst(data, "GET", None) == http_semantics.Response( - (1, 1), 100, 'CONTINUE', odict.ODictCaseless(), '' - ) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None).content == 'foo' - assert tst(data, "HEAD", None).content == '' - - data = """ - HTTP/1.1 200 OK - \tContent-Length: 3 - - foo - """ - tutils.raises("invalid headers", tst, data, "GET", None) - - data = """ - HTTP/1.1 200 OK - Content-Length: 3 - - foo - """ - assert tst(data, "GET", None, include_body=False).content is None - - -def test_parse_url(): - assert not http.parse_url("") - - u = "http://foo.com:8888/test" - s, h, po, pa = http.parse_url(u) - assert s == "http" - assert h == "foo.com" - assert po == 8888 - assert pa == "/test" - - s, h, po, pa = http.parse_url("http://foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://user:pass@foo/bar") - assert s == "http" - assert h == "foo" - assert po == 80 - assert pa == "/bar" - - s, h, po, pa = http.parse_url("http://foo") - assert pa == "/" - - s, h, po, pa = http.parse_url("https://foo") - assert po == 443 - - assert not http.parse_url("https://foo:bar") - assert not http.parse_url("https://foo:") - - # Invalid IDNA - assert not http.parse_url("http://\xfafoo") - # Invalid PATH - assert not http.parse_url("http:/\xc6/localhost:56121") - # Null byte in host - assert not http.parse_url("http://foo\0") - # Port out of range - assert not http.parse_url("http://foo:999999") - # Invalid IPv6 URL - see http://www.ietf.org/rfc/rfc2732.txt - assert not http.parse_url('http://lo[calhost') - - -def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") - assert http.parse_http_basic_auth( - http.assemble_http_basic_auth(*vals) - ) == vals - assert not http.parse_http_basic_auth("") - assert not http.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") - assert not http.parse_http_basic_auth(v) - - -def test_get_request_line(): - r = cStringIO.StringIO("\nfoo") - assert http.get_request_line(r) == "foo" - assert not http.get_request_line(r) - - -class TestReadRequest(): - - def tst(self, data, **kwargs): - r = cStringIO.StringIO(data) - return http.read_request(r, **kwargs) - - def test_invalid(self): - tutils.raises( - "bad http request", - self.tst, - "xxx" - ) - tutils.raises( - "bad http request line", - self.tst, - "get /\xff HTTP/1.1" - ) - tutils.raises( - "invalid headers", - self.tst, - "get / HTTP/1.1\r\nfoo" - ) - tutils.raises( - tcp.NetLibDisconnect, - self.tst, - "\r\n" - ) - - def test_asterisk_form_in(self): - v = self.tst("OPTIONS * HTTP/1.1") - assert v.form_in == "relative" - assert v.method == "OPTIONS" - - def test_absolute_form_in(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "GET oops-no-protocol.com HTTP/1.1" - ) - v = self.tst("GET http://address:22/ HTTP/1.1") - assert v.form_in == "absolute" - assert v.port == 22 - assert v.host == "address" - assert v.scheme == "http" - - def test_connect(self): - tutils.raises( - "Bad HTTP request line", - self.tst, - "CONNECT oops-no-port.com HTTP/1.1" - ) - v = self.tst("CONNECT foo.com:443 HTTP/1.1") - assert v.form_in == "authority" - assert v.method == "CONNECT" - assert v.port == 443 - assert v.host == "foo.com" - - def test_expect(self): - w = cStringIO.StringIO() - r = cStringIO.StringIO( - "GET / HTTP/1.1\r\n" - "Content-Length: 3\r\n" - "Expect: 100-continue\r\n\r\n" - "foobar", - ) - v = http.read_request(r, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.content == "foo" - assert r.read(3) == "bar" diff --git a/test/test_http_auth.py b/test/test_http_auth.py deleted file mode 100644 index c842925b..00000000 --- a/test/test_http_auth.py +++ /dev/null @@ -1,109 +0,0 @@ -from netlib import odict, http_auth, http -import tutils - - -class TestPassManNonAnon: - - def test_simple(self): - p = http_auth.PassManNonAnon() - assert not p.test("", "") - assert p.test("user", "") - - -class TestPassManHtpasswd: - - def test_file_errors(self): - tutils.raises( - "malformed htpasswd file", - http_auth.PassManHtpasswd, - tutils.test_data.path("data/server.crt")) - - def test_simple(self): - pm = http_auth.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) - - vals = ("basic", "test", "test") - http.assemble_http_basic_auth(*vals) - assert pm.test("test", "test") - assert not pm.test("test", "foo") - assert not pm.test("foo", "test") - assert not pm.test("test", "") - assert not pm.test("", "") - - -class TestPassManSingleUser: - - def test_simple(self): - pm = http_auth.PassManSingleUser("test", "test") - assert pm.test("test", "test") - assert not pm.test("test", "foo") - assert not pm.test("foo", "test") - - -class TestNullProxyAuth: - - def test_simple(self): - na = http_auth.NullProxyAuth(http_auth.PassManNonAnon()) - assert not na.auth_challenge_headers() - assert na.authenticate("foo") - na.clean({}) - - -class TestBasicProxyAuth: - - def test_simple(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") - h = odict.ODictCaseless() - assert ba.auth_challenge_headers() - assert not ba.authenticate(h) - - def test_authenticate_clean(self): - ba = http_auth.BasicProxyAuth(http_auth.PassManNonAnon(), "test") - - hdrs = odict.ODictCaseless() - vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert ba.authenticate(hdrs) - - ba.clean(hdrs) - assert not ba.AUTH_HEADER in hdrs - - hdrs[ba.AUTH_HEADER] = [""] - assert not ba.authenticate(hdrs) - - hdrs[ba.AUTH_HEADER] = ["foo"] - assert not ba.authenticate(hdrs) - - vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) - - ba = http_auth.BasicProxyAuth(http_auth.PassMan(), "test") - vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.assemble_http_basic_auth(*vals)] - assert not ba.authenticate(hdrs) - - -class Bunch: - pass - - -class TestAuthAction: - - def test_nonanonymous(self): - m = Bunch() - aa = http_auth.NonanonymousAuthAction(None, "authenticator") - aa(None, m, None, None) - assert m.authenticator - - def test_singleuser(self): - m = Bunch() - aa = http_auth.SingleuserAuthAction(None, "authenticator") - aa(None, m, "foo:bar", None) - assert m.authenticator - tutils.raises("invalid", aa, None, m, "foo", None) - - def test_httppasswd(self): - m = Bunch() - aa = http_auth.HtpasswdAuthAction(None, "authenticator") - aa(None, m, tutils.test_data.path("data/htpasswd"), None) - assert m.authenticator diff --git a/test/test_http_cookies.py b/test/test_http_cookies.py deleted file mode 100644 index 070849cf..00000000 --- a/test/test_http_cookies.py +++ /dev/null @@ -1,219 +0,0 @@ -import nose.tools - -from netlib import http_cookies - - -def test_read_token(): - tokens = [ - [("foo", 0), ("foo", 3)], - [("foo", 1), ("oo", 3)], - [(" foo", 1), ("foo", 4)], - [(" foo;", 1), ("foo", 4)], - [(" foo=", 1), ("foo", 4)], - [(" foo=bar", 1), ("foo", 4)], - ] - for q, a in tokens: - nose.tools.eq_(http_cookies._read_token(*q), a) - - -def test_read_quoted_string(): - tokens = [ - [('"foo" x', 0), ("foo", 5)], - [('"f\oo" x', 0), ("foo", 6)], - [(r'"f\\o" x', 0), (r"f\o", 6)], - [(r'"f\\" x', 0), (r"f" + '\\', 5)], - [('"fo\\\"" x', 0), ("fo\"", 6)], - ] - for q, a in tokens: - nose.tools.eq_(http_cookies._read_quoted_string(*q), a) - - -def test_read_pairs(): - vals = [ - [ - "one", - [["one", None]] - ], - [ - "one=two", - [["one", "two"]] - ], - [ - "one=", - [["one", ""]] - ], - [ - 'one="two"', - [["one", "two"]] - ], - [ - 'one="two"; three=four', - [["one", "two"], ["three", "four"]] - ], - [ - 'one="two"; three=four; five', - [["one", "two"], ["three", "four"], ["five", None]] - ], - [ - 'one="\\"two"; three=four', - [["one", '"two'], ["three", "four"]] - ], - ] - for s, lst in vals: - ret, off = http_cookies._read_pairs(s) - nose.tools.eq_(ret, lst) - - -def test_pairs_roundtrips(): - pairs = [ - [ - "", - [] - ], - [ - "one=uno", - [["one", "uno"]] - ], - [ - "one", - [["one", None]] - ], - [ - "one=uno; two=due", - [["one", "uno"], ["two", "due"]] - ], - [ - 'one="uno"; two="\due"', - [["one", "uno"], ["two", "due"]] - ], - [ - 'one="un\\"o"', - [["one", 'un"o']] - ], - [ - 'one="uno,due"', - [["one", 'uno,due']] - ], - [ - "one=uno; two; three=tre", - [["one", "uno"], ["two", None], ["three", "tre"]] - ], - [ - "_lvs2=zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g=; " - "_rcc2=53VdltWl+Ov6ordflA==;", - [ - ["_lvs2", "zHai1+Hq+Tc2vmc2r4GAbdOI5Jopg3EwsdUT9g="], - ["_rcc2", "53VdltWl+Ov6ordflA=="] - ] - ] - ] - for s, lst in pairs: - ret, off = http_cookies._read_pairs(s) - nose.tools.eq_(ret, lst) - s2 = http_cookies._format_pairs(lst) - ret, off = http_cookies._read_pairs(s2) - nose.tools.eq_(ret, lst) - - -def test_cookie_roundtrips(): - pairs = [ - [ - "one=uno", - [["one", "uno"]] - ], - [ - "one=uno; two=due", - [["one", "uno"], ["two", "due"]] - ], - ] - for s, lst in pairs: - ret = http_cookies.parse_cookie_header(s) - nose.tools.eq_(ret.lst, lst) - s2 = http_cookies.format_cookie_header(ret) - ret = http_cookies.parse_cookie_header(s2) - nose.tools.eq_(ret.lst, lst) - - -def test_parse_set_cookie_pairs(): - pairs = [ - [ - "one=uno", - [ - ["one", "uno"] - ] - ], - [ - "one=un\x20", - [ - ["one", "un\x20"] - ] - ], - [ - "one=uno; foo", - [ - ["one", "uno"], - ["foo", None] - ] - ], - [ - "mun=1.390.f60; " - "expires=sun, 11-oct-2015 12:38:31 gmt; path=/; " - "domain=b.aol.com", - [ - ["mun", "1.390.f60"], - ["expires", "sun, 11-oct-2015 12:38:31 gmt"], - ["path", "/"], - ["domain", "b.aol.com"] - ] - ], - [ - r'rpb=190%3d1%2616726%3d1%2634832%3d1%2634874%3d1; ' - 'domain=.rubiconproject.com; ' - 'expires=mon, 11-may-2015 21:54:57 gmt; ' - 'path=/', - [ - ['rpb', r'190%3d1%2616726%3d1%2634832%3d1%2634874%3d1'], - ['domain', '.rubiconproject.com'], - ['expires', 'mon, 11-may-2015 21:54:57 gmt'], - ['path', '/'] - ] - ], - ] - for s, lst in pairs: - ret = http_cookies._parse_set_cookie_pairs(s) - nose.tools.eq_(ret, lst) - s2 = http_cookies._format_set_cookie_pairs(ret) - ret2 = http_cookies._parse_set_cookie_pairs(s2) - nose.tools.eq_(ret2, lst) - - -def test_parse_set_cookie_header(): - vals = [ - [ - "", None - ], - [ - ";", None - ], - [ - "one=uno", - ("one", "uno", []) - ], - [ - "one=uno; foo=bar", - ("one", "uno", [["foo", "bar"]]) - ] - ] - for s, expected in vals: - ret = http_cookies.parse_set_cookie_header(s) - if expected: - assert ret[0] == expected[0] - assert ret[1] == expected[1] - nose.tools.eq_(ret[2].lst, expected[2]) - s2 = http_cookies.format_set_cookie_header(*ret) - ret2 = http_cookies.parse_set_cookie_header(s2) - assert ret2[0] == expected[0] - assert ret2[1] == expected[1] - nose.tools.eq_(ret2[2].lst, expected[2]) - else: - assert ret is None diff --git a/test/test_http_uastrings.py b/test/test_http_uastrings.py deleted file mode 100644 index 3fa4f359..00000000 --- a/test/test_http_uastrings.py +++ /dev/null @@ -1,6 +0,0 @@ -from netlib import http_uastrings - - -def test_get_shortcut(): - assert http_uastrings.get_by_shortcut("c")[0] == "chrome" - assert not http_uastrings.get_by_shortcut("_") diff --git a/test/test_websockets.py b/test/test_websockets.py deleted file mode 100644 index ae0a5e33..00000000 --- a/test/test_websockets.py +++ /dev/null @@ -1,261 +0,0 @@ -import os - -from nose.tools import raises - -from netlib import tcp, websockets, http -from . import tutils, tservers - - -class WebSocketsEchoHandler(tcp.BaseHandler): - - def __init__(self, connection, address, server): - super(WebSocketsEchoHandler, self).__init__( - connection, address, server - ) - self.protocol = websockets.WebsocketsProtocol() - self.handshake_done = False - - def handle(self): - while True: - if not self.handshake_done: - self.handshake() - else: - self.read_next_message() - - def read_next_message(self): - frame = websockets.Frame.from_file(self.rfile) - self.on_message(frame.payload) - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=False) - frame.to_file(self.wfile) - - def handshake(self): - req = http.read_request(self.rfile) - key = self.protocol.check_client_handshake(req.headers) - - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = self.protocol.server_handshake_headers(key) - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - self.handshake_done = True - - def on_message(self, message): - if message is not None: - self.send_message(message) - - -class WebSocketsClient(tcp.TCPClient): - - def __init__(self, address, source_address=None): - super(WebSocketsClient, self).__init__(address, source_address) - self.protocol = websockets.WebsocketsProtocol() - self.client_nonce = None - - def connect(self): - super(WebSocketsClient, self).connect() - - preamble = http.request_preamble("GET", "/") - self.wfile.write(preamble + "\r\n") - headers = self.protocol.client_handshake_headers() - self.client_nonce = headers.get_first("sec-websocket-key") - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - - resp = http.read_response(self.rfile, "get", None) - server_nonce = self.protocol.check_server_handshake(resp.headers) - - if not server_nonce == self.protocol.create_server_nonce( - self.client_nonce): - self.close() - - def read_next_message(self): - return websockets.Frame.from_file(self.rfile).payload - - def send_message(self, message): - frame = websockets.Frame.default(message, from_client=True) - frame.to_file(self.wfile) - - -class TestWebSockets(tservers.ServerTestBase): - handler = WebSocketsEchoHandler - - def __init__(self): - self.protocol = websockets.WebsocketsProtocol() - - def random_bytes(self, n=100): - return os.urandom(n) - - def echo(self, msg): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message(msg) - response = client.read_next_message() - assert response == msg - - def test_simple_echo(self): - self.echo("hello I'm the client") - - def test_frame_sizes(self): - # length can fit in the the 7 bit payload length - small_msg = self.random_bytes(100) - # 50kb, sligthly larger than can fit in a 7 bit int - medium_msg = self.random_bytes(50000) - # 150kb, slightly larger than can fit in a 16 bit int - large_msg = self.random_bytes(150000) - - self.echo(small_msg) - self.echo(medium_msg) - self.echo(large_msg) - - def test_default_builder(self): - """ - default builder should always generate valid frames - """ - msg = self.random_bytes() - client_frame = websockets.Frame.default(msg, from_client=True) - server_frame = websockets.Frame.default(msg, from_client=False) - - def test_serialization_bijection(self): - """ - Ensure that various frame types can be serialized/deserialized back - and forth between to_bytes() and from_bytes() - """ - for is_client in [True, False]: - for num_bytes in [100, 50000, 150000]: - frame = websockets.Frame.default( - self.random_bytes(num_bytes), is_client - ) - frame2 = websockets.Frame.from_bytes( - frame.to_bytes() - ) - assert frame == frame2 - - bytes = b'\x81\x03cba' - assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes - - def test_check_server_handshake(self): - headers = self.protocol.server_handshake_headers("key") - assert self.protocol.check_server_handshake(headers) - headers["Upgrade"] = ["not_websocket"] - assert not self.protocol.check_server_handshake(headers) - - def test_check_client_handshake(self): - headers = self.protocol.client_handshake_headers("key") - assert self.protocol.check_client_handshake(headers) == "key" - headers["Upgrade"] = ["not_websocket"] - assert not self.protocol.check_client_handshake(headers) - - -class BadHandshakeHandler(WebSocketsEchoHandler): - - def handshake(self): - client_hs = http.read_request(self.rfile) - self.protocol.check_client_handshake(client_hs.headers) - - self.wfile.write(http.response_preamble(101) + "\r\n") - headers = self.protocol.server_handshake_headers("malformed key") - self.wfile.write(headers.format() + "\r\n") - self.wfile.flush() - self.handshake_done = True - - -class TestBadHandshake(tservers.ServerTestBase): - - """ - Ensure that the client disconnects if the server handshake is malformed - """ - handler = BadHandshakeHandler - - @raises(tcp.NetLibDisconnect) - def test(self): - client = WebSocketsClient(("127.0.0.1", self.port)) - client.connect() - client.send_message("hello") - - -class TestFrameHeader: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.FrameHeader(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) - assert f == f2 - round() - round(fin=1) - round(rsv1=1) - round(rsv2=1) - round(rsv3=1) - round(payload_length=1) - round(payload_length=100) - round(payload_length=1000) - round(payload_length=10000) - round(opcode=websockets.OPCODE.PING) - round(masking_key="test") - - def test_human_readable(self): - f = websockets.FrameHeader( - masking_key="test", - fin=True, - payload_length=10 - ) - assert f.human_readable() - f = websockets.FrameHeader() - assert f.human_readable() - - def test_funky(self): - f = websockets.FrameHeader(masking_key="test", mask=False) - bytes = f.to_bytes() - f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) - assert not f2.mask - - def test_violations(self): - tutils.raises("opcode", websockets.FrameHeader, opcode=17) - tutils.raises("masking key", websockets.FrameHeader, masking_key="x") - - def test_automask(self): - f = websockets.FrameHeader(mask=True) - assert f.masking_key - - f = websockets.FrameHeader(masking_key="foob") - assert f.mask - - f = websockets.FrameHeader(masking_key="foob", mask=0) - assert not f.mask - assert f.masking_key - - -class TestFrame: - - def test_roundtrip(self): - def round(*args, **kwargs): - f = websockets.Frame(*args, **kwargs) - bytes = f.to_bytes() - f2 = websockets.Frame.from_file(tutils.treader(bytes)) - assert f == f2 - round("test") - round("test", fin=1) - round("test", rsv1=1) - round("test", opcode=websockets.OPCODE.PING) - round("test", masking_key="test") - - def test_human_readable(self): - f = websockets.Frame() - assert f.human_readable() - - -def test_masker(): - tests = [ - ["a"], - ["four"], - ["fourf"], - ["fourfive"], - ["a", "aasdfasdfa", "asdf"], - ["a" * 50, "aasdfasdfa", "asdf"], - ] - for i in tests: - m = websockets.Masker("abcd") - data = "".join([m(t) for t in i]) - data2 = websockets.Masker("abcd")(data) - assert data2 == "".join(i) diff --git a/test/websockets/__init__.py b/test/websockets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py new file mode 100644 index 00000000..07ad0452 --- /dev/null +++ b/test/websockets/test_websockets.py @@ -0,0 +1,262 @@ +import os + +from nose.tools import raises + +from netlib import tcp, http, websockets +from netlib.http.exceptions import * +from .. import tutils, tservers + + +class WebSocketsEchoHandler(tcp.BaseHandler): + + def __init__(self, connection, address, server): + super(WebSocketsEchoHandler, self).__init__( + connection, address, server + ) + self.protocol = websockets.WebsocketsProtocol() + self.handshake_done = False + + def handle(self): + while True: + if not self.handshake_done: + self.handshake() + else: + self.read_next_message() + + def read_next_message(self): + frame = websockets.Frame.from_file(self.rfile) + self.on_message(frame.payload) + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client=False) + frame.to_file(self.wfile) + + def handshake(self): + req = http.http1.read_request(self.rfile) + key = self.protocol.check_client_handshake(req.headers) + + self.wfile.write(http.http1.response_preamble(101) + "\r\n") + headers = self.protocol.server_handshake_headers(key) + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + self.handshake_done = True + + def on_message(self, message): + if message is not None: + self.send_message(message) + + +class WebSocketsClient(tcp.TCPClient): + + def __init__(self, address, source_address=None): + super(WebSocketsClient, self).__init__(address, source_address) + self.protocol = websockets.WebsocketsProtocol() + self.client_nonce = None + + def connect(self): + super(WebSocketsClient, self).connect() + + preamble = http.http1.protocol.request_preamble("GET", "/") + self.wfile.write(preamble + "\r\n") + headers = self.protocol.client_handshake_headers() + self.client_nonce = headers.get_first("sec-websocket-key") + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + + resp = http.http1.protocol.read_response(self.rfile, "get", None) + server_nonce = self.protocol.check_server_handshake(resp.headers) + + if not server_nonce == self.protocol.create_server_nonce( + self.client_nonce): + self.close() + + def read_next_message(self): + return websockets.Frame.from_file(self.rfile).payload + + def send_message(self, message): + frame = websockets.Frame.default(message, from_client=True) + frame.to_file(self.wfile) + + +class TestWebSockets(tservers.ServerTestBase): + handler = WebSocketsEchoHandler + + def __init__(self): + self.protocol = websockets.WebsocketsProtocol() + + def random_bytes(self, n=100): + return os.urandom(n) + + def echo(self, msg): + client = WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message(msg) + response = client.read_next_message() + assert response == msg + + def test_simple_echo(self): + self.echo("hello I'm the client") + + def test_frame_sizes(self): + # length can fit in the the 7 bit payload length + small_msg = self.random_bytes(100) + # 50kb, sligthly larger than can fit in a 7 bit int + medium_msg = self.random_bytes(50000) + # 150kb, slightly larger than can fit in a 16 bit int + large_msg = self.random_bytes(150000) + + self.echo(small_msg) + self.echo(medium_msg) + self.echo(large_msg) + + def test_default_builder(self): + """ + default builder should always generate valid frames + """ + msg = self.random_bytes() + client_frame = websockets.Frame.default(msg, from_client=True) + server_frame = websockets.Frame.default(msg, from_client=False) + + def test_serialization_bijection(self): + """ + Ensure that various frame types can be serialized/deserialized back + and forth between to_bytes() and from_bytes() + """ + for is_client in [True, False]: + for num_bytes in [100, 50000, 150000]: + frame = websockets.Frame.default( + self.random_bytes(num_bytes), is_client + ) + frame2 = websockets.Frame.from_bytes( + frame.to_bytes() + ) + assert frame == frame2 + + bytes = b'\x81\x03cba' + assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes + + def test_check_server_handshake(self): + headers = self.protocol.server_handshake_headers("key") + assert self.protocol.check_server_handshake(headers) + headers["Upgrade"] = ["not_websocket"] + assert not self.protocol.check_server_handshake(headers) + + def test_check_client_handshake(self): + headers = self.protocol.client_handshake_headers("key") + assert self.protocol.check_client_handshake(headers) == "key" + headers["Upgrade"] = ["not_websocket"] + assert not self.protocol.check_client_handshake(headers) + + +class BadHandshakeHandler(WebSocketsEchoHandler): + + def handshake(self): + client_hs = http.http1.protocol.read_request(self.rfile) + self.protocol.check_client_handshake(client_hs.headers) + + self.wfile.write(http.http1.protocol.response_preamble(101) + "\r\n") + headers = self.protocol.server_handshake_headers("malformed key") + self.wfile.write(headers.format() + "\r\n") + self.wfile.flush() + self.handshake_done = True + + +class TestBadHandshake(tservers.ServerTestBase): + + """ + Ensure that the client disconnects if the server handshake is malformed + """ + handler = BadHandshakeHandler + + @raises(tcp.NetLibDisconnect) + def test(self): + client = WebSocketsClient(("127.0.0.1", self.port)) + client.connect() + client.send_message("hello") + + +class TestFrameHeader: + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.FrameHeader(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + assert f == f2 + round() + round(fin=1) + round(rsv1=1) + round(rsv2=1) + round(rsv3=1) + round(payload_length=1) + round(payload_length=100) + round(payload_length=1000) + round(payload_length=10000) + round(opcode=websockets.OPCODE.PING) + round(masking_key="test") + + def test_human_readable(self): + f = websockets.FrameHeader( + masking_key="test", + fin=True, + payload_length=10 + ) + assert f.human_readable() + f = websockets.FrameHeader() + assert f.human_readable() + + def test_funky(self): + f = websockets.FrameHeader(masking_key="test", mask=False) + bytes = f.to_bytes() + f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) + assert not f2.mask + + def test_violations(self): + tutils.raises("opcode", websockets.FrameHeader, opcode=17) + tutils.raises("masking key", websockets.FrameHeader, masking_key="x") + + def test_automask(self): + f = websockets.FrameHeader(mask=True) + assert f.masking_key + + f = websockets.FrameHeader(masking_key="foob") + assert f.mask + + f = websockets.FrameHeader(masking_key="foob", mask=0) + assert not f.mask + assert f.masking_key + + +class TestFrame: + + def test_roundtrip(self): + def round(*args, **kwargs): + f = websockets.Frame(*args, **kwargs) + bytes = f.to_bytes() + f2 = websockets.Frame.from_file(tutils.treader(bytes)) + assert f == f2 + round("test") + round("test", fin=1) + round("test", rsv1=1) + round("test", opcode=websockets.OPCODE.PING) + round("test", masking_key="test") + + def test_human_readable(self): + f = websockets.Frame() + assert f.human_readable() + + +def test_masker(): + tests = [ + ["a"], + ["four"], + ["fourf"], + ["fourfive"], + ["a", "aasdfasdfa", "asdf"], + ["a" * 50, "aasdfasdfa", "asdf"], + ] + for i in tests: + m = websockets.Masker("abcd") + data = "".join([m(t) for t in i]) + data2 = websockets.Masker("abcd")(data) + assert data2 == "".join(i) -- cgit v1.2.3 From bab6cbff1e5444aea72a188d57812130c375e0f0 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 15 Jul 2015 22:32:14 +0200 Subject: extract authentication methods from protocol --- netlib/http/authentication.py | 22 +++++++++++++++++++++- netlib/http/http1/protocol.py | 39 ++------------------------------------- netlib/http/semantics.py | 14 +++++++++++++- test/http/http1/test_protocol.py | 19 ++++--------------- test/http/test_authentication.py | 21 +++++++++++++++++---- 5 files changed, 57 insertions(+), 58 deletions(-) diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 26e3c2c4..9a227010 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -1,8 +1,28 @@ from __future__ import (absolute_import, print_function, division) from argparse import Action, ArgumentTypeError +import binascii from .. import http +def parse_http_basic_auth(s): + words = s.split() + if len(words) != 2: + return None + scheme = words[0] + try: + user = binascii.a2b_base64(words[1]) + except binascii.Error: + return None + parts = user.split(':') + if len(parts) != 2: + return None + return scheme, parts[0], parts[1] + + +def assemble_http_basic_auth(scheme, username, password): + v = binascii.b2a_base64(username + ":" + password) + return scheme + " " + v + class NullProxyAuth(object): @@ -47,7 +67,7 @@ class BasicProxyAuth(NullProxyAuth): auth_value = headers.get(self.AUTH_HEADER, []) if not auth_value: return False - parts = http.http1.parse_http_basic_auth(auth_value[0]) + parts = parse_http_basic_auth(auth_value[0]) if not parts: return False scheme, username, password = parts diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 0f7a0bd3..97c119a9 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -85,22 +85,9 @@ def read_chunked(fp, limit, is_request): return -def get_header_tokens(headers, key): - """ - Retrieve all tokens for a header key. A number of different headers - follow a pattern where each header line can containe comma-separated - tokens, and headers can be set multiple times. - """ - toks = [] - for i in headers[key]: - for j in i.split(","): - toks.append(j.strip()) - return toks - - def has_chunked_encoding(headers): return "chunked" in [ - i.lower() for i in get_header_tokens(headers, "transfer-encoding") + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") ] @@ -123,28 +110,6 @@ def parse_http_protocol(s): return major, minor -def parse_http_basic_auth(s): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - words = s.split() - if len(words) != 2: - return None - scheme = words[0] - try: - user = binascii.a2b_base64(words[1]) - except binascii.Error: - return None - parts = user.split(':') - if len(parts) != 2: - return None - return scheme, parts[0], parts[1] - - -def assemble_http_basic_auth(scheme, username, password): - # TODO: check if this is HTTP/1 only - otherwise move it to netlib.http.semantics - v = binascii.b2a_base64(username + ":" + password) - return scheme + " " + v - - def parse_init(line): try: method, url, protocol = string.split(line) @@ -221,7 +186,7 @@ def connection_close(httpversion, headers): """ # At first, check if we have an explicit Connection header. if "connection" in headers: - toks = get_header_tokens(headers, "connection") + toks = http.get_header_tokens(headers, "connection") if "close" in toks: return True elif "keep-alive" in toks: diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index e7e84fe3..a62c93e3 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -49,7 +49,6 @@ def is_valid_host(host): return True - def parse_url(url): """ Returns a (scheme, host, port, path) tuple, or None on error. @@ -92,3 +91,16 @@ def parse_url(url): if not is_valid_port(port): return None return scheme, host, port, path + + +def get_header_tokens(headers, key): + """ + Retrieve all tokens for a header key. A number of different headers + follow a pattern where each header line can containe comma-separated + tokens, and headers can be set multiple times. + """ + toks = [] + for i in headers[key]: + for j in i.split(","): + toks.append(j.strip()) + return toks diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 05e82831..d0a2ee02 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -71,13 +71,13 @@ def test_connection_close(): def test_get_header_tokens(): h = odict.ODictCaseless() - assert protocol.get_header_tokens(h, "foo") == [] + assert http.get_header_tokens(h, "foo") == [] h["foo"] = ["bar"] - assert protocol.get_header_tokens(h, "foo") == ["bar"] + assert http.get_header_tokens(h, "foo") == ["bar"] h["foo"] = ["bar, voing"] - assert protocol.get_header_tokens(h, "foo") == ["bar", "voing"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing"] h["foo"] = ["bar, voing", "oink"] - assert protocol.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] + assert http.get_header_tokens(h, "foo") == ["bar", "voing", "oink"] def test_read_http_body_request(): @@ -357,17 +357,6 @@ def test_read_response(): assert tst(data, "GET", None, include_body=False).content is None -def test_parse_http_basic_auth(): - vals = ("basic", "foo", "bar") - assert protocol.parse_http_basic_auth( - protocol.assemble_http_basic_auth(*vals) - ) == vals - assert not protocol.parse_http_basic_auth("") - assert not protocol.parse_http_basic_auth("foo bar") - v = "basic " + binascii.b2a_base64("foo") - assert not protocol.parse_http_basic_auth(v) - - def test_get_request_line(): r = cStringIO.StringIO("\nfoo") assert protocol.get_request_line(r) == "foo" diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index c0dae1a2..8f231643 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -1,8 +1,21 @@ +import binascii + from netlib import odict, http from netlib.http import authentication from .. import tutils +def test_parse_http_basic_auth(): + vals = ("basic", "foo", "bar") + assert http.authentication.parse_http_basic_auth( + http.authentication.assemble_http_basic_auth(*vals) + ) == vals + assert not http.authentication.parse_http_basic_auth("") + assert not http.authentication.parse_http_basic_auth("foo bar") + v = "basic " + binascii.b2a_base64("foo") + assert not http.authentication.parse_http_basic_auth(v) + + class TestPassManNonAnon: def test_simple(self): @@ -23,7 +36,7 @@ class TestPassManHtpasswd: pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) vals = ("basic", "test", "test") - http.http1.assemble_http_basic_auth(*vals) + authentication.assemble_http_basic_auth(*vals) assert pm.test("test", "test") assert not pm.test("test", "foo") assert not pm.test("foo", "test") @@ -62,7 +75,7 @@ class TestBasicProxyAuth: hdrs = odict.ODictCaseless() vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert ba.authenticate(hdrs) ba.clean(hdrs) @@ -75,12 +88,12 @@ class TestBasicProxyAuth: assert not ba.authenticate(hdrs) vals = ("foo", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") vals = ("basic", "foo", "bar") - hdrs[ba.AUTH_HEADER] = [http.http1.assemble_http_basic_auth(*vals)] + hdrs[ba.AUTH_HEADER] = [authentication.assemble_http_basic_auth(*vals)] assert not ba.authenticate(hdrs) -- cgit v1.2.3 From 230c16122b06f5c6af60e6ddc2d8e2e83cd75273 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 16 Jul 2015 22:50:24 +0200 Subject: change HTTP2 interface to match HTTP1 --- netlib/http/http2/protocol.py | 6 +++--- test/http/http2/test_protocol.py | 20 ++++++++++++-------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 8e5f5429..0d6eac85 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import itertools from hpack.hpack import Encoder, Decoder -from .. import utils +from netlib import http, utils from . import frame @@ -186,9 +186,9 @@ class HTTP2Protocol(object): self._create_headers(headers, stream_id, end_stream=(body is None)), self._create_body(body, stream_id))) - def read_response(self): + def read_response(self, *args): stream_id_, headers, body = self._receive_transmission() - return headers[':status'], headers, body + return http.Response("HTTP/2", headers[':status'], "", headers, body) def read_request(self): return self._receive_transmission() diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index f607860e..403a2589 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -251,11 +251,13 @@ class TestReadResponse(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - status, headers, body = protocol.read_response() + resp = protocol.read_response() - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'foobar' + assert resp.httpversion == "HTTP/2" + assert resp.status_code == "200" + assert resp.msg == "" + assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.content == b'foobar' class TestReadEmptyResponse(tservers.ServerTestBase): @@ -274,11 +276,13 @@ class TestReadEmptyResponse(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c) - status, headers, body = protocol.read_response() + resp = protocol.read_response() - assert headers == {':status': '200', 'etag': 'foobar'} - assert status == "200" - assert body == b'' + assert resp.httpversion == "HTTP/2" + assert resp.status_code == "200" + assert resp.msg == "" + assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.content == b'' class TestReadRequest(tservers.ServerTestBase): -- cgit v1.2.3 From 808b294865257fc3f52b33ed2a796009658b126f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 16 Jul 2015 22:56:34 +0200 Subject: refactor HTTP/1 as protocol --- netlib/http/http1/protocol.py | 901 +++++++++++++++++++------------------ test/http/http1/test_protocol.py | 214 ++++----- test/websockets/test_websockets.py | 21 +- 3 files changed, 583 insertions(+), 553 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 97c119a9..401654c1 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -9,475 +9,488 @@ from netlib import odict, utils, tcp, http from .. import status_codes from ..exceptions import * +class HTTP1Protocol(object): + + # TODO: make this a regular class - just like Response + Request = collections.namedtuple( + "Request", + [ + "form_in", + "method", + "scheme", + "host", + "port", + "path", + "httpversion", + "headers", + "content" + ] + ) -def get_request_line(fp): - """ - Get a line, possibly preceded by a blank. - """ - line = fp.readline() - if line == "\r\n" or line == "\n": - # Possible leftover from previous message - line = fp.readline() - return line - -def read_headers(fp): - """ - Read a set of headers from a file pointer. Stop once a blank line is - reached. Return a ODictCaseless object, or None if headers are invalid. - """ - ret = [] - name = '' - while True: - line = fp.readline() - if not line or line == '\r\n' or line == '\n': - break - if line[0] in ' \t': - if not ret: - return None - # continued header - ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() - else: - i = line.find(':') - # We're being liberal in what we accept, here. - if i > 0: - name = line[:i] - value = line[i + 1:].strip() - ret.append([name, value]) + def __init__(self, tcp_handler): + self.tcp_handler = tcp_handler + + def get_request_line(self): + """ + Get a line, possibly preceded by a blank. + """ + line = self.tcp_handler.rfile.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = self.tcp_handler.rfile.readline() + return line + + def read_headers(self): + """ + Read a set of headers. + Stop once a blank line is reached. + + Return a ODictCaseless object, or None if headers are invalid. + """ + ret = [] + name = '' + while True: + line = self.tcp_handler.rfile.readline() + if not line or line == '\r\n' or line == '\n': + break + if line[0] in ' \t': + if not ret: + return None + # continued header + ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip() else: - return None - return odict.ODictCaseless(ret) - - -def read_chunked(fp, limit, is_request): - """ - Read a chunked HTTP body. - - May raise HttpError. - """ - # FIXME: Should check if chunked is the final encoding in the headers - # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 - # 3.3 2. - total = 0 - code = 400 if is_request else 502 - while True: - line = fp.readline(128) - if line == "": - raise HttpErrorConnClosed(code, "Connection closed prematurely") - if line != '\r\n' and line != '\n': - try: - length = int(line, 16) - except ValueError: - raise HttpError( - code, - "Invalid chunked encoding length: %s" % line - ) - total += length - if limit is not None and total > limit: - msg = "HTTP Body too large. Limit is %s," \ - " chunked content longer than %s" % (limit, total) - raise HttpError(code, msg) - chunk = fp.read(length) - suffix = fp.readline(5) - if suffix != '\r\n': - raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - -def has_chunked_encoding(headers): - return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") - ] - - -def parse_http_protocol(s): - """ - Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or - None. - """ - if not s.startswith("HTTP/"): - return None - _, version = s.split('/', 1) - if "." not in version: - return None - major, minor = version.split('.', 1) - try: - major = int(major) - minor = int(minor) - except ValueError: - return None - return major, minor - - -def parse_init(line): - try: - method, url, protocol = string.split(line) - except ValueError: - return None - httpversion = parse_http_protocol(protocol) - if not httpversion: - return None - if not utils.isascii(method): - return None - return method, url, httpversion - - -def parse_init_connect(line): - """ - Returns (host, port, httpversion) if line is a valid CONNECT line. - http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - if method.upper() != 'CONNECT': - return None - try: - host, port = url.split(":") - except ValueError: - return None - try: - port = int(port) - except ValueError: - return None - if not http.is_valid_port(port): - return None - if not http.is_valid_host(host): - return None - return host, port, httpversion - - -def parse_init_proxy(line): - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - - parts = http.parse_url(url) - if not parts: - return None - scheme, host, port, path = parts - return method, scheme, host, port, path, httpversion - - -def parse_init_http(line): - """ - Returns (method, url, httpversion) - """ - v = parse_init(line) - if not v: - return None - method, url, httpversion = v - if not utils.isascii(url): - return None - if not (url.startswith("/") or url == "*"): - return None - return method, url, httpversion - - -def connection_close(httpversion, headers): - """ - Checks the message to see if the client connection should be closed - according to RFC 2616 Section 8.1 Note that a connection should be - closed as well if the response has been read until end of the stream. - """ - # At first, check if we have an explicit Connection header. - if "connection" in headers: - toks = http.get_header_tokens(headers, "connection") - if "close" in toks: - return True - elif "keep-alive" in toks: - return False - # If we don't have a Connection header, HTTP 1.1 connections are assumed to - # be persistent - if httpversion == (1, 1): - return False - return True - - -def parse_response_line(line): - parts = line.strip().split(" ", 2) - if len(parts) == 2: # handle missing message gracefully - parts.append("") - if len(parts) != 3: - return None - proto, code, msg = parts - try: - code = int(code) - except ValueError: - return None - return (proto, code, msg) - - -def read_http_body(*args, **kwargs): - return "".join( - content for _, content, _ in read_http_body_chunked(*args, **kwargs) - ) + i = line.find(':') + # We're being liberal in what we accept, here. + if i > 0: + name = line[:i] + value = line[i + 1:].strip() + ret.append([name, value]) + else: + return None + return odict.ODictCaseless(ret) + + + def read_chunked(self, limit, is_request): + """ + Read a chunked HTTP body. + + May raise HttpError. + """ + # FIXME: Should check if chunked is the final encoding in the headers + # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 + # 3.3 2. + total = 0 + code = 400 if is_request else 502 + while True: + line = self.tcp_handler.rfile.readline(128) + if line == "": + raise HttpErrorConnClosed(code, "Connection closed prematurely") + if line != '\r\n' and line != '\n': + try: + length = int(line, 16) + except ValueError: + raise HttpError( + code, + "Invalid chunked encoding length: %s" % line + ) + total += length + if limit is not None and total > limit: + msg = "HTTP Body too large. Limit is %s," \ + " chunked content longer than %s" % (limit, total) + raise HttpError(code, msg) + chunk = self.tcp_handler.rfile.read(length) + suffix = self.tcp_handler.rfile.readline(5) + if suffix != '\r\n': + raise HttpError(code, "Malformed chunked body") + yield line, chunk, '\r\n' + if length == 0: + return + + + @classmethod + def has_chunked_encoding(self, headers): + return "chunked" in [ + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + ] + + + @classmethod + def parse_http_protocol(self, line): + """ + Parse an HTTP protocol declaration. + Returns a (major, minor) tuple, or None. + """ + if not line.startswith("HTTP/"): + return None + _, version = line.split('/', 1) + if "." not in version: + return None + major, minor = version.split('.', 1) + try: + major = int(major) + minor = int(minor) + except ValueError: + return None + return major, minor -def read_http_body_chunked( - rfile, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None -): - """ - Read an HTTP message body: - - rfile: A file descriptor to read from - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = expected_http_body_size( - headers, is_request, request_method, response_code - ) + @classmethod + def parse_init(self, line): + try: + method, url, protocol = string.split(line) + except ValueError: + return None + httpversion = self.parse_http_protocol(protocol) + if not httpversion: + return None + if not utils.isascii(method): + return None + return method, url, httpversion - if expected_size is None: - if has_chunked_encoding(headers): - # Python 3: yield from - for x in read_chunked(rfile, limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) + @classmethod + def parse_init_connect(self, line): + """ + Returns (host, port, httpversion) if line is a valid CONNECT line. + http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 + """ + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v -def expected_http_body_size(headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if has_chunked_encoding(headers): - return None - if "content-length" in headers: + if method.upper() != 'CONNECT': + return None try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size + host, port = url.split(":") except ValueError: return None - if is_request: - return 0 - return -1 - - -# TODO: make this a regular class - just like Response -Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] -) - - -def read_request(rfile, include_body=True, body_size_limit=None, wfile=None): - """ - Parse an HTTP request from a file stream - - Args: - rfile (file): Input file to read from - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = get_request_line(rfile) - if not request_line: - raise tcp.NetLibDisconnect() - - request_line_parts = parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + try: + port = int(port) + except ValueError: + return None + if not http.is_valid_port(port): + return None + if not http.is_valid_host(host): + return None + return host, port, httpversion + + @classmethod + def parse_init_proxy(self, line): + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v + + parts = http.parse_url(url) + if not parts: + return None + scheme, host, port, path = parts + return method, scheme, host, port, path, httpversion + + @classmethod + def parse_init_http(self, line): + """ + Returns (method, url, httpversion) + """ + v = self.parse_init(line) + if not v: + return None + method, url, httpversion = v + if not utils.isascii(url): + return None + if not (url.startswith("/") or url == "*"): + return None + return method, url, httpversion + + + @classmethod + def connection_close(self, httpversion, headers): + """ + Checks the message to see if the client connection should be closed + according to RFC 2616 Section 8.1 Note that a connection should be + closed as well if the response has been read until end of the stream. + """ + # At first, check if we have an explicit Connection header. + if "connection" in headers: + toks = http.get_header_tokens(headers, "connection") + if "close" in toks: + return True + elif "keep-alive" in toks: + return False + + # If we don't have a Connection header, HTTP 1.1 connections are assumed to + # be persistent + return httpversion != (1, 1) + + + @classmethod + def parse_response_line(self, line): + parts = line.strip().split(" ", 2) + if len(parts) == 2: # handle missing message gracefully + parts.append("") + if len(parts) != 3: + return None + proto, code, msg = parts + try: + code = int(code) + except ValueError: + return None + return (proto, code, msg) + + + def read_http_body(self, *args, **kwargs): + return "".join( + content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) + ) + + + def read_http_body_chunked( + self, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None + ): + """ + Read an HTTP message body: + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = self.expected_http_body_size( + headers, is_request, request_method, response_code ) - method, path, httpversion = request_line_parts - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): + if expected_size is None: + if self.has_chunked_encoding(headers): + # Python 3: yield from + for x in self.read_chunked(limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", self.tcp_handler.rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = self.tcp_handler.rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = self.tcp_handler.rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + + @classmethod + def expected_http_body_size(self, headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if self.has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + + def read_request(self, include_body=True, body_size_limit=None): + """ + Parse an HTTP request from a file stream + + Args: + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. + """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = self.get_request_line() + if not request_line: + raise tcp.NetLibDisconnect() + + request_line_parts = self.parse_init(request_line) + if not request_line_parts: raise HttpError( 400, "Bad HTTP request line: %s" % repr(request_line) ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = self.parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + path = None + else: + form_in = "absolute" + r = self.parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = self.read_headers() + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + self.tcp_handler.wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' ) - host, port, _ = r - path = None - else: - form_in = "absolute" - r = parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) + self.tcp_handler.wfile.flush() + del headers['expect'] + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + method, + None, + True ) - _, scheme, host, port, path, _ = r - headers = read_headers(rfile) - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' + return self.Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content ) - wfile.flush() - del headers['expect'] - if include_body: - content = read_http_body( - rfile, headers, body_size_limit, method, None, True - ) - return Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) + def read_response(self, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ -def read_response(rfile, request_method, body_size_limit, include_body=True): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - - line = rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = read_headers(rfile) - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = read_http_body( - rfile, - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + line = self.tcp_handler.rfile.readline() + # Possible leftover from previous message + if line == "\r\n" or line == "\n": + line = self.tcp_handler.rfile.readline() + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = self.parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = self.parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = self.read_headers() + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) -def request_preamble(method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) + @classmethod + def request_preamble(self, method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) -def response_preamble(code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) + @classmethod + def response_preamble(self, code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index d0a2ee02..6b8a884c 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -3,70 +3,79 @@ import textwrap import binascii from netlib import http, odict, tcp -from netlib.http.http1 import protocol +from netlib.http.http1 import HTTP1Protocol from ... import tutils, tservers +def mock_protocol(data='', chunked=False): + class TCPHandlerMock(object): + pass + tcp_handler = TCPHandlerMock() + tcp_handler.rfile = cStringIO.StringIO(data) + tcp_handler.wfile = cStringIO.StringIO() + return HTTP1Protocol(tcp_handler) + + + def test_has_chunked_encoding(): h = odict.ODictCaseless() - assert not protocol.has_chunked_encoding(h) + assert not HTTP1Protocol.has_chunked_encoding(h) h["transfer-encoding"] = ["chunked"] - assert protocol.has_chunked_encoding(h) + assert HTTP1Protocol.has_chunked_encoding(h) def test_read_chunked(): - h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = cStringIO.StringIO("1\r\na\r\n0\r\n") + data = "1\r\na\r\n0\r\n" tutils.raises( "malformed chunked body", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") - assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + data = "1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") - assert protocol.read_http_body(s, h, None, "GET", None, True) == "a" + data = "\r\n\r\n1\r\na\r\n0\r\n\r\n" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "a" - s = cStringIO.StringIO("\r\n") + data = "\r\n" tutils.raises( "closed prematurely", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("1\r\nfoo") + data = "1\r\nfoo" tutils.raises( "malformed chunked body", - protocol.read_http_body, - s, h, None, "GET", None, True + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("foo\r\nfoo") + data = "foo\r\nfoo" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", None, True + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", None, True ) - s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") - tutils.raises("too large", protocol.read_http_body, s, h, 2, "GET", None, True) + data = "5\r\naaaaa\r\n0\r\n\r\n" + tutils.raises("too large", mock_protocol(data).read_http_body, h, 2, "GET", None, True) def test_connection_close(): h = odict.ODictCaseless() - assert protocol.connection_close((1, 0), h) - assert not protocol.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 0), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["keep-alive"] - assert not protocol.connection_close((1, 1), h) + assert not HTTP1Protocol.connection_close((1, 1), h) h["connection"] = ["close"] - assert protocol.connection_close((1, 1), h) + assert HTTP1Protocol.connection_close((1, 1), h) def test_get_header_tokens(): @@ -82,119 +91,119 @@ def test_get_header_tokens(): def test_read_http_body_request(): h = odict.ODictCaseless() - r = cStringIO.StringIO("testing") - assert protocol.read_http_body(r, h, None, "GET", None, True) == "" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", None, True) == "" def test_read_http_body_response(): h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data, chunked=True).read_http_body(h, None, "GET", 200, False) == "testing" def test_read_http_body(): # test default case h = odict.ODictCaseless() h["content-length"] = [7] - s = cStringIO.StringIO("testing") - assert protocol.read_http_body(s, h, None, "GET", 200, False) == "testing" + data = "testing" + assert mock_protocol(data).read_http_body(h, None, "GET", 200, False) == "testing" # test content length: invalid header h["content-length"] = ["foo"] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: invalid header #2 h["content-length"] = [-1] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, None, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, None, "GET", 200, False ) # test content length: content length > actual content h["content-length"] = [5] - s = cStringIO.StringIO("testing") + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, 4, "GET", 200, False + http.HttpError, + mock_protocol(data).read_http_body, + h, 4, "GET", 200, False ) # test content length: content length < actual content - s = cStringIO.StringIO("testing") - assert len(protocol.read_http_body(s, h, None, "GET", 200, False)) == 5 + data = "testing" + assert len(mock_protocol(data).read_http_body(h, None, "GET", 200, False)) == 5 # test no content length: limit > actual content h = odict.ODictCaseless() - s = tcp.Reader(cStringIO.StringIO("testing")) - assert len(protocol.read_http_body(s, h, 100, "GET", 200, False)) == 7 + data = "testing" + assert len(mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False)) == 7 # test no content length: limit < actual content - s = tcp.Reader(cStringIO.StringIO("testing")) + data = "testing" tutils.raises( - protocol.HttpError, - protocol.read_http_body, - s, h, 4, "GET", 200, False + http.HttpError, + mock_protocol(data, chunked=True).read_http_body, + h, 4, "GET", 200, False ) # test chunked h = odict.ODictCaseless() h["transfer-encoding"] = ["chunked"] - s = tcp.Reader(cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n")) - assert protocol.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" + data = "5\r\naaaaa\r\n0\r\n\r\n" + assert mock_protocol(data, chunked=True).read_http_body(h, 100, "GET", 200, False) == "aaaaa" def test_expected_http_body_size(): # gibber in the content-length field h = odict.ODictCaseless() h["content-length"] = ["foo"] - assert protocol.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # negative number in the content-length field h = odict.ODictCaseless() h["content-length"] = ["-7"] - assert protocol.expected_http_body_size(h, False, "GET", 200) is None + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) is None # explicit length h = odict.ODictCaseless() h["content-length"] = ["5"] - assert protocol.expected_http_body_size(h, False, "GET", 200) == 5 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == 5 # no length h = odict.ODictCaseless() - assert protocol.expected_http_body_size(h, False, "GET", 200) == -1 + assert HTTP1Protocol.expected_http_body_size(h, False, "GET", 200) == -1 # no length request h = odict.ODictCaseless() - assert protocol.expected_http_body_size(h, True, "GET", None) == 0 + assert HTTP1Protocol.expected_http_body_size(h, True, "GET", None) == 0 def test_parse_http_protocol(): - assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1) - assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not protocol.parse_http_protocol("HTTP/a.1") - assert not protocol.parse_http_protocol("HTTP/1.a") - assert not protocol.parse_http_protocol("foo/0.0") - assert not protocol.parse_http_protocol("HTTP/x") + assert HTTP1Protocol.parse_http_protocol("HTTP/1.1") == (1, 1) + assert HTTP1Protocol.parse_http_protocol("HTTP/0.0") == (0, 0) + assert not HTTP1Protocol.parse_http_protocol("HTTP/a.1") + assert not HTTP1Protocol.parse_http_protocol("HTTP/1.a") + assert not HTTP1Protocol.parse_http_protocol("foo/0.0") + assert not HTTP1Protocol.parse_http_protocol("HTTP/x") def test_parse_init_connect(): - assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not protocol.parse_init_connect("bogus") - assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + assert HTTP1Protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("bogus") + assert not HTTP1Protocol.parse_init_connect("GET host.com:443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") def test_parse_init_proxy(): u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u) + m, s, h, po, pa, httpversion = HTTP1Protocol.parse_init_proxy(u) assert m == "GET" assert s == "http" assert h == "foo.com" @@ -203,27 +212,27 @@ def test_parse_init_proxy(): assert httpversion == (1, 1) u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not protocol.parse_init_proxy(u) + assert not HTTP1Protocol.parse_init_proxy(u) - assert not protocol.parse_init_proxy("invalid") - assert not protocol.parse_init_proxy("GET invalid HTTP/1.1") - assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + assert not HTTP1Protocol.parse_init_proxy("invalid") + assert not HTTP1Protocol.parse_init_proxy("GET invalid HTTP/1.1") + assert not HTTP1Protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") def test_parse_init_http(): u = "GET /test HTTP/1.1" - m, u, httpversion = protocol.parse_init_http(u) + m, u, httpversion = HTTP1Protocol.parse_init_http(u) assert m == "GET" assert u == "/test" assert httpversion == (1, 1) u = "G\xfeET /test HTTP/1.1" - assert not protocol.parse_init_http(u) + assert not HTTP1Protocol.parse_init_http(u) - assert not protocol.parse_init_http("invalid") - assert not protocol.parse_init_http("GET invalid HTTP/1.1") - assert not protocol.parse_init_http("GET /test foo/1.1") - assert not protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + assert not HTTP1Protocol.parse_init_http("invalid") + assert not HTTP1Protocol.parse_init_http("GET invalid HTTP/1.1") + assert not HTTP1Protocol.parse_init_http("GET /test foo/1.1") + assert not HTTP1Protocol.parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: @@ -232,8 +241,7 @@ class TestReadHeaders: if not verbatim: data = textwrap.dedent(data) data = data.strip() - s = cStringIO.StringIO(data) - return protocol.read_headers(s) + return mock_protocol(data).read_headers() def test_read_simple(self): data = """ @@ -287,16 +295,15 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): def test_no_content_length(self): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() - resp = protocol.read_response(c.rfile, "GET", None) + resp = HTTP1Protocol(c).read_response("GET", None) assert resp.content == "bar\r\n\r\n" def test_read_response(): def tst(data, method, limit, include_body=True): data = textwrap.dedent(data) - r = cStringIO.StringIO(data) - return protocol.read_response( - r, method, limit, include_body=include_body + return mock_protocol(data).read_response( + method, limit, include_body=include_body ) tutils.raises("server disconnect", tst, "", "GET", None) @@ -358,16 +365,16 @@ def test_read_response(): def test_get_request_line(): - r = cStringIO.StringIO("\nfoo") - assert protocol.get_request_line(r) == "foo" - assert not protocol.get_request_line(r) + data = "\nfoo" + p = mock_protocol(data) + assert p.get_request_line() == "foo" + assert not p.get_request_line() class TestReadRequest(): def tst(self, data, **kwargs): - r = cStringIO.StringIO(data) - return protocol.read_request(r, **kwargs) + return mock_protocol(data).read_request(**kwargs) def test_invalid(self): tutils.raises( @@ -421,14 +428,15 @@ class TestReadRequest(): assert v.host == "foo.com" def test_expect(self): - w = cStringIO.StringIO() - r = cStringIO.StringIO( + data = "".join( "GET / HTTP/1.1\r\n" "Content-Length: 3\r\n" "Expect: 100-continue\r\n\r\n" - "foobar", + "foobar" ) - v = protocol.read_request(r, wfile=w) - assert w.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" + + p = mock_protocol(data) + v = p.read_request() + assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" assert v.content == "foo" - assert r.read(3) == "bar" + assert p.tcp_handler.rfile.read(3) == "bar" diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 07ad0452..fb7ba39a 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -4,6 +4,7 @@ from nose.tools import raises from netlib import tcp, http, websockets from netlib.http.exceptions import * +from netlib.http.http1 import HTTP1Protocol from .. import tutils, tservers @@ -32,10 +33,13 @@ class WebSocketsEchoHandler(tcp.BaseHandler): frame.to_file(self.wfile) def handshake(self): - req = http.http1.read_request(self.rfile) + http1_protocol = HTTP1Protocol(self) + + req = http1_protocol.read_request() key = self.protocol.check_client_handshake(req.headers) - self.wfile.write(http.http1.response_preamble(101) + "\r\n") + preamble = http1_protocol.response_preamble(101) + self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") self.wfile.flush() @@ -56,14 +60,16 @@ class WebSocketsClient(tcp.TCPClient): def connect(self): super(WebSocketsClient, self).connect() - preamble = http.http1.protocol.request_preamble("GET", "/") + http1_protocol = HTTP1Protocol(self) + + preamble = http1_protocol.request_preamble("GET", "/") self.wfile.write(preamble + "\r\n") headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() - resp = http.http1.protocol.read_response(self.rfile, "get", None) + resp = http1_protocol.read_response("get", None) server_nonce = self.protocol.check_server_handshake(resp.headers) if not server_nonce == self.protocol.create_server_nonce( @@ -151,10 +157,13 @@ class TestWebSockets(tservers.ServerTestBase): class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): - client_hs = http.http1.protocol.read_request(self.rfile) + http1_protocol = HTTP1Protocol(self) + + client_hs = http1_protocol.read_request() self.protocol.check_client_handshake(client_hs.headers) - self.wfile.write(http.http1.protocol.response_preamble(101) + "\r\n") + preamble = http1_protocol.response_preamble(101) + self.wfile.write(preamble + "\r\n") headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() -- cgit v1.2.3 From 4617ab8a3a981f3abd8d62b561c80f9ad141e57b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Fri, 17 Jul 2015 09:37:57 +0200 Subject: add Request class and unify read_request interface --- netlib/http/__init__.py | 1 + netlib/http/http1/protocol.py | 22 +++++----------------- netlib/http/http2/protocol.py | 20 +++++++++++++++++--- netlib/http/semantics.py | 31 +++++++++++++++++++++++++++++++ test/http/http2/test_protocol.py | 9 +++++---- 5 files changed, 59 insertions(+), 24 deletions(-) diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 9b4b0e6b..b01afc6d 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,2 +1,3 @@ +from . import * from exceptions import * from semantics import * diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 401654c1..8d631a13 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -11,25 +11,10 @@ from ..exceptions import * class HTTP1Protocol(object): - # TODO: make this a regular class - just like Response - Request = collections.namedtuple( - "Request", - [ - "form_in", - "method", - "scheme", - "host", - "port", - "path", - "httpversion", - "headers", - "content" - ] - ) - def __init__(self, tcp_handler): self.tcp_handler = tcp_handler + def get_request_line(self): """ Get a line, possibly preceded by a blank. @@ -40,6 +25,7 @@ class HTTP1Protocol(object): line = self.tcp_handler.rfile.readline() return line + def read_headers(self): """ Read a set of headers. @@ -175,6 +161,7 @@ class HTTP1Protocol(object): return None return host, port, httpversion + @classmethod def parse_init_proxy(self, line): v = self.parse_init(line) @@ -188,6 +175,7 @@ class HTTP1Protocol(object): scheme, host, port, path = parts return method, scheme, host, port, path, httpversion + @classmethod def parse_init_http(self, line): """ @@ -425,7 +413,7 @@ class HTTP1Protocol(object): True ) - return self.Request( + return http.Request( form_in, method, scheme, diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 0d6eac85..1dfdda21 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -187,11 +187,25 @@ class HTTP2Protocol(object): self._create_body(body, stream_id))) def read_response(self, *args): - stream_id_, headers, body = self._receive_transmission() - return http.Response("HTTP/2", headers[':status'], "", headers, body) + stream_id, headers, body = self._receive_transmission() + + response = http.Response("HTTP/2", headers[':status'], "", headers, body) + response.stream_id = stream_id + return response def read_request(self): - return self._receive_transmission() + stream_id, headers, body = self._receive_transmission() + + form_in = "" + method = headers.get(':method', '') + scheme = headers.get(':scheme', '') + host = headers.get(':host', '') + port = '' # TODO: parse port number? + path = headers.get(':path', '') + + request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) + request.stream_id = stream_id + return request def _receive_transmission(self): body_expected = True diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index a62c93e3..9a010318 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,37 @@ import urlparse from .. import utils +class Request(object): + + def __init__( + self, + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content, + ): + self.form_in = form_in + self.method = method + self.scheme = scheme + self.host = host + self.port = port + self.path = path + self.httpversion = httpversion + self.headers = headers + self.content = content + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + + class Response(object): def __init__( diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 403a2589..f41b9565 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -278,6 +278,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): resp = protocol.read_response() + assert resp.stream_id assert resp.httpversion == "HTTP/2" assert resp.status_code == "200" assert resp.msg == "" @@ -303,11 +304,11 @@ class TestReadRequest(tservers.ServerTestBase): c.convert_to_ssl() protocol = http2.HTTP2Protocol(c, is_server=True) - stream_id, headers, body = protocol.read_request() + resp = protocol.read_request() - assert stream_id - assert headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert body == b'foobar' + assert resp.stream_id + assert resp.headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert resp.content == b'foobar' class TestCreateResponse(): -- cgit v1.2.3 From 37a0cb858cda255bac8f06749a81859c82c5177f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 17:52:10 +0200 Subject: introduce ConnectRequest class --- netlib/http/http1/protocol.py | 2 +- netlib/http/semantics.py | 24 +++++++++++++++++++----- netlib/odict.py | 2 ++ 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 8d631a13..257efb19 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -380,7 +380,7 @@ class HTTP1Protocol(object): "Bad HTTP request line: %s" % repr(request_line) ) host, port, _ = r - path = None + return http.ConnectRequest(host, port) else: form_in = "absolute" r = self.parse_init_proxy(request_line) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9a010318..664f9def 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -19,7 +19,7 @@ class Request(object): path, httpversion, headers, - content, + body, ): self.form_in = form_in self.method = method @@ -29,7 +29,7 @@ class Request(object): self.path = path self.httpversion = httpversion self.headers = headers - self.content = content + self.body = body def __eq__(self, other): return self.__dict__ == other.__dict__ @@ -38,6 +38,21 @@ class Request(object): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) +class ConnectRequest(Request): + def __init__(self, host, port): + super(ConnectRequest, self).__init__( + form_in="authority", + method="CONNECT", + scheme="", + host=host, + port=port, + path="", + httpversion="", + headers="", + body="", + ) + + class Response(object): def __init__( @@ -46,14 +61,14 @@ class Response(object): status_code, msg, headers, - content, + body, sslinfo=None, ): self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers - self.content = content + self.body = body self.sslinfo = sslinfo def __eq__(self, other): @@ -63,7 +78,6 @@ class Response(object): return "Response(%s - %s)" % (self.status_code, self.msg) - def is_valid_port(port): if not 0 <= port <= 65535: return False diff --git a/netlib/odict.py b/netlib/odict.py index f52acd50..ee1e6938 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -20,6 +20,8 @@ class ODict(object): """ def __init__(self, lst=None): + if isinstance(lst, ODict): + lst = lst.items() self.lst = lst or [] def _kconv(self, s): -- cgit v1.2.3 From d62dbee0f6cd47b4cad1ee7cc731b413600c0add Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 18:17:30 +0200 Subject: rename content -> body --- netlib/wsgi.py | 6 +++--- test/http/http1/test_protocol.py | 10 +++++----- test/http/http2/test_protocol.py | 6 +++--- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/netlib/wsgi.py b/netlib/wsgi.py index ad43dc19..99afe00e 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -21,9 +21,9 @@ class Flow(object): class Request(object): - def __init__(self, scheme, method, path, headers, content): + def __init__(self, scheme, method, path, headers, body): self.scheme, self.method, self.path = scheme, method, path - self.headers, self.content = headers, content + self.headers, self.body = headers, body def date_time_string(): @@ -58,7 +58,7 @@ class WSGIAdaptor(object): environ = { 'wsgi.version': (1, 0), 'wsgi.url_scheme': flow.request.scheme, - 'wsgi.input': cStringIO.StringIO(flow.request.content), + 'wsgi.input': cStringIO.StringIO(flow.request.body or ""), 'wsgi.errors': errsoc, 'wsgi.multithread': True, 'wsgi.multiprocess': False, diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 6b8a884c..936fe20d 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -296,7 +296,7 @@ class TestReadResponseNoContentLength(tservers.ServerTestBase): c = tcp.TCPClient(("127.0.0.1", self.port)) c.connect() resp = HTTP1Protocol(c).read_response("GET", None) - assert resp.content == "bar\r\n\r\n" + assert resp.body == "bar\r\n\r\n" def test_read_response(): @@ -344,8 +344,8 @@ def test_read_response(): foo """ - assert tst(data, "GET", None).content == 'foo' - assert tst(data, "HEAD", None).content == '' + assert tst(data, "GET", None).body == 'foo' + assert tst(data, "HEAD", None).body == '' data = """ HTTP/1.1 200 OK @@ -361,7 +361,7 @@ def test_read_response(): foo """ - assert tst(data, "GET", None, include_body=False).content is None + assert tst(data, "GET", None, include_body=False).body is None def test_get_request_line(): @@ -438,5 +438,5 @@ class TestReadRequest(): p = mock_protocol(data) v = p.read_request() assert p.tcp_handler.wfile.getvalue() == "HTTP/1.1 100 Continue\r\n\r\n" - assert v.content == "foo" + assert v.body == "foo" assert p.tcp_handler.rfile.read(3) == "bar" diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index f41b9565..34e4ef50 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -257,7 +257,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.status_code == "200" assert resp.msg == "" assert resp.headers == {':status': '200', 'etag': 'foobar'} - assert resp.content == b'foobar' + assert resp.body == b'foobar' class TestReadEmptyResponse(tservers.ServerTestBase): @@ -283,7 +283,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.status_code == "200" assert resp.msg == "" assert resp.headers == {':status': '200', 'etag': 'foobar'} - assert resp.content == b'' + assert resp.body == b'' class TestReadRequest(tservers.ServerTestBase): @@ -308,7 +308,7 @@ class TestReadRequest(tservers.ServerTestBase): assert resp.stream_id assert resp.headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} - assert resp.content == b'foobar' + assert resp.body == b'foobar' class TestCreateResponse(): -- cgit v1.2.3 From 83f013fca13c7395ca4e3da3fac60c8d907172b6 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 20:46:26 +0200 Subject: introduce EmptyRequest class --- netlib/http/http1/protocol.py | 7 +++++-- netlib/http/semantics.py | 14 ++++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index 257efb19..d2a77399 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -333,7 +333,7 @@ class HTTP1Protocol(object): return -1 - def read_request(self, include_body=True, body_size_limit=None): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): """ Parse an HTTP request from a file stream @@ -354,7 +354,10 @@ class HTTP1Protocol(object): request_line = self.get_request_line() if not request_line: - raise tcp.NetLibDisconnect() + if allow_empty: + return http.EmptyRequest() + else: + raise tcp.NetLibDisconnect() request_line_parts = self.parse_init(request_line) if not request_line_parts: diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 664f9def..355906dd 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -38,6 +38,20 @@ class Request(object): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) +class EmptyRequest(Request): + def __init__(self): + super(EmptyRequest, self).__init__( + form_in="", + method="", + scheme="", + host="", + port="", + path="", + httpversion="", + headers="", + body="", + ) + class ConnectRequest(Request): def __init__(self, host, port): super(ConnectRequest, self).__init__( -- cgit v1.2.3 From ecc7ffe9282ae9d1b652a88946d6edc550dc9633 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sun, 19 Jul 2015 23:25:15 +0200 Subject: reduce public interface MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit use private indicator pattern “_methodname” --- netlib/http/http1/protocol.py | 569 ++++++++++++++++++++------------------- test/http/http1/test_protocol.py | 56 ++-- 2 files changed, 313 insertions(+), 312 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index d2a77399..e7727e00 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -15,15 +15,144 @@ class HTTP1Protocol(object): self.tcp_handler = tcp_handler - def get_request_line(self): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): """ - Get a line, possibly preceded by a blank. + Parse an HTTP request from a file stream + + Args: + include_body (bool): Read response body as well + body_size_limit (bool): Maximum body size + wfile (file): If specified, HTTP Expect headers are handled + automatically, by writing a HTTP 100 CONTINUE response to the stream. + + Returns: + Request: The HTTP request + + Raises: + HttpError: If the input is invalid. """ + httpversion, host, port, scheme, method, path, headers, content = ( + None, None, None, None, None, None, None, None) + + request_line = self._get_request_line() + if not request_line: + if allow_empty: + return http.EmptyRequest() + else: + raise tcp.NetLibDisconnect() + + request_line_parts = self._parse_init(request_line) + if not request_line_parts: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + method, path, httpversion = request_line_parts + + if path == '*' or path.startswith("/"): + form_in = "relative" + if not utils.isascii(path): + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + elif method.upper() == 'CONNECT': + form_in = "authority" + r = self._parse_init_connect(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + host, port, _ = r + return http.ConnectRequest(host, port) + else: + form_in = "absolute" + r = self._parse_init_proxy(request_line) + if not r: + raise HttpError( + 400, + "Bad HTTP request line: %s" % repr(request_line) + ) + _, scheme, host, port, path, _ = r + + headers = self.read_headers() + if headers is None: + raise HttpError(400, "Invalid headers") + + expect_header = headers.get_first("expect", "").lower() + if expect_header == "100-continue" and httpversion >= (1, 1): + self.tcp_handler.wfile.write( + 'HTTP/1.1 100 Continue\r\n' + '\r\n' + ) + self.tcp_handler.wfile.flush() + del headers['expect'] + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + method, + None, + True + ) + + return http.Request( + form_in, + method, + scheme, + host, + port, + path, + httpversion, + headers, + content + ) + + + def read_response(self, request_method, body_size_limit, include_body=True): + """ + Returns an http.Response + + By default, both response header and body are read. + If include_body=False is specified, content may be one of the + following: + - None, if the response is technically allowed to have a response body + - "", if the response must not have a response body (e.g. it's a + response to a HEAD request) + """ + line = self.tcp_handler.rfile.readline() + # Possible leftover from previous message if line == "\r\n" or line == "\n": - # Possible leftover from previous message line = self.tcp_handler.rfile.readline() - return line + if not line: + raise HttpErrorConnClosed(502, "Server disconnect.") + parts = self.parse_response_line(line) + if not parts: + raise HttpError(502, "Invalid server response: %s" % repr(line)) + proto, code, msg = parts + httpversion = self._parse_http_protocol(proto) + if httpversion is None: + raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) + headers = self.read_headers() + if headers is None: + raise HttpError(502, "Invalid headers.") + + if include_body: + content = self.read_http_body( + headers, + body_size_limit, + request_method, + code, + False + ) + else: + # if include_body==False then a None content means the body should be + # read separately + content = None + return http.Response(httpversion, code, msg, headers, content) def read_headers(self): @@ -56,7 +185,146 @@ class HTTP1Protocol(object): return odict.ODictCaseless(ret) - def read_chunked(self, limit, is_request): + def read_http_body(self, *args, **kwargs): + return "".join( + content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) + ) + + + def read_http_body_chunked( + self, + headers, + limit, + request_method, + response_code, + is_request, + max_chunk_size=None + ): + """ + Read an HTTP message body: + headers: An ODictCaseless object + limit: Size limit. + is_request: True if the body to read belongs to a request, False + otherwise + """ + if max_chunk_size is None: + max_chunk_size = limit or sys.maxsize + + expected_size = self.expected_http_body_size( + headers, is_request, request_method, response_code + ) + + if expected_size is None: + if self.has_chunked_encoding(headers): + # Python 3: yield from + for x in self._read_chunked(limit, is_request): + yield x + else: # pragma: nocover + raise HttpError( + 400 if is_request else 502, + "Content-Length unknown but no chunked encoding" + ) + elif expected_size >= 0: + if limit is not None and expected_size > limit: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s, content-length was %s" % ( + limit, expected_size + ) + ) + bytes_left = expected_size + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + yield "", self.tcp_handler.rfile.read(chunk_size), "" + bytes_left -= chunk_size + else: + bytes_left = limit or -1 + while bytes_left: + chunk_size = min(bytes_left, max_chunk_size) + content = self.tcp_handler.rfile.read(chunk_size) + if not content: + return + yield "", content, "" + bytes_left -= chunk_size + not_done = self.tcp_handler.rfile.read(1) + if not_done: + raise HttpError( + 400 if is_request else 509, + "HTTP Body too large. Limit is %s," % limit + ) + + + @classmethod + def expected_http_body_size(self, headers, is_request, request_method, response_code): + """ + Returns the expected body length: + - a positive integer, if the size is known in advance + - None, if the size in unknown in advance (chunked encoding or invalid + data) + - -1, if all data should be read until end of stream. + + May raise HttpError. + """ + # Determine response size according to + # http://tools.ietf.org/html/rfc7230#section-3.3 + if request_method: + request_method = request_method.upper() + + if (not is_request and ( + request_method == "HEAD" or + (request_method == "CONNECT" and response_code == 200) or + response_code in [204, 304] or + 100 <= response_code <= 199)): + return 0 + if self.has_chunked_encoding(headers): + return None + if "content-length" in headers: + try: + size = int(headers["content-length"][0]) + if size < 0: + raise ValueError() + return size + except ValueError: + return None + if is_request: + return 0 + return -1 + + + @classmethod + def request_preamble(self, method, resource, http_major="1", http_minor="1"): + return '%s %s HTTP/%s.%s' % ( + method, resource, http_major, http_minor + ) + + + @classmethod + def response_preamble(self, code, message=None, http_major="1", http_minor="1"): + if message is None: + message = status_codes.RESPONSES.get(code) + return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) + + + @classmethod + def has_chunked_encoding(self, headers): + return "chunked" in [ + i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") + ] + + + def _get_request_line(self): + """ + Get a line, possibly preceded by a blank. + """ + line = self.tcp_handler.rfile.readline() + if line == "\r\n" or line == "\n": + # Possible leftover from previous message + line = self.tcp_handler.rfile.readline() + return line + + + + def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -88,20 +356,13 @@ class HTTP1Protocol(object): suffix = self.tcp_handler.rfile.readline(5) if suffix != '\r\n': raise HttpError(code, "Malformed chunked body") - yield line, chunk, '\r\n' - if length == 0: - return - - - @classmethod - def has_chunked_encoding(self, headers): - return "chunked" in [ - i.lower() for i in http.get_header_tokens(headers, "transfer-encoding") - ] + yield line, chunk, '\r\n' + if length == 0: + return @classmethod - def parse_http_protocol(self, line): + def _parse_http_protocol(self, line): """ Parse an HTTP protocol declaration. Returns a (major, minor) tuple, or None. @@ -121,12 +382,12 @@ class HTTP1Protocol(object): @classmethod - def parse_init(self, line): + def _parse_init(self, line): try: method, url, protocol = string.split(line) except ValueError: return None - httpversion = self.parse_http_protocol(protocol) + httpversion = self._parse_http_protocol(protocol) if not httpversion: return None if not utils.isascii(method): @@ -135,12 +396,12 @@ class HTTP1Protocol(object): @classmethod - def parse_init_connect(self, line): + def _parse_init_connect(self, line): """ Returns (host, port, httpversion) if line is a valid CONNECT line. http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1 """ - v = self.parse_init(line) + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -163,8 +424,8 @@ class HTTP1Protocol(object): @classmethod - def parse_init_proxy(self, line): - v = self.parse_init(line) + def _parse_init_proxy(self, line): + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -177,11 +438,11 @@ class HTTP1Protocol(object): @classmethod - def parse_init_http(self, line): + def _parse_init_http(self, line): """ Returns (method, url, httpversion) """ - v = self.parse_init(line) + v = self._parse_init(line) if not v: return None method, url, httpversion = v @@ -225,263 +486,3 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) - - - def read_http_body(self, *args, **kwargs): - return "".join( - content for _, content, _ in self.read_http_body_chunked(*args, **kwargs) - ) - - - def read_http_body_chunked( - self, - headers, - limit, - request_method, - response_code, - is_request, - max_chunk_size=None - ): - """ - Read an HTTP message body: - headers: An ODictCaseless object - limit: Size limit. - is_request: True if the body to read belongs to a request, False - otherwise - """ - if max_chunk_size is None: - max_chunk_size = limit or sys.maxsize - - expected_size = self.expected_http_body_size( - headers, is_request, request_method, response_code - ) - - if expected_size is None: - if self.has_chunked_encoding(headers): - # Python 3: yield from - for x in self.read_chunked(limit, is_request): - yield x - else: # pragma: nocover - raise HttpError( - 400 if is_request else 502, - "Content-Length unknown but no chunked encoding" - ) - elif expected_size >= 0: - if limit is not None and expected_size > limit: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s, content-length was %s" % ( - limit, expected_size - ) - ) - bytes_left = expected_size - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - yield "", self.tcp_handler.rfile.read(chunk_size), "" - bytes_left -= chunk_size - else: - bytes_left = limit or -1 - while bytes_left: - chunk_size = min(bytes_left, max_chunk_size) - content = self.tcp_handler.rfile.read(chunk_size) - if not content: - return - yield "", content, "" - bytes_left -= chunk_size - not_done = self.tcp_handler.rfile.read(1) - if not_done: - raise HttpError( - 400 if is_request else 509, - "HTTP Body too large. Limit is %s," % limit - ) - - - @classmethod - def expected_http_body_size(self, headers, is_request, request_method, response_code): - """ - Returns the expected body length: - - a positive integer, if the size is known in advance - - None, if the size in unknown in advance (chunked encoding or invalid - data) - - -1, if all data should be read until end of stream. - - May raise HttpError. - """ - # Determine response size according to - # http://tools.ietf.org/html/rfc7230#section-3.3 - if request_method: - request_method = request_method.upper() - - if (not is_request and ( - request_method == "HEAD" or - (request_method == "CONNECT" and response_code == 200) or - response_code in [204, 304] or - 100 <= response_code <= 199)): - return 0 - if self.has_chunked_encoding(headers): - return None - if "content-length" in headers: - try: - size = int(headers["content-length"][0]) - if size < 0: - raise ValueError() - return size - except ValueError: - return None - if is_request: - return 0 - return -1 - - - def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): - """ - Parse an HTTP request from a file stream - - Args: - include_body (bool): Read response body as well - body_size_limit (bool): Maximum body size - wfile (file): If specified, HTTP Expect headers are handled - automatically, by writing a HTTP 100 CONTINUE response to the stream. - - Returns: - Request: The HTTP request - - Raises: - HttpError: If the input is invalid. - """ - httpversion, host, port, scheme, method, path, headers, content = ( - None, None, None, None, None, None, None, None) - - request_line = self.get_request_line() - if not request_line: - if allow_empty: - return http.EmptyRequest() - else: - raise tcp.NetLibDisconnect() - - request_line_parts = self.parse_init(request_line) - if not request_line_parts: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - method, path, httpversion = request_line_parts - - if path == '*' or path.startswith("/"): - form_in = "relative" - if not utils.isascii(path): - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - elif method.upper() == 'CONNECT': - form_in = "authority" - r = self.parse_init_connect(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - host, port, _ = r - return http.ConnectRequest(host, port) - else: - form_in = "absolute" - r = self.parse_init_proxy(request_line) - if not r: - raise HttpError( - 400, - "Bad HTTP request line: %s" % repr(request_line) - ) - _, scheme, host, port, path, _ = r - - headers = self.read_headers() - if headers is None: - raise HttpError(400, "Invalid headers") - - expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): - self.tcp_handler.wfile.write( - 'HTTP/1.1 100 Continue\r\n' - '\r\n' - ) - self.tcp_handler.wfile.flush() - del headers['expect'] - - if include_body: - content = self.read_http_body( - headers, - body_size_limit, - method, - None, - True - ) - - return http.Request( - form_in, - method, - scheme, - host, - port, - path, - httpversion, - headers, - content - ) - - - def read_response(self, request_method, body_size_limit, include_body=True): - """ - Returns an http.Response - - By default, both response header and body are read. - If include_body=False is specified, content may be one of the - following: - - None, if the response is technically allowed to have a response body - - "", if the response must not have a response body (e.g. it's a - response to a HEAD request) - """ - - line = self.tcp_handler.rfile.readline() - # Possible leftover from previous message - if line == "\r\n" or line == "\n": - line = self.tcp_handler.rfile.readline() - if not line: - raise HttpErrorConnClosed(502, "Server disconnect.") - parts = self.parse_response_line(line) - if not parts: - raise HttpError(502, "Invalid server response: %s" % repr(line)) - proto, code, msg = parts - httpversion = self.parse_http_protocol(proto) - if httpversion is None: - raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto)) - headers = self.read_headers() - if headers is None: - raise HttpError(502, "Invalid headers.") - - if include_body: - content = self.read_http_body( - headers, - body_size_limit, - request_method, - code, - False - ) - else: - # if include_body==False then a None content means the body should be - # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) - - - @classmethod - def request_preamble(self, method, resource, http_major="1", http_minor="1"): - return '%s %s HTTP/%s.%s' % ( - method, resource, http_major, http_minor - ) - - - @classmethod - def response_preamble(self, code, message=None, http_major="1", http_minor="1"): - if message is None: - message = status_codes.RESPONSES.get(code) - return 'HTTP/%s.%s %s %s' % (http_major, http_minor, code, message) diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 936fe20d..8d05b31f 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -181,29 +181,29 @@ def test_expected_http_body_size(): def test_parse_http_protocol(): - assert HTTP1Protocol.parse_http_protocol("HTTP/1.1") == (1, 1) - assert HTTP1Protocol.parse_http_protocol("HTTP/0.0") == (0, 0) - assert not HTTP1Protocol.parse_http_protocol("HTTP/a.1") - assert not HTTP1Protocol.parse_http_protocol("HTTP/1.a") - assert not HTTP1Protocol.parse_http_protocol("foo/0.0") - assert not HTTP1Protocol.parse_http_protocol("HTTP/x") + assert HTTP1Protocol._parse_http_protocol("HTTP/1.1") == (1, 1) + assert HTTP1Protocol._parse_http_protocol("HTTP/0.0") == (0, 0) + assert not HTTP1Protocol._parse_http_protocol("HTTP/a.1") + assert not HTTP1Protocol._parse_http_protocol("HTTP/1.a") + assert not HTTP1Protocol._parse_http_protocol("foo/0.0") + assert not HTTP1Protocol._parse_http_protocol("HTTP/x") def test_parse_init_connect(): - assert HTTP1Protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:444444 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("bogus") - assert not HTTP1Protocol.parse_init_connect("GET host.com:443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:443 foo/1.0") - assert not HTTP1Protocol.parse_init_connect("CONNECT host.com:foo HTTP/1.0") + assert HTTP1Protocol._parse_init_connect("CONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("C\xfeONNECT host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT \0host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:444444 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("bogus") + assert not HTTP1Protocol._parse_init_connect("GET host.com:443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com443 HTTP/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:443 foo/1.0") + assert not HTTP1Protocol._parse_init_connect("CONNECT host.com:foo HTTP/1.0") def test_parse_init_proxy(): u = "GET http://foo.com:8888/test HTTP/1.1" - m, s, h, po, pa, httpversion = HTTP1Protocol.parse_init_proxy(u) + m, s, h, po, pa, httpversion = HTTP1Protocol._parse_init_proxy(u) assert m == "GET" assert s == "http" assert h == "foo.com" @@ -212,27 +212,27 @@ def test_parse_init_proxy(): assert httpversion == (1, 1) u = "G\xfeET http://foo.com:8888/test HTTP/1.1" - assert not HTTP1Protocol.parse_init_proxy(u) + assert not HTTP1Protocol._parse_init_proxy(u) - assert not HTTP1Protocol.parse_init_proxy("invalid") - assert not HTTP1Protocol.parse_init_proxy("GET invalid HTTP/1.1") - assert not HTTP1Protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1") + assert not HTTP1Protocol._parse_init_proxy("invalid") + assert not HTTP1Protocol._parse_init_proxy("GET invalid HTTP/1.1") + assert not HTTP1Protocol._parse_init_proxy("GET http://foo.com:8888/test foo/1.1") def test_parse_init_http(): u = "GET /test HTTP/1.1" - m, u, httpversion = HTTP1Protocol.parse_init_http(u) + m, u, httpversion = HTTP1Protocol._parse_init_http(u) assert m == "GET" assert u == "/test" assert httpversion == (1, 1) u = "G\xfeET /test HTTP/1.1" - assert not HTTP1Protocol.parse_init_http(u) + assert not HTTP1Protocol._parse_init_http(u) - assert not HTTP1Protocol.parse_init_http("invalid") - assert not HTTP1Protocol.parse_init_http("GET invalid HTTP/1.1") - assert not HTTP1Protocol.parse_init_http("GET /test foo/1.1") - assert not HTTP1Protocol.parse_init_http("GET /test\xc0 HTTP/1.1") + assert not HTTP1Protocol._parse_init_http("invalid") + assert not HTTP1Protocol._parse_init_http("GET invalid HTTP/1.1") + assert not HTTP1Protocol._parse_init_http("GET /test foo/1.1") + assert not HTTP1Protocol._parse_init_http("GET /test\xc0 HTTP/1.1") class TestReadHeaders: @@ -367,8 +367,8 @@ def test_read_response(): def test_get_request_line(): data = "\nfoo" p = mock_protocol(data) - assert p.get_request_line() == "foo" - assert not p.get_request_line() + assert p._get_request_line() == "foo" + assert not p._get_request_line() class TestReadRequest(): -- cgit v1.2.3 From faf17d3d60e658d0cd1df30a10be4f11035502f8 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 20 Jul 2015 16:33:00 +0200 Subject: http2: make proper use of odict --- netlib/http/http2/protocol.py | 19 +++++++++++-------- netlib/odict.py | 2 -- test/http/http2/test_protocol.py | 8 ++++---- 3 files changed, 15 insertions(+), 14 deletions(-) diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 1dfdda21..55b5ca76 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -2,7 +2,7 @@ from __future__ import (absolute_import, print_function, division) import itertools from hpack.hpack import Encoder, Decoder -from netlib import http, utils +from netlib import http, utils, odict from . import frame @@ -189,7 +189,8 @@ class HTTP2Protocol(object): def read_response(self, *args): stream_id, headers, body = self._receive_transmission() - response = http.Response("HTTP/2", headers[':status'], "", headers, body) + status = headers[':status'][0] + response = http.Response("HTTP/2", status, "", headers, body) response.stream_id = stream_id return response @@ -197,11 +198,11 @@ class HTTP2Protocol(object): stream_id, headers, body = self._receive_transmission() form_in = "" - method = headers.get(':method', '') - scheme = headers.get(':scheme', '') - host = headers.get(':host', '') + method = headers.get(':method', [''])[0] + scheme = headers.get(':scheme', [''])[0] + host = headers.get(':host', [''])[0] port = '' # TODO: parse port number? - path = headers.get(':path', '') + path = headers.get(':path', [''])[0] request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) request.stream_id = stream_id @@ -233,15 +234,17 @@ class HTTP2Protocol(object): break # TODO: implement window update & flow - headers = {} + headers = odict.ODictCaseless() for header, value in self.decoder.decode(header_block_fragment): - headers[header] = value + headers.add(header, value) return stream_id, headers, body def create_response(self, code, stream_id=None, headers=None, body=None): if headers is None: headers = [] + if isinstance(headers, odict.ODict): + headers = headers.items() headers = [(b':status', bytes(str(code)))] + headers diff --git a/netlib/odict.py b/netlib/odict.py index ee1e6938..f52acd50 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -20,8 +20,6 @@ class ODict(object): """ def __init__(self, lst=None): - if isinstance(lst, ODict): - lst = lst.items() self.lst = lst or [] def _kconv(self, s): diff --git a/test/http/http2/test_protocol.py b/test/http/http2/test_protocol.py index 34e4ef50..d3040266 100644 --- a/test/http/http2/test_protocol.py +++ b/test/http/http2/test_protocol.py @@ -1,6 +1,6 @@ import OpenSSL -from netlib import tcp +from netlib import tcp, odict from netlib.http import http2 from netlib.http.http2.frame import * from ... import tutils, tservers @@ -256,7 +256,7 @@ class TestReadResponse(tservers.ServerTestBase): assert resp.httpversion == "HTTP/2" assert resp.status_code == "200" assert resp.msg == "" - assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'foobar' @@ -282,7 +282,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase): assert resp.httpversion == "HTTP/2" assert resp.status_code == "200" assert resp.msg == "" - assert resp.headers == {':status': '200', 'etag': 'foobar'} + assert resp.headers.lst == [[':status', '200'], ['etag', 'foobar']] assert resp.body == b'' @@ -307,7 +307,7 @@ class TestReadRequest(tservers.ServerTestBase): resp = protocol.read_request() assert resp.stream_id - assert resp.headers == {':method': 'GET', ':path': '/', ':scheme': 'https'} + assert resp.headers.lst == [[u':method', u'GET'], [u':path', u'/'], [u':scheme', u'https']] assert resp.body == b'foobar' -- cgit v1.2.3 From 657973eca3b091cdf07a65f8363affd3d36f0d0f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 22 Jul 2015 13:01:24 +0200 Subject: fix bugs --- netlib/http/http1/protocol.py | 26 +++++++++++++++++--------- netlib/http/semantics.py | 28 +++++++++++----------------- test/http/http1/test_protocol.py | 9 +++------ 3 files changed, 31 insertions(+), 32 deletions(-) diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e7727e00..e46ad7ab 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -9,10 +9,18 @@ from netlib import odict, utils, tcp, http from .. import status_codes from ..exceptions import * +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + class HTTP1Protocol(object): - def __init__(self, tcp_handler): - self.tcp_handler = tcp_handler + def __init__(self, tcp_handler=None, rfile=None, wfile=None): + if tcp_handler: + self.tcp_handler = tcp_handler + else: + self.tcp_handler = TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -31,7 +39,7 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ - httpversion, host, port, scheme, method, path, headers, content = ( + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) request_line = self._get_request_line() @@ -56,7 +64,7 @@ class HTTP1Protocol(object): 400, "Bad HTTP request line: %s" % repr(request_line) ) - elif method.upper() == 'CONNECT': + elif method == 'CONNECT': form_in = "authority" r = self._parse_init_connect(request_line) if not r: @@ -64,8 +72,8 @@ class HTTP1Protocol(object): 400, "Bad HTTP request line: %s" % repr(request_line) ) - host, port, _ = r - return http.ConnectRequest(host, port) + host, port, httpversion = r + path = None else: form_in = "absolute" r = self._parse_init_proxy(request_line) @@ -81,7 +89,7 @@ class HTTP1Protocol(object): raise HttpError(400, "Invalid headers") expect_header = headers.get_first("expect", "").lower() - if expect_header == "100-continue" and httpversion >= (1, 1): + if expect_header == "100-continue" and httpversion == (1, 1): self.tcp_handler.wfile.write( 'HTTP/1.1 100 Continue\r\n' '\r\n' @@ -90,7 +98,7 @@ class HTTP1Protocol(object): del headers['expect'] if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, method, @@ -107,7 +115,7 @@ class HTTP1Protocol(object): path, httpversion, headers, - content + body ) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 355906dd..9e13edaa 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -5,7 +5,7 @@ import string import sys import urlparse -from .. import utils +from .. import utils, odict class Request(object): @@ -37,6 +37,10 @@ class Request(object): def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) + @property + def content(self): + return self.body + class EmptyRequest(Request): def __init__(self): @@ -47,22 +51,8 @@ class EmptyRequest(Request): host="", port="", path="", - httpversion="", - headers="", - body="", - ) - -class ConnectRequest(Request): - def __init__(self, host, port): - super(ConnectRequest, self).__init__( - form_in="authority", - method="CONNECT", - scheme="", - host=host, - port=port, - path="", - httpversion="", - headers="", + httpversion=(0, 0), + headers=odict.ODictCaseless(), body="", ) @@ -91,6 +81,10 @@ class Response(object): def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) + @property + def content(self): + return self.body + def is_valid_port(port): if not 0 <= port <= 65535: diff --git a/test/http/http1/test_protocol.py b/test/http/http1/test_protocol.py index 8d05b31f..dcebbd5e 100644 --- a/test/http/http1/test_protocol.py +++ b/test/http/http1/test_protocol.py @@ -8,12 +8,9 @@ from ... import tutils, tservers def mock_protocol(data='', chunked=False): - class TCPHandlerMock(object): - pass - tcp_handler = TCPHandlerMock() - tcp_handler.rfile = cStringIO.StringIO(data) - tcp_handler.wfile = cStringIO.StringIO() - return HTTP1Protocol(tcp_handler) + rfile = cStringIO.StringIO(data) + wfile = cStringIO.StringIO() + return HTTP1Protocol(rfile=rfile, wfile=wfile) -- cgit v1.2.3