diff options
40 files changed, 1015 insertions, 642 deletions
| diff --git a/docs/dev/models.rst b/docs/dev/models.rst index 8c4e6825..f2ddf242 100644 --- a/docs/dev/models.rst +++ b/docs/dev/models.rst @@ -56,6 +56,17 @@ Datastructures          :special-members:          :no-undoc-members: +    .. autoclass:: MultiDictView + +        .. automethod:: get_all +        .. automethod:: set_all +        .. automethod:: add +        .. automethod:: insert +        .. automethod:: keys +        .. automethod:: values +        .. automethod:: items +        .. automethod:: to_dict +      .. autoclass:: decoded  .. automodule:: mitmproxy.models diff --git a/examples/modify_form.py b/examples/modify_form.py index 86188781..3fe0cf96 100644 --- a/examples/modify_form.py +++ b/examples/modify_form.py @@ -1,5 +1,8 @@  def request(context, flow): -    form = flow.request.urlencoded_form -    if form is not None: -        form["mitmproxy"] = ["rocks"] -        flow.request.urlencoded_form = form +    if flow.request.urlencoded_form: +        flow.request.urlencoded_form["mitmproxy"] = "rocks" +    else: +        # This sets the proper content type and overrides the body. +        flow.request.urlencoded_form = [ +            ("foo", "bar") +        ] diff --git a/examples/modify_querystring.py b/examples/modify_querystring.py index d682df69..b89e5c8d 100644 --- a/examples/modify_querystring.py +++ b/examples/modify_querystring.py @@ -1,5 +1,2 @@  def request(context, flow): -    q = flow.request.query -    if q: -        q["mitmproxy"] = ["rocks"] -        flow.request.query = q +    flow.request.query["mitmproxy"] = "rocks" diff --git a/mitmproxy/console/flowview.py b/mitmproxy/console/flowview.py index b2ebe49e..2010cecd 100644 --- a/mitmproxy/console/flowview.py +++ b/mitmproxy/console/flowview.py @@ -6,8 +6,7 @@ import sys  import math  import urwid -from netlib import odict -from netlib.http import Headers +from netlib.http import Headers, status_codes  from . import common, grideditor, signals, searchable, tabs  from . import flowdetailview  from .. import utils, controller, contentviews @@ -187,7 +186,7 @@ class FlowView(tabs.Tabs):                  viewmode,                  message,                  limit, -                (bytes(message.headers), message.content)  # Cache invalidation +                message  # Cache invalidation              )      def _get_content_view(self, viewmode, message, max_lines, _): @@ -316,21 +315,18 @@ class FlowView(tabs.Tabs):              return "Invalid URL."          signals.flow_change.send(self, flow = self.flow) -    def set_resp_code(self, code): -        response = self.flow.response +    def set_resp_status_code(self, status_code):          try: -            response.status_code = int(code) +            status_code = int(status_code)          except ValueError:              return None -        import BaseHTTPServer -        if int(code) in BaseHTTPServer.BaseHTTPRequestHandler.responses: -            response.msg = BaseHTTPServer.BaseHTTPRequestHandler.responses[ -                int(code)][0] +        self.flow.response.status_code = status_code +        if status_code in status_codes.RESPONSES: +            self.flow.response.reason = status_codes.RESPONSES[status_code]          signals.flow_change.send(self, flow = self.flow) -    def set_resp_msg(self, msg): -        response = self.flow.response -        response.msg = msg +    def set_resp_reason(self, reason): +        self.flow.response.reason = reason          signals.flow_change.send(self, flow = self.flow)      def set_headers(self, fields, conn): @@ -338,22 +334,22 @@ class FlowView(tabs.Tabs):          signals.flow_change.send(self, flow = self.flow)      def set_query(self, lst, conn): -        conn.set_query(odict.ODict(lst)) +        conn.query = lst          signals.flow_change.send(self, flow = self.flow)      def set_path_components(self, lst, conn): -        conn.set_path_components(lst) +        conn.path_components = lst          signals.flow_change.send(self, flow = self.flow)      def set_form(self, lst, conn): -        conn.set_form_urlencoded(odict.ODict(lst)) +        conn.urlencoded_form = lst          signals.flow_change.send(self, flow = self.flow)      def edit_form(self, conn):          self.master.view_grideditor(              grideditor.URLEncodedFormEditor(                  self.master, -                conn.get_form_urlencoded().lst, +                conn.urlencoded_form.items(multi=True),                  self.set_form,                  conn              ) @@ -364,7 +360,7 @@ class FlowView(tabs.Tabs):              self.edit_form(conn)      def set_cookies(self, lst, conn): -        conn.cookies = odict.ODict(lst) +        conn.cookies = lst          signals.flow_change.send(self, flow = self.flow)      def set_setcookies(self, data, conn): @@ -388,7 +384,7 @@ class FlowView(tabs.Tabs):              self.master.view_grideditor(                  grideditor.CookieEditor(                      self.master, -                    message.cookies.lst, +                    message.cookies.items(multi=True),                      self.set_cookies,                      message                  ) @@ -397,7 +393,7 @@ class FlowView(tabs.Tabs):              self.master.view_grideditor(                  grideditor.SetCookieEditor(                      self.master, -                    message.cookies, +                    message.cookies.items(multi=True),                      self.set_setcookies,                      message                  ) @@ -413,7 +409,7 @@ class FlowView(tabs.Tabs):                  c = self.master.spawn_editor(message.content or "")                  message.content = c.rstrip("\n")          elif part == "f": -            if not message.get_form_urlencoded() and message.content: +            if not message.urlencoded_form and message.content:                  signals.status_prompt_onekey.send(                      prompt = "Existing body is not a URL-encoded form. Clear and edit?",                      keys = [ @@ -435,7 +431,7 @@ class FlowView(tabs.Tabs):                  )              )          elif part == "p": -            p = message.get_path_components() +            p = message.path_components              self.master.view_grideditor(                  grideditor.PathEditor(                      self.master, @@ -448,7 +444,7 @@ class FlowView(tabs.Tabs):              self.master.view_grideditor(                  grideditor.QueryEditor(                      self.master, -                    message.get_query().lst, +                    message.query.items(multi=True),                      self.set_query, message                  )              ) @@ -458,7 +454,7 @@ class FlowView(tabs.Tabs):                  text = message.url,                  callback = self.set_url              ) -        elif part == "m": +        elif part == "m" and message == self.flow.request:              signals.status_prompt_onekey.send(                  prompt = "Method",                  keys = common.METHOD_OPTIONS, @@ -468,13 +464,13 @@ class FlowView(tabs.Tabs):              signals.status_prompt.send(                  prompt = "Code",                  text = str(message.status_code), -                callback = self.set_resp_code +                callback = self.set_resp_status_code              ) -        elif part == "m": +        elif part == "m" and message == self.flow.response:              signals.status_prompt.send(                  prompt = "Message", -                text = message.msg, -                callback = self.set_resp_msg +                text = message.reason, +                callback = self.set_resp_reason              )          signals.flow_change.send(self, flow = self.flow) diff --git a/mitmproxy/console/grideditor.py b/mitmproxy/console/grideditor.py index 46ff348e..11ce7d02 100644 --- a/mitmproxy/console/grideditor.py +++ b/mitmproxy/console/grideditor.py @@ -700,17 +700,17 @@ class SetCookieEditor(GridEditor):      def data_in(self, data):          flattened = [] -        for k, v in data.items(): -            flattened.append([k, v[0], v[1].lst]) +        for key, (value, attrs) in data: +            flattened.append([key, value, attrs.items(multi=True)])          return flattened      def data_out(self, data):          vals = [] -        for i in data: +        for key, value, attrs in data:              vals.append(                  [ -                    i[0], -                    [i[1], odict.ODictCaseless(i[2])] +                    key, +                    (value, attrs)                  ]              ) -        return odict.ODict(vals) +        return vals diff --git a/mitmproxy/flow.py b/mitmproxy/flow.py index ccedd1d4..a9018e16 100644 --- a/mitmproxy/flow.py +++ b/mitmproxy/flow.py @@ -158,9 +158,9 @@ class SetHeaders:          for _, header, value, cpatt in self.lst:              if cpatt(f):                  if f.response: -                    f.response.headers.fields.append((header, value)) +                    f.response.headers.add(header, value)                  else: -                    f.request.headers.fields.append((header, value)) +                    f.request.headers.add(header, value)  class StreamLargeBodies(object): @@ -265,7 +265,7 @@ class ServerPlaybackState:              form_contents = r.urlencoded_form or r.multipart_form              if self.ignore_payload_params and form_contents:                  key.extend( -                    p for p in form_contents +                    p for p in form_contents.items(multi=True)                      if p[0] not in self.ignore_payload_params                  )              else: @@ -321,10 +321,10 @@ class StickyCookieState:          """          domain = f.request.host          path = "/" -        if attrs["domain"]: -            domain = attrs["domain"][-1] -        if attrs["path"]: -            path = attrs["path"][-1] +        if "domain" in attrs: +            domain = attrs["domain"] +        if "path" in attrs: +            path = attrs["path"]          return (domain, f.request.port, path)      def domain_match(self, a, b): @@ -335,28 +335,26 @@ class StickyCookieState:          return False      def handle_response(self, f): -        for i in f.response.headers.get_all("set-cookie"): +        for name, (value, attrs) in f.response.cookies.items(multi=True):              # FIXME: We now know that Cookie.py screws up some cookies with              # valid RFC 822/1123 datetime specifications for expiry. Sigh. -            name, value, attrs = cookies.parse_set_cookie_header(str(i))              a = self.ckey(attrs, f)              if self.domain_match(f.request.host, a[0]): -                b = attrs.lst -                b.insert(0, [name, value]) -                self.jar[a][name] = odict.ODictCaseless(b) +                b = attrs.with_insert(0, name, value) +                self.jar[a][name] = b      def handle_request(self, f):          l = []          if f.match(self.flt): -            for i in self.jar.keys(): +            for domain, port, path in self.jar.keys():                  match = [ -                    self.domain_match(f.request.host, i[0]), -                    f.request.port == i[1], -                    f.request.path.startswith(i[2]) +                    self.domain_match(f.request.host, domain), +                    f.request.port == port, +                    f.request.path.startswith(path)                  ]                  if all(match): -                    c = self.jar[i] -                    l.extend([cookies.format_cookie_header(c[name]) for name in c.keys()]) +                    c = self.jar[(domain, port, path)] +                    l.extend([cookies.format_cookie_header(c[name].items(multi=True)) for name in c.keys()])          if l:              f.request.stickycookie = True              f.request.headers["cookie"] = "; ".join(l) diff --git a/mitmproxy/flow_export.py b/mitmproxy/flow_export.py index d8e65704..ae282fce 100644 --- a/mitmproxy/flow_export.py +++ b/mitmproxy/flow_export.py @@ -51,7 +51,7 @@ def python_code(flow):      params = ""      if flow.request.query: -        lines = ["    '%s': '%s',\n" % (k, v) for k, v in flow.request.query] +        lines = ["    %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()]          params = "\nparams = {\n%s}\n" % "".join(lines)          args += "\n    params=params," @@ -140,7 +140,7 @@ def locust_code(flow):      params = ""      if flow.request.query: -        lines = ["            '%s': '%s',\n" % (k, v) for k, v in flow.request.query] +        lines = ["            %s: %s,\n" % (repr(k), repr(v)) for k, v in flow.request.query.to_dict().items()]          params = "\n        params = {\n%s        }\n" % "".join(lines)          args += "\n            params=params," diff --git a/mitmproxy/protocol/base.py b/mitmproxy/protocol/base.py index 536f2753..c8e58d1b 100644 --- a/mitmproxy/protocol/base.py +++ b/mitmproxy/protocol/base.py @@ -133,24 +133,15 @@ class ServerConnectionMixin(object):                      "The proxy shall not connect to itself.".format(repr(address))                  ) -    def set_server(self, address, server_tls=None, sni=None): +    def set_server(self, address):          """          Sets a new server address. If there is an existing connection, it will be closed. - -        Raises: -            ~mitmproxy.exceptions.ProtocolException: -                if ``server_tls`` is ``True``, but there was no TLS layer on the -                protocol stack which could have processed this.          """          if self.server_conn:              self.disconnect()          self.log("Set new server address: " + repr(address), "debug")          self.server_conn.address = address          self.__check_self_connect() -        if server_tls: -            raise ProtocolException( -                "Cannot upgrade to TLS, no TLS layer on the protocol stack." -            )      def disconnect(self):          """ diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py index 9cb35176..d9111303 100644 --- a/mitmproxy/protocol/http.py +++ b/mitmproxy/protocol/http.py @@ -120,7 +120,7 @@ class UpstreamConnectLayer(Layer):          if address != self.server_conn.via.address:              self.ctx.set_server(address) -    def set_server(self, address, server_tls=None, sni=None): +    def set_server(self, address):          if self.ctx.server_conn:              self.ctx.disconnect()          address = tcp.Address.wrap(address) @@ -128,11 +128,6 @@ class UpstreamConnectLayer(Layer):          self.connect_request.port = address.port          self.server_conn.address = address -        if server_tls: -            raise ProtocolException( -                "Cannot upgrade to TLS, no TLS layer on the protocol stack." -            ) -  class HttpLayer(Layer): @@ -149,7 +144,7 @@ class HttpLayer(Layer):      def __call__(self):          if self.mode == "transparent": -            self.__initial_server_tls = self._server_tls +            self.__initial_server_tls = self.server_tls              self.__initial_server_conn = self.server_conn          while True:              try: @@ -360,8 +355,9 @@ class HttpLayer(Layer):          if self.mode == "regular" or self.mode == "transparent":              # If there's an existing connection that doesn't match our expectations, kill it. -            if address != self.server_conn.address or tls != self.server_conn.tls_established: -                self.set_server(address, tls, address.host) +            if address != self.server_conn.address or tls != self.server_tls: +                self.set_server(address) +                self.set_server_tls(tls, address.host)              # Establish connection is neccessary.              if not self.server_conn:                  self.connect() diff --git a/mitmproxy/protocol/tls.py b/mitmproxy/protocol/tls.py index 26c3f9d2..74c55ab4 100644 --- a/mitmproxy/protocol/tls.py +++ b/mitmproxy/protocol/tls.py @@ -266,18 +266,22 @@ class TlsClientHello(object):          return self._client_hello      @property -    def client_cipher_suites(self): +    def cipher_suites(self):          return self._client_hello.cipher_suites.cipher_suites      @property -    def client_sni(self): +    def sni(self):          for extension in self._client_hello.extensions: -            if (extension.type == 0x00 and len(extension.server_names) == 1 -                    and extension.server_names[0].type == 0): +            is_valid_sni_extension = ( +                extension.type == 0x00 +                and len(extension.server_names) == 1 +                and extension.server_names[0].type == 0 +            ) +            if is_valid_sni_extension:                  return extension.server_names[0].name      @property -    def client_alpn_protocols(self): +    def alpn_protocols(self):          for extension in self._client_hello.extensions:              if extension.type == 0x10:                  return list(extension.alpn_protocols) @@ -304,55 +308,78 @@ class TlsClientHello(object):      def __repr__(self):          return "TlsClientHello( sni: %s alpn_protocols: %s,  cipher_suites: %s)" % \ -            (self.client_sni, self.client_alpn_protocols, self.client_cipher_suites) +            (self.sni, self.alpn_protocols, self.cipher_suites)  class TlsLayer(Layer): +    """ +    The TLS layer implements transparent TLS connections. -    def __init__(self, ctx, client_tls, server_tls): -        self.client_sni = None -        self.client_alpn_protocols = None -        self.client_ciphers = [] +    It exposes the following API to child layers: + +        - :py:meth:`set_server_tls` to modify TLS settings for the server connection. +        - :py:attr:`server_tls`, :py:attr:`server_sni` as read-only attributes describing the current TLS settings for +          the server connection. +    """ +    def __init__(self, ctx, client_tls, server_tls):          super(TlsLayer, self).__init__(ctx)          self._client_tls = client_tls          self._server_tls = server_tls -        self._sni_from_server_change = None +        self._custom_server_sni = None +        self._client_hello = None  # type: TlsClientHello      def __call__(self):          """ -        The strategy for establishing SSL is as follows: +        The strategy for establishing TLS is as follows:              First, we determine whether we need the server cert to establish ssl with the client.              If so, we first connect to the server and then to the client. -            If not, we only connect to the client and do the server_ssl lazily on a Connect message. - -        An additional complexity is that establish ssl with the server may require a SNI value from -        the client. In an ideal world, we'd do the following: -            1. Start the SSL handshake with the client -            2. Check if the client sends a SNI. -            3. Pause the client handshake, establish SSL with the server. -            4. Finish the client handshake with the certificate from the server. -        There's just one issue: We cannot get a callback from OpenSSL if the client doesn't send a SNI. :( -        Thus, we manually peek into the connection and parse the ClientHello message to obtain both SNI and ALPN values. - -        Further notes: -            - OpenSSL 1.0.2 introduces a callback that would help here: -              https://www.openssl.org/docs/ssl/SSL_CTX_set_cert_cb.html -            - The original mitmproxy issue is https://github.com/mitmproxy/mitmproxy/issues/427 -        """ - -        client_tls_requires_server_cert = ( -            self._client_tls and self._server_tls and not self.config.no_upstream_cert -        ) +            If not, we only connect to the client and do the server handshake lazily. +        An additional complexity is that we need to mirror SNI and ALPN from the client when connecting to the server. +        We manually peek into the connection and parse the ClientHello message to obtain these values. +        """          if self._client_tls: -            self._parse_client_hello() +            # Peek into the connection, read the initial client hello and parse it to obtain SNI and ALPN values. +            try: +                self._client_hello = TlsClientHello.from_client_conn(self.client_conn) +            except TlsProtocolException as e: +                self.log("Cannot parse Client Hello: %s" % repr(e), "error") + +        # Do we need to do a server handshake now? +        # There are two reasons why we would want to establish TLS with the server now: +        #  1. If we already have an existing server connection and server_tls is True, +        #     we need to establish TLS now because .connect() will not be called anymore. +        #  2. We may need information from the server connection for the client handshake. +        # +        # A couple of factors influence (2): +        #  2.1 There actually is (or will be) a TLS-enabled upstream connection +        #  2.2 An upstream connection is not wanted by the user if --no-upstream-cert is passed. +        #  2.3 An upstream connection is implied by add_upstream_certs_to_client_chain +        #  2.4 The client wants to negotiate an alternative protocol in its handshake, we need to find out +        #      what is supported by the server +        #  2.5 The client did not sent a SNI value, we don't know the certificate subject. +        client_tls_requires_server_connection = ( +            self._server_tls +            and not self.config.no_upstream_cert +            and ( +                self.config.add_upstream_certs_to_client_chain +                or self._client_hello.alpn_protocols +                or not self._client_hello.sni +            ) +        ) +        establish_server_tls_now = ( +            (self.server_conn and self._server_tls) +            or client_tls_requires_server_connection +        ) -        if client_tls_requires_server_cert: +        if self._client_tls and establish_server_tls_now:              self._establish_tls_with_client_and_server()          elif self._client_tls:              self._establish_tls_with_client() +        elif establish_server_tls_now: +            self._establish_tls_with_server()          layer = self.ctx.next_layer(self)          layer() @@ -367,47 +394,48 @@ class TlsLayer(Layer):          else:              return "TlsLayer(inactive)" -    def _parse_client_hello(self): -        """ -        Peek into the connection, read the initial client hello and parse it to obtain ALPN values. -        """ -        try: -            parsed = TlsClientHello.from_client_conn(self.client_conn) -            self.client_sni = parsed.client_sni -            self.client_alpn_protocols = parsed.client_alpn_protocols -            self.client_ciphers = parsed.client_cipher_suites -        except TlsProtocolException as e: -            self.log("Cannot parse Client Hello: %s" % repr(e), "error") -      def connect(self):          if not self.server_conn:              self.ctx.connect()          if self._server_tls and not self.server_conn.tls_established:              self._establish_tls_with_server() -    def set_server(self, address, server_tls=None, sni=None): -        if server_tls is not None: -            self._sni_from_server_change = sni -            self._server_tls = server_tls -        self.ctx.set_server(address, None, None) +    def set_server_tls(self, server_tls, sni=None): +        """ +        Set the TLS settings for the next server connection that will be established. +        This function will not alter an existing connection. + +        Args: +            server_tls: Shall we establish TLS with the server? +            sni: ``bytes`` for a custom SNI value, +                ``None`` for the client SNI value, +                ``False`` if no SNI value should be sent. +        """ +        self._server_tls = server_tls +        self._custom_server_sni = sni + +    @property +    def server_tls(self): +        """ +        ``True``, if the next server connection that will be established should be upgraded to TLS. +        """ +        return self._server_tls      @property -    def sni_for_server_connection(self): -        if self._sni_from_server_change is False: +    def server_sni(self): +        """ +        The Server Name Indication we want to send with the next server TLS handshake. +        """ +        if self._custom_server_sni is False:              return None          else: -            return self._sni_from_server_change or self.client_sni +            return self._custom_server_sni or self._client_hello.sni      @property      def alpn_for_client_connection(self):          return self.server_conn.get_alpn_proto_negotiated()      def __alpn_select_callback(self, conn_, options): -        """ -        Once the client signals the alternate protocols it supports, -        we reconnect upstream with the same list and pass the server's choice down to the client. -        """ -          # This gets triggered if we haven't established an upstream connection yet.          default_alpn = b'http/1.1'          # alpn_preference = b'h2' @@ -422,12 +450,12 @@ class TlsLayer(Layer):          return choice      def _establish_tls_with_client_and_server(self): -        # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless -        # to send an error message over TLS.          try:              self.ctx.connect()              self._establish_tls_with_server()          except Exception: +            # If establishing TLS with the server fails, we try to establish TLS with the client nonetheless +            # to send an error message over TLS.              try:                  self._establish_tls_with_client()              except: @@ -466,9 +494,9 @@ class TlsLayer(Layer):                  ClientHandshakeException,                  ClientHandshakeException(                      "Cannot establish TLS with client (sni: {sni}): {e}".format( -                        sni=self.client_sni, e=repr(e) +                        sni=self._client_hello.sni, e=repr(e)                      ), -                    self.client_sni or repr(self.server_conn.address) +                    self._client_hello.sni or repr(self.server_conn.address)                  ),                  sys.exc_info()[2]              ) @@ -480,8 +508,8 @@ class TlsLayer(Layer):              # If the server only supports spdy (next to http/1.1), it may select that              # and mitmproxy would enter TCP passthrough mode, which we want to avoid.              deprecated_http2_variant = lambda x: x.startswith(b"h2-") or x.startswith(b"spdy") -            if self.client_alpn_protocols: -                alpn = [x for x in self.client_alpn_protocols if not deprecated_http2_variant(x)] +            if self._client_hello.alpn_protocols: +                alpn = [x for x in self._client_hello.alpn_protocols if not deprecated_http2_variant(x)]              else:                  alpn = None              if alpn and b"h2" in alpn and not self.config.http2: @@ -490,14 +518,14 @@ class TlsLayer(Layer):              ciphers_server = self.config.ciphers_server              if not ciphers_server:                  ciphers_server = [] -                for id in self.client_ciphers: +                for id in self._client_hello.cipher_suites:                      if id in CIPHER_ID_NAME_MAP.keys():                          ciphers_server.append(CIPHER_ID_NAME_MAP[id])                  ciphers_server = ':'.join(ciphers_server)              self.server_conn.establish_ssl(                  self.config.clientcerts, -                self.sni_for_server_connection, +                self.server_sni,                  method=self.config.openssl_method_server,                  options=self.config.openssl_options_server,                  verify_options=self.config.openssl_verification_mode_server, @@ -524,7 +552,7 @@ class TlsLayer(Layer):                  TlsProtocolException,                  TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format(                      address=repr(self.server_conn.address), -                    sni=self.sni_for_server_connection, +                    sni=self.server_sni,                      e=repr(e),                  )),                  sys.exc_info()[2] @@ -534,7 +562,7 @@ class TlsLayer(Layer):                  TlsProtocolException,                  TlsProtocolException("Cannot establish TLS with {address} (sni: {sni}): {e}".format(                      address=repr(self.server_conn.address), -                    sni=self.sni_for_server_connection, +                    sni=self.server_sni,                      e=repr(e),                  )),                  sys.exc_info()[2] @@ -569,13 +597,13 @@ class TlsLayer(Layer):                  sans.add(host)                  host = upstream_cert.cn.decode("utf8").encode("idna")          # Also add SNI values. -        if self.client_sni: -            sans.add(self.client_sni) -        if self._sni_from_server_change: -            sans.add(self._sni_from_server_change) +        if self._client_hello.sni: +            sans.add(self._client_hello.sni) +        if self._custom_server_sni: +            sans.add(self._custom_server_sni) -        # Some applications don't consider the CN and expect the hostname to be in the SANs. -        # For example, Thunderbird 38 will display a warning if the remote host is only the CN. +        # RFC 2818: If a subjectAltName extension of type dNSName is present, that MUST be used as the identity. +        # In other words, the Common Name is irrelevant then.          if host:              sans.add(host)          return self.config.certstore.get_cert(host, list(sans)) diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py index 9caae02a..c55105ec 100644 --- a/mitmproxy/proxy/root_context.py +++ b/mitmproxy/proxy/root_context.py @@ -63,7 +63,7 @@ class RootContext(object):                  except TlsProtocolException as e:                      self.log("Cannot parse Client Hello: %s" % repr(e), "error")                  else: -                    ignore = self.config.check_ignore((client_hello.client_sni, 443)) +                    ignore = self.config.check_ignore((client_hello.sni, 443))              if ignore:                  return RawTCPLayer(top_layer, logging=False) diff --git a/mitmproxy/utils.py b/mitmproxy/utils.py index 5fd062ea..cda5bba6 100644 --- a/mitmproxy/utils.py +++ b/mitmproxy/utils.py @@ -7,6 +7,9 @@ import json  import importlib  import inspect +import netlib.utils + +  def timestamp():      """          Returns a serializable UTC timestamp. @@ -73,25 +76,7 @@ def pretty_duration(secs):      return "{:.0f}ms".format(secs * 1000) -class Data: - -    def __init__(self, name): -        m = importlib.import_module(name) -        dirname = os.path.dirname(inspect.getsourcefile(m)) -        self.dirname = os.path.abspath(dirname) - -    def path(self, path): -        """ -            Returns a path to the package data housed at 'path' under this -            module.Path can be a path to a file, or to a directory. - -            This function will raise ValueError if the path does not exist. -        """ -        fullpath = os.path.join(self.dirname, path) -        if not os.path.exists(fullpath): -            raise ValueError("dataPath: %s does not exist." % fullpath) -        return fullpath -pkg_data = Data(__name__) +pkg_data = netlib.utils.Data(__name__)  class LRUCache: diff --git a/netlib/encoding.py b/netlib/encoding.py index 14479e00..98502451 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,7 +5,6 @@ from __future__ import absolute_import  from io import BytesIO  import gzip  import zlib -from .utils import always_byte_args  ENCODINGS = {"identity", "gzip", "deflate"} diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index 917080f7..c4eb1d58 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -3,12 +3,12 @@ from .request import Request  from .response import Response  from .headers import Headers  from .message import decoded -from . import http1, http2 +from . import http1, http2, status_codes  __all__ = [      "Request",      "Response",      "Headers",      "decoded", -    "http1", "http2", +    "http1", "http2", "status_codes",  ] diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py index 4451f1da..88c76870 100644 --- a/netlib/http/cookies.py +++ b/netlib/http/cookies.py @@ -1,8 +1,8 @@ -from six.moves import http_cookies as Cookie +import collections  import re -import string  from email.utils import parsedate_tz, formatdate, mktime_tz +from netlib.multidict import ImmutableMultiDict  from .. import odict  """ @@ -157,42 +157,76 @@ def _parse_set_cookie_pairs(s):      return pairs +def parse_set_cookie_headers(headers): +    ret = [] +    for header in headers: +        v = parse_set_cookie_header(header) +        if v: +            name, value, attrs = v +            ret.append((name, SetCookie(value, attrs))) +    return ret + + +class CookieAttrs(ImmutableMultiDict): +    @staticmethod +    def _kconv(key): +        return key.lower() + +    @staticmethod +    def _reduce_values(values): +        # See the StickyCookieTest for a weird cookie that only makes sense +        # if we take the last part. +        return values[-1] + + +SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"]) + +  def parse_set_cookie_header(line):      """          Parse a Set-Cookie header value          Returns a (name, value, attrs) tuple, or None, where attrs is an -        ODictCaseless set of attributes. No attempt is made to parse attribute +        CookieAttrs dict of attributes. No attempt is made to parse attribute          values - they are treated purely as strings.      """      pairs = _parse_set_cookie_pairs(line)      if pairs: -        return pairs[0][0], pairs[0][1], odict.ODictCaseless(pairs[1:]) +        return pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:])  def format_set_cookie_header(name, value, attrs):      """          Formats a Set-Cookie header value.      """ -    pairs = [[name, value]] -    pairs.extend(attrs.lst) +    pairs = [(name, value)] +    pairs.extend( +        attrs.fields if hasattr(attrs, "fields") else attrs +    )      return _format_set_cookie_pairs(pairs) +def parse_cookie_headers(cookie_headers): +    cookie_list = [] +    for header in cookie_headers: +        cookie_list.extend(parse_cookie_header(header)) +    return cookie_list + +  def parse_cookie_header(line):      """          Parse a Cookie header value. -        Returns a (possibly empty) ODict object. +        Returns a list of (lhs, rhs) tuples.      """      pairs, off_ = _read_pairs(line) -    return odict.ODict(pairs) +    return pairs -def format_cookie_header(od): +def format_cookie_header(lst):      """          Formats a Cookie header value.      """ -    return _format_pairs(od.lst) +    return _format_pairs(lst)  def refresh_set_cookie_header(c, delta): @@ -209,10 +243,10 @@ def refresh_set_cookie_header(c, delta):          raise ValueError("Invalid Cookie")      if "expires" in attrs: -        e = parsedate_tz(attrs["expires"][-1]) +        e = parsedate_tz(attrs["expires"])          if e:              f = mktime_tz(e) + delta -            attrs["expires"] = [formatdate(f)] +            attrs = attrs.with_set_all("expires", [formatdate(f)])          else:              # This can happen when the expires tag is invalid.              # reddit.com sends a an expires tag like this: "Thu, 31 Dec @@ -220,7 +254,7 @@ def refresh_set_cookie_header(c, delta):              # strictly correct according to the cookie spec. Browsers              # appear to parse this tolerantly - maybe we should too.              # For now, we just ignore this. -            del attrs["expires"] +            attrs = attrs.with_delitem("expires")      ret = format_set_cookie_header(name, value, attrs)      if not ret: diff --git a/netlib/http/headers.py b/netlib/http/headers.py index 72739f90..60d3f429 100644 --- a/netlib/http/headers.py +++ b/netlib/http/headers.py @@ -1,9 +1,3 @@ -""" - -Unicode Handling ----------------- -See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ -"""  from __future__ import absolute_import, print_function, division  import re @@ -13,23 +7,22 @@ try:  except ImportError:  # pragma: no cover      from collections import MutableMapping  # Workaround for Python < 3.3 -  import six +from ..multidict import MultiDict +from ..utils import always_bytes -from netlib.utils import always_byte_args, always_bytes, Serializable +# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/  if six.PY2:  # pragma: no cover      _native = lambda x: x      _always_bytes = lambda x: x -    _always_byte_args = lambda x: x  else:      # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.      _native = lambda x: x.decode("utf-8", "surrogateescape")      _always_bytes = lambda x: always_bytes(x, "utf-8", "surrogateescape") -    _always_byte_args = always_byte_args("utf-8", "surrogateescape") -class Headers(MutableMapping, Serializable): +class Headers(MultiDict):      """      Header class which allows both convenient access to individual headers as well as      direct access to the underlying raw data. Provides a full dictionary interface. @@ -49,11 +42,11 @@ class Headers(MutableMapping, Serializable):          >>> h["host"]          "example.com" -        # Headers can also be creatd from a list of raw (header_name, header_value) byte tuples +        # Headers can also be created from a list of raw (header_name, header_value) byte tuples          >>> h = Headers([ -            [b"Host",b"example.com"], -            [b"Accept",b"text/html"], -            [b"accept",b"application/xml"] +            (b"Host",b"example.com"), +            (b"Accept",b"text/html"), +            (b"accept",b"application/xml")          ])          # Multiple headers are folded into a single header as per RFC7230 @@ -77,7 +70,6 @@ class Headers(MutableMapping, Serializable):          For use with the "Set-Cookie" header, see :py:meth:`get_all`.      """ -    @_always_byte_args      def __init__(self, fields=None, **headers):          """          Args: @@ -89,19 +81,29 @@ class Headers(MutableMapping, Serializable):                  If ``**headers`` contains multiple keys that have equal ``.lower()`` s,                  the behavior is undefined.          """ -        self.fields = fields or [] +        super(Headers, self).__init__(fields) -        for name, value in self.fields: -            if not isinstance(name, bytes) or not isinstance(value, bytes): -                raise ValueError("Headers passed as fields must be bytes.") +        for key, value in self.fields: +            if not isinstance(key, bytes) or not isinstance(value, bytes): +                raise TypeError("Header fields must be bytes.")          # content_type -> content-type          headers = { -            _always_bytes(name).replace(b"_", b"-"): value +            _always_bytes(name).replace(b"_", b"-"): _always_bytes(value)              for name, value in six.iteritems(headers)              }          self.update(headers) +    @staticmethod +    def _reduce_values(values): +        # Headers can be folded +        return ", ".join(values) + +    @staticmethod +    def _kconv(key): +        # Headers are case-insensitive +        return key.lower() +      def __bytes__(self):          if self.fields:              return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" @@ -111,98 +113,40 @@ class Headers(MutableMapping, Serializable):      if six.PY2:  # pragma: no cover          __str__ = __bytes__ -    @_always_byte_args -    def __getitem__(self, name): -        values = self.get_all(name) -        if not values: -            raise KeyError(name) -        return ", ".join(values) - -    @_always_byte_args -    def __setitem__(self, name, value): -        idx = self._index(name) - -        # To please the human eye, we insert at the same position the first existing header occured. -        if idx is not None: -            del self[name] -            self.fields.insert(idx, [name, value]) -        else: -            self.fields.append([name, value]) - -    @_always_byte_args -    def __delitem__(self, name): -        if name not in self: -            raise KeyError(name) -        name = name.lower() -        self.fields = [ -            field for field in self.fields -            if name != field[0].lower() -        ] +    def __delitem__(self, key): +        key = _always_bytes(key) +        super(Headers, self).__delitem__(key)      def __iter__(self): -        seen = set() -        for name, _ in self.fields: -            name_lower = name.lower() -            if name_lower not in seen: -                seen.add(name_lower) -                yield _native(name) - -    def __len__(self): -        return len(set(name.lower() for name, _ in self.fields)) - -    # __hash__ = object.__hash__ - -    def _index(self, name): -        name = name.lower() -        for i, field in enumerate(self.fields): -            if field[0].lower() == name: -                return i -        return None - -    def __eq__(self, other): -        if isinstance(other, Headers): -            return self.fields == other.fields -        return False - -    def __ne__(self, other): -        return not self.__eq__(other) - -    @_always_byte_args +        for x in super(Headers, self).__iter__(): +            yield _native(x) +      def get_all(self, name):          """          Like :py:meth:`get`, but does not fold multiple headers into a single one.          This is useful for Set-Cookie headers, which do not support folding. -          See also: https://tools.ietf.org/html/rfc7230#section-3.2.2          """ -        name_lower = name.lower() -        values = [_native(value) for n, value in self.fields if n.lower() == name_lower] -        return values +        name = _always_bytes(name) +        return [ +            _native(x) for x in +            super(Headers, self).get_all(name) +        ] -    @_always_byte_args      def set_all(self, name, values):          """          Explicitly set multiple headers for the given key.          See: :py:meth:`get_all`          """ -        values = map(_always_bytes, values)  # _always_byte_args does not fix lists -        if name in self: -            del self[name] -        self.fields.extend( -            [name, value] for value in values -        ) - -    def get_state(self): -        return tuple(tuple(field) for field in self.fields) - -    def set_state(self, state): -        self.fields = [list(field) for field in state] +        name = _always_bytes(name) +        values = [_always_bytes(x) for x in values] +        return super(Headers, self).set_all(name, values) -    @classmethod -    def from_state(cls, state): -        return cls([list(field) for field in state]) +    def insert(self, index, key, value): +        key = _always_bytes(key) +        value = _always_bytes(value) +        super(Headers, self).insert(index, key, value) -    @_always_byte_args      def replace(self, pattern, repl, flags=0):          """          Replaces a regular expression pattern with repl in each "name: value" @@ -211,6 +155,8 @@ class Headers(MutableMapping, Serializable):          Returns:              The number of replacements made.          """ +        pattern = _always_bytes(pattern) +        repl = _always_bytes(repl)          pattern = re.compile(pattern, flags)          replacements = 0 diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index 6e3a1b93..d30976bd 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -316,14 +316,14 @@ def _read_headers(rfile):              if not ret:                  raise HttpSyntaxException("Invalid headers")              # continued header -            ret[-1][1] = ret[-1][1] + b'\r\n ' + line.strip() +            ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())          else:              try:                  name, value = line.split(b":", 1)                  value = value.strip()                  if not name:                      raise ValueError() -                ret.append([name, value]) +                ret.append((name, value))              except ValueError:                  raise HttpSyntaxException("Invalid headers")      return Headers(ret) diff --git a/netlib/http/http2/connections.py b/netlib/http/http2/connections.py index f900b67c..6643b6b9 100644 --- a/netlib/http/http2/connections.py +++ b/netlib/http/http2/connections.py @@ -201,13 +201,13 @@ class HTTP2Protocol(object):          headers = request.headers.copy()          if ':authority' not in headers: -            headers.fields.insert(0, (b':authority', authority.encode('ascii'))) +            headers.insert(0, b':authority', authority.encode('ascii'))          if ':scheme' not in headers: -            headers.fields.insert(0, (b':scheme', request.scheme.encode('ascii'))) +            headers.insert(0, b':scheme', request.scheme.encode('ascii'))          if ':path' not in headers: -            headers.fields.insert(0, (b':path', request.path.encode('ascii'))) +            headers.insert(0, b':path', request.path.encode('ascii'))          if ':method' not in headers: -            headers.fields.insert(0, (b':method', request.method.encode('ascii'))) +            headers.insert(0, b':method', request.method.encode('ascii'))          if hasattr(request, 'stream_id'):              stream_id = request.stream_id @@ -224,7 +224,7 @@ class HTTP2Protocol(object):          headers = response.headers.copy()          if ':status' not in headers: -            headers.fields.insert(0, (b':status', str(response.status_code).encode('ascii'))) +            headers.insert(0, b':status', str(response.status_code).encode('ascii'))          if hasattr(response, 'stream_id'):              stream_id = response.stream_id @@ -420,7 +420,7 @@ class HTTP2Protocol(object):                  self._handle_unexpected_frame(frm)          headers = Headers( -            [[k.encode('ascii'), v.encode('ascii')] for k, v in self.decoder.decode(header_blocks)] +            (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)          )          return stream_id, headers, body diff --git a/netlib/http/message.py b/netlib/http/message.py index da9681a0..028f43a1 100644 --- a/netlib/http/message.py +++ b/netlib/http/message.py @@ -4,6 +4,7 @@ import warnings  import six +from ..multidict import MultiDict  from .headers import Headers  from .. import encoding, utils @@ -25,6 +26,9 @@ class MessageData(utils.Serializable):      def __ne__(self, other):          return not self.__eq__(other) +    def __hash__(self): +        return hash(frozenset(self.__dict__.items())) +      def set_state(self, state):          for k, v in state.items():              if k == "headers": @@ -51,6 +55,9 @@ class Message(utils.Serializable):      def __ne__(self, other):          return not self.__eq__(other) +    def __hash__(self): +        return hash(self.data) ^ 1 +      def get_state(self):          return self.data.get_state() diff --git a/netlib/http/request.py b/netlib/http/request.py index a42150ff..056a2d93 100644 --- a/netlib/http/request.py +++ b/netlib/http/request.py @@ -10,6 +10,7 @@ from netlib import utils  from netlib.http import cookies  from netlib.odict import ODict  from .. import encoding +from ..multidict import MultiDictView  from .headers import Headers  from .message import Message, _native, _always_bytes, MessageData @@ -224,45 +225,64 @@ class Request(Message):      @property      def query(self): +        # type: () -> MultiDictView          """ -        The request query string as an :py:class:`ODict` object. -        None, if there is no query. +        The request query string as an :py:class:`MultiDictView` object.          """ +        return MultiDictView( +            self._get_query, +            self._set_query +        ) + +    def _get_query(self):          _, _, _, _, query, _ = urllib.parse.urlparse(self.url) -        if query: -            return ODict(utils.urldecode(query)) -        return None +        return tuple(utils.urldecode(query)) -    @query.setter -    def query(self, odict): -        query = utils.urlencode(odict.lst) +    def _set_query(self, value): +        query = utils.urlencode(value)          scheme, netloc, path, params, _, fragment = urllib.parse.urlparse(self.url)          _, _, _, self.path = utils.parse_url(                  urllib.parse.urlunparse([scheme, netloc, path, params, query, fragment])) +    @query.setter +    def query(self, value): +        self._set_query(value) +      @property      def cookies(self): +        # type: () -> MultiDictView          """          The request cookies. -        An empty :py:class:`ODict` object if the cookie monster ate them all. + +        An empty :py:class:`MultiDictView` object if the cookie monster ate them all.          """ -        ret = ODict() -        for i in self.headers.get_all("Cookie"): -            ret.extend(cookies.parse_cookie_header(i)) -        return ret +        return MultiDictView( +            self._get_cookies, +            self._set_cookies +        ) + +    def _get_cookies(self): +        h = self.headers.get_all("Cookie") +        return tuple(cookies.parse_cookie_headers(h)) + +    def _set_cookies(self, value): +        self.headers["cookie"] = cookies.format_cookie_header(value)      @cookies.setter -    def cookies(self, odict): -        self.headers["cookie"] = cookies.format_cookie_header(odict) +    def cookies(self, value): +        self._set_cookies(value)      @property      def path_components(self):          """ -        The URL's path components as a list of strings. +        The URL's path components as a tuple of strings.          Components are unquoted.          """          _, _, path, _, _, _ = urllib.parse.urlparse(self.url) -        return [urllib.parse.unquote(i) for i in path.split("/") if i] +        # This needs to be a tuple so that it's immutable. +        # Otherwise, this would fail silently: +        #   request.path_components.append("foo") +        return tuple(urllib.parse.unquote(i) for i in path.split("/") if i)      @path_components.setter      def path_components(self, components): @@ -309,64 +329,53 @@ class Request(Message):      @property      def urlencoded_form(self):          """ -        The URL-encoded form data as an :py:class:`ODict` object. -        None if there is no data or the content-type indicates non-form data. +        The URL-encoded form data as an :py:class:`MultiDictView` object. +        An empty MultiDictView if the content-type indicates non-form data +        or the content could not be parsed.          """ +        return MultiDictView( +            self._get_urlencoded_form, +            self._set_urlencoded_form +        ) + +    def _get_urlencoded_form(self):          is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower() -        if self.content and is_valid_content_type: -            return ODict(utils.urldecode(self.content)) -        return None +        if is_valid_content_type: +            return tuple(utils.urldecode(self.content)) +        return () -    @urlencoded_form.setter -    def urlencoded_form(self, odict): +    def _set_urlencoded_form(self, value):          """          Sets the body to the URL-encoded form data, and adds the appropriate content-type header.          This will overwrite the existing content if there is one.          """          self.headers["content-type"] = "application/x-www-form-urlencoded" -        self.content = utils.urlencode(odict.lst) +        self.content = utils.urlencode(value) + +    @urlencoded_form.setter +    def urlencoded_form(self, value): +        self._set_urlencoded_form(value)      @property      def multipart_form(self):          """ -        The multipart form data as an :py:class:`ODict` object. -        None if there is no data or the content-type indicates non-form data. +        The multipart form data as an :py:class:`MultipartFormDict` object. +        None if the content-type indicates non-form data.          """ +        return MultiDictView( +            self._get_multipart_form, +            self._set_multipart_form +        ) + +    def _get_multipart_form(self):          is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower() -        if self.content and is_valid_content_type: -            return ODict(utils.multipartdecode(self.headers,self.content)) -        return None +        if is_valid_content_type: +            return utils.multipartdecode(self.headers, self.content) +        return () -    @multipart_form.setter -    def multipart_form(self, value): +    def _set_multipart_form(self, value):          raise NotImplementedError() -    # Legacy - -    def get_query(self):  # pragma: no cover -        warnings.warn(".get_query is deprecated, use .query instead.", DeprecationWarning) -        return self.query or ODict([]) - -    def set_query(self, odict):  # pragma: no cover -        warnings.warn(".set_query is deprecated, use .query instead.", DeprecationWarning) -        self.query = odict - -    def get_path_components(self):  # pragma: no cover -        warnings.warn(".get_path_components is deprecated, use .path_components instead.", DeprecationWarning) -        return self.path_components - -    def set_path_components(self, lst):  # pragma: no cover -        warnings.warn(".set_path_components is deprecated, use .path_components instead.", DeprecationWarning) -        self.path_components = lst - -    def get_form_urlencoded(self):  # pragma: no cover -        warnings.warn(".get_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) -        return self.urlencoded_form or ODict([]) - -    def set_form_urlencoded(self, odict):  # pragma: no cover -        warnings.warn(".set_form_urlencoded is deprecated, use .urlencoded_form instead.", DeprecationWarning) -        self.urlencoded_form = odict - -    def get_form_multipart(self):  # pragma: no cover -        warnings.warn(".get_form_multipart is deprecated, use .multipart_form instead.", DeprecationWarning) -        return self.multipart_form or ODict([]) +    @multipart_form.setter +    def multipart_form(self, value): +        self._set_multipart_form(value) diff --git a/netlib/http/response.py b/netlib/http/response.py index 2f06149e..7d272e10 100644 --- a/netlib/http/response.py +++ b/netlib/http/response.py @@ -1,14 +1,13 @@  from __future__ import absolute_import, print_function, division -import warnings  from email.utils import parsedate_tz, formatdate, mktime_tz  import time  from . import cookies  from .headers import Headers  from .message import Message, _native, _always_bytes, MessageData +from ..multidict import MultiDictView  from .. import utils -from ..odict import ODict  class ResponseData(MessageData): @@ -72,29 +71,35 @@ class Response(Message):      @property      def cookies(self): +        # type: () -> MultiDictView          """ -        Get the contents of all Set-Cookie headers. +        The response cookies. A possibly empty :py:class:`MultiDictView`, where the keys are +        cookie name strings, and values are (value, attr) tuples. Value is a string, and attr is +        an ODictCaseless containing cookie attributes. Within attrs, unary attributes (e.g. HTTPOnly) +        are indicated by a Null value. -        A possibly empty :py:class:`ODict`, where keys are cookie name strings, -        and values are [value, attr] lists. Value is a string, and attr is -        an ODictCaseless containing cookie attributes. Within attrs, unary -        attributes (e.g. HTTPOnly) are indicated by a Null value. +        Caveats: +            Updating the attr          """ -        ret = [] -        for header in self.headers.get_all("set-cookie"): -            v = cookies.parse_set_cookie_header(header) -            if v: -                name, value, attrs = v -                ret.append([name, [value, attrs]]) -        return ODict(ret) +        return MultiDictView( +            self._get_cookies, +            self._set_cookies +        ) + +    def _get_cookies(self): +        h = self.headers.get_all("set-cookie") +        return tuple(cookies.parse_set_cookie_headers(h)) + +    def _set_cookies(self, value): +        cookie_headers = [] +        for k, v in value: +            header = cookies.format_set_cookie_header(k, v[0], v[1]) +            cookie_headers.append(header) +        self.headers.set_all("set-cookie", cookie_headers)      @cookies.setter -    def cookies(self, odict): -        values = [] -        for i in odict.lst: -            header = cookies.format_set_cookie_header(i[0], i[1][0], i[1][1]) -            values.append(header) -        self.headers.set_all("set-cookie", values) +    def cookies(self, value): +        self._set_cookies(value)      def refresh(self, now=None):          """ diff --git a/netlib/multidict.py b/netlib/multidict.py new file mode 100644 index 00000000..248acdec --- /dev/null +++ b/netlib/multidict.py @@ -0,0 +1,282 @@ +from __future__ import absolute_import, print_function, division + +from abc import ABCMeta, abstractmethod + +from typing import Tuple, TypeVar + +try: +    from collections.abc import MutableMapping +except ImportError:  # pragma: no cover +    from collections import MutableMapping  # Workaround for Python < 3.3 + +import six + +from .utils import Serializable + + +@six.add_metaclass(ABCMeta) +class _MultiDict(MutableMapping, Serializable): +    def __repr__(self): +        fields = tuple( +            repr(field) +            for field in self.fields +        ) +        return "{cls}[{fields}]".format( +            cls=type(self).__name__, +            fields=", ".join(fields) +        ) + +    @staticmethod +    @abstractmethod +    def _reduce_values(values): +        """ +        If a user accesses multidict["foo"], this method +        reduces all values for "foo" to a single value that is returned. +        For example, HTTP headers are folded, whereas we will just take +        the first cookie we found with that name. +        """ + +    @staticmethod +    @abstractmethod +    def _kconv(key): +        """ +        This method converts a key to its canonical representation. +        For example, HTTP headers are case-insensitive, so this method returns key.lower(). +        """ + +    def __getitem__(self, key): +        values = self.get_all(key) +        if not values: +            raise KeyError(key) +        return self._reduce_values(values) + +    def __setitem__(self, key, value): +        self.set_all(key, [value]) + +    def __delitem__(self, key): +        if key not in self: +            raise KeyError(key) +        key = self._kconv(key) +        self.fields = tuple( +            field for field in self.fields +            if key != self._kconv(field[0]) +        ) + +    def __iter__(self): +        seen = set() +        for key, _ in self.fields: +            key_kconv = self._kconv(key) +            if key_kconv not in seen: +                seen.add(key_kconv) +                yield key + +    def __len__(self): +        return len(set(self._kconv(key) for key, _ in self.fields)) + +    def __eq__(self, other): +        if isinstance(other, MultiDict): +            return self.fields == other.fields +        return False + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def __hash__(self): +        return hash(self.fields) + +    def get_all(self, key): +        """ +        Return the list of all values for a given key. +        If that key is not in the MultiDict, the return value will be an empty list. +        """ +        key = self._kconv(key) +        return [ +            value +            for k, value in self.fields +            if self._kconv(k) == key +        ] + +    def set_all(self, key, values): +        """ +        Remove the old values for a key and add new ones. +        """ +        key_kconv = self._kconv(key) + +        new_fields = [] +        for field in self.fields: +            if self._kconv(field[0]) == key_kconv: +                if values: +                    new_fields.append( +                        (key, values.pop(0)) +                    ) +            else: +                new_fields.append(field) +        while values: +            new_fields.append( +                (key, values.pop(0)) +            ) +        self.fields = tuple(new_fields) + +    def add(self, key, value): +        """ +        Add an additional value for the given key at the bottom. +        """ +        self.insert(len(self.fields), key, value) + +    def insert(self, index, key, value): +        """ +        Insert an additional value for the given key at the specified position. +        """ +        item = (key, value) +        self.fields = self.fields[:index] + (item,) + self.fields[index:] + +    def keys(self, multi=False): +        """ +        Get all keys. + +        Args: +            multi(bool): +                If True, one key per value will be returned. +                If False, duplicate keys will only be returned once. +        """ +        return ( +            k +            for k, _ in self.items(multi) +        ) + +    def values(self, multi=False): +        """ +        Get all values. + +        Args: +            multi(bool): +                If True, all values will be returned. +                If False, only the first value per key will be returned. +        """ +        return ( +            v +            for _, v in self.items(multi) +        ) + +    def items(self, multi=False): +        """ +        Get all (key, value) tuples. + +        Args: +            multi(bool): +                If True, all (key, value) pairs will be returned +                If False, only the first (key, value) pair per unique key will be returned. +        """ +        if multi: +            return self.fields +        else: +            return super(_MultiDict, self).items() + +    def to_dict(self): +        """ +        Get the MultiDict as a plain Python dict. +        Keys with multiple values are returned as lists. + +        Example: + +        .. code-block:: python + +            # Simple dict with duplicate values. +            >>> d +            MultiDictView[("name", "value"), ("a", "false"), ("a", "42")] +            >>> d.to_dict() +            { +                "name": "value", +                "a": ["false", "42"] +            } +        """ +        d = {} +        for key in self: +            values = self.get_all(key) +            if len(values) == 1: +                d[key] = values[0] +            else: +                d[key] = values +        return d + +    def get_state(self): +        return self.fields + +    def set_state(self, state): +        self.fields = tuple(tuple(x) for x in state) + +    @classmethod +    def from_state(cls, state): +        return cls(tuple(x) for x in state) + + +class MultiDict(_MultiDict): +    def __init__(self, fields=None): +        super(MultiDict, self).__init__() +        self.fields = tuple(fields) if fields else tuple()  # type: Tuple[Tuple[bytes, bytes], ...] + + +@six.add_metaclass(ABCMeta) +class ImmutableMultiDict(MultiDict): +    def _immutable(self, *_): +        raise TypeError('{} objects are immutable'.format(self.__class__.__name__)) + +    __delitem__ = set_all = insert = _immutable + +    def with_delitem(self, key): +        """ +        Returns: +            An updated ImmutableMultiDict. The original object will not be modified. +        """ +        ret = self.copy() +        super(ImmutableMultiDict, ret).__delitem__(key) +        return ret + +    def with_set_all(self, key, values): +        """ +        Returns: +            An updated ImmutableMultiDict. The original object will not be modified. +        """ +        ret = self.copy() +        super(ImmutableMultiDict, ret).set_all(key, values) +        return ret + +    def with_insert(self, index, key, value): +        """ +        Returns: +            An updated ImmutableMultiDict. The original object will not be modified. +        """ +        ret = self.copy() +        super(ImmutableMultiDict, ret).insert(index, key, value) +        return ret + + +class MultiDictView(_MultiDict): +    """ +    The MultiDictView provides the MultiDict interface over calculated data. +    The view itself contains no state - data is retrieved from the parent on +    request, and stored back to the parent on change. +    """ +    def __init__(self, getter, setter): +        self._getter = getter +        self._setter = setter +        super(MultiDictView, self).__init__() + +    @staticmethod +    def _kconv(key): +        # All request-attributes are case-sensitive. +        return key + +    @staticmethod +    def _reduce_values(values): +        # We just return the first element if +        # multiple elements exist with the same key. +        return values[0] + +    @property +    def fields(self): +        return self._getter() + +    @fields.setter +    def fields(self, value): +        return self._setter(value) diff --git a/netlib/utils.py b/netlib/utils.py index be2701a0..7499f71f 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -51,17 +51,6 @@ def always_bytes(unicode_or_bytes, *encode_args):      return unicode_or_bytes -def always_byte_args(*encode_args): -    """Decorator that transparently encodes all arguments passed as unicode""" -    def decorator(fun): -        def _fun(*args, **kwargs): -            args = [always_bytes(arg, *encode_args) for arg in args] -            kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)} -            return fun(*args, **kwargs) -        return _fun -    return decorator - -  def native(s, *encoding_opts):      """      Convert :py:class:`bytes` or :py:class:`unicode` to the native diff --git a/pathod/utils.py b/pathod/utils.py index 1e5bd9a4..d1e2dd00 100644 --- a/pathod/utils.py +++ b/pathod/utils.py @@ -1,5 +1,6 @@  import os  import sys +import netlib.utils  SIZE_UNITS = dict( @@ -75,27 +76,7 @@ def escape_unprintables(s):      return s -class Data(object): - -    def __init__(self, name): -        m = __import__(name) -        dirname, _ = os.path.split(m.__file__) -        self.dirname = os.path.abspath(dirname) - -    def path(self, path): -        """ -            Returns a path to the package data housed at 'path' under this -            module.Path can be a path to a file, or to a directory. - -            This function will raise ValueError if the path does not exist. -        """ -        fullpath = os.path.join(self.dirname, path) -        if not os.path.exists(fullpath): -            raise ValueError("dataPath: %s does not exist." % fullpath) -        return fullpath - - -data = Data(__name__) +data = netlib.utils.Data(__name__)  def daemonize(stdin='/dev/null', stdout='/dev/null', stderr='/dev/null'):  # pragma: no cover diff --git a/test/mitmproxy/test_examples.py b/test/mitmproxy/test_examples.py index c401a6b9..c4b06f4b 100644 --- a/test/mitmproxy/test_examples.py +++ b/test/mitmproxy/test_examples.py @@ -5,11 +5,12 @@ from contextlib import contextmanager  from mitmproxy import utils, script  from mitmproxy.proxy import config +import netlib.utils  from netlib import tutils as netutils  from netlib.http import Headers  from . import tservers, tutils -example_dir = utils.Data(__name__).path("../../examples") +example_dir = netlib.utils.Data(__name__).path("../../examples")  class DummyContext(object): @@ -94,14 +95,22 @@ def test_modify_form():      flow = tutils.tflow(req=netutils.treq(headers=form_header))      with example("modify_form.py") as ex:          ex.run("request", flow) -        assert flow.request.urlencoded_form["mitmproxy"] == ["rocks"] +        assert flow.request.urlencoded_form["mitmproxy"] == "rocks" + +        flow.request.headers["content-type"] = "" +        ex.run("request", flow) +        assert list(flow.request.urlencoded_form.items()) == [("foo", "bar")]  def test_modify_querystring():      flow = tutils.tflow(req=netutils.treq(path="/search?q=term"))      with example("modify_querystring.py") as ex:          ex.run("request", flow) -        assert flow.request.query["mitmproxy"] == ["rocks"] +        assert flow.request.query["mitmproxy"] == "rocks" + +        flow.request.path = "/" +        ex.run("request", flow) +        assert flow.request.query["mitmproxy"] == "rocks"  def test_modify_response_body(): diff --git a/test/mitmproxy/test_flow.py b/test/mitmproxy/test_flow.py index b9c6a2f6..bf417423 100644 --- a/test/mitmproxy/test_flow.py +++ b/test/mitmproxy/test_flow.py @@ -1067,60 +1067,6 @@ class TestRequest:          assert r.url == "https://address:22/path"          assert r.pretty_url == "https://foo.com:22/path" -    def test_path_components(self): -        r = HTTPRequest.wrap(netlib.tutils.treq()) -        r.path = "/" -        assert r.get_path_components() == [] -        r.path = "/foo/bar" -        assert r.get_path_components() == ["foo", "bar"] -        q = odict.ODict() -        q["test"] = ["123"] -        r.set_query(q) -        assert r.get_path_components() == ["foo", "bar"] - -        r.set_path_components([]) -        assert r.get_path_components() == [] -        r.set_path_components(["foo"]) -        assert r.get_path_components() == ["foo"] -        r.set_path_components(["/oo"]) -        assert r.get_path_components() == ["/oo"] -        assert "%2F" in r.path - -    def test_getset_form_urlencoded(self): -        d = odict.ODict([("one", "two"), ("three", "four")]) -        r = HTTPRequest.wrap(netlib.tutils.treq(content=netlib.utils.urlencode(d.lst))) -        r.headers["content-type"] = "application/x-www-form-urlencoded" -        assert r.get_form_urlencoded() == d - -        d = odict.ODict([("x", "y")]) -        r.set_form_urlencoded(d) -        assert r.get_form_urlencoded() == d - -        r.headers["content-type"] = "foo" -        assert not r.get_form_urlencoded() - -    def test_getset_query(self): -        r = HTTPRequest.wrap(netlib.tutils.treq()) -        r.path = "/foo?x=y&a=b" -        q = r.get_query() -        assert q.lst == [("x", "y"), ("a", "b")] - -        r.path = "/" -        q = r.get_query() -        assert not q - -        r.path = "/?adsfa" -        q = r.get_query() -        assert q.lst == [("adsfa", "")] - -        r.path = "/foo?x=y&a=b" -        assert r.get_query() -        r.set_query(odict.ODict([])) -        assert not r.get_query() -        qv = odict.ODict([("a", "b"), ("c", "d")]) -        r.set_query(qv) -        assert r.get_query() == qv -      def test_anticache(self):          r = HTTPRequest.wrap(netlib.tutils.treq())          r.headers = Headers() diff --git a/test/mitmproxy/test_flow_export.py b/test/mitmproxy/test_flow_export.py index 035f07b7..c252c5bd 100644 --- a/test/mitmproxy/test_flow_export.py +++ b/test/mitmproxy/test_flow_export.py @@ -21,7 +21,7 @@ def python_equals(testdata, text):      assert clean_blanks(text).rstrip() == clean_blanks(d).rstrip() -req_get = lambda: netlib.tutils.treq(method='GET', content='') +req_get = lambda: netlib.tutils.treq(method='GET', content='', path=b"/path?a=foo&a=bar&b=baz")  req_post = lambda: netlib.tutils.treq(method='POST', headers=None) @@ -31,7 +31,7 @@ req_patch = lambda: netlib.tutils.treq(method='PATCH', path=b"/path?query=param"  class TestExportCurlCommand():      def test_get(self):          flow = tutils.tflow(req=req_get()) -        result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path'""" +        result = """curl -H 'header:qvalue' -H 'content-length:7' 'http://address/path?a=foo&a=bar&b=baz'"""          assert flow_export.curl_command(flow) == result      def test_post(self): @@ -70,7 +70,7 @@ class TestRawRequest():      def test_get(self):          flow = tutils.tflow(req=req_get())          result = dedent(""" -            GET /path HTTP/1.1\r +            GET /path?a=foo&a=bar&b=baz HTTP/1.1\r              header: qvalue\r              content-length: 7\r              host: address:22\r diff --git a/test/mitmproxy/test_flow_export/locust_get.py b/test/mitmproxy/test_flow_export/locust_get.py index 72d5932a..632d5d53 100644 --- a/test/mitmproxy/test_flow_export/locust_get.py +++ b/test/mitmproxy/test_flow_export/locust_get.py @@ -14,10 +14,16 @@ class UserBehavior(TaskSet):              'content-length': '7',          } +        params = { +            'a': ['foo', 'bar'], +            'b': 'baz', +        } +          self.response = self.client.request(              method='GET',              url=url,              headers=headers, +            params=params,          )      ### Additional tasks can go here ### diff --git a/test/mitmproxy/test_flow_export/locust_task_get.py b/test/mitmproxy/test_flow_export/locust_task_get.py index 76f144fa..03821cd8 100644 --- a/test/mitmproxy/test_flow_export/locust_task_get.py +++ b/test/mitmproxy/test_flow_export/locust_task_get.py @@ -7,8 +7,14 @@              'content-length': '7',          } +        params = { +            'a': ['foo', 'bar'], +            'b': 'baz', +        } +          self.response = self.client.request(              method='GET',              url=url,              headers=headers, +            params=params,          ) diff --git a/test/mitmproxy/test_flow_export/python_get.py b/test/mitmproxy/test_flow_export/python_get.py index ee3f48eb..af8f7c81 100644 --- a/test/mitmproxy/test_flow_export/python_get.py +++ b/test/mitmproxy/test_flow_export/python_get.py @@ -7,10 +7,16 @@ headers = {      'content-length': '7',  } +params = { +    'a': ['foo', 'bar'], +    'b': 'baz', +} +  response = requests.request(      method='GET',      url=url,      headers=headers, +    params=params,  )  print(response.text) diff --git a/test/mitmproxy/tutils.py b/test/mitmproxy/tutils.py index d51ac185..2dfd710e 100644 --- a/test/mitmproxy/tutils.py +++ b/test/mitmproxy/tutils.py @@ -8,6 +8,7 @@ from contextlib import contextmanager  from unittest.case import SkipTest +import netlib.utils  import netlib.tutils  from mitmproxy import utils, controller  from mitmproxy.models import ( @@ -163,4 +164,4 @@ def capture_stderr(command, *args, **kwargs):      sys.stderr = out -test_data = utils.Data(__name__) +test_data = netlib.utils.Data(__name__) diff --git a/test/netlib/http/http1/test_read.py b/test/netlib/http/http1/test_read.py index 90234070..d8106904 100644 --- a/test/netlib/http/http1/test_read.py +++ b/test/netlib/http/http1/test_read.py @@ -261,7 +261,7 @@ class TestReadHeaders(object):              b"\r\n"          )          headers = self._read(data) -        assert headers.fields == [[b"Header", b"one"], [b"Header2", b"two"]] +        assert headers.fields == ((b"Header", b"one"), (b"Header2", b"two"))      def test_read_multi(self):          data = ( @@ -270,7 +270,7 @@ class TestReadHeaders(object):              b"\r\n"          )          headers = self._read(data) -        assert headers.fields == [[b"Header", b"one"], [b"Header", b"two"]] +        assert headers.fields == ((b"Header", b"one"), (b"Header", b"two"))      def test_read_continued(self):          data = ( @@ -280,7 +280,7 @@ class TestReadHeaders(object):              b"\r\n"          )          headers = self._read(data) -        assert headers.fields == [[b"Header", b"one\r\n two"], [b"Header2", b"three"]] +        assert headers.fields == ((b"Header", b"one\r\n two"), (b"Header2", b"three"))      def test_read_continued_err(self):          data = b"\tfoo: bar\r\n" @@ -300,7 +300,7 @@ class TestReadHeaders(object):      def test_read_empty_value(self):          data = b"bar:"          headers = self._read(data) -        assert headers.fields == [[b"bar", b""]] +        assert headers.fields == ((b"bar", b""),)  def test_read_chunked():      req = treq(content=None) diff --git a/test/netlib/http/http2/test_connections.py b/test/netlib/http/http2/test_connections.py index 7b003067..7d240c0e 100644 --- a/test/netlib/http/http2/test_connections.py +++ b/test/netlib/http/http2/test_connections.py @@ -312,7 +312,7 @@ class TestReadRequest(tservers.ServerTestBase):          req = protocol.read_request(NotImplemented)          assert req.stream_id -        assert req.headers.fields == [[b':method', b'GET'], [b':path', b'/'], [b':scheme', b'https']] +        assert req.headers.fields == ((b':method', b'GET'), (b':path', b'/'), (b':scheme', b'https'))          assert req.content == b'foobar' @@ -418,7 +418,7 @@ class TestReadResponse(tservers.ServerTestBase):          assert resp.http_version == "HTTP/2.0"          assert resp.status_code == 200          assert resp.reason == '' -        assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] +        assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))          assert resp.content == b'foobar'          assert resp.timestamp_end @@ -445,7 +445,7 @@ class TestReadEmptyResponse(tservers.ServerTestBase):          assert resp.http_version == "HTTP/2.0"          assert resp.status_code == 200          assert resp.reason == '' -        assert resp.headers.fields == [[b':status', b'200'], [b'etag', b'foobar']] +        assert resp.headers.fields == ((b':status', b'200'), (b'etag', b'foobar'))          assert resp.content == b'' diff --git a/test/netlib/http/test_cookies.py b/test/netlib/http/test_cookies.py index da28850f..6f84c4ce 100644 --- a/test/netlib/http/test_cookies.py +++ b/test/netlib/http/test_cookies.py @@ -128,10 +128,10 @@ def test_cookie_roundtrips():      ]      for s, lst in pairs:          ret = cookies.parse_cookie_header(s) -        assert ret.lst == lst +        assert ret == lst          s2 = cookies.format_cookie_header(ret)          ret = cookies.parse_cookie_header(s2) -        assert ret.lst == lst +        assert ret == lst  def test_parse_set_cookie_pairs(): @@ -197,24 +197,28 @@ def test_parse_set_cookie_header():          ],          [              "one=uno", -            ("one", "uno", []) +            ("one", "uno", ())          ],          [              "one=uno; foo=bar", -            ("one", "uno", [["foo", "bar"]]) -        ] +            ("one", "uno", (("foo", "bar"),)) +        ], +        [ +            "one=uno; foo=bar; foo=baz", +            ("one", "uno", (("foo", "bar"), ("foo", "baz"))) +        ],      ]      for s, expected in vals:          ret = cookies.parse_set_cookie_header(s)          if expected:              assert ret[0] == expected[0]              assert ret[1] == expected[1] -            assert ret[2].lst == expected[2] +            assert ret[2].items(multi=True) == expected[2]              s2 = cookies.format_set_cookie_header(*ret)              ret2 = cookies.parse_set_cookie_header(s2)              assert ret2[0] == expected[0]              assert ret2[1] == expected[1] -            assert ret2[2].lst == expected[2] +            assert ret2[2].items(multi=True) == expected[2]          else:              assert ret is None diff --git a/test/netlib/http/test_headers.py b/test/netlib/http/test_headers.py index 8c1db9dc..cd2ca9d1 100644 --- a/test/netlib/http/test_headers.py +++ b/test/netlib/http/test_headers.py @@ -5,10 +5,10 @@ from netlib.tutils import raises  class TestHeaders(object):      def _2host(self):          return Headers( -            [ -                [b"Host", b"example.com"], -                [b"host", b"example.org"] -            ] +            ( +                (b"Host", b"example.com"), +                (b"host", b"example.org") +            )          )      def test_init(self): @@ -38,20 +38,10 @@ class TestHeaders(object):          assert headers["Host"] == "example.com"          assert headers["Accept"] == "text/plain" -        with raises(ValueError): +        with raises(TypeError):              Headers([[b"Host", u"not-bytes"]]) -    def test_getitem(self): -        headers = Headers(Host="example.com") -        assert headers["Host"] == "example.com" -        assert headers["host"] == "example.com" -        with raises(KeyError): -            _ = headers["Accept"] - -        headers = self._2host() -        assert headers["Host"] == "example.com, example.org" - -    def test_str(self): +    def test_bytes(self):          headers = Headers(Host="example.com")          assert bytes(headers) == b"Host: example.com\r\n" @@ -64,93 +54,6 @@ class TestHeaders(object):          headers = Headers()          assert bytes(headers) == b"" -    def test_setitem(self): -        headers = Headers() -        headers["Host"] = "example.com" -        assert "Host" in headers -        assert "host" in headers -        assert headers["Host"] == "example.com" - -        headers["host"] = "example.org" -        assert "Host" in headers -        assert "host" in headers -        assert headers["Host"] == "example.org" - -        headers["accept"] = "text/plain" -        assert len(headers) == 2 -        assert "Accept" in headers -        assert "Host" in headers - -        headers = self._2host() -        assert len(headers.fields) == 2 -        headers["Host"] = "example.com" -        assert len(headers.fields) == 1 -        assert "Host" in headers - -    def test_delitem(self): -        headers = Headers(Host="example.com") -        assert len(headers) == 1 -        del headers["host"] -        assert len(headers) == 0 -        try: -            del headers["host"] -        except KeyError: -            assert True -        else: -            assert False - -        headers = self._2host() -        del headers["Host"] -        assert len(headers) == 0 - -    def test_keys(self): -        headers = Headers(Host="example.com") -        assert list(headers.keys()) == ["Host"] - -        headers = self._2host() -        assert list(headers.keys()) == ["Host"] - -    def test_eq_ne(self): -        headers1 = Headers(Host="example.com") -        headers2 = Headers(host="example.com") -        assert not (headers1 == headers2) -        assert headers1 != headers2 - -        headers1 = Headers(Host="example.com") -        headers2 = Headers(Host="example.com") -        assert headers1 == headers2 -        assert not (headers1 != headers2) - -        assert headers1 != 42 - -    def test_get_all(self): -        headers = self._2host() -        assert headers.get_all("host") == ["example.com", "example.org"] -        assert headers.get_all("accept") == [] - -    def test_set_all(self): -        headers = Headers(Host="example.com") -        headers.set_all("Accept", ["text/plain"]) -        assert len(headers) == 2 -        assert "accept" in headers - -        headers = self._2host() -        headers.set_all("Host", ["example.org"]) -        assert headers["host"] == "example.org" - -        headers.set_all("Host", ["example.org", "example.net"]) -        assert headers["host"] == "example.org, example.net" - -    def test_state(self): -        headers = self._2host() -        assert len(headers.get_state()) == 2 -        assert headers == Headers.from_state(headers.get_state()) - -        headers2 = Headers() -        assert headers != headers2 -        headers2.set_state(headers.get_state()) -        assert headers == headers2 -      def test_replace_simple(self):          headers = Headers(Host="example.com", Accept="text/plain")          replacements = headers.replace("Host: ", "X-Host: ") diff --git a/test/netlib/http/test_request.py b/test/netlib/http/test_request.py index 7ed6bd0f..fae7aefe 100644 --- a/test/netlib/http/test_request.py +++ b/test/netlib/http/test_request.py @@ -3,16 +3,14 @@ from __future__ import absolute_import, print_function, division  import six -from netlib import utils  from netlib.http import Headers -from netlib.odict import ODict  from netlib.tutils import treq, raises  from .test_message import _test_decoded_attr, _test_passthrough_attr  class TestRequestData(object):      def test_init(self): -        with raises(ValueError if six.PY2 else TypeError): +        with raises(ValueError):              treq(headers="foobar")          assert isinstance(treq(headers=None).headers, Headers) @@ -158,16 +156,17 @@ class TestRequestUtils(object):      def test_get_query(self):          request = treq() -        assert request.query is None +        assert not request.query          request.url = "http://localhost:80/foo?bar=42" -        assert request.query.lst == [("bar", "42")] +        assert dict(request.query) == {"bar": "42"}      def test_set_query(self): -        request = treq(host=b"foo", headers = Headers(host=b"bar")) -        request.query = ODict([]) -        assert request.host == "foo" -        assert request.headers["host"] == "bar" +        request = treq() +        assert not request.query +        request.query["foo"] = "bar" +        assert request.query["foo"] == "bar" +        assert request.path == "/path?foo=bar"      def test_get_cookies_none(self):          request = treq() @@ -177,47 +176,50 @@ class TestRequestUtils(object):      def test_get_cookies_single(self):          request = treq()          request.headers = Headers(cookie="cookiename=cookievalue") -        result = request.cookies -        assert len(result) == 1 -        assert result['cookiename'] == ['cookievalue'] +        assert len(request.cookies) == 1 +        assert request.cookies['cookiename'] == 'cookievalue'      def test_get_cookies_double(self):          request = treq()          request.headers = Headers(cookie="cookiename=cookievalue;othercookiename=othercookievalue")          result = request.cookies          assert len(result) == 2 -        assert result['cookiename'] == ['cookievalue'] -        assert result['othercookiename'] == ['othercookievalue'] +        assert result['cookiename'] == 'cookievalue' +        assert result['othercookiename'] == 'othercookievalue'      def test_get_cookies_withequalsign(self):          request = treq()          request.headers = Headers(cookie="cookiename=coo=kievalue;othercookiename=othercookievalue")          result = request.cookies          assert len(result) == 2 -        assert result['cookiename'] == ['coo=kievalue'] -        assert result['othercookiename'] == ['othercookievalue'] +        assert result['cookiename'] == 'coo=kievalue' +        assert result['othercookiename'] == 'othercookievalue'      def test_set_cookies(self):          request = treq()          request.headers = Headers(cookie="cookiename=cookievalue")          result = request.cookies -        result["cookiename"] = ["foo"] -        request.cookies = result -        assert request.cookies["cookiename"] == ["foo"] +        result["cookiename"] = "foo" +        assert request.cookies["cookiename"] == "foo"      def test_get_path_components(self):          request = treq(path=b"/foo/bar") -        assert request.path_components == ["foo", "bar"] +        assert request.path_components == ("foo", "bar")      def test_set_path_components(self): -        request = treq(host=b"foo", headers = Headers(host=b"bar")) +        request = treq()          request.path_components = ["foo", "baz"]          assert request.path == "/foo/baz" +          request.path_components = []          assert request.path == "/" -        request.query = ODict([]) -        assert request.host == "foo" -        assert request.headers["host"] == "bar" + +        request.path_components = ["foo", "baz"] +        request.query["hello"] = "hello" +        assert request.path_components == ("foo", "baz") + +        request.path_components = ["abc"] +        assert request.path == "/abc?hello=hello"      def test_anticache(self):          request = treq() @@ -246,26 +248,21 @@ class TestRequestUtils(object):          assert "gzip" in request.headers["Accept-Encoding"]      def test_get_urlencoded_form(self): -        request = treq(content="foobar") -        assert request.urlencoded_form is None +        request = treq(content="foobar=baz") +        assert not request.urlencoded_form          request.headers["Content-Type"] = "application/x-www-form-urlencoded" -        assert request.urlencoded_form == ODict(utils.urldecode(request.content)) +        assert list(request.urlencoded_form.items()) == [("foobar", "baz")]      def test_set_urlencoded_form(self):          request = treq() -        request.urlencoded_form = ODict([('foo', 'bar'), ('rab', 'oof')]) +        request.urlencoded_form = [('foo', 'bar'), ('rab', 'oof')]          assert request.headers["Content-Type"] == "application/x-www-form-urlencoded"          assert request.content      def test_get_multipart_form(self):          request = treq(content="foobar") -        assert request.multipart_form is None +        assert not request.multipart_form          request.headers["Content-Type"] = "multipart/form-data" -        assert request.multipart_form == ODict( -            utils.multipartdecode( -                request.headers, -                request.content -            ) -        ) +        assert list(request.multipart_form.items()) == [] diff --git a/test/netlib/http/test_response.py b/test/netlib/http/test_response.py index 5440176c..cfd093d4 100644 --- a/test/netlib/http/test_response.py +++ b/test/netlib/http/test_response.py @@ -6,6 +6,7 @@ import six  import time  from netlib.http import Headers +from netlib.http.cookies import CookieAttrs  from netlib.odict import ODict, ODictCaseless  from netlib.tutils import raises, tresp  from .test_message import _test_passthrough_attr, _test_decoded_attr @@ -13,7 +14,7 @@ from .test_message import _test_passthrough_attr, _test_decoded_attr  class TestResponseData(object):      def test_init(self): -        with raises(ValueError if six.PY2 else TypeError): +        with raises(ValueError):              tresp(headers="foobar")          assert isinstance(tresp(headers=None).headers, Headers) @@ -56,7 +57,7 @@ class TestResponseUtils(object):          result = resp.cookies          assert len(result) == 1          assert "cookiename" in result -        assert result["cookiename"][0] == ["cookievalue", ODict()] +        assert result["cookiename"] == ("cookievalue", CookieAttrs())      def test_get_cookies_with_parameters(self):          resp = tresp() @@ -64,13 +65,13 @@ class TestResponseUtils(object):          result = resp.cookies          assert len(result) == 1          assert "cookiename" in result -        assert result["cookiename"][0][0] == "cookievalue" -        attrs = result["cookiename"][0][1] +        assert result["cookiename"][0] == "cookievalue" +        attrs = result["cookiename"][1]          assert len(attrs) == 4 -        assert attrs["domain"] == ["example.com"] -        assert attrs["expires"] == ["Wed Oct  21 16:29:41 2015"] -        assert attrs["path"] == ["/"] -        assert attrs["httponly"] == [None] +        assert attrs["domain"] == "example.com" +        assert attrs["expires"] == "Wed Oct  21 16:29:41 2015" +        assert attrs["path"] == "/" +        assert attrs["httponly"] is None      def test_get_cookies_no_value(self):          resp = tresp() @@ -78,8 +79,8 @@ class TestResponseUtils(object):          result = resp.cookies          assert len(result) == 1          assert "cookiename" in result -        assert result["cookiename"][0][0] == "" -        assert len(result["cookiename"][0][1]) == 2 +        assert result["cookiename"][0] == "" +        assert len(result["cookiename"][1]) == 2      def test_get_cookies_twocookies(self):          resp = tresp() @@ -90,19 +91,16 @@ class TestResponseUtils(object):          result = resp.cookies          assert len(result) == 2          assert "cookiename" in result -        assert result["cookiename"][0] == ["cookievalue", ODict()] +        assert result["cookiename"] == ("cookievalue", CookieAttrs())          assert "othercookie" in result -        assert result["othercookie"][0] == ["othervalue", ODict()] +        assert result["othercookie"] == ("othervalue", CookieAttrs())      def test_set_cookies(self):          resp = tresp() -        v = resp.cookies -        v.add("foo", ["bar", ODictCaseless()]) -        resp.cookies = v +        resp.cookies["foo"] = ("bar", {}) -        v = resp.cookies -        assert len(v) == 1 -        assert v["foo"] == [["bar", ODictCaseless()]] +        assert len(resp.cookies) == 1 +        assert resp.cookies["foo"] == ("bar", CookieAttrs())      def test_refresh(self):          r = tresp() diff --git a/test/netlib/test_multidict.py b/test/netlib/test_multidict.py new file mode 100644 index 00000000..5bb65e3f --- /dev/null +++ b/test/netlib/test_multidict.py @@ -0,0 +1,239 @@ +from netlib import tutils +from netlib.multidict import MultiDict, ImmutableMultiDict, MultiDictView + + +class _TMulti(object): +    @staticmethod +    def _reduce_values(values): +        return values[0] + +    @staticmethod +    def _kconv(key): +        return key.lower() + + +class TMultiDict(_TMulti, MultiDict): +    pass + + +class TImmutableMultiDict(_TMulti, ImmutableMultiDict): +    pass + + +class TestMultiDict(object): +    @staticmethod +    def _multi(): +        return TMultiDict(( +            ("foo", "bar"), +            ("bar", "baz"), +            ("Bar", "bam") +        )) + +    def test_init(self): +        md = TMultiDict() +        assert len(md) == 0 + +        md = TMultiDict([("foo", "bar")]) +        assert len(md) == 1 +        assert md.fields == (("foo", "bar"),) + +    def test_repr(self): +        assert repr(self._multi()) == ( +            "TMultiDict[('foo', 'bar'), ('bar', 'baz'), ('Bar', 'bam')]" +        ) + +    def test_getitem(self): +        md = TMultiDict([("foo", "bar")]) +        assert "foo" in md +        assert "Foo" in md +        assert md["foo"] == "bar" + +        with tutils.raises(KeyError): +            _ = md["bar"] + +        md_multi = TMultiDict( +            [("foo", "a"), ("foo", "b")] +        ) +        assert md_multi["foo"] == "a" + +    def test_setitem(self): +        md = TMultiDict() +        md["foo"] = "bar" +        assert md.fields == (("foo", "bar"),) + +        md["foo"] = "baz" +        assert md.fields == (("foo", "baz"),) + +        md["bar"] = "bam" +        assert md.fields == (("foo", "baz"), ("bar", "bam")) + +    def test_delitem(self): +        md = self._multi() +        del md["foo"] +        assert "foo" not in md +        assert "bar" in md + +        with tutils.raises(KeyError): +            del md["foo"] + +        del md["bar"] +        assert md.fields == () + +    def test_iter(self): +        md = self._multi() +        assert list(md.__iter__()) == ["foo", "bar"] + +    def test_len(self): +        md = TMultiDict() +        assert len(md) == 0 + +        md = self._multi() +        assert len(md) == 2 + +    def test_eq(self): +        assert TMultiDict() == TMultiDict() +        assert not (TMultiDict() == 42) + +        md1 = self._multi() +        md2 = self._multi() +        assert md1 == md2 +        md1.fields = md1.fields[1:] + md1.fields[:1] +        assert not (md1 == md2) + +    def test_ne(self): +        assert not TMultiDict() != TMultiDict() +        assert TMultiDict() != self._multi() +        assert TMultiDict() != 42 + +    def test_get_all(self): +        md = self._multi() +        assert md.get_all("foo") == ["bar"] +        assert md.get_all("bar") == ["baz", "bam"] +        assert md.get_all("baz") == [] + +    def test_set_all(self): +        md = TMultiDict() +        md.set_all("foo", ["bar", "baz"]) +        assert md.fields == (("foo", "bar"), ("foo", "baz")) + +        md = TMultiDict(( +            ("a", "b"), +            ("x", "x"), +            ("c", "d"), +            ("X", "x"), +            ("e", "f"), +        )) +        md.set_all("x", ["1", "2", "3"]) +        assert md.fields == ( +            ("a", "b"), +            ("x", "1"), +            ("c", "d"), +            ("x", "2"), +            ("e", "f"), +            ("x", "3"), +        ) +        md.set_all("x", ["4"]) +        assert md.fields == ( +            ("a", "b"), +            ("x", "4"), +            ("c", "d"), +            ("e", "f"), +        ) + +    def test_add(self): +        md = self._multi() +        md.add("foo", "foo") +        assert md.fields == ( +            ("foo", "bar"), +            ("bar", "baz"), +            ("Bar", "bam"), +            ("foo", "foo") +        ) + +    def test_insert(self): +        md = TMultiDict([("b", "b")]) +        md.insert(0, "a", "a") +        md.insert(2, "c", "c") +        assert md.fields == (("a", "a"), ("b", "b"), ("c", "c")) + +    def test_keys(self): +        md = self._multi() +        assert list(md.keys()) == ["foo", "bar"] +        assert list(md.keys(multi=True)) == ["foo", "bar", "Bar"] + +    def test_values(self): +        md = self._multi() +        assert list(md.values()) == ["bar", "baz"] +        assert list(md.values(multi=True)) == ["bar", "baz", "bam"] + +    def test_items(self): +        md = self._multi() +        assert list(md.items()) == [("foo", "bar"), ("bar", "baz")] +        assert list(md.items(multi=True)) == [("foo", "bar"), ("bar", "baz"), ("Bar", "bam")] + +    def test_to_dict(self): +        md = self._multi() +        assert md.to_dict() == { +            "foo": "bar", +            "bar": ["baz", "bam"] +        } + +    def test_state(self): +        md = self._multi() +        assert len(md.get_state()) == 3 +        assert md == TMultiDict.from_state(md.get_state()) + +        md2 = TMultiDict() +        assert md != md2 +        md2.set_state(md.get_state()) +        assert md == md2 + + +class TestImmutableMultiDict(object): +    def test_modify(self): +        md = TImmutableMultiDict() +        with tutils.raises(TypeError): +            md["foo"] = "bar" + +        with tutils.raises(TypeError): +            del md["foo"] + +        with tutils.raises(TypeError): +            md.add("foo", "bar") + +    def test_with_delitem(self): +        md = TImmutableMultiDict([("foo", "bar")]) +        assert md.with_delitem("foo").fields == () +        assert md.fields == (("foo", "bar"),) + +    def test_with_set_all(self): +        md = TImmutableMultiDict() +        assert md.with_set_all("foo", ["bar"]).fields == (("foo", "bar"),) +        assert md.fields == () + +    def test_with_insert(self): +        md = TImmutableMultiDict() +        assert md.with_insert(0, "foo", "bar").fields == (("foo", "bar"),) + + +class TParent(object): +    def __init__(self): +        self.vals = tuple() + +    def setter(self, vals): +        self.vals = vals + +    def getter(self): +        return self.vals + + +class TestMultiDictView(object): +    def test_modify(self): +        p = TParent() +        tv = MultiDictView(p.getter, p.setter) +        assert len(tv) == 0 +        tv["a"] = "b" +        assert p.vals == (("a", "b"),) +        tv["c"] = "b" +        assert p.vals == (("a", "b"), ("c", "b")) +        assert tv["a"] == "b" diff --git a/test/pathod/test_pathod.py b/test/pathod/test_pathod.py index 10f3b5a3..05a3962e 100644 --- a/test/pathod/test_pathod.py +++ b/test/pathod/test_pathod.py @@ -233,6 +233,7 @@ class CommonTests(tutils.DaemonTests):          # FIXME: Race Condition?          assert "Parse error" in self.d.text_log() +    @pytest.mark.skip(reason="race condition")      def test_websocket_frame_disconnect_error(self):          self.pathoc(["ws:/p/", "wf:b@10:d3"], ws_read_limit=0)          assert self.d.last_log() diff --git a/test/pathod/tutils.py b/test/pathod/tutils.py index 9739afde..f6ed3efb 100644 --- a/test/pathod/tutils.py +++ b/test/pathod/tutils.py @@ -116,7 +116,7 @@ tmpdir = netlib.tutils.tmpdir  raises = netlib.tutils.raises -test_data = utils.Data(__name__) +test_data = netlib.utils.Data(__name__)  def render(r, settings=language.Settings()): | 
