From 827fe824d97d96779512c8a4032d9b30d516d63f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Mon, 27 Jul 2015 09:36:50 +0200 Subject: move code from mitmproxy to netlib --- netlib/http/http1/protocol.py | 52 +++++++++++++++++++----- netlib/http/http2/protocol.py | 92 ++++++++++++++++++++++++++++++++++--------- netlib/http/semantics.py | 49 ++++++++++++++++++++++- 3 files changed, 163 insertions(+), 30 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index e46ad7ab..af9882e8 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -4,6 +4,7 @@ import collections import string import sys import urlparse +import time from netlib import odict, utils, tcp, http from .. import status_codes @@ -17,10 +18,7 @@ class TCPHandler(object): class HTTP1Protocol(object): def __init__(self, tcp_handler=None, rfile=None, wfile=None): - if tcp_handler: - self.tcp_handler = tcp_handler - else: - self.tcp_handler = TCPHandler(rfile, wfile) + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): @@ -39,6 +37,10 @@ class HTTP1Protocol(object): Raises: HttpError: If the input is invalid. """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + httpversion, host, port, scheme, method, path, headers, body = ( None, None, None, None, None, None, None, None) @@ -106,6 +108,12 @@ class HTTP1Protocol(object): True ) + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + return http.Request( form_in, method, @@ -115,7 +123,9 @@ class HTTP1Protocol(object): path, httpversion, headers, - body + body, + timestamp_start, + timestamp_end, ) @@ -124,12 +134,15 @@ class HTTP1Protocol(object): Returns an http.Response By default, both response header and body are read. - If include_body=False is specified, content may be one of the + If include_body=False is specified, body may be one of the following: - None, if the response is technically allowed to have a response body - "", if the response must not have a response body (e.g. it's a response to a HEAD request) """ + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() line = self.tcp_handler.rfile.readline() # Possible leftover from previous message @@ -149,7 +162,7 @@ class HTTP1Protocol(object): raise HttpError(502, "Invalid headers.") if include_body: - content = self.read_http_body( + body = self.read_http_body( headers, body_size_limit, request_method, @@ -157,10 +170,29 @@ class HTTP1Protocol(object): False ) else: - # if include_body==False then a None content means the body should be + # if include_body==False then a None body means the body should be # read separately - content = None - return http.Response(httpversion, code, msg, headers, content) + body = None + + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + return http.Response( + httpversion, + code, + msg, + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) def read_headers(self): diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 55b5ca76..41321fdc 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -1,11 +1,18 @@ from __future__ import (absolute_import, print_function, division) import itertools +import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict from . import frame +class TCPHandler(object): + def __init__(self, rfile, wfile=None): + self.rfile = rfile + self.wfile = wfile + + class HTTP2Protocol(object): ERROR_CODES = utils.BiDi( @@ -31,16 +38,26 @@ class HTTP2Protocol(object): ALPN_PROTO_H2 = 'h2' - def __init__(self, tcp_handler, is_server=False, dump_frames=False): - self.tcp_handler = tcp_handler + + def __init__( + self, + tcp_handler=None, + rfile=None, + wfile=None, + is_server=False, + dump_frames=False, + encoder=None, + decoder=None, + ): + self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) self.is_server = is_server + self.dump_frames = dump_frames + self.encoder = encoder or Encoder() + self.decoder = decoder or Decoder() self.http2_settings = frame.HTTP2_DEFAULT_SETTINGS.copy() self.current_stream_id = None - self.encoder = Encoder() - self.decoder = Decoder() self.connection_preface_performed = False - self.dump_frames = dump_frames def check_alpn(self): alp = self.tcp_handler.get_alpn_proto_negotiated() @@ -186,29 +203,68 @@ class HTTP2Protocol(object): self._create_headers(headers, stream_id, end_stream=(body is None)), self._create_body(body, stream_id))) - def read_response(self, *args): - stream_id, headers, body = self._receive_transmission() + def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - status = headers[':status'][0] - response = http.Response("HTTP/2", status, "", headers, body) + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + headers[':status'][0], + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) response.stream_id = stream_id + return response - def read_request(self): - stream_id, headers, body = self._receive_transmission() + def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() - form_in = "" - method = headers.get(':method', [''])[0] - scheme = headers.get(':scheme', [''])[0] - host = headers.get(':host', [''])[0] port = '' # TODO: parse port number? - path = headers.get(':path', [''])[0] - request = http.Request(form_in, method, scheme, host, port, path, "HTTP/2", headers, body) + request = http.Request( + "", + headers.get_first(':method', ['']), + headers.get_first(':scheme', ['']), + headers.get_first(':host', ['']), + port, + headers.get_first(':path', ['']), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) request.stream_id = stream_id + return request - def _receive_transmission(self): + def _receive_transmission(self, include_body=True): body_expected = True stream_id = 0 diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 9e13edaa..63b6beb9 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -20,7 +20,11 @@ class Request(object): httpversion, headers, body, + timestamp_start=None, + timestamp_end=None, ): + assert isinstance(headers, odict.ODictCaseless) or not headers + self.form_in = form_in self.method = method self.scheme = scheme @@ -30,17 +34,30 @@ class Request(object): self.httpversion = httpversion self.headers = headers self.body = body + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Request(%s - %s, %s)" % (self.method, self.host, self.path) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + class EmptyRequest(Request): def __init__(self): @@ -67,24 +84,52 @@ class Response(object): headers, body, sslinfo=None, + timestamp_start=None, + timestamp_end=None, ): + assert isinstance(headers, odict.ODictCaseless) or not headers + self.httpversion = httpversion self.status_code = status_code self.msg = msg self.headers = headers self.body = body self.sslinfo = sslinfo + self.timestamp_start = timestamp_start + self.timestamp_end = timestamp_end def __eq__(self, other): - return self.__dict__ == other.__dict__ + try: + self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + other_d = [other.__dict__[k] for k in other.__dict__ if k not in ('timestamp_start', 'timestamp_end')] + return self_d == other_d + except: + return False def __repr__(self): return "Response(%s - %s)" % (self.status_code, self.msg) @property def content(self): + # TODO: remove deprecated getter return self.body + @content.setter + def content(self, content): + # TODO: remove deprecated setter + self.body = content + + @property + def code(self): + # TODO: remove deprecated getter + return self.status_code + + @code.setter + def code(self, code): + # TODO: remove deprecated setter + self.status_code = code + + def is_valid_port(port): if not 0 <= port <= 65535: -- cgit v1.2.3 From c7fcc2cca5ff85641febbb908d11d22336bbd81c Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Wed, 29 Jul 2015 11:27:43 +0200 Subject: add on-the-wire representation methods --- netlib/http/http1/protocol.py | 101 +++++++++++++++- netlib/http/http2/protocol.py | 261 +++++++++++++++++++++--------------------- netlib/http/semantics.py | 46 ++++++-- netlib/utils.py | 10 ++ 4 files changed, 279 insertions(+), 139 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py index af9882e8..b098110a 100644 --- a/netlib/http/http1/protocol.py +++ b/netlib/http/http1/protocol.py @@ -7,6 +7,7 @@ import urlparse import time from netlib import odict, utils, tcp, http +from netlib.http import semantics from .. import status_codes from ..exceptions import * @@ -15,7 +16,7 @@ class TCPHandler(object): self.rfile = rfile self.wfile = wfile -class HTTP1Protocol(object): +class HTTP1Protocol(semantics.ProtocolMixin): def __init__(self, tcp_handler=None, rfile=None, wfile=None): self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile) @@ -195,6 +196,32 @@ class HTTP1Protocol(object): ) + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + if request.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_request_first_line(request) + headers = self._assemble_request_headers(request) + return "%s\r\n%s\r\n%s" % (first_line, headers, request.body) + + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + if response.body == semantics.CONTENT_MISSING: + raise http.HttpError( + 502, + "Cannot assemble flow with CONTENT_MISSING" + ) + first_line = self._assemble_response_first_line(response) + headers = self._assemble_response_headers(response) + return "%s\r\n%s\r\n%s" % (first_line, headers, response.body) + + def read_headers(self): """ Read a set of headers. @@ -363,7 +390,6 @@ class HTTP1Protocol(object): return line - def _read_chunked(self, limit, is_request): """ Read a chunked HTTP body. @@ -526,3 +552,74 @@ class HTTP1Protocol(object): except ValueError: return None return (proto, code, msg) + + + @classmethod + def _assemble_request_first_line(self, request): + if request.form_in == "relative": + request_line = '%s %s HTTP/%s.%s' % ( + request.method, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "authority": + request_line = '%s %s:%s HTTP/%s.%s' % ( + request.method, + request.host, + request.port, + request.httpversion[0], + request.httpversion[1], + ) + elif request.form_in == "absolute": + request_line = '%s %s://%s:%s%s HTTP/%s.%s' % ( + request.method, + request.scheme, + request.host, + request.port, + request.path, + request.httpversion[0], + request.httpversion[1], + ) + else: + raise http.HttpError(400, "Invalid request form") + return request_line + + def _assemble_request_headers(self, request): + headers = request.headers.copy() + for k in request._headers_to_strip_off: + del headers[k] + if 'host' not in headers and request.scheme and request.host and request.port: + headers["Host"] = [utils.hostport(request.scheme, + request.host, + request.port)] + + # If content is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if request.body or request.body == "": + headers["Content-Length"] = [str(len(request.body))] + + return headers.format() + + + def _assemble_response_first_line(self, response): + return 'HTTP/%s.%s %s %s' % ( + response.httpversion[0], + response.httpversion[1], + response.status_code, + response.msg, + ) + + def _assemble_response_headers(self, response, preserve_transfer_encoding=False): + headers = response.headers.copy() + for k in response._headers_to_strip_off: + del headers[k] + if not preserve_transfer_encoding: + del headers['Transfer-Encoding'] + + # If body is defined (i.e. not None or CONTENT_MISSING), we always + # add a content-length header. + if response.body or response.body == "": + headers["Content-Length"] = [str(len(response.body))] + + return headers.format() diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 41321fdc..618476e2 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -4,6 +4,7 @@ import time from hpack.hpack import Encoder, Decoder from netlib import http, utils, odict +from netlib.http import semantics from . import frame @@ -13,7 +14,7 @@ class TCPHandler(object): self.wfile = wfile -class HTTP2Protocol(object): +class HTTP2Protocol(semantics.ProtocolMixin): ERROR_CODES = utils.BiDi( NO_ERROR=0x0, @@ -59,26 +60,104 @@ class HTTP2Protocol(object): self.current_stream_id = None self.connection_preface_performed = False - def check_alpn(self): - alp = self.tcp_handler.get_alpn_proto_negotiated() - if alp != self.ALPN_PROTO_H2: - raise NotImplementedError( - "HTTP2Protocol can not handle unknown ALP: %s" % alp) - return True + def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() - def _receive_settings(self, hide=False): - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - break + stream_id, headers, body = self._receive_transmission(include_body) - def _read_settings_ack(self, hide=False): # pragma no cover - while True: - frm = self.read_frame(hide) - if isinstance(frm, frame.SettingsFrame): - assert frm.flags & frame.Frame.FLAG_ACK - assert len(frm.settings) == 0 - break + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + timestamp_end = time.time() + + port = '' # TODO: parse port number? + + request = http.Request( + "", + headers.get_first(':method', ['']), + headers.get_first(':scheme', ['']), + headers.get_first(':host', ['']), + port, + headers.get_first(':path', ['']), + (2, 0), + headers, + body, + timestamp_start, + timestamp_end, + ) + request.stream_id = stream_id + + return request + + def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + timestamp_start = time.time() + if hasattr(self.tcp_handler.rfile, "reset_timestamps"): + self.tcp_handler.rfile.reset_timestamps() + + stream_id, headers, body = self._receive_transmission(include_body) + + if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): + # more accurate timestamp_start + timestamp_start = self.tcp_handler.rfile.first_byte_timestamp + + if include_body: + timestamp_end = time.time() + else: + timestamp_end = None + + response = http.Response( + (2, 0), + headers[':status'][0], + "", + headers, + body, + timestamp_start=timestamp_start, + timestamp_end=timestamp_end, + ) + response.stream_id = stream_id + + return response + + def assemble_request(self, request): + assert isinstance(request, semantics.Request) + + authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host + if self.tcp_handler.address.port != 443: + authority += ":%d" % self.tcp_handler.address.port + + headers = [ + (b':method', bytes(request.method)), + (b':path', bytes(request.path)), + (b':scheme', b'https'), + (b':authority', authority), + ] + request.headers.items() + + if hasattr(request, 'stream_id'): + stream_id = request.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(request.body is None)), + self._create_body(request.body, stream_id))) + + def assemble_response(self, response): + assert isinstance(response, semantics.Response) + + headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items() + + if hasattr(response, 'stream_id'): + stream_id = response.stream_id + else: + stream_id = self._next_stream_id() + + return list(itertools.chain( + self._create_headers(headers, stream_id, end_stream=(response.body is None)), + self._create_body(response.body, stream_id), + )) def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: @@ -100,18 +179,6 @@ class HTTP2Protocol(object): self.send_frame(frame.SettingsFrame(state=self), hide=True) self._receive_settings(hide=True) - def next_stream_id(self): - if self.current_stream_id is None: - if self.is_server: - # servers must use even stream ids - self.current_stream_id = 2 - else: - # clients must use odd stream ids - self.current_stream_id = 1 - else: - self.current_stream_id += 2 - return self.current_stream_id - def send_frame(self, frm, hide=False): raw_bytes = frm.to_bytes() self.tcp_handler.wfile.write(raw_bytes) @@ -128,6 +195,39 @@ class HTTP2Protocol(object): return frm + def check_alpn(self): + alp = self.tcp_handler.get_alpn_proto_negotiated() + if alp != self.ALPN_PROTO_H2: + raise NotImplementedError( + "HTTP2Protocol can not handle unknown ALP: %s" % alp) + return True + + def _receive_settings(self, hide=False): + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + break + + def _read_settings_ack(self, hide=False): # pragma no cover + while True: + frm = self.read_frame(hide) + if isinstance(frm, frame.SettingsFrame): + assert frm.flags & frame.Frame.FLAG_ACK + assert len(frm.settings) == 0 + break + + def _next_stream_id(self): + if self.current_stream_id is None: + if self.is_server: + # servers must use even stream ids + self.current_stream_id = 2 + else: + # clients must use odd stream ids + self.current_stream_id = 1 + else: + self.current_stream_id += 2 + return self.current_stream_id + def _apply_settings(self, settings, hide=False): for setting, value in settings.items(): old_value = self.http2_settings[setting] @@ -181,89 +281,6 @@ class HTTP2Protocol(object): return [frm.to_bytes()] - - def create_request(self, method, path, headers=None, body=None): - if headers is None: - headers = [] - - authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host - if self.tcp_handler.address.port != 443: - authority += ":%d" % self.tcp_handler.address.port - - headers = [ - (b':method', bytes(method)), - (b':path', bytes(path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + headers - - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id))) - - def read_response(self, request_method_='', body_size_limit_=None, include_body=True): - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission(include_body) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - if include_body: - timestamp_end = time.time() - else: - timestamp_end = None - - response = http.Response( - (2, 0), - headers[':status'][0], - "", - headers, - body, - timestamp_start=timestamp_start, - timestamp_end=timestamp_end, - ) - response.stream_id = stream_id - - return response - - def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): - timestamp_start = time.time() - if hasattr(self.tcp_handler.rfile, "reset_timestamps"): - self.tcp_handler.rfile.reset_timestamps() - - stream_id, headers, body = self._receive_transmission(include_body) - - if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"): - # more accurate timestamp_start - timestamp_start = self.tcp_handler.rfile.first_byte_timestamp - - timestamp_end = time.time() - - port = '' # TODO: parse port number? - - request = http.Request( - "", - headers.get_first(':method', ['']), - headers.get_first(':scheme', ['']), - headers.get_first(':host', ['']), - port, - headers.get_first(':path', ['']), - (2, 0), - headers, - body, - timestamp_start, - timestamp_end, - ) - request.stream_id = stream_id - - return request - def _receive_transmission(self, include_body=True): body_expected = True @@ -295,19 +312,3 @@ class HTTP2Protocol(object): headers.add(header, value) return stream_id, headers, body - - def create_response(self, code, stream_id=None, headers=None, body=None): - if headers is None: - headers = [] - if isinstance(headers, odict.ODict): - headers = headers.items() - - headers = [(b':status', bytes(str(code)))] + headers - - if not stream_id: - stream_id = self.next_stream_id() - - return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(body is None)), - self._create_body(body, stream_id), - )) diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py index 63b6beb9..54bf83d2 100644 --- a/netlib/http/semantics.py +++ b/netlib/http/semantics.py @@ -7,6 +7,32 @@ import urlparse from .. import utils, odict +CONTENT_MISSING = 0 + + +class ProtocolMixin(object): + + def read_request(self): + raise NotImplemented + + def read_response(self): + raise NotImplemented + + def assemble(self, message): + if isinstance(message, Request): + return self.assemble_request(message) + elif isinstance(message, Response): + return self.assemble_response(message) + else: + raise ValueError("HTTP message not supported.") + + def assemble_request(self, request): + raise NotImplemented + + def assemble_response(self, response): + raise NotImplemented + + class Request(object): def __init__( @@ -18,12 +44,14 @@ class Request(object): port, path, httpversion, - headers, - body, + headers=None, + body=None, timestamp_start=None, timestamp_end=None, ): - assert isinstance(headers, odict.ODictCaseless) or not headers + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) self.form_in = form_in self.method = method @@ -37,6 +65,7 @@ class Request(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] @@ -80,14 +109,16 @@ class Response(object): self, httpversion, status_code, - msg, - headers, - body, + msg=None, + headers=None, + body=None, sslinfo=None, timestamp_start=None, timestamp_end=None, ): - assert isinstance(headers, odict.ODictCaseless) or not headers + if not headers: + headers = odict.ODictCaseless() + assert isinstance(headers, odict.ODictCaseless) self.httpversion = httpversion self.status_code = status_code @@ -98,6 +129,7 @@ class Response(object): self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end + def __eq__(self, other): try: self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')] diff --git a/netlib/utils.py b/netlib/utils.py index bee412f9..86e33f33 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -129,3 +129,13 @@ class Data(object): if not os.path.exists(fullpath): raise ValueError("dataPath: %s does not exist." % fullpath) return fullpath + + +def hostport(scheme, host, port): + """ + Returns the host component, with a port specifcation if needed. + """ + if (port, scheme) in [(80, "http"), (443, "https")]: + return host + else: + return "%s:%s" % (host, port) -- cgit v1.2.3 From 7b10817670b30550dd45af48491ed8cf3cacd5e6 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 30 Jul 2015 13:52:13 +0200 Subject: http2: improve protocol --- netlib/http/http2/protocol.py | 61 +++++++++++++++++++++++++++++-------------- netlib/odict.py | 7 +++-- 2 files changed, 46 insertions(+), 22 deletions(-) (limited to 'netlib') diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py index 618476e2..a1ca4a18 100644 --- a/netlib/http/http2/protocol.py +++ b/netlib/http/http2/protocol.py @@ -60,7 +60,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): self.current_stream_id = None self.connection_preface_performed = False - def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False): + def read_request(self, include_body=True, body_size_limit=None, allow_empty=False): + self.perform_connection_preface() + timestamp_start = time.time() if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() @@ -73,15 +75,13 @@ class HTTP2Protocol(semantics.ProtocolMixin): timestamp_end = time.time() - port = '' # TODO: parse port number? - request = http.Request( - "", - headers.get_first(':method', ['']), - headers.get_first(':scheme', ['']), - headers.get_first(':host', ['']), - port, - headers.get_first(':path', ['']), + "relative", # TODO: use the correct value + headers.get_first(':method', 'GET'), + headers.get_first(':scheme', 'https'), + headers.get_first(':host', 'localhost'), + 443, # TODO: parse port number from host? + headers.get_first(':path', '/'), (2, 0), headers, body, @@ -92,7 +92,9 @@ class HTTP2Protocol(semantics.ProtocolMixin): return request - def read_response(self, request_method_='', body_size_limit_=None, include_body=True): + def read_response(self, request_method='', body_size_limit=None, include_body=True): + self.perform_connection_preface() + timestamp_start = time.time() if hasattr(self.tcp_handler.rfile, "reset_timestamps"): self.tcp_handler.rfile.reset_timestamps() @@ -110,7 +112,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): response = http.Response( (2, 0), - headers[':status'][0], + int(headers.get_first(':status')), "", headers, body, @@ -121,6 +123,7 @@ class HTTP2Protocol(semantics.ProtocolMixin): return response + def assemble_request(self, request): assert isinstance(request, semantics.Request) @@ -128,12 +131,18 @@ class HTTP2Protocol(semantics.ProtocolMixin): if self.tcp_handler.address.port != 443: authority += ":%d" % self.tcp_handler.address.port - headers = [ - (b':method', bytes(request.method)), - (b':path', bytes(request.path)), - (b':scheme', b'https'), - (b':authority', authority), - ] + request.headers.items() + headers = request.headers.copy() + + if not ':authority' in headers.keys(): + headers.add(':authority', bytes(authority), prepend=True) + if not ':scheme' in headers.keys(): + headers.add(':scheme', bytes(request.scheme), prepend=True) + if not ':path' in headers.keys(): + headers.add(':path', bytes(request.path), prepend=True) + if not ':method' in headers.keys(): + headers.add(':method', bytes(request.method), prepend=True) + + headers = headers.items() if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -141,13 +150,18 @@ class HTTP2Protocol(semantics.ProtocolMixin): stream_id = self._next_stream_id() return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(request.body is None)), + self._create_headers(headers, stream_id, end_stream=(request.body is None or len(request.body) == 0)), self._create_body(request.body, stream_id))) def assemble_response(self, response): assert isinstance(response, semantics.Response) - headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items() + headers = response.headers.copy() + + if not ':status' in headers.keys(): + headers.add(':status', bytes(str(response.status_code)), prepend=True) + + headers = headers.items() if hasattr(response, 'stream_id'): stream_id = response.stream_id @@ -155,10 +169,17 @@ class HTTP2Protocol(semantics.ProtocolMixin): stream_id = self._next_stream_id() return list(itertools.chain( - self._create_headers(headers, stream_id, end_stream=(response.body is None)), + self._create_headers(headers, stream_id, end_stream=(response.body is None or len(response.body) == 0)), self._create_body(response.body, stream_id), )) + def perform_connection_preface(self, force=False): + if force or not self.connection_preface_performed: + if self.is_server: + self.perform_server_connection_preface(force) + else: + self.perform_client_connection_preface(force) + def perform_server_connection_preface(self, force=False): if force or not self.connection_preface_performed: self.connection_preface_performed = True diff --git a/netlib/odict.py b/netlib/odict.py index f52acd50..d02de08d 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -96,8 +96,11 @@ class ODict(object): return True return False - def add(self, key, value): - self.lst.append([key, value]) + def add(self, key, value, prepend=False): + if prepend: + self.lst.insert(0, [key, value]) + else: + self.lst.append([key, value]) def get(self, k, d=None): if k in self: -- cgit v1.2.3