diff options
| -rw-r--r-- | netlib/http/__init__.py | 6 | ||||
| -rw-r--r-- | netlib/http/authentication.py | 10 | ||||
| -rw-r--r-- | netlib/http/headers.py | 205 | ||||
| -rw-r--r-- | netlib/http/http1/assemble.py | 6 | ||||
| -rw-r--r-- | netlib/http/http1/read.py | 14 | ||||
| -rw-r--r-- | netlib/http/models.py | 215 | ||||
| -rw-r--r-- | netlib/utils.py | 17 | ||||
| -rw-r--r-- | netlib/websockets/protocol.py | 14 | ||||
| -rw-r--r-- | test/http/http1/test_assemble.py | 6 | ||||
| -rw-r--r-- | test/http/http1/test_read.py | 22 | ||||
| -rw-r--r-- | test/http/test_authentication.py | 12 | ||||
| -rw-r--r-- | test/http/test_headers.py | 149 | ||||
| -rw-r--r-- | test/http/test_models.py | 152 | ||||
| -rw-r--r-- | test/test_utils.py | 20 | ||||
| -rw-r--r-- | test/websockets/test_websockets.py | 13 | 
15 files changed, 443 insertions, 418 deletions
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py index d72884b3..0ccf6b32 100644 --- a/netlib/http/__init__.py +++ b/netlib/http/__init__.py @@ -1,11 +1,13 @@  from __future__ import absolute_import, print_function, division -from .models import Request, Response, Headers +from .headers import Headers +from .models import Request, Response  from .models import ALPN_PROTO_HTTP1, ALPN_PROTO_H2  from .models import HDR_FORM_MULTIPART, HDR_FORM_URLENCODED, CONTENT_MISSING  from . import http1, http2  __all__ = [ -    "Request", "Response", "Headers", +    "Headers", +    "Request", "Response",      "ALPN_PROTO_HTTP1", "ALPN_PROTO_H2",      "HDR_FORM_MULTIPART", "HDR_FORM_URLENCODED", "CONTENT_MISSING",      "http1", "http2", diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py index 5831660b..d769abe5 100644 --- a/netlib/http/authentication.py +++ b/netlib/http/authentication.py @@ -9,18 +9,18 @@ def parse_http_basic_auth(s):          return None      scheme = words[0]      try: -        user = binascii.a2b_base64(words[1]) +        user = binascii.a2b_base64(words[1]).decode("utf8", "replace")      except binascii.Error:          return None -    parts = user.split(b':') +    parts = user.split(':')      if len(parts) != 2:          return None      return scheme, parts[0], parts[1]  def assemble_http_basic_auth(scheme, username, password): -    v = binascii.b2a_base64(username + b":" + password) -    return scheme + b" " + v +    v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii") +    return scheme + " " + v  class NullProxyAuth(object): @@ -69,7 +69,7 @@ class BasicProxyAuth(NullProxyAuth):          if not parts:              return False          scheme, username, password = parts -        if scheme.lower() != b'basic': +        if scheme.lower() != 'basic':              return False          if not self.password_manager.test(username, password):              return False diff --git a/netlib/http/headers.py b/netlib/http/headers.py new file mode 100644 index 00000000..1511ea2d --- /dev/null +++ b/netlib/http/headers.py @@ -0,0 +1,205 @@ +""" + +Unicode Handling +---------------- +See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/ +""" +from __future__ import absolute_import, print_function, division +import copy +try: +    from collections.abc import MutableMapping +except ImportError:  # Workaround for Python < 3.3 +    from collections import MutableMapping + + +import six + +from netlib.utils import always_byte_args + +if six.PY2: +    _native = lambda x: x +    _asbytes = lambda x: x +    _always_byte_args = lambda x: x +else: +    # While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded. +    _native = lambda x: x.decode("utf-8", "surrogateescape") +    _asbytes = lambda x: x.encode("utf-8", "surrogateescape") +    _always_byte_args = always_byte_args("utf-8", "surrogateescape") + + +class Headers(MutableMapping, object): +    """ +    Header class which allows both convenient access to individual headers as well as +    direct access to the underlying raw data. Provides a full dictionary interface. + +    Example: + +    .. code-block:: python + +        # Create header from a list of (header_name, header_value) tuples +        >>> h = Headers([ +                ["Host","example.com"], +                ["Accept","text/html"], +                ["accept","application/xml"] +            ]) + +        # Headers mostly behave like a normal dict. +        >>> h["Host"] +        "example.com" + +        # HTTP Headers are case insensitive +        >>> h["host"] +        "example.com" + +        # Multiple headers are folded into a single header as per RFC7230 +        >>> h["Accept"] +        "text/html, application/xml" + +        # Setting a header removes all existing headers with the same name. +        >>> h["Accept"] = "application/text" +        >>> h["Accept"] +        "application/text" + +        # str(h) returns a HTTP1 header block. +        >>> print(h) +        Host: example.com +        Accept: application/text + +        # For full control, the raw header fields can be accessed +        >>> h.fields + +        # Headers can also be crated from keyword arguments +        >>> h = Headers(host="example.com", content_type="application/xml") + +    Caveats: +        For use with the "Set-Cookie" header, see :py:meth:`get_all`. +    """ + +    @_always_byte_args +    def __init__(self, fields=None, **headers): +        """ +        Args: +            fields: (optional) list of ``(name, value)`` header tuples, +                e.g. ``[("Host","example.com")]``. All names and values must be bytes. +            **headers: Additional headers to set. Will overwrite existing values from `fields`. +                For convenience, underscores in header names will be transformed to dashes - +                this behaviour does not extend to other methods. +                If ``**headers`` contains multiple keys that have equal ``.lower()`` s, +                the behavior is undefined. +        """ +        self.fields = fields or [] + +        for name, value in self.fields: +            if not isinstance(name, bytes) or not isinstance(value, bytes): +                raise ValueError("Headers passed as fields must be bytes.") + +        # content_type -> content-type +        headers = { +            _asbytes(name).replace(b"_", b"-"): value +            for name, value in six.iteritems(headers) +        } +        self.update(headers) + +    def __bytes__(self): +        if self.fields: +            return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" +        else: +            return b"" + +    if six.PY2: +        __str__ = __bytes__ + +    @_always_byte_args +    def __getitem__(self, name): +        values = self.get_all(name) +        if not values: +            raise KeyError(name) +        return ", ".join(values) + +    @_always_byte_args +    def __setitem__(self, name, value): +        idx = self._index(name) + +        # To please the human eye, we insert at the same position the first existing header occured. +        if idx is not None: +            del self[name] +            self.fields.insert(idx, [name, value]) +        else: +            self.fields.append([name, value]) + +    @_always_byte_args +    def __delitem__(self, name): +        if name not in self: +            raise KeyError(name) +        name = name.lower() +        self.fields = [ +            field for field in self.fields +            if name != field[0].lower() +        ] + +    def __iter__(self): +        seen = set() +        for name, _ in self.fields: +            name_lower = name.lower() +            if name_lower not in seen: +                seen.add(name_lower) +                yield _native(name) + +    def __len__(self): +        return len(set(name.lower() for name, _ in self.fields)) + +    # __hash__ = object.__hash__ + +    def _index(self, name): +        name = name.lower() +        for i, field in enumerate(self.fields): +            if field[0].lower() == name: +                return i +        return None + +    def __eq__(self, other): +        if isinstance(other, Headers): +            return self.fields == other.fields +        return False + +    def __ne__(self, other): +        return not self.__eq__(other) + +    @_always_byte_args +    def get_all(self, name): +        """ +        Like :py:meth:`get`, but does not fold multiple headers into a single one. +        This is useful for Set-Cookie headers, which do not support folding. + +        See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 +        """ +        name_lower = name.lower() +        values = [_native(value) for n, value in self.fields if n.lower() == name_lower] +        return values + +    @_always_byte_args +    def set_all(self, name, values): +        """ +        Explicitly set multiple headers for the given key. +        See: :py:meth:`get_all` +        """ +        values = map(_asbytes, values)  # _always_byte_args does not fix lists +        if name in self: +            del self[name] +        self.fields.extend( +            [name, value] for value in values +        ) + +    def copy(self): +        return Headers(copy.copy(self.fields)) + +    # Implement the StateObject protocol from mitmproxy +    def get_state(self, short=False): +        return tuple(tuple(field) for field in self.fields) + +    def load_state(self, state): +        self.fields = [list(field) for field in state] + +    @classmethod +    def from_state(cls, state): +        return cls([list(field) for field in state])
\ No newline at end of file diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py index c2b60a0f..88aeac05 100644 --- a/netlib/http/http1/assemble.py +++ b/netlib/http/http1/assemble.py @@ -35,7 +35,7 @@ def assemble_response_head(response):  def assemble_body(headers, body_chunks): -    if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): +    if "chunked" in headers.get("transfer-encoding", "").lower():          for chunk in body_chunks:              if chunk:                  yield b"%x\r\n%s\r\n" % (len(chunk), chunk) @@ -76,8 +76,8 @@ def _assemble_request_line(request, form=None):  def _assemble_request_headers(request):      headers = request.headers.copy() -    if b"host" not in headers and request.scheme and request.host and request.port: -        headers[b"Host"] = utils.hostport( +    if "host" not in headers and request.scheme and request.host and request.port: +        headers["host"] = utils.hostport(              request.scheme,              request.host,              request.port diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py index c6760ff3..4c898348 100644 --- a/netlib/http/http1/read.py +++ b/netlib/http/http1/read.py @@ -146,11 +146,11 @@ def connection_close(http_version, headers):          according to RFC 2616 Section 8.1.      """      # At first, check if we have an explicit Connection header. -    if b"connection" in headers: +    if "connection" in headers:          tokens = utils.get_header_tokens(headers, "connection") -        if b"close" in tokens: +        if "close" in tokens:              return True -        elif b"keep-alive" in tokens: +        elif "keep-alive" in tokens:              return False      # If we don't have a Connection header, HTTP 1.1 connections are assumed to @@ -181,7 +181,7 @@ def expected_http_body_size(request, response=None):          is_request = False      if is_request: -        if headers.get(b"expect", b"").lower() == b"100-continue": +        if headers.get("expect", "").lower() == "100-continue":              return 0      else:          if request.method.upper() == b"HEAD": @@ -193,11 +193,11 @@ def expected_http_body_size(request, response=None):          if response_code in (204, 304):              return 0 -    if b"chunked" in headers.get(b"transfer-encoding", b"").lower(): +    if "chunked" in headers.get("transfer-encoding", "").lower():          return None -    if b"content-length" in headers: +    if "content-length" in headers:          try: -            size = int(headers[b"content-length"]) +            size = int(headers["content-length"])              if size < 0:                  raise ValueError()              return size diff --git a/netlib/http/models.py b/netlib/http/models.py index 512a764d..55664533 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -1,201 +1,22 @@ -from __future__ import absolute_import, print_function, division -import copy +  from ..odict import ODict  from .. import utils, encoding -from ..utils import always_bytes, always_byte_args, native +from ..utils import always_bytes, native  from . import cookies +from .headers import Headers -import six  from six.moves import urllib -try: -    from collections import MutableMapping -except ImportError: -    from collections.abc import MutableMapping  # TODO: Move somewhere else?  ALPN_PROTO_HTTP1 = b'http/1.1'  ALPN_PROTO_H2 = b'h2' -HDR_FORM_URLENCODED = b"application/x-www-form-urlencoded" -HDR_FORM_MULTIPART = b"multipart/form-data" +HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" +HDR_FORM_MULTIPART = "multipart/form-data"  CONTENT_MISSING = 0 -class Headers(MutableMapping, object): -    """ -    Header class which allows both convenient access to individual headers as well as -    direct access to the underlying raw data. Provides a full dictionary interface. - -    Example: - -    .. code-block:: python - -        # Create header from a list of (header_name, header_value) tuples -        >>> h = Headers([ -                ["Host","example.com"], -                ["Accept","text/html"], -                ["accept","application/xml"] -            ]) - -        # Headers mostly behave like a normal dict. -        >>> h["Host"] -        "example.com" - -        # HTTP Headers are case insensitive -        >>> h["host"] -        "example.com" - -        # Multiple headers are folded into a single header as per RFC7230 -        >>> h["Accept"] -        "text/html, application/xml" - -        # Setting a header removes all existing headers with the same name. -        >>> h["Accept"] = "application/text" -        >>> h["Accept"] -        "application/text" - -        # str(h) returns a HTTP1 header block. -        >>> print(h) -        Host: example.com -        Accept: application/text - -        # For full control, the raw header fields can be accessed -        >>> h.fields - -        # Headers can also be crated from keyword arguments -        >>> h = Headers(host="example.com", content_type="application/xml") - -    Caveats: -        For use with the "Set-Cookie" header, see :py:meth:`get_all`. -    """ - -    @always_byte_args("ascii") -    def __init__(self, fields=None, **headers): -        """ -        Args: -            fields: (optional) list of ``(name, value)`` header tuples, -                e.g. ``[("Host","example.com")]``. All names and values must be bytes. -            **headers: Additional headers to set. Will overwrite existing values from `fields`. -                For convenience, underscores in header names will be transformed to dashes - -                this behaviour does not extend to other methods. -                If ``**headers`` contains multiple keys that have equal ``.lower()`` s, -                the behavior is undefined. -        """ -        self.fields = fields or [] - -        # content_type -> content-type -        headers = { -            name.encode("ascii").replace(b"_", b"-"): value -            for name, value in six.iteritems(headers) -        } -        self.update(headers) - -    def __bytes__(self): -        if self.fields: -            return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n" -        else: -            return b"" - -    if six.PY2: -        __str__ = __bytes__ - -    @always_byte_args("ascii") -    def __getitem__(self, name): -        values = self.get_all(name) -        if not values: -            raise KeyError(name) -        return b", ".join(values) - -    @always_byte_args("ascii") -    def __setitem__(self, name, value): -        idx = self._index(name) - -        # To please the human eye, we insert at the same position the first existing header occured. -        if idx is not None: -            del self[name] -            self.fields.insert(idx, [name, value]) -        else: -            self.fields.append([name, value]) - -    @always_byte_args("ascii") -    def __delitem__(self, name): -        if name not in self: -            raise KeyError(name) -        name = name.lower() -        self.fields = [ -            field for field in self.fields -            if name != field[0].lower() -        ] - -    def __iter__(self): -        seen = set() -        for name, _ in self.fields: -            name_lower = name.lower() -            if name_lower not in seen: -                seen.add(name_lower) -                yield name - -    def __len__(self): -        return len(set(name.lower() for name, _ in self.fields)) - -    # __hash__ = object.__hash__ - -    def _index(self, name): -        name = name.lower() -        for i, field in enumerate(self.fields): -            if field[0].lower() == name: -                return i -        return None - -    def __eq__(self, other): -        if isinstance(other, Headers): -            return self.fields == other.fields -        return False - -    def __ne__(self, other): -        return not self.__eq__(other) - -    @always_byte_args("ascii") -    def get_all(self, name): -        """ -        Like :py:meth:`get`, but does not fold multiple headers into a single one. -        This is useful for Set-Cookie headers, which do not support folding. - -        See also: https://tools.ietf.org/html/rfc7230#section-3.2.2 -        """ -        name_lower = name.lower() -        values = [value for n, value in self.fields if n.lower() == name_lower] -        return values - -    def set_all(self, name, values): -        """ -        Explicitly set multiple headers for the given key. -        See: :py:meth:`get_all` -        """ -        name = always_bytes(name, "ascii") -        values = (always_bytes(value, "ascii") for value in values) -        if name in self: -            del self[name] -        self.fields.extend( -            [name, value] for value in values -        ) - -    def copy(self): -        return Headers(copy.copy(self.fields)) - -    # Implement the StateObject protocol from mitmproxy -    def get_state(self, short=False): -        return tuple(tuple(field) for field in self.fields) - -    def load_state(self, state): -        self.fields = [list(field) for field in state] - -    @classmethod -    def from_state(cls, state): -        return cls([list(field) for field in state]) - -  class Message(object):      def __init__(self, http_version, headers, body, timestamp_start, timestamp_end):          self.http_version = http_version @@ -216,7 +37,7 @@ class Message(object):      def body(self, body):          self._body = body          if isinstance(body, bytes): -            self.headers[b"content-length"] = str(len(body)).encode() +            self.headers["content-length"] = str(len(body)).encode()      content = body @@ -268,8 +89,8 @@ class Request(Message):              response. That is, we remove ETags and If-Modified-Since headers.          """          delheaders = [ -            b"if-modified-since", -            b"if-none-match", +            "if-modified-since", +            "if-none-match",          ]          for i in delheaders:              self.headers.pop(i, None) @@ -279,14 +100,14 @@ class Request(Message):              Modifies this request to remove headers that will compress the              resource's data.          """ -        self.headers["accept-encoding"] = b"identity" +        self.headers["accept-encoding"] = "identity"      def constrain_encoding(self):          """              Limits the permissible Accept-Encoding values, based on what we can              decode appropriately.          """ -        accept_encoding = native(self.headers.get("accept-encoding"), "ascii") +        accept_encoding = self.headers.get("accept-encoding")          if accept_encoding:              self.headers["accept-encoding"] = (                  ', '.join( @@ -309,9 +130,9 @@ class Request(Message):              indicates non-form data.          """          if self.body: -            if HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): +            if HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower():                  return self.get_form_urlencoded() -            elif HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): +            elif HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower():                  return self.get_form_multipart()          return ODict([]) @@ -321,12 +142,12 @@ class Request(Message):              Returns an empty ODict if there is no data or the content-type              indicates non-form data.          """ -        if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", b"").lower(): +        if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type", "").lower():              return ODict(utils.urldecode(self.body))          return ODict([])      def get_form_multipart(self): -        if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", b"").lower(): +        if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type", "").lower():              return ODict(                  utils.multipartdecode(                      self.headers, @@ -341,7 +162,7 @@ class Request(Message):          """          # FIXME: If there's an existing content-type header indicating a          # url-encoded form, leave it alone. -        self.headers[b"content-type"] = HDR_FORM_URLENCODED +        self.headers["content-type"] = HDR_FORM_URLENCODED          self.body = utils.urlencode(odict.lst)      def get_path_components(self): @@ -400,7 +221,7 @@ class Request(Message):          """          if hostheader and "host" in self.headers:              try: -                return self.headers["host"].decode("idna") +                return self.headers["host"]              except ValueError:                  pass          if self.host: @@ -420,7 +241,7 @@ class Request(Message):          """          ret = ODict()          for i in self.headers.get_all("Cookie"): -            ret.extend(cookies.parse_cookie_header(native(i,"ascii"))) +            ret.extend(cookies.parse_cookie_header(i))          return ret      def set_cookies(self, odict): @@ -499,7 +320,7 @@ class Response(Message):          """          ret = []          for header in self.headers.get_all("set-cookie"): -            v = cookies.parse_set_cookie_header(native(header, "ascii")) +            v = cookies.parse_set_cookie_header(header)              if v:                  name, value, attrs = v                  ret.append([name, [value, attrs]]) diff --git a/netlib/utils.py b/netlib/utils.py index b9848038..d5b30128 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -269,7 +269,7 @@ def get_header_tokens(headers, key):      """      if key not in headers:          return [] -    tokens = headers[key].split(b",") +    tokens = headers[key].split(",")      return [token.strip() for token in tokens] @@ -320,14 +320,14 @@ def parse_content_type(c):              ("text", "html", {"charset": "UTF-8"})      """ -    parts = c.split(b";", 1) -    ts = parts[0].split(b"/", 1) +    parts = c.split(";", 1) +    ts = parts[0].split("/", 1)      if len(ts) != 2:          return None      d = {}      if len(parts) == 2: -        for i in parts[1].split(b";"): -            clause = i.split(b"=", 1) +        for i in parts[1].split(";"): +            clause = i.split("=", 1)              if len(clause) == 2:                  d[clause[0].strip()] = clause[1].strip()      return ts[0].lower(), ts[1].lower(), d @@ -337,13 +337,14 @@ def multipartdecode(headers, content):      """          Takes a multipart boundary encoded string and returns list of (key, value) tuples.      """ -    v = headers.get(b"Content-Type") +    v = headers.get("Content-Type")      if v:          v = parse_content_type(v)          if not v:              return [] -        boundary = v[2].get(b"boundary") -        if not boundary: +        try: +            boundary = v[2]["boundary"].encode("ascii") +        except (KeyError, UnicodeError):              return []          rx = re.compile(br'\bname="([^"]+)"') diff --git a/netlib/websockets/protocol.py b/netlib/websockets/protocol.py index 778fe7e7..e62f8df6 100644 --- a/netlib/websockets/protocol.py +++ b/netlib/websockets/protocol.py @@ -80,7 +80,7 @@ class WebsocketsProtocol(object):              Returns an instance of Headers          """          if not key: -            key = base64.b64encode(os.urandom(16)).decode('utf-8') +            key = base64.b64encode(os.urandom(16)).decode('ascii')          return Headers(**{              HEADER_WEBSOCKET_KEY: key,              HEADER_WEBSOCKET_VERSION: version, @@ -95,27 +95,25 @@ class WebsocketsProtocol(object):          """          return Headers(**{              HEADER_WEBSOCKET_ACCEPT: self.create_server_nonce(key), -            "Connection": "Upgrade", -            "Upgrade": "websocket", +            "connection": "Upgrade", +            "upgrade": "websocket",          })      @classmethod      def check_client_handshake(self, headers): -        if headers.get("upgrade") != b"websocket": +        if headers.get("upgrade") != "websocket":              return          return headers.get(HEADER_WEBSOCKET_KEY)      @classmethod      def check_server_handshake(self, headers): -        if headers.get("upgrade") != b"websocket": +        if headers.get("upgrade") != "websocket":              return          return headers.get(HEADER_WEBSOCKET_ACCEPT)      @classmethod      def create_server_nonce(self, client_nonce): -        return base64.b64encode( -            binascii.unhexlify(hashlib.sha1(client_nonce + websockets_magic).hexdigest()) -        ) +        return base64.b64encode(hashlib.sha1(client_nonce + websockets_magic).digest()) diff --git a/test/http/http1/test_assemble.py b/test/http/http1/test_assemble.py index 2d250909..963e7549 100644 --- a/test/http/http1/test_assemble.py +++ b/test/http/http1/test_assemble.py @@ -77,16 +77,16 @@ def test_assemble_request_line():  def test_assemble_request_headers():      # https://github.com/mitmproxy/mitmproxy/issues/186      r = treq(body=b"") -    r.headers[b"Transfer-Encoding"] = b"chunked" +    r.headers["Transfer-Encoding"] = "chunked"      c = _assemble_request_headers(r)      assert b"Transfer-Encoding" in c -    assert b"Host" in _assemble_request_headers(treq(headers=Headers())) +    assert b"host" in _assemble_request_headers(treq(headers=Headers()))  def test_assemble_response_headers():      # https://github.com/mitmproxy/mitmproxy/issues/186      r = tresp(body=b"") -    r.headers["Transfer-Encoding"] = b"chunked" +    r.headers["Transfer-Encoding"] = "chunked"      c = _assemble_response_headers(r)      assert b"Transfer-Encoding" in c diff --git a/test/http/http1/test_read.py b/test/http/http1/test_read.py index 55def2a5..9eb02a24 100644 --- a/test/http/http1/test_read.py +++ b/test/http/http1/test_read.py @@ -1,9 +1,7 @@  from __future__ import absolute_import, print_function, division  from io import BytesIO  import textwrap -  from mock import Mock -  from netlib.exceptions import HttpException, HttpSyntaxException, HttpReadDisconnect  from netlib.http import Headers  from netlib.http.http1.read import ( @@ -35,7 +33,7 @@ def test_read_request_head():      rfile.first_byte_timestamp = 42      r = read_request_head(rfile)      assert r.method == b"GET" -    assert r.headers["Content-Length"] == b"4" +    assert r.headers["Content-Length"] == "4"      assert r.body is None      assert rfile.reset_timestamps.called      assert r.timestamp_start == 42 @@ -62,7 +60,7 @@ def test_read_response_head():      rfile.first_byte_timestamp = 42      r = read_response_head(rfile)      assert r.status_code == 418 -    assert r.headers["Content-Length"] == b"4" +    assert r.headers["Content-Length"] == "4"      assert r.body is None      assert rfile.reset_timestamps.called      assert r.timestamp_start == 42 @@ -76,14 +74,12 @@ class TestReadBody(object):          assert body == b"foo"          assert rfile.read() == b"bar" -      def test_known_size(self):          rfile = BytesIO(b"foobar")          body = b"".join(read_body(rfile, 3))          assert body == b"foo"          assert rfile.read() == b"bar" -      def test_known_size_limit(self):          rfile = BytesIO(b"foobar")          with raises(HttpException): @@ -99,7 +95,6 @@ class TestReadBody(object):          body = b"".join(read_body(rfile, -1))          assert body == b"foobar" -      def test_unknown_size_limit(self):          rfile = BytesIO(b"foobar")          with raises(HttpException): @@ -121,13 +116,13 @@ def test_connection_close():  def test_expected_http_body_size():      # Expect: 100-continue      assert expected_http_body_size( -        treq(headers=Headers(expect=b"100-continue", content_length=b"42")) +        treq(headers=Headers(expect="100-continue", content_length="42"))      ) == 0      # http://tools.ietf.org/html/rfc7230#section-3.3      assert expected_http_body_size(          treq(method=b"HEAD"), -        tresp(headers=Headers(content_length=b"42")) +        tresp(headers=Headers(content_length="42"))      ) == 0      assert expected_http_body_size(          treq(method=b"CONNECT"), @@ -141,17 +136,17 @@ def test_expected_http_body_size():      # chunked      assert expected_http_body_size( -        treq(headers=Headers(transfer_encoding=b"chunked")), +        treq(headers=Headers(transfer_encoding="chunked")),      ) is None      # explicit length -    for l in (b"foo", b"-7"): +    for val in (b"foo", b"-7"):          with raises(HttpSyntaxException):              expected_http_body_size( -                treq(headers=Headers(content_length=l)) +                treq(headers=Headers(content_length=val))              )      assert expected_http_body_size( -        treq(headers=Headers(content_length=b"42")) +        treq(headers=Headers(content_length="42"))      ) == 42      # no length @@ -286,6 +281,7 @@ class TestReadHeaders(object):          with raises(HttpSyntaxException):              self._read(data) +  def test_read_chunked():      req = treq(body=None)      req.headers["Transfer-Encoding"] = "chunked" diff --git a/test/http/test_authentication.py b/test/http/test_authentication.py index a2aa774a..1df7cd9c 100644 --- a/test/http/test_authentication.py +++ b/test/http/test_authentication.py @@ -5,13 +5,13 @@ from netlib.http import authentication, Headers  def test_parse_http_basic_auth(): -    vals = (b"basic", b"foo", b"bar") +    vals = ("basic", "foo", "bar")      assert authentication.parse_http_basic_auth(          authentication.assemble_http_basic_auth(*vals)      ) == vals      assert not authentication.parse_http_basic_auth("")      assert not authentication.parse_http_basic_auth("foo bar") -    v = b"basic " + binascii.b2a_base64(b"foo") +    v = "basic " + binascii.b2a_base64(b"foo").decode("ascii")      assert not authentication.parse_http_basic_auth(v) @@ -34,7 +34,7 @@ class TestPassManHtpasswd:      def test_simple(self):          pm = authentication.PassManHtpasswd(tutils.test_data.path("data/htpasswd")) -        vals = (b"basic", b"test", b"test") +        vals = ("basic", "test", "test")          authentication.assemble_http_basic_auth(*vals)          assert pm.test("test", "test")          assert not pm.test("test", "foo") @@ -73,7 +73,7 @@ class TestBasicProxyAuth:          ba = authentication.BasicProxyAuth(authentication.PassManNonAnon(), "test")          headers = Headers() -        vals = (b"basic", b"foo", b"bar") +        vals = ("basic", "foo", "bar")          headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals)          assert ba.authenticate(headers) @@ -86,12 +86,12 @@ class TestBasicProxyAuth:          headers[ba.AUTH_HEADER] = "foo"          assert not ba.authenticate(headers) -        vals = (b"foo", b"foo", b"bar") +        vals = ("foo", "foo", "bar")          headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals)          assert not ba.authenticate(headers)          ba = authentication.BasicProxyAuth(authentication.PassMan(), "test") -        vals = (b"basic", b"foo", b"bar") +        vals = ("basic", "foo", "bar")          headers[ba.AUTH_HEADER] = authentication.assemble_http_basic_auth(*vals)          assert not ba.authenticate(headers) diff --git a/test/http/test_headers.py b/test/http/test_headers.py new file mode 100644 index 00000000..f1af1feb --- /dev/null +++ b/test/http/test_headers.py @@ -0,0 +1,149 @@ +from netlib.http import Headers +from netlib.tutils import raises + + +class TestHeaders(object): +    def _2host(self): +        return Headers( +            [ +                [b"Host", b"example.com"], +                [b"host", b"example.org"] +            ] +        ) + +    def test_init(self): +        headers = Headers() +        assert len(headers) == 0 + +        headers = Headers([[b"Host", b"example.com"]]) +        assert len(headers) == 1 +        assert headers["Host"] == "example.com" + +        headers = Headers(Host="example.com") +        assert len(headers) == 1 +        assert headers["Host"] == "example.com" + +        headers = Headers( +            [[b"Host", b"invalid"]], +            Host="example.com" +        ) +        assert len(headers) == 1 +        assert headers["Host"] == "example.com" + +        headers = Headers( +            [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], +            Host="example.com" +        ) +        assert len(headers) == 2 +        assert headers["Host"] == "example.com" +        assert headers["Accept"] == "text/plain" + +    def test_getitem(self): +        headers = Headers(Host="example.com") +        assert headers["Host"] == "example.com" +        assert headers["host"] == "example.com" +        with raises(KeyError): +            _ = headers["Accept"] + +        headers = self._2host() +        assert headers["Host"] == "example.com, example.org" + +    def test_str(self): +        headers = Headers(Host="example.com") +        assert bytes(headers) == b"Host: example.com\r\n" + +        headers = Headers([ +            [b"Host", b"example.com"], +            [b"Accept", b"text/plain"] +        ]) +        assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" + +        headers = Headers() +        assert bytes(headers) == b"" + +    def test_setitem(self): +        headers = Headers() +        headers["Host"] = "example.com" +        assert "Host" in headers +        assert "host" in headers +        assert headers["Host"] == "example.com" + +        headers["host"] = "example.org" +        assert "Host" in headers +        assert "host" in headers +        assert headers["Host"] == "example.org" + +        headers["accept"] = "text/plain" +        assert len(headers) == 2 +        assert "Accept" in headers +        assert "Host" in headers + +        headers = self._2host() +        assert len(headers.fields) == 2 +        headers["Host"] = "example.com" +        assert len(headers.fields) == 1 +        assert "Host" in headers + +    def test_delitem(self): +        headers = Headers(Host="example.com") +        assert len(headers) == 1 +        del headers["host"] +        assert len(headers) == 0 +        try: +            del headers["host"] +        except KeyError: +            assert True +        else: +            assert False + +        headers = self._2host() +        del headers["Host"] +        assert len(headers) == 0 + +    def test_keys(self): +        headers = Headers(Host="example.com") +        assert list(headers.keys()) == ["Host"] + +        headers = self._2host() +        assert list(headers.keys()) == ["Host"] + +    def test_eq_ne(self): +        headers1 = Headers(Host="example.com") +        headers2 = Headers(host="example.com") +        assert not (headers1 == headers2) +        assert headers1 != headers2 + +        headers1 = Headers(Host="example.com") +        headers2 = Headers(Host="example.com") +        assert headers1 == headers2 +        assert not (headers1 != headers2) + +        assert headers1 != 42 + +    def test_get_all(self): +        headers = self._2host() +        assert headers.get_all("host") == ["example.com", "example.org"] +        assert headers.get_all("accept") == [] + +    def test_set_all(self): +        headers = Headers(Host="example.com") +        headers.set_all("Accept", ["text/plain"]) +        assert len(headers) == 2 +        assert "accept" in headers + +        headers = self._2host() +        headers.set_all("Host", ["example.org"]) +        assert headers["host"] == "example.org" + +        headers.set_all("Host", ["example.org", "example.net"]) +        assert headers["host"] == "example.org, example.net" + +    def test_state(self): +        headers = self._2host() +        assert len(headers.get_state()) == 2 +        assert headers == Headers.from_state(headers.get_state()) + +        headers2 = Headers() +        assert headers != headers2 +        headers2.load_state(headers.get_state()) +        assert headers == headers2 diff --git a/test/http/test_models.py b/test/http/test_models.py index d420b22b..10e0795a 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -58,20 +58,20 @@ class TestRequest(object):          req = tutils.treq()          req.headers["Accept-Encoding"] = "foobar"          req.anticomp() -        assert req.headers["Accept-Encoding"] == b"identity" +        assert req.headers["Accept-Encoding"] == "identity"      def test_constrain_encoding(self):          req = tutils.treq()          req.headers["Accept-Encoding"] = "identity, gzip, foo"          req.constrain_encoding() -        assert b"foo" not in req.headers["Accept-Encoding"] +        assert "foo" not in req.headers["Accept-Encoding"]      def test_update_host(self):          req = tutils.treq()          req.headers["Host"] = ""          req.host = "foobar"          req.update_host_header() -        assert req.headers["Host"] == b"foobar" +        assert req.headers["Host"] == "foobar"      def test_get_form(self):          req = tutils.treq() @@ -393,149 +393,3 @@ class TestResponse(object):          v = resp.get_cookies()          assert len(v) == 1          assert v["foo"] == [["bar", ODictCaseless()]] - - -class TestHeaders(object): -    def _2host(self): -        return Headers( -            [ -                [b"Host", b"example.com"], -                [b"host", b"example.org"] -            ] -        ) - -    def test_init(self): -        headers = Headers() -        assert len(headers) == 0 - -        headers = Headers([[b"Host", b"example.com"]]) -        assert len(headers) == 1 -        assert headers["Host"] == b"example.com" - -        headers = Headers(Host="example.com") -        assert len(headers) == 1 -        assert headers["Host"] == b"example.com" - -        headers = Headers( -            [[b"Host", b"invalid"]], -            Host="example.com" -        ) -        assert len(headers) == 1 -        assert headers["Host"] == b"example.com" - -        headers = Headers( -            [[b"Host", b"invalid"], [b"Accept", b"text/plain"]], -            Host="example.com" -        ) -        assert len(headers) == 2 -        assert headers["Host"] == b"example.com" -        assert headers["Accept"] == b"text/plain" - -    def test_getitem(self): -        headers = Headers(Host="example.com") -        assert headers["Host"] == b"example.com" -        assert headers["host"] == b"example.com" -        tutils.raises(KeyError, headers.__getitem__, "Accept") - -        headers = self._2host() -        assert headers["Host"] == b"example.com, example.org" - -    def test_str(self): -        headers = Headers(Host="example.com") -        assert bytes(headers) == b"Host: example.com\r\n" - -        headers = Headers([ -            [b"Host", b"example.com"], -            [b"Accept", b"text/plain"] -        ]) -        assert bytes(headers) == b"Host: example.com\r\nAccept: text/plain\r\n" - -        headers = Headers() -        assert bytes(headers) == b"" - -    def test_setitem(self): -        headers = Headers() -        headers["Host"] = "example.com" -        assert "Host" in headers -        assert "host" in headers -        assert headers["Host"] == b"example.com" - -        headers["host"] = "example.org" -        assert "Host" in headers -        assert "host" in headers -        assert headers["Host"] == b"example.org" - -        headers["accept"] = "text/plain" -        assert len(headers) == 2 -        assert "Accept" in headers -        assert "Host" in headers - -        headers = self._2host() -        assert len(headers.fields) == 2 -        headers["Host"] = "example.com" -        assert len(headers.fields) == 1 -        assert "Host" in headers - -    def test_delitem(self): -        headers = Headers(Host="example.com") -        assert len(headers) == 1 -        del headers["host"] -        assert len(headers) == 0 -        try: -            del headers["host"] -        except KeyError: -            assert True -        else: -            assert False - -        headers = self._2host() -        del headers["Host"] -        assert len(headers) == 0 - -    def test_keys(self): -        headers = Headers(Host="example.com") -        assert list(headers.keys()) == [b"Host"] - -        headers = self._2host() -        assert list(headers.keys()) == [b"Host"] - -    def test_eq_ne(self): -        headers1 = Headers(Host="example.com") -        headers2 = Headers(host="example.com") -        assert not (headers1 == headers2) -        assert headers1 != headers2 - -        headers1 = Headers(Host="example.com") -        headers2 = Headers(Host="example.com") -        assert headers1 == headers2 -        assert not (headers1 != headers2) - -        assert headers1 != 42 - -    def test_get_all(self): -        headers = self._2host() -        assert headers.get_all("host") == [b"example.com", b"example.org"] -        assert headers.get_all("accept") == [] - -    def test_set_all(self): -        headers = Headers(Host="example.com") -        headers.set_all("Accept", ["text/plain"]) -        assert len(headers) == 2 -        assert "accept" in headers - -        headers = self._2host() -        headers.set_all("Host", ["example.org"]) -        assert headers["host"] == b"example.org" - -        headers.set_all("Host", ["example.org", "example.net"]) -        assert headers["host"] == b"example.org, example.net" - -    def test_state(self): -        headers = self._2host() -        assert len(headers.get_state()) == 2 -        assert headers == Headers.from_state(headers.get_state()) - -        headers2 = Headers() -        assert headers != headers2 -        headers2.load_state(headers.get_state()) -        assert headers == headers2 diff --git a/test/test_utils.py b/test/test_utils.py index 8f4b4059..17636cc4 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -103,17 +103,17 @@ def test_get_header_tokens():      headers = Headers()      assert utils.get_header_tokens(headers, "foo") == []      headers["foo"] = "bar" -    assert utils.get_header_tokens(headers, "foo") == [b"bar"] +    assert utils.get_header_tokens(headers, "foo") == ["bar"]      headers["foo"] = "bar, voing" -    assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing"] +    assert utils.get_header_tokens(headers, "foo") == ["bar", "voing"]      headers.set_all("foo", ["bar, voing", "oink"]) -    assert utils.get_header_tokens(headers, "foo") == [b"bar", b"voing", b"oink"] +    assert utils.get_header_tokens(headers, "foo") == ["bar", "voing", "oink"]  def test_multipartdecode(): -    boundary = b'somefancyboundary' +    boundary = 'somefancyboundary'      headers = Headers( -        content_type=b'multipart/form-data; boundary=' + boundary +        content_type='multipart/form-data; boundary=' + boundary      )      content = (          "--{0}\n" @@ -122,7 +122,7 @@ def test_multipartdecode():          "--{0}\n"          "Content-Disposition: form-data; name=\"field2\"\n\n"          "value2\n" -        "--{0}--".format(boundary.decode()).encode() +        "--{0}--".format(boundary).encode()      )      form = utils.multipartdecode(headers, content) @@ -134,8 +134,8 @@ def test_multipartdecode():  def test_parse_content_type():      p = utils.parse_content_type -    assert p(b"text/html") == (b"text", b"html", {}) -    assert p(b"text") is None +    assert p("text/html") == ("text", "html", {}) +    assert p("text") is None -    v = p(b"text/html; charset=UTF-8") -    assert v == (b'text', b'html', {b'charset': b'UTF-8'}) +    v = p("text/html; charset=UTF-8") +    assert v == ('text', 'html', {'charset': 'UTF-8'}) diff --git a/test/websockets/test_websockets.py b/test/websockets/test_websockets.py index 48acc2d6..4ae4cf45 100644 --- a/test/websockets/test_websockets.py +++ b/test/websockets/test_websockets.py @@ -64,15 +64,14 @@ class WebSocketsClient(tcp.TCPClient):          preamble = b'GET / HTTP/1.1'          self.wfile.write(preamble + b"\r\n")          headers = self.protocol.client_handshake_headers() -        self.client_nonce = headers["sec-websocket-key"] +        self.client_nonce = headers["sec-websocket-key"].encode("ascii")          self.wfile.write(bytes(headers) + b"\r\n")          self.wfile.flush()          resp = read_response(self.rfile, treq(method="GET"))          server_nonce = self.protocol.check_server_handshake(resp.headers) -        if not server_nonce == self.protocol.create_server_nonce( -                self.client_nonce): +        if not server_nonce == self.protocol.create_server_nonce(self.client_nonce):              self.close()      def read_next_message(self): @@ -207,14 +206,14 @@ class TestFrameHeader:              fin=True,              payload_length=10          ) -        assert f.human_readable() +        assert repr(f)          f = websockets.FrameHeader() -        assert f.human_readable() +        assert repr(f)      def test_funky(self):          f = websockets.FrameHeader(masking_key=b"test", mask=False) -        bytes = f.to_bytes() -        f2 = websockets.FrameHeader.from_file(tutils.treader(bytes)) +        raw = bytes(f) +        f2 = websockets.FrameHeader.from_file(tutils.treader(raw))          assert not f2.mask      def test_violations(self):  | 
