diff options
| author | Maximilian Hils <git@maximilianhils.com> | 2015-09-17 15:16:12 +0200 | 
|---|---|---|
| committer | Maximilian Hils <git@maximilianhils.com> | 2015-09-17 15:16:12 +0200 | 
| commit | 8d71059d77c2dd1d9858d7971dd0b6b4387ed9f4 (patch) | |
| tree | 831f47cfd19e7d58c0f31b0a924832d421d4eb52 /netlib | |
| parent | a07e43df8b3988f137b48957f978ad570d9dc782 (diff) | |
| download | mitmproxy-8d71059d77c2dd1d9858d7971dd0b6b4387ed9f4.tar.gz mitmproxy-8d71059d77c2dd1d9858d7971dd0b6b4387ed9f4.tar.bz2 mitmproxy-8d71059d77c2dd1d9858d7971dd0b6b4387ed9f4.zip | |
clean up http message models
Diffstat (limited to 'netlib')
| -rw-r--r-- | netlib/http/http1/assemble.py | 8 | ||||
| -rw-r--r-- | netlib/http/models.py | 159 | ||||
| -rw-r--r-- | netlib/tutils.py | 4 | ||||
| -rw-r--r-- | netlib/utils.py | 30 | ||||
| -rw-r--r-- | netlib/websockets/frame.py | 9 | ||||
| -rw-r--r-- | netlib/websockets/protocol.py | 3 | 
6 files changed, 74 insertions, 139 deletions
| diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index 7252c446..b65a6be0 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -50,14 +50,14 @@ def _assemble_request_line(request, form=None):          return b"%s %s %s" % (              request.method,              request.path, -            request.httpversion +            request.http_version          )      elif form == "authority":          return b"%s %s:%d %s" % (              request.method,              request.host,              request.port, -            request.httpversion +            request.http_version          )      elif form == "absolute":          return b"%s %s://%s:%d%s %s" % ( @@ -66,7 +66,7 @@ def _assemble_request_line(request, form=None):              request.host,              request.port,              request.path, -            request.httpversion +            request.http_version          )      else:  # pragma: nocover          raise RuntimeError("Invalid request form") @@ -93,7 +93,7 @@ def _assemble_request_headers(request):  def _assemble_response_line(response):      return b"%s %d %s" % ( -        response.httpversion, +        response.http_version,          response.status_code,          response.msg,      ) diff --git a/netlib/http/models.py b/netlib/http/models.py index b4446ecb..54b8b112 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -193,15 +193,45 @@ class Headers(MutableMapping, object):          return cls([list(field) for field in state]) -class Request(object): +class Message(object): +    def __init__(self, http_version, headers, body, timestamp_start, timestamp_end): +        self.http_version = http_version +        if not headers: +            headers = Headers() +        assert isinstance(headers, Headers) +        self.headers = headers + +        self._body = body +        self.timestamp_start = timestamp_start +        self.timestamp_end = timestamp_end + +    @property +    def body(self): +        return self._body + +    @body.setter +    def body(self, body): +        self._body = body +        if isinstance(body, bytes): +            self.headers[b"Content-Length"] = str(len(body)).encode() + +    content = body + +    def __eq__(self, other): +        if isinstance(other, Message): +            return self.__dict__ == other.__dict__ +        return False + + +class Request(Message):      # This list is adopted legacy code.      # We probably don't need to strip off keep-alive.      _headers_to_strip_off = [ -        'Proxy-Connection', -        'Keep-Alive', -        'Connection', -        'Transfer-Encoding', -        'Upgrade', +        b'Proxy-Connection', +        b'Keep-Alive', +        b'Connection', +        b'Transfer-Encoding', +        b'Upgrade',      ]      def __init__( @@ -212,16 +242,14 @@ class Request(object):              host,              port,              path, -            httpversion, +            http_version,              headers=None,              body=None,              timestamp_start=None,              timestamp_end=None,              form_out=None      ): -        if not headers: -            headers = Headers() -        assert isinstance(headers, Headers) +        super(Request, self).__init__(http_version, headers, body, timestamp_start, timestamp_end)          self.form_in = form_in          self.method = method @@ -229,23 +257,8 @@ class Request(object):          self.host = host          self.port = port          self.path = path -        self.httpversion = httpversion -        self.headers = headers -        self._body = body -        self.timestamp_start = timestamp_start -        self.timestamp_end = timestamp_end          self.form_out = form_out or form_in -    def __eq__(self, other): -        try: -            self_d = [self.__dict__[k] for k in self.__dict__ if -                      k not in ('timestamp_start', 'timestamp_end')] -            other_d = [other.__dict__[k] for k in other.__dict__ if -                       k not in ('timestamp_start', 'timestamp_end')] -            return self_d == other_d -        except: -            return False -      def __repr__(self):          if self.host and self.port:              hostport = "{}:{}".format(self.host, self.port) @@ -262,8 +275,8 @@ class Request(object):              response. That is, we remove ETags and If-Modified-Since headers.          """          delheaders = [ -            "if-modified-since", -            "if-none-match", +            b"if-modified-since", +            b"if-none-match",          ]          for i in delheaders:              self.headers.pop(i, None) @@ -273,16 +286,16 @@ class Request(object):              Modifies this request to remove headers that will compress the              resource's data.          """ -        self.headers["accept-encoding"] = "identity" +        self.headers[b"accept-encoding"] = b"identity"      def constrain_encoding(self):          """              Limits the permissible Accept-Encoding values, based on what we can              decode appropriately.          """ -        accept_encoding = self.headers.get("accept-encoding") +        accept_encoding = self.headers.get(b"accept-encoding")          if accept_encoding: -            self.headers["accept-encoding"] = ( +            self.headers[b"accept-encoding"] = (                  ', '.join(                      e                      for e in encoding.ENCODINGS @@ -335,7 +348,7 @@ class Request(object):          """          # FIXME: If there's an existing content-type header indicating a          # url-encoded form, leave it alone. -        self.headers["Content-Type"] = HDR_FORM_URLENCODED +        self.headers[b"Content-Type"] = HDR_FORM_URLENCODED          self.body = utils.urlencode(odict.lst)      def get_path_components(self): @@ -452,37 +465,17 @@ class Request(object):              raise ValueError("Invalid URL: %s" % url)          self.scheme, self.host, self.port, self.path = parts -    @property -    def body(self): -        return self._body - -    @body.setter -    def body(self, body): -        self._body = body -        if isinstance(body, bytes): -            self.headers["Content-Length"] = str(len(body)).encode() - -    @property -    def content(self):  # pragma: no cover -        # TODO: remove deprecated getter -        return self.body - -    @content.setter -    def content(self, content):  # pragma: no cover -        # TODO: remove deprecated setter -        self.body = content - -class Response(object): +class Response(Message):      _headers_to_strip_off = [ -        'Proxy-Connection', -        'Alternate-Protocol', -        'Alt-Svc', +        b'Proxy-Connection', +        b'Alternate-Protocol', +        b'Alt-Svc',      ]      def __init__(              self, -            httpversion, +            http_version,              status_code,              msg=None,              headers=None, @@ -490,27 +483,9 @@ class Response(object):              timestamp_start=None,              timestamp_end=None,      ): -        if not headers: -            headers = Headers() -        assert isinstance(headers, Headers) - -        self.httpversion = httpversion +        super(Response, self).__init__(http_version, headers, body, timestamp_start, timestamp_end)          self.status_code = status_code          self.msg = msg -        self.headers = headers -        self._body = body -        self.timestamp_start = timestamp_start -        self.timestamp_end = timestamp_end - -    def __eq__(self, other): -        try: -            self_d = [self.__dict__[k] for k in self.__dict__ if -                      k not in ('timestamp_start', 'timestamp_end')] -            other_d = [other.__dict__[k] for k in other.__dict__ if -                       k not in ('timestamp_start', 'timestamp_end')] -            return self_d == other_d -        except: -            return False      def __repr__(self):          # return "Response(%s - %s)" % (self.status_code, self.msg) @@ -536,7 +511,7 @@ class Response(object):              attributes (e.g. HTTPOnly) are indicated by a Null value.          """          ret = [] -        for header in self.headers.get_all("set-cookie"): +        for header in self.headers.get_all(b"set-cookie"):              v = cookies.parse_set_cookie_header(header)              if v:                  name, value, attrs = v @@ -559,34 +534,4 @@ class Response(object):                      i[1][1]                  )              ) -        self.headers.set_all("Set-Cookie", values) - -    @property -    def body(self): -        return self._body - -    @body.setter -    def body(self, body): -        self._body = body -        if isinstance(body, bytes): -            self.headers["Content-Length"] = str(len(body)).encode() - -    @property -    def content(self):  # pragma: no cover -        # TODO: remove deprecated getter -        return self.body - -    @content.setter -    def content(self, content):  # pragma: no cover -        # TODO: remove deprecated setter -        self.body = content - -    @property -    def code(self):  # pragma: no cover -        # TODO: remove deprecated getter -        return self.status_code - -    @code.setter -    def code(self, code):  # pragma: no cover -        # TODO: remove deprecated setter -        self.status_code = code +        self.headers.set_all(b"Set-Cookie", values) diff --git a/netlib/tutils.py b/netlib/tutils.py index 05791c49..b69495a3 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -105,7 +105,7 @@ def treq(**kwargs):          host=b"address",          port=22,          path=b"/path", -        httpversion=b"HTTP/1.1", +        http_version=b"HTTP/1.1",          headers=Headers(header=b"qvalue"),          body=b"content"      ) @@ -119,7 +119,7 @@ def tresp(**kwargs):          netlib.http.Response      """      default = dict( -        httpversion=b"HTTP/1.1", +        http_version=b"HTTP/1.1",          status_code=200,          msg=b"OK",          headers=Headers(header_response=b"svalue"), diff --git a/netlib/utils.py b/netlib/utils.py index a86b8019..14b428d7 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -17,11 +17,6 @@ def isascii(bytes):      return True -# best way to do it in python 2.x -def bytes_to_int(i): -    return int(i.encode('hex'), 16) - -  def clean_bin(s, keep_spacing=True):      """          Cleans binary data to make it safe to display. @@ -51,21 +46,15 @@ def clean_bin(s, keep_spacing=True):  def hexdump(s):      """ -        Returns a set of tuples: -            (offset, hex, str) +        Returns: +            A generator of (offset, hex, str) tuples      """ -    parts = []      for i in range(0, len(s), 16): -        o = "%.10x" % i +        offset = b"%.10x" % i          part = s[i:i + 16] -        x = " ".join("%.2x" % ord(i) for i in part) -        if len(part) < 16: -            x += " " -            x += " ".join("  " for i in range(16 - len(part))) -        parts.append( -            (o, x, clean_bin(part, False)) -        ) -    return parts +        x = b" ".join(b"%.2x" % i for i in six.iterbytes(part)) +        x = x.ljust(47)  # 16*2 + 15 +        yield (offset, x, clean_bin(part, False))  def setbit(byte, offset, value): @@ -80,8 +69,7 @@ def setbit(byte, offset, value):  def getbit(byte, offset):      mask = 1 << offset -    if byte & mask: -        return True +    return bool(byte & mask)  class BiDi(object): @@ -159,7 +147,7 @@ def is_valid_host(host):          return False      if len(host) > 255:          return False -    if host[-1] == ".": +    if host[-1] == b".":          host = host[:-1]      return all(_label_valid.match(x) for x in host.split(b".")) @@ -248,7 +236,7 @@ def hostport(scheme, host, port):      """          Returns the host component, with a port specifcation if needed.      """ -    if (port, scheme) in [(80, "http"), (443, "https")]: +    if (port, scheme) in [(80, b"http"), (443, b"https")]:          return host      else:          return b"%s:%d" % (host, port) diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py index e3ff1405..ceddd273 100644 --- a/netlib/websockets/frame.py +++ b/netlib/websockets/frame.py @@ -2,6 +2,7 @@ from __future__ import absolute_import  import os  import struct  import io +import six  from .protocol import Masker  from netlib import tcp @@ -127,8 +128,8 @@ class FrameHeader(object):          """            read a websockets frame header          """ -        first_byte = utils.bytes_to_int(fp.safe_read(1)) -        second_byte = utils.bytes_to_int(fp.safe_read(1)) +        first_byte = six.byte2int(fp.safe_read(1)) +        second_byte = six.byte2int(fp.safe_read(1))          fin = utils.getbit(first_byte, 7)          rsv1 = utils.getbit(first_byte, 6) @@ -145,9 +146,9 @@ class FrameHeader(object):          if length_code <= 125:              payload_length = length_code          elif length_code == 126: -            payload_length = utils.bytes_to_int(fp.safe_read(2)) +            payload_length, = struct.unpack("!H", fp.safe_read(2))          elif length_code == 127: -            payload_length = utils.bytes_to_int(fp.safe_read(8)) +            payload_length, = struct.unpack("!Q", fp.safe_read(8))          # masking key only present if mask bit set          if mask_bit == 1: diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 46c02875..68d827a5 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -17,6 +17,7 @@ from __future__ import absolute_import  import base64  import hashlib  import os +import six  from ..http import Headers  from .. import utils @@ -40,7 +41,7 @@ class Masker(object):      def __init__(self, key):          self.key = key -        self.masks = [utils.bytes_to_int(byte) for byte in key] +        self.masks = [six.byte2int(byte) for byte in key]          self.offset = 0      def mask(self, offset, data): | 
