diff options
| -rw-r--r-- | .travis.yml | 6 | ||||
| -rw-r--r-- | netlib/encoding.py | 20 | ||||
| -rw-r--r-- | netlib/http/models.py | 48 | ||||
| -rw-r--r-- | netlib/odict.py | 25 | ||||
| -rw-r--r-- | netlib/tutils.py | 4 | ||||
| -rw-r--r-- | netlib/utils.py | 22 | ||||
| -rw-r--r-- | test/http/test_models.py | 8 | ||||
| -rw-r--r-- | test/test_encoding.py | 10 | ||||
| -rw-r--r-- | test/test_odict.py | 40 | ||||
| -rw-r--r-- | test/test_socks.py | 55 | ||||
| -rw-r--r-- | test/test_utils.py | 10 | 
11 files changed, 105 insertions, 143 deletions
| diff --git a/.travis.yml b/.travis.yml index fa997542..7e18176c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,11 @@ matrix:            packages:              - libssl-dev      - python: 3.5 -      script: "nosetests --with-cov --cov-report term-missing test/http/http1" +      script: +        - nosetests --with-cov --cov-report term-missing test/http/http1 +        - nosetests --with-cov --cov-report term-missing test/test_utils.py +        - nosetests --with-cov --cov-report term-missing test/test_encoding.py +        - nosetests --with-cov --cov-report term-missing test/test_odict.py      - python: pypy      - python: pypy        env: OPENSSL=1.0.2 diff --git a/netlib/encoding.py b/netlib/encoding.py index 06830f2c..8ac59905 100644 --- a/netlib/encoding.py +++ b/netlib/encoding.py @@ -5,28 +5,30 @@ from __future__ import absolute_import  from io import BytesIO  import gzip  import zlib +from .utils import always_byte_args -__ALL__ = ["ENCODINGS"] -ENCODINGS = {"identity", "gzip", "deflate"} +ENCODINGS = {b"identity", b"gzip", b"deflate"} +@always_byte_args("ascii", "ignore")  def decode(e, content):      encoding_map = { -        "identity": identity, -        "gzip": decode_gzip, -        "deflate": decode_deflate, +        b"identity": identity, +        b"gzip": decode_gzip, +        b"deflate": decode_deflate,      }      if e not in encoding_map:          return None      return encoding_map[e](content) +@always_byte_args("ascii", "ignore")  def encode(e, content):      encoding_map = { -        "identity": identity, -        "gzip": encode_gzip, -        "deflate": encode_deflate, +        b"identity": identity, +        b"gzip": encode_gzip, +        b"deflate": encode_deflate,      }      if e not in encoding_map:          return None @@ -80,3 +82,5 @@ def encode_deflate(content):          Returns compressed content, always including zlib header and checksum.      """      return zlib.compress(content) + +__all__ = ["ENCODINGS", "encode", "decode"] diff --git a/netlib/http/models.py b/netlib/http/models.py index 54b8b112..bc681de3 100644 --- a/netlib/http/models.py +++ b/netlib/http/models.py @@ -136,7 +136,7 @@ class Headers(MutableMapping, object):      def __len__(self):          return len(set(name.lower() for name, _ in self.fields)) -    #__hash__ = object.__hash__ +    # __hash__ = object.__hash__      def _index(self, name):          name = name.lower() @@ -227,11 +227,11 @@ class Request(Message):      # This list is adopted legacy code.      # We probably don't need to strip off keep-alive.      _headers_to_strip_off = [ -        b'Proxy-Connection', -        b'Keep-Alive', -        b'Connection', -        b'Transfer-Encoding', -        b'Upgrade', +        'Proxy-Connection', +        'Keep-Alive', +        'Connection', +        'Transfer-Encoding', +        'Upgrade',      ]      def __init__( @@ -275,8 +275,8 @@ class Request(Message):              response. That is, we remove ETags and If-Modified-Since headers.          """          delheaders = [ -            b"if-modified-since", -            b"if-none-match", +            b"If-Modified-Since", +            b"If-None-Match",          ]          for i in delheaders:              self.headers.pop(i, None) @@ -286,16 +286,16 @@ class Request(Message):              Modifies this request to remove headers that will compress the              resource's data.          """ -        self.headers[b"accept-encoding"] = b"identity" +        self.headers["Accept-Encoding"] = b"identity"      def constrain_encoding(self):          """              Limits the permissible Accept-Encoding values, based on what we can              decode appropriately.          """ -        accept_encoding = self.headers.get(b"accept-encoding") +        accept_encoding = self.headers.get(b"Accept-Encoding")          if accept_encoding: -            self.headers[b"accept-encoding"] = ( +            self.headers["Accept-Encoding"] = (                  ', '.join(                      e                      for e in encoding.ENCODINGS @@ -316,9 +316,9 @@ class Request(Message):              indicates non-form data.          """          if self.body: -            if HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): +            if HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower():                  return self.get_form_urlencoded() -            elif HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): +            elif HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower():                  return self.get_form_multipart()          return ODict([]) @@ -328,12 +328,12 @@ class Request(Message):              Returns an empty ODict if there is no data or the content-type              indicates non-form data.          """ -        if self.body and HDR_FORM_URLENCODED in self.headers.get("content-type","").lower(): +        if self.body and HDR_FORM_URLENCODED in self.headers.get("Content-Type", "").lower():              return ODict(utils.urldecode(self.body))          return ODict([])      def get_form_multipart(self): -        if self.body and HDR_FORM_MULTIPART in self.headers.get("content-type","").lower(): +        if self.body and HDR_FORM_MULTIPART in self.headers.get("Content-Type", "").lower():              return ODict(                  utils.multipartdecode(                      self.headers, @@ -405,9 +405,9 @@ class Request(Message):              but not the resolved name. This is disabled by default, as an              attacker may spoof the host header to confuse an analyst.          """ -        if hostheader and b"Host" in self.headers: +        if hostheader and "Host" in self.headers:              try: -                return self.headers[b"Host"].decode("idna") +                return self.headers["Host"].decode("idna")              except ValueError:                  pass          if self.host: @@ -426,7 +426,7 @@ class Request(Message):              Returns a possibly empty netlib.odict.ODict object.          """          ret = ODict() -        for i in self.headers.get_all("cookie"): +        for i in self.headers.get_all("Cookie"):              ret.extend(cookies.parse_cookie_header(i))          return ret @@ -468,9 +468,9 @@ class Request(Message):  class Response(Message):      _headers_to_strip_off = [ -        b'Proxy-Connection', -        b'Alternate-Protocol', -        b'Alt-Svc', +        'Proxy-Connection', +        'Alternate-Protocol', +        'Alt-Svc',      ]      def __init__( @@ -498,7 +498,7 @@ class Response(Message):          return "<Response: {status_code} {msg} ({contenttype}, {size})>".format(              status_code=self.status_code,              msg=self.msg, -            contenttype=self.headers.get("content-type", "unknown content type"), +            contenttype=self.headers.get("Content-Type", "unknown content type"),              size=size)      def get_cookies(self): @@ -511,7 +511,7 @@ class Response(Message):              attributes (e.g. HTTPOnly) are indicated by a Null value.          """          ret = [] -        for header in self.headers.get_all(b"set-cookie"): +        for header in self.headers.get_all("Set-Cookie"):              v = cookies.parse_set_cookie_header(header)              if v:                  name, value, attrs = v @@ -534,4 +534,4 @@ class Response(Message):                      i[1][1]                  )              ) -        self.headers.set_all(b"Set-Cookie", values) +        self.headers.set_all("Set-Cookie", values) diff --git a/netlib/odict.py b/netlib/odict.py index 11d5d52a..1124b23a 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,6 +1,7 @@  from __future__ import (absolute_import, print_function, division)  import re  import copy +import six  def safe_subn(pattern, repl, target, *args, **kwargs): @@ -67,10 +68,10 @@ class ODict(object):              Sets the values for key k. If there are existing values for this              key, they are cleared.          """ -        if isinstance(valuelist, basestring): +        if isinstance(valuelist, six.text_type) or isinstance(valuelist, six.binary_type):              raise ValueError(                  "Expected list of values instead of string. " -                "Example: odict['Host'] = ['www.example.com']" +                "Example: odict[b'Host'] = [b'www.example.com']"              )          kc = self._kconv(k)          new = [] @@ -134,13 +135,6 @@ class ODict(object):      def __repr__(self):          return repr(self.lst) -    def format(self): -        elements = [] -        for itm in self.lst: -            elements.append(itm[0] + ": " + str(itm[1])) -        elements.append("") -        return "\r\n".join(elements) -      def in_any(self, key, value, caseless=False):          """              Do any of the values matching key contain value? @@ -156,19 +150,6 @@ class ODict(object):                  return True          return False -    def match_re(self, expr): -        """ -            Match the regular expression against each (key, value) pair. For -            each pair a string of the following format is matched against: - -            "key: value" -        """ -        for k, v in self.lst: -            s = "%s: %s" % (k, v) -            if re.search(expr, s): -                return True -        return False -      def replace(self, pattern, repl, *args, **kwargs):          """              Replaces a regular expression pattern with repl in both keys and diff --git a/netlib/tutils.py b/netlib/tutils.py index b69495a3..746e1488 100644 --- a/netlib/tutils.py +++ b/netlib/tutils.py @@ -123,9 +123,7 @@ def tresp(**kwargs):          status_code=200,          msg=b"OK",          headers=Headers(header_response=b"svalue"), -        body=b"message", -        timestamp_start=time.time(), -        timestamp_end=time.time() +        body=b"message"      )      default.update(kwargs)      return Response(**default) diff --git a/netlib/utils.py b/netlib/utils.py index 14b428d7..6fed44b6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -246,7 +246,7 @@ def unparse_url(scheme, host, port, path=""):      """          Returns a URL string, constructed from the specified compnents.      """ -    return "%s://%s%s" % (scheme, hostport(scheme, host, port), path) +    return b"%s://%s%s" % (scheme, hostport(scheme, host, port), path)  def urlencode(s): @@ -295,7 +295,7 @@ def multipartdecode(headers, content):      """          Takes a multipart boundary encoded string and returns list of (key, value) tuples.      """ -    v = headers.get("content-type") +    v = headers.get(b"Content-Type")      if v:          v = parse_content_type(v)          if not v: @@ -304,33 +304,33 @@ def multipartdecode(headers, content):          if not boundary:              return [] -        rx = re.compile(r'\bname="([^"]+)"') +        rx = re.compile(br'\bname="([^"]+)"')          r = [] -        for i in content.split("--" + boundary): +        for i in content.split(b"--" + boundary):              parts = i.splitlines() -            if len(parts) > 1 and parts[0][0:2] != "--": +            if len(parts) > 1 and parts[0][0:2] != b"--":                  match = rx.search(parts[1])                  if match:                      key = match.group(1) -                    value = "".join(parts[3 + parts[2:].index(""):]) +                    value = b"".join(parts[3 + parts[2:].index(b""):])                      r.append((key, value))          return r      return [] -def always_bytes(unicode_or_bytes, encoding): +def always_bytes(unicode_or_bytes, *encode_args):      if isinstance(unicode_or_bytes, six.text_type): -        return unicode_or_bytes.encode(encoding) +        return unicode_or_bytes.encode(*encode_args)      return unicode_or_bytes -def always_byte_args(encoding): +def always_byte_args(*encode_args):      """Decorator that transparently encodes all arguments passed as unicode"""      def decorator(fun):          def _fun(*args, **kwargs): -            args = [always_bytes(arg, encoding) for arg in args] -            kwargs = {k: always_bytes(v, encoding) for k, v in six.iteritems(kwargs)} +            args = [always_bytes(arg, *encode_args) for arg in args] +            kwargs = {k: always_bytes(v, *encode_args) for k, v in six.iteritems(kwargs)}              return fun(*args, **kwargs)          return _fun      return decorator diff --git a/test/http/test_models.py b/test/http/test_models.py index 8fce2e9d..c3ab4d0f 100644 --- a/test/http/test_models.py +++ b/test/http/test_models.py @@ -36,8 +36,8 @@ class TestRequest(object):          assert isinstance(req.headers, Headers)      def test_equal(self): -        a = tutils.treq() -        b = tutils.treq() +        a = tutils.treq(timestamp_start=42, timestamp_end=43) +        b = tutils.treq(timestamp_start=42, timestamp_end=43)          assert a == b          assert not a == 'foo' @@ -319,8 +319,8 @@ class TestResponse(object):          assert isinstance(resp.headers, Headers)      def test_equal(self): -        a = tutils.tresp() -        b = tutils.tresp() +        a = tutils.tresp(timestamp_start=42, timestamp_end=43) +        b = tutils.tresp(timestamp_start=42, timestamp_end=43)          assert a == b          assert not a == 'foo' diff --git a/test/test_encoding.py b/test/test_encoding.py index 9da3a38d..90f99338 100644 --- a/test/test_encoding.py +++ b/test/test_encoding.py @@ -2,10 +2,12 @@ from netlib import encoding  def test_identity(): -    assert "string" == encoding.decode("identity", "string") -    assert "string" == encoding.encode("identity", "string") -    assert not encoding.encode("nonexistent", "string") -    assert None == encoding.decode("nonexistent encoding", "string") +    assert b"string" == encoding.decode("identity", b"string") +    assert b"string" == encoding.encode("identity", b"string") +    assert b"string" == encoding.encode(b"identity", b"string") +    assert b"string" == encoding.decode(b"identity", b"string") +    assert not encoding.encode("nonexistent", b"string") +    assert not encoding.decode("nonexistent encoding", b"string")  def test_gzip(): diff --git a/test/test_odict.py b/test/test_odict.py index be3d862d..962c0daa 100644 --- a/test/test_odict.py +++ b/test/test_odict.py @@ -1,7 +1,7 @@  from netlib import odict, tutils -class TestODict: +class TestODict(object):      def setUp(self):          self.od = odict.ODict() @@ -13,21 +13,10 @@ class TestODict:      def test_str_err(self):          h = odict.ODict() -        tutils.raises(ValueError, h.__setitem__, "key", "foo") - -    def test_dictToHeader1(self): -        self.od.add("one", "uno") -        self.od.add("two", "due") -        self.od.add("two", "tre") -        expected = [ -            "one: uno\r\n", -            "two: due\r\n", -            "two: tre\r\n", -            "\r\n" -        ] -        out = self.od.format() -        for i in expected: -            assert out.find(i) >= 0 +        with tutils.raises(ValueError): +            h["key"] = u"foo" +        with tutils.raises(ValueError): +            h["key"] = b"foo"      def test_getset_state(self):          self.od.add("foo", 1) @@ -40,23 +29,6 @@ class TestODict:          b.load_state(state)          assert b == self.od -    def test_dictToHeader2(self): -        self.od["one"] = ["uno"] -        expected1 = "one: uno\r\n" -        expected2 = "\r\n" -        out = self.od.format() -        assert out.find(expected1) >= 0 -        assert out.find(expected2) >= 0 - -    def test_match_re(self): -        h = odict.ODict() -        h.add("one", "uno") -        h.add("two", "due") -        h.add("two", "tre") -        assert h.match_re("uno") -        assert h.match_re("two: due") -        assert not h.match_re("nonono") -      def test_in_any(self):          self.od["one"] = ["atwoa", "athreea"]          assert self.od.in_any("one", "two") @@ -122,7 +94,7 @@ class TestODict:          assert a["a"] == ["b", "b"] -class TestODictCaseless: +class TestODictCaseless(object):      def setUp(self):          self.od = odict.ODictCaseless() diff --git a/test/test_socks.py b/test/test_socks.py index 3d109f42..65a0f0eb 100644 --- a/test/test_socks.py +++ b/test/test_socks.py @@ -1,12 +1,12 @@ -from cStringIO import StringIO +from io import BytesIO  import socket  from nose.plugins.skip import SkipTest  from netlib import socks, tcp, tutils  def test_client_greeting(): -    raw = tutils.treader("\x05\x02\x00\xBE\xEF") -    out = StringIO() +    raw = tutils.treader(b"\x05\x02\x00\xBE\xEF") +    out = BytesIO()      msg = socks.ClientGreeting.from_file(raw)      msg.assert_socks5()      msg.to_file(out) @@ -19,11 +19,11 @@ def test_client_greeting():  def test_client_greeting_assert_socks5(): -    raw = tutils.treader("\x00\x00") +    raw = tutils.treader(b"\x00\x00")      msg = socks.ClientGreeting.from_file(raw)      tutils.raises(socks.SocksError, msg.assert_socks5) -    raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) +    raw = tutils.treader(b"HTTP/1.1 200 OK" + " " * 100)      msg = socks.ClientGreeting.from_file(raw)      try:          msg.assert_socks5() @@ -33,7 +33,7 @@ def test_client_greeting_assert_socks5():      else:          assert False -    raw = tutils.treader("GET / HTTP/1.1" + " " * 100) +    raw = tutils.treader(b"GET / HTTP/1.1" + " " * 100)      msg = socks.ClientGreeting.from_file(raw)      try:          msg.assert_socks5() @@ -43,7 +43,7 @@ def test_client_greeting_assert_socks5():      else:          assert False -    raw = tutils.treader("XX") +    raw = tutils.treader(b"XX")      tutils.raises(          socks.SocksError,          socks.ClientGreeting.from_file, @@ -52,8 +52,8 @@ def test_client_greeting_assert_socks5():  def test_server_greeting(): -    raw = tutils.treader("\x05\x02") -    out = StringIO() +    raw = tutils.treader(b"\x05\x02") +    out = BytesIO()      msg = socks.ServerGreeting.from_file(raw)      msg.assert_socks5()      msg.to_file(out) @@ -64,7 +64,7 @@ def test_server_greeting():  def test_server_greeting_assert_socks5(): -    raw = tutils.treader("HTTP/1.1 200 OK" + " " * 100) +    raw = tutils.treader(b"HTTP/1.1 200 OK" + " " * 100)      msg = socks.ServerGreeting.from_file(raw)      try:          msg.assert_socks5() @@ -74,7 +74,7 @@ def test_server_greeting_assert_socks5():      else:          assert False -    raw = tutils.treader("GET / HTTP/1.1" + " " * 100) +    raw = tutils.treader(b"GET / HTTP/1.1" + " " * 100)      msg = socks.ServerGreeting.from_file(raw)      try:          msg.assert_socks5() @@ -86,36 +86,37 @@ def test_server_greeting_assert_socks5():  def test_message(): -    raw = tutils.treader("\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") -    out = StringIO() +    raw = tutils.treader(b"\x05\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") +    out = BytesIO()      msg = socks.Message.from_file(raw)      msg.assert_socks5() -    assert raw.read(2) == "\xBE\xEF" +    assert raw.read(2) == b"\xBE\xEF"      msg.to_file(out)      assert out.getvalue() == raw.getvalue()[:-2]      assert msg.ver == 5      assert msg.msg == 0x01      assert msg.atyp == 0x03 -    assert msg.addr == ("example.com", 0xDEAD) +    assert msg.addr == (b"example.com", 0xDEAD)  def test_message_assert_socks5(): -    raw = tutils.treader("\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF") +    raw = tutils.treader(b"\xEE\x01\x00\x03\x0bexample.com\xDE\xAD\xBE\xEF")      msg = socks.Message.from_file(raw)      tutils.raises(socks.SocksError, msg.assert_socks5)  def test_message_ipv4():      # Test ATYP=0x01 (IPV4) -    raw = tutils.treader("\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") -    out = StringIO() +    raw = tutils.treader(b"\x05\x01\x00\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") +    out = BytesIO()      msg = socks.Message.from_file(raw) -    assert raw.read(2) == "\xBE\xEF" +    left = raw.read(2) +    assert left == b"\xBE\xEF"      msg.to_file(out)      assert out.getvalue() == raw.getvalue()[:-2] -    assert msg.addr == ("127.0.0.1", 0xDEAD) +    assert msg.addr == (b"127.0.0.1", 0xDEAD)  def test_message_ipv6(): @@ -125,14 +126,14 @@ def test_message_ipv6():      ipv6_addr = "2001:db8:85a3:8d3:1319:8a2e:370:7344"      raw = tutils.treader( -        "\x05\x01\x00\x04" + +        b"\x05\x01\x00\x04" +          socket.inet_pton(              socket.AF_INET6,              ipv6_addr) + -        "\xDE\xAD\xBE\xEF") -    out = StringIO() +        b"\xDE\xAD\xBE\xEF") +    out = BytesIO()      msg = socks.Message.from_file(raw) -    assert raw.read(2) == "\xBE\xEF" +    assert raw.read(2) == b"\xBE\xEF"      msg.to_file(out)      assert out.getvalue() == raw.getvalue()[:-2] @@ -140,13 +141,13 @@ def test_message_ipv6():  def test_message_invalid_rsv(): -    raw = tutils.treader("\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") +    raw = tutils.treader(b"\x05\x01\xFF\x01\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")      tutils.raises(socks.SocksError, socks.Message.from_file, raw)  def test_message_unknown_atyp(): -    raw = tutils.treader("\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF") +    raw = tutils.treader(b"\x05\x02\x00\x02\x7f\x00\x00\x01\xDE\xAD\xBE\xEF")      tutils.raises(socks.SocksError, socks.Message.from_file, raw)      m = socks.Message(5, 1, 0x02, tcp.Address(("example.com", 5050))) -    tutils.raises(socks.SocksError, m.to_file, StringIO()) +    tutils.raises(socks.SocksError, m.to_file, BytesIO()) diff --git a/test/test_utils.py b/test/test_utils.py index 0db75578..ff27486c 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -84,10 +84,10 @@ def test_parse_url():  def test_unparse_url(): -    assert utils.unparse_url("http", "foo.com", 99, "") == "http://foo.com:99" -    assert utils.unparse_url("http", "foo.com", 80, "") == "http://foo.com" -    assert utils.unparse_url("https", "foo.com", 80, "") == "https://foo.com:80" -    assert utils.unparse_url("https", "foo.com", 443, "") == "https://foo.com" +    assert utils.unparse_url(b"http", b"foo.com", 99, b"") == b"http://foo.com:99" +    assert utils.unparse_url(b"http", b"foo.com", 80, b"/bar") == b"http://foo.com/bar" +    assert utils.unparse_url(b"https", b"foo.com", 80, b"") == b"https://foo.com:80" +    assert utils.unparse_url(b"https", b"foo.com", 443, b"") == b"https://foo.com"  def test_urlencode(): @@ -122,7 +122,7 @@ def test_multipartdecode():          "--{0}\n"          "Content-Disposition: form-data; name=\"field2\"\n\n"          "value2\n" -        "--{0}--".format(boundary).encode("ascii") +        "--{0}--".format(boundary.decode()).encode()      )      form = utils.multipartdecode(headers, content) | 
