aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2015-07-29 11:27:43 +0200
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2015-07-29 11:27:43 +0200
commitc7fcc2cca5ff85641febbb908d11d22336bbd81c (patch)
tree5344d505a5a4c771702c6ef5689a04268bb6b30d /netlib
parent827fe824d97d96779512c8a4032d9b30d516d63f (diff)
downloadmitmproxy-c7fcc2cca5ff85641febbb908d11d22336bbd81c.tar.gz
mitmproxy-c7fcc2cca5ff85641febbb908d11d22336bbd81c.tar.bz2
mitmproxy-c7fcc2cca5ff85641febbb908d11d22336bbd81c.zip
add on-the-wire representation methods
Diffstat (limited to 'netlib')
-rw-r--r--netlib/http/http1/protocol.py101
-rw-r--r--netlib/http/http2/protocol.py261
-rw-r--r--netlib/http/semantics.py46
-rw-r--r--netlib/utils.py10
4 files changed, 279 insertions, 139 deletions
diff --git a/netlib/http/http1/protocol.py b/netlib/http/http1/protocol.py
index af9882e8..b098110a 100644
--- a/netlib/http/http1/protocol.py
+++ b/netlib/http/http1/protocol.py
@@ -7,6 +7,7 @@ import urlparse
import time
from netlib import odict, utils, tcp, http
+from netlib.http import semantics
from .. import status_codes
from ..exceptions import *
@@ -15,7 +16,7 @@ class TCPHandler(object):
self.rfile = rfile
self.wfile = wfile
-class HTTP1Protocol(object):
+class HTTP1Protocol(semantics.ProtocolMixin):
def __init__(self, tcp_handler=None, rfile=None, wfile=None):
self.tcp_handler = tcp_handler or TCPHandler(rfile, wfile)
@@ -195,6 +196,32 @@ class HTTP1Protocol(object):
)
+ def assemble_request(self, request):
+ assert isinstance(request, semantics.Request)
+
+ if request.body == semantics.CONTENT_MISSING:
+ raise http.HttpError(
+ 502,
+ "Cannot assemble flow with CONTENT_MISSING"
+ )
+ first_line = self._assemble_request_first_line(request)
+ headers = self._assemble_request_headers(request)
+ return "%s\r\n%s\r\n%s" % (first_line, headers, request.body)
+
+
+ def assemble_response(self, response):
+ assert isinstance(response, semantics.Response)
+
+ if response.body == semantics.CONTENT_MISSING:
+ raise http.HttpError(
+ 502,
+ "Cannot assemble flow with CONTENT_MISSING"
+ )
+ first_line = self._assemble_response_first_line(response)
+ headers = self._assemble_response_headers(response)
+ return "%s\r\n%s\r\n%s" % (first_line, headers, response.body)
+
+
def read_headers(self):
"""
Read a set of headers.
@@ -363,7 +390,6 @@ class HTTP1Protocol(object):
return line
-
def _read_chunked(self, limit, is_request):
"""
Read a chunked HTTP body.
@@ -526,3 +552,74 @@ class HTTP1Protocol(object):
except ValueError:
return None
return (proto, code, msg)
+
+
+ @classmethod
+ def _assemble_request_first_line(self, request):
+ if request.form_in == "relative":
+ request_line = '%s %s HTTP/%s.%s' % (
+ request.method,
+ request.path,
+ request.httpversion[0],
+ request.httpversion[1],
+ )
+ elif request.form_in == "authority":
+ request_line = '%s %s:%s HTTP/%s.%s' % (
+ request.method,
+ request.host,
+ request.port,
+ request.httpversion[0],
+ request.httpversion[1],
+ )
+ elif request.form_in == "absolute":
+ request_line = '%s %s://%s:%s%s HTTP/%s.%s' % (
+ request.method,
+ request.scheme,
+ request.host,
+ request.port,
+ request.path,
+ request.httpversion[0],
+ request.httpversion[1],
+ )
+ else:
+ raise http.HttpError(400, "Invalid request form")
+ return request_line
+
+ def _assemble_request_headers(self, request):
+ headers = request.headers.copy()
+ for k in request._headers_to_strip_off:
+ del headers[k]
+ if 'host' not in headers and request.scheme and request.host and request.port:
+ headers["Host"] = [utils.hostport(request.scheme,
+ request.host,
+ request.port)]
+
+ # If content is defined (i.e. not None or CONTENT_MISSING), we always
+ # add a content-length header.
+ if request.body or request.body == "":
+ headers["Content-Length"] = [str(len(request.body))]
+
+ return headers.format()
+
+
+ def _assemble_response_first_line(self, response):
+ return 'HTTP/%s.%s %s %s' % (
+ response.httpversion[0],
+ response.httpversion[1],
+ response.status_code,
+ response.msg,
+ )
+
+ def _assemble_response_headers(self, response, preserve_transfer_encoding=False):
+ headers = response.headers.copy()
+ for k in response._headers_to_strip_off:
+ del headers[k]
+ if not preserve_transfer_encoding:
+ del headers['Transfer-Encoding']
+
+ # If body is defined (i.e. not None or CONTENT_MISSING), we always
+ # add a content-length header.
+ if response.body or response.body == "":
+ headers["Content-Length"] = [str(len(response.body))]
+
+ return headers.format()
diff --git a/netlib/http/http2/protocol.py b/netlib/http/http2/protocol.py
index 41321fdc..618476e2 100644
--- a/netlib/http/http2/protocol.py
+++ b/netlib/http/http2/protocol.py
@@ -4,6 +4,7 @@ import time
from hpack.hpack import Encoder, Decoder
from netlib import http, utils, odict
+from netlib.http import semantics
from . import frame
@@ -13,7 +14,7 @@ class TCPHandler(object):
self.wfile = wfile
-class HTTP2Protocol(object):
+class HTTP2Protocol(semantics.ProtocolMixin):
ERROR_CODES = utils.BiDi(
NO_ERROR=0x0,
@@ -59,26 +60,104 @@ class HTTP2Protocol(object):
self.current_stream_id = None
self.connection_preface_performed = False
- def check_alpn(self):
- alp = self.tcp_handler.get_alpn_proto_negotiated()
- if alp != self.ALPN_PROTO_H2:
- raise NotImplementedError(
- "HTTP2Protocol can not handle unknown ALP: %s" % alp)
- return True
+ def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
+ timestamp_start = time.time()
+ if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
+ self.tcp_handler.rfile.reset_timestamps()
- def _receive_settings(self, hide=False):
- while True:
- frm = self.read_frame(hide)
- if isinstance(frm, frame.SettingsFrame):
- break
+ stream_id, headers, body = self._receive_transmission(include_body)
- def _read_settings_ack(self, hide=False): # pragma no cover
- while True:
- frm = self.read_frame(hide)
- if isinstance(frm, frame.SettingsFrame):
- assert frm.flags & frame.Frame.FLAG_ACK
- assert len(frm.settings) == 0
- break
+ if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
+ # more accurate timestamp_start
+ timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
+
+ timestamp_end = time.time()
+
+ port = '' # TODO: parse port number?
+
+ request = http.Request(
+ "",
+ headers.get_first(':method', ['']),
+ headers.get_first(':scheme', ['']),
+ headers.get_first(':host', ['']),
+ port,
+ headers.get_first(':path', ['']),
+ (2, 0),
+ headers,
+ body,
+ timestamp_start,
+ timestamp_end,
+ )
+ request.stream_id = stream_id
+
+ return request
+
+ def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
+ timestamp_start = time.time()
+ if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
+ self.tcp_handler.rfile.reset_timestamps()
+
+ stream_id, headers, body = self._receive_transmission(include_body)
+
+ if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
+ # more accurate timestamp_start
+ timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
+
+ if include_body:
+ timestamp_end = time.time()
+ else:
+ timestamp_end = None
+
+ response = http.Response(
+ (2, 0),
+ headers[':status'][0],
+ "",
+ headers,
+ body,
+ timestamp_start=timestamp_start,
+ timestamp_end=timestamp_end,
+ )
+ response.stream_id = stream_id
+
+ return response
+
+ def assemble_request(self, request):
+ assert isinstance(request, semantics.Request)
+
+ authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
+ if self.tcp_handler.address.port != 443:
+ authority += ":%d" % self.tcp_handler.address.port
+
+ headers = [
+ (b':method', bytes(request.method)),
+ (b':path', bytes(request.path)),
+ (b':scheme', b'https'),
+ (b':authority', authority),
+ ] + request.headers.items()
+
+ if hasattr(request, 'stream_id'):
+ stream_id = request.stream_id
+ else:
+ stream_id = self._next_stream_id()
+
+ return list(itertools.chain(
+ self._create_headers(headers, stream_id, end_stream=(request.body is None)),
+ self._create_body(request.body, stream_id)))
+
+ def assemble_response(self, response):
+ assert isinstance(response, semantics.Response)
+
+ headers = [(b':status', bytes(str(response.status_code)))] + response.headers.items()
+
+ if hasattr(response, 'stream_id'):
+ stream_id = response.stream_id
+ else:
+ stream_id = self._next_stream_id()
+
+ return list(itertools.chain(
+ self._create_headers(headers, stream_id, end_stream=(response.body is None)),
+ self._create_body(response.body, stream_id),
+ ))
def perform_server_connection_preface(self, force=False):
if force or not self.connection_preface_performed:
@@ -100,18 +179,6 @@ class HTTP2Protocol(object):
self.send_frame(frame.SettingsFrame(state=self), hide=True)
self._receive_settings(hide=True)
- def next_stream_id(self):
- if self.current_stream_id is None:
- if self.is_server:
- # servers must use even stream ids
- self.current_stream_id = 2
- else:
- # clients must use odd stream ids
- self.current_stream_id = 1
- else:
- self.current_stream_id += 2
- return self.current_stream_id
-
def send_frame(self, frm, hide=False):
raw_bytes = frm.to_bytes()
self.tcp_handler.wfile.write(raw_bytes)
@@ -128,6 +195,39 @@ class HTTP2Protocol(object):
return frm
+ def check_alpn(self):
+ alp = self.tcp_handler.get_alpn_proto_negotiated()
+ if alp != self.ALPN_PROTO_H2:
+ raise NotImplementedError(
+ "HTTP2Protocol can not handle unknown ALP: %s" % alp)
+ return True
+
+ def _receive_settings(self, hide=False):
+ while True:
+ frm = self.read_frame(hide)
+ if isinstance(frm, frame.SettingsFrame):
+ break
+
+ def _read_settings_ack(self, hide=False): # pragma no cover
+ while True:
+ frm = self.read_frame(hide)
+ if isinstance(frm, frame.SettingsFrame):
+ assert frm.flags & frame.Frame.FLAG_ACK
+ assert len(frm.settings) == 0
+ break
+
+ def _next_stream_id(self):
+ if self.current_stream_id is None:
+ if self.is_server:
+ # servers must use even stream ids
+ self.current_stream_id = 2
+ else:
+ # clients must use odd stream ids
+ self.current_stream_id = 1
+ else:
+ self.current_stream_id += 2
+ return self.current_stream_id
+
def _apply_settings(self, settings, hide=False):
for setting, value in settings.items():
old_value = self.http2_settings[setting]
@@ -181,89 +281,6 @@ class HTTP2Protocol(object):
return [frm.to_bytes()]
-
- def create_request(self, method, path, headers=None, body=None):
- if headers is None:
- headers = []
-
- authority = self.tcp_handler.sni if self.tcp_handler.sni else self.tcp_handler.address.host
- if self.tcp_handler.address.port != 443:
- authority += ":%d" % self.tcp_handler.address.port
-
- headers = [
- (b':method', bytes(method)),
- (b':path', bytes(path)),
- (b':scheme', b'https'),
- (b':authority', authority),
- ] + headers
-
- stream_id = self.next_stream_id()
-
- return list(itertools.chain(
- self._create_headers(headers, stream_id, end_stream=(body is None)),
- self._create_body(body, stream_id)))
-
- def read_response(self, request_method_='', body_size_limit_=None, include_body=True):
- timestamp_start = time.time()
- if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
- self.tcp_handler.rfile.reset_timestamps()
-
- stream_id, headers, body = self._receive_transmission(include_body)
-
- if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
- # more accurate timestamp_start
- timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
-
- if include_body:
- timestamp_end = time.time()
- else:
- timestamp_end = None
-
- response = http.Response(
- (2, 0),
- headers[':status'][0],
- "",
- headers,
- body,
- timestamp_start=timestamp_start,
- timestamp_end=timestamp_end,
- )
- response.stream_id = stream_id
-
- return response
-
- def read_request(self, include_body=True, body_size_limit_=None, allow_empty_=False):
- timestamp_start = time.time()
- if hasattr(self.tcp_handler.rfile, "reset_timestamps"):
- self.tcp_handler.rfile.reset_timestamps()
-
- stream_id, headers, body = self._receive_transmission(include_body)
-
- if hasattr(self.tcp_handler.rfile, "first_byte_timestamp"):
- # more accurate timestamp_start
- timestamp_start = self.tcp_handler.rfile.first_byte_timestamp
-
- timestamp_end = time.time()
-
- port = '' # TODO: parse port number?
-
- request = http.Request(
- "",
- headers.get_first(':method', ['']),
- headers.get_first(':scheme', ['']),
- headers.get_first(':host', ['']),
- port,
- headers.get_first(':path', ['']),
- (2, 0),
- headers,
- body,
- timestamp_start,
- timestamp_end,
- )
- request.stream_id = stream_id
-
- return request
-
def _receive_transmission(self, include_body=True):
body_expected = True
@@ -295,19 +312,3 @@ class HTTP2Protocol(object):
headers.add(header, value)
return stream_id, headers, body
-
- def create_response(self, code, stream_id=None, headers=None, body=None):
- if headers is None:
- headers = []
- if isinstance(headers, odict.ODict):
- headers = headers.items()
-
- headers = [(b':status', bytes(str(code)))] + headers
-
- if not stream_id:
- stream_id = self.next_stream_id()
-
- return list(itertools.chain(
- self._create_headers(headers, stream_id, end_stream=(body is None)),
- self._create_body(body, stream_id),
- ))
diff --git a/netlib/http/semantics.py b/netlib/http/semantics.py
index 63b6beb9..54bf83d2 100644
--- a/netlib/http/semantics.py
+++ b/netlib/http/semantics.py
@@ -7,6 +7,32 @@ import urlparse
from .. import utils, odict
+CONTENT_MISSING = 0
+
+
+class ProtocolMixin(object):
+
+ def read_request(self):
+ raise NotImplemented
+
+ def read_response(self):
+ raise NotImplemented
+
+ def assemble(self, message):
+ if isinstance(message, Request):
+ return self.assemble_request(message)
+ elif isinstance(message, Response):
+ return self.assemble_response(message)
+ else:
+ raise ValueError("HTTP message not supported.")
+
+ def assemble_request(self, request):
+ raise NotImplemented
+
+ def assemble_response(self, response):
+ raise NotImplemented
+
+
class Request(object):
def __init__(
@@ -18,12 +44,14 @@ class Request(object):
port,
path,
httpversion,
- headers,
- body,
+ headers=None,
+ body=None,
timestamp_start=None,
timestamp_end=None,
):
- assert isinstance(headers, odict.ODictCaseless) or not headers
+ if not headers:
+ headers = odict.ODictCaseless()
+ assert isinstance(headers, odict.ODictCaseless)
self.form_in = form_in
self.method = method
@@ -37,6 +65,7 @@ class Request(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
+
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
@@ -80,14 +109,16 @@ class Response(object):
self,
httpversion,
status_code,
- msg,
- headers,
- body,
+ msg=None,
+ headers=None,
+ body=None,
sslinfo=None,
timestamp_start=None,
timestamp_end=None,
):
- assert isinstance(headers, odict.ODictCaseless) or not headers
+ if not headers:
+ headers = odict.ODictCaseless()
+ assert isinstance(headers, odict.ODictCaseless)
self.httpversion = httpversion
self.status_code = status_code
@@ -98,6 +129,7 @@ class Response(object):
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
+
def __eq__(self, other):
try:
self_d = [self.__dict__[k] for k in self.__dict__ if k not in ('timestamp_start', 'timestamp_end')]
diff --git a/netlib/utils.py b/netlib/utils.py
index bee412f9..86e33f33 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -129,3 +129,13 @@ class Data(object):
if not os.path.exists(fullpath):
raise ValueError("dataPath: %s does not exist." % fullpath)
return fullpath
+
+
+def hostport(scheme, host, port):
+ """
+ Returns the host component, with a port specifcation if needed.
+ """
+ if (port, scheme) in [(80, "http"), (443, "https")]:
+ return host
+ else:
+ return "%s:%s" % (host, port)