diff options
| author | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2015-05-27 11:18:54 +0200 | 
|---|---|---|
| committer | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2015-05-27 11:19:11 +0200 | 
| commit | e3d390e036430b9d7cc4b93679229fe118eb583a (patch) | |
| tree | 1f10a2a59c8abef4bd24c8189d14602264b20fad /netlib | |
| parent | f7b75ba8c21c66e38da81adf3b2c573b7dae87f3 (diff) | |
| download | mitmproxy-e3d390e036430b9d7cc4b93679229fe118eb583a.tar.gz mitmproxy-e3d390e036430b9d7cc4b93679229fe118eb583a.tar.bz2 mitmproxy-e3d390e036430b9d7cc4b93679229fe118eb583a.zip | |
cleanup code with autopep8
run the following command:
  $ autopep8 -i -r -a -a .
Diffstat (limited to 'netlib')
| -rw-r--r-- | netlib/certutils.py | 56 | ||||
| -rw-r--r-- | netlib/h2/frame.py | 34 | ||||
| -rw-r--r-- | netlib/h2/h2.py | 30 | ||||
| -rw-r--r-- | netlib/http.py | 13 | ||||
| -rw-r--r-- | netlib/http_auth.py | 22 | ||||
| -rw-r--r-- | netlib/http_cookies.py | 10 | ||||
| -rw-r--r-- | netlib/http_status.py | 84 | ||||
| -rw-r--r-- | netlib/odict.py | 10 | ||||
| -rw-r--r-- | netlib/socks.py | 43 | ||||
| -rw-r--r-- | netlib/tcp.py | 62 | ||||
| -rw-r--r-- | netlib/test.py | 24 | ||||
| -rw-r--r-- | netlib/utils.py | 10 | ||||
| -rw-r--r-- | netlib/websockets.py | 87 | ||||
| -rw-r--r-- | netlib/wsgi.py | 52 | 
14 files changed, 308 insertions, 229 deletions
| diff --git a/netlib/certutils.py b/netlib/certutils.py index f5375c03..da0e3355 100644 --- a/netlib/certutils.py +++ b/netlib/certutils.py @@ -1,12 +1,15 @@  from __future__ import (absolute_import, print_function, division) -import os, ssl, time, datetime +import os +import ssl +import time +import datetime  import itertools  from pyasn1.type import univ, constraint, char, namedtype, tag  from pyasn1.codec.der.decoder import decode  from pyasn1.error import PyAsn1Error  import OpenSSL -DEFAULT_EXP = 157680000 # = 24 * 60 * 60 * 365 * 5 +DEFAULT_EXP = 157680000  # = 24 * 60 * 60 * 365 * 5  # Generated with "openssl dhparam". It's too slow to generate this on startup.  DEFAULT_DHPARAM = """-----BEGIN DH PARAMETERS-----  MIGHAoGBAOdPzMbYgoYfO3YBYauCLRlE8X1XypTiAjoeCFD0qWRx8YUsZ6Sj20W5 @@ -14,31 +17,32 @@ zsfQxlZfKovo3f2MftjkDkbI/C/tDgxoe0ZPbjy5CjdOhkzxn0oTbKTs16Rw8DyK  1LjTR65sQJkJEdgsX8TSi/cicCftJZl9CaZEaObF2bdgSgGK+PezAgEC  -----END DH PARAMETERS-----""" +  def create_ca(o, cn, exp):      key = OpenSSL.crypto.PKey()      key.generate_key(OpenSSL.crypto.TYPE_RSA, 1024)      cert = OpenSSL.crypto.X509() -    cert.set_serial_number(int(time.time()*10000)) +    cert.set_serial_number(int(time.time() * 10000))      cert.set_version(2)      cert.get_subject().CN = cn      cert.get_subject().O = o -    cert.gmtime_adj_notBefore(-3600*48) +    cert.gmtime_adj_notBefore(-3600 * 48)      cert.gmtime_adj_notAfter(exp)      cert.set_issuer(cert.get_subject())      cert.set_pubkey(key)      cert.add_extensions([ -      OpenSSL.crypto.X509Extension("basicConstraints", True, -                                   "CA:TRUE"), -      OpenSSL.crypto.X509Extension("nsCertType", False, -                                   "sslCA"), -      OpenSSL.crypto.X509Extension("extendedKeyUsage", False, -                                    "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" -                                    ), -      OpenSSL.crypto.X509Extension("keyUsage", True, -                                   "keyCertSign, cRLSign"), -      OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", -                                   subject=cert), -      ]) +        OpenSSL.crypto.X509Extension("basicConstraints", True, +                                     "CA:TRUE"), +        OpenSSL.crypto.X509Extension("nsCertType", False, +                                     "sslCA"), +        OpenSSL.crypto.X509Extension("extendedKeyUsage", False, +                                     "serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC" +                                     ), +        OpenSSL.crypto.X509Extension("keyUsage", True, +                                     "keyCertSign, cRLSign"), +        OpenSSL.crypto.X509Extension("subjectKeyIdentifier", False, "hash", +                                     subject=cert), +    ])      cert.sign(key, "sha1")      return key, cert @@ -56,15 +60,15 @@ def dummy_cert(privkey, cacert, commonname, sans):      """      ss = []      for i in sans: -        ss.append("DNS: %s"%i) +        ss.append("DNS: %s" % i)      ss = ", ".join(ss)      cert = OpenSSL.crypto.X509() -    cert.gmtime_adj_notBefore(-3600*48) +    cert.gmtime_adj_notBefore(-3600 * 48)      cert.gmtime_adj_notAfter(DEFAULT_EXP)      cert.set_issuer(cacert.get_subject())      cert.get_subject().CN = commonname -    cert.set_serial_number(int(time.time()*10000)) +    cert.set_serial_number(int(time.time() * 10000))      if ss:          cert.set_version(2)          cert.add_extensions([OpenSSL.crypto.X509Extension("subjectAltName", False, ss)]) @@ -114,6 +118,7 @@ def dummy_cert(privkey, cacert, commonname, sans):  class CertStoreEntry(object): +      def __init__(self, cert, privatekey, chain_file):          self.cert = cert          self.privatekey = privatekey @@ -121,9 +126,11 @@ class CertStoreEntry(object):  class CertStore(object): +      """          Implements an in-memory certificate store.      """ +      def __init__(self, default_privatekey, default_ca, default_chain_file, dhparams=None):          self.default_privatekey = default_privatekey          self.default_ca = default_ca @@ -144,11 +151,11 @@ class CertStore(object):          if bio != OpenSSL.SSL._ffi.NULL:              bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)              dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams( -                    bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL -                ) +                bio, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL, OpenSSL.SSL._ffi.NULL +            )              dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)              return dh -     +      @classmethod      def from_store(cls, path, basename):          ca_path = os.path.join(path, basename + "-ca.pem") @@ -277,8 +284,8 @@ class _GeneralName(univ.Choice):      # other types.      componentType = namedtype.NamedTypes(          namedtype.NamedType('dNSName', char.IA5String().subtype( -                implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) -            ) +            implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2) +        )          ),      ) @@ -289,6 +296,7 @@ class _GeneralNames(univ.SequenceOf):  class SSLCert(object): +      def __init__(self, cert):          """              Returns a (common name, [subject alternative names]) tuple. diff --git a/netlib/h2/frame.py b/netlib/h2/frame.py index 52cc2992..d846b3b9 100644 --- a/netlib/h2/frame.py +++ b/netlib/h2/frame.py @@ -5,8 +5,11 @@ import struct  import io  from .. import utils, odict, tcp +from functools import reduce +  class Frame(object): +      """          Baseclass Frame          contains header @@ -53,6 +56,7 @@ class Frame(object):      def __eq__(self, other):          return self.to_bytes() == other.to_bytes() +  class DataFrame(Frame):      TYPE = 0x0      VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_PADDED] @@ -89,11 +93,13 @@ class DataFrame(Frame):          return b +  class HeadersFrame(Frame):      TYPE = 0x1      VALID_FLAGS = [Frame.FLAG_END_STREAM, Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED, Frame.FLAG_PRIORITY] -    def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', pad_length=0, exclusive=False, stream_dependency=0x0, weight=0): +    def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, header_block_fragment=b'', +                 pad_length=0, exclusive=False, stream_dependency=0x0, weight=0):          super(HeadersFrame, self).__init__(length, flags, stream_id)          self.header_block_fragment = header_block_fragment          self.pad_length = pad_length @@ -137,6 +143,7 @@ class HeadersFrame(Frame):          return b +  class PriorityFrame(Frame):      TYPE = 0x2      VALID_FLAGS = [] @@ -166,6 +173,7 @@ class PriorityFrame(Frame):          return struct.pack('!LB', (int(self.exclusive) << 31) | self.stream_dependency, self.weight) +  class RstStreamFrame(Frame):      TYPE = 0x3      VALID_FLAGS = [] @@ -186,18 +194,19 @@ class RstStreamFrame(Frame):          return struct.pack('!L', self.error_code) +  class SettingsFrame(Frame):      TYPE = 0x4      VALID_FLAGS = [Frame.FLAG_ACK]      SETTINGS = utils.BiDi( -        SETTINGS_HEADER_TABLE_SIZE = 0x1, -        SETTINGS_ENABLE_PUSH = 0x2, -        SETTINGS_MAX_CONCURRENT_STREAMS = 0x3, -        SETTINGS_INITIAL_WINDOW_SIZE = 0x4, -        SETTINGS_MAX_FRAME_SIZE = 0x5, -        SETTINGS_MAX_HEADER_LIST_SIZE = 0x6, -        ) +        SETTINGS_HEADER_TABLE_SIZE=0x1, +        SETTINGS_ENABLE_PUSH=0x2, +        SETTINGS_MAX_CONCURRENT_STREAMS=0x3, +        SETTINGS_INITIAL_WINDOW_SIZE=0x4, +        SETTINGS_MAX_FRAME_SIZE=0x5, +        SETTINGS_MAX_HEADER_LIST_SIZE=0x6, +    )      def __init__(self, length=0, flags=Frame.FLAG_NO_FLAGS, stream_id=0x0, settings={}):          super(SettingsFrame, self).__init__(length, flags, stream_id) @@ -208,7 +217,7 @@ class SettingsFrame(Frame):          f = self(length=length, flags=flags, stream_id=stream_id)          for i in xrange(0, len(payload), 6): -            identifier, value = struct.unpack("!HL", payload[i:i+6]) +            identifier, value = struct.unpack("!HL", payload[i:i + 6])              f.settings[identifier] = value          return f @@ -223,6 +232,7 @@ class SettingsFrame(Frame):          return b +  class PushPromiseFrame(Frame):      TYPE = 0x5      VALID_FLAGS = [Frame.FLAG_END_HEADERS, Frame.FLAG_PADDED] @@ -267,6 +277,7 @@ class PushPromiseFrame(Frame):          return b +  class PingFrame(Frame):      TYPE = 0x6      VALID_FLAGS = [Frame.FLAG_ACK] @@ -289,6 +300,7 @@ class PingFrame(Frame):          b += b'\0' * (8 - len(b))          return b +  class GoAwayFrame(Frame):      TYPE = 0x7      VALID_FLAGS = [] @@ -317,6 +329,7 @@ class GoAwayFrame(Frame):          b += bytes(self.data)          return b +  class WindowUpdateFrame(Frame):      TYPE = 0x8      VALID_FLAGS = [] @@ -335,11 +348,12 @@ class WindowUpdateFrame(Frame):          return f      def payload_bytes(self): -        if self.window_size_increment <= 0 or self.window_size_increment >= 2**31: +        if self.window_size_increment <= 0 or self.window_size_increment >= 2 ** 31:              raise ValueError('Window Szie Increment MUST be greater than 0 and less than 2^31.')          return struct.pack('!L', self.window_size_increment & 0x7FFFFFFF) +  class ContinuationFrame(Frame):      TYPE = 0x9      VALID_FLAGS = [Frame.FLAG_END_HEADERS] diff --git a/netlib/h2/h2.py b/netlib/h2/h2.py index 5d74c1c8..1a39a635 100644 --- a/netlib/h2/h2.py +++ b/netlib/h2/h2.py @@ -8,18 +8,18 @@ import io  CLIENT_CONNECTION_PREFACE = '505249202a20485454502f322e300d0a0d0a534d0d0a0d0a'  ERROR_CODES = utils.BiDi( -    NO_ERROR = 0x0, -    PROTOCOL_ERROR = 0x1, -    INTERNAL_ERROR = 0x2, -    FLOW_CONTROL_ERROR = 0x3, -    SETTINGS_TIMEOUT = 0x4, -    STREAM_CLOSED = 0x5, -    FRAME_SIZE_ERROR = 0x6, -    REFUSED_STREAM = 0x7, -    CANCEL = 0x8, -    COMPRESSION_ERROR = 0x9, -    CONNECT_ERROR = 0xa, -    ENHANCE_YOUR_CALM = 0xb, -    INADEQUATE_SECURITY = 0xc, -    HTTP_1_1_REQUIRED = 0xd -    ) +    NO_ERROR=0x0, +    PROTOCOL_ERROR=0x1, +    INTERNAL_ERROR=0x2, +    FLOW_CONTROL_ERROR=0x3, +    SETTINGS_TIMEOUT=0x4, +    STREAM_CLOSED=0x5, +    FRAME_SIZE_ERROR=0x6, +    REFUSED_STREAM=0x7, +    CANCEL=0x8, +    COMPRESSION_ERROR=0x9, +    CONNECT_ERROR=0xa, +    ENHANCE_YOUR_CALM=0xb, +    INADEQUATE_SECURITY=0xc, +    HTTP_1_1_REQUIRED=0xd +) diff --git a/netlib/http.py b/netlib/http.py index 43155486..47658097 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -8,6 +8,7 @@ from . import odict, utils, tcp, http_status  class HttpError(Exception): +      def __init__(self, code, message):          super(HttpError, self).__init__(message)          self.code = code @@ -95,7 +96,7 @@ def read_headers(fp):      """      ret = []      name = '' -    while 1: +    while True:          line = fp.readline()          if not line or line == '\r\n' or line == '\n':              break @@ -337,7 +338,7 @@ def read_http_body_chunked(              otherwise      """      if max_chunk_size is None: -        max_chunk_size = limit or sys.maxint +        max_chunk_size = limit or sys.maxsize      expected_size = expected_http_body_size(          headers, is_request, request_method, response_code @@ -399,10 +400,10 @@ def expected_http_body_size(headers, is_request, request_method, response_code):          request_method = request_method.upper()      if (not is_request and ( -                    request_method == "HEAD" or -                    (request_method == "CONNECT" and response_code == 200) or -                    response_code in [204, 304] or -                    100 <= response_code <= 199)): +            request_method == "HEAD" or +            (request_method == "CONNECT" and response_code == 200) or +            response_code in [204, 304] or +            100 <= response_code <= 199)):          return 0      if has_chunked_encoding(headers):          return None diff --git a/netlib/http_auth.py b/netlib/http_auth.py index 296e094c..261b6654 100644 --- a/netlib/http_auth.py +++ b/netlib/http_auth.py @@ -4,9 +4,11 @@ from . import http  class NullProxyAuth(object): +      """          No proxy auth at all (returns empty challange headers)      """ +      def __init__(self, password_manager):          self.password_manager = password_manager @@ -48,7 +50,7 @@ class BasicProxyAuth(NullProxyAuth):          if not parts:              return False          scheme, username, password = parts -        if scheme.lower()!='basic': +        if scheme.lower() != 'basic':              return False          if not self.password_manager.test(username, password):              return False @@ -56,18 +58,21 @@ class BasicProxyAuth(NullProxyAuth):          return True      def auth_challenge_headers(self): -        return {self.CHALLENGE_HEADER:'Basic realm="%s"'%self.realm} +        return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm}  class PassMan(object): +      def test(self, username, password_token):          return False  class PassManNonAnon(PassMan): +      """          Ensure the user specifies a username, accept any password.      """ +      def test(self, username, password_token):          if username:              return True @@ -75,9 +80,11 @@ class PassManNonAnon(PassMan):  class PassManHtpasswd(PassMan): +      """          Read usernames and passwords from an htpasswd file      """ +      def __init__(self, path):          """              Raises ValueError if htpasswd file is invalid. @@ -90,14 +97,16 @@ class PassManHtpasswd(PassMan):  class PassManSingleUser(PassMan): +      def __init__(self, username, password):          self.username, self.password = username, password      def test(self, username, password_token): -        return self.username==username and self.password==password_token +        return self.username == username and self.password == password_token  class AuthAction(Action): +      """      Helper class to allow seamless integration int argparse. Example usage:      parser.add_argument( @@ -106,16 +115,18 @@ class AuthAction(Action):          help="Allow access to any user long as a credentials are specified."      )      """ +      def __call__(self, parser, namespace, values, option_string=None):          passman = self.getPasswordManager(values)          authenticator = BasicProxyAuth(passman, "mitmproxy")          setattr(namespace, self.dest, authenticator) -    def getPasswordManager(self, s): # pragma: nocover +    def getPasswordManager(self, s):  # pragma: nocover          raise NotImplementedError()  class SingleuserAuthAction(AuthAction): +      def getPasswordManager(self, s):          if len(s.split(':')) != 2:              raise ArgumentTypeError( @@ -126,11 +137,12 @@ class SingleuserAuthAction(AuthAction):  class NonanonymousAuthAction(AuthAction): +      def getPasswordManager(self, s):          return PassManNonAnon()  class HtpasswdAuthAction(AuthAction): +      def getPasswordManager(self, s):          return PassManHtpasswd(s) - diff --git a/netlib/http_cookies.py b/netlib/http_cookies.py index 8e245891..73e3f589 100644 --- a/netlib/http_cookies.py +++ b/netlib/http_cookies.py @@ -96,7 +96,7 @@ def _read_pairs(s, off=0, specials=()):          specials: a lower-cased list of keys that may contain commas      """      vals = [] -    while 1: +    while True:          lhs, off = _read_token(s, off)          lhs = lhs.lstrip()          if lhs: @@ -135,15 +135,15 @@ def _format_pairs(lst, specials=(), sep="; "):          else:              if k.lower() not in specials and _has_special(v):                  v = ESCAPE.sub(r"\\\1", v) -                v = '"%s"'%v -            vals.append("%s=%s"%(k, v)) +                v = '"%s"' % v +            vals.append("%s=%s" % (k, v))      return sep.join(vals)  def _format_set_cookie_pairs(lst):      return _format_pairs(          lst, -        specials = ("expires", "path") +        specials=("expires", "path")      ) @@ -154,7 +154,7 @@ def _parse_set_cookie_pairs(s):      """      pairs, off = _read_pairs(          s, -        specials = ("expires", "path") +        specials=("expires", "path")      )      return pairs diff --git a/netlib/http_status.py b/netlib/http_status.py index 7dba2d56..dc09f465 100644 --- a/netlib/http_status.py +++ b/netlib/http_status.py @@ -1,51 +1,51 @@  from __future__ import (absolute_import, print_function, division) -CONTINUE                        = 100 -SWITCHING                       = 101 -OK                              = 200 -CREATED                         = 201 -ACCEPTED                        = 202 -NON_AUTHORITATIVE_INFORMATION   = 203 -NO_CONTENT                      = 204 -RESET_CONTENT                   = 205 -PARTIAL_CONTENT                 = 206 -MULTI_STATUS                    = 207 +CONTINUE = 100 +SWITCHING = 101 +OK = 200 +CREATED = 201 +ACCEPTED = 202 +NON_AUTHORITATIVE_INFORMATION = 203 +NO_CONTENT = 204 +RESET_CONTENT = 205 +PARTIAL_CONTENT = 206 +MULTI_STATUS = 207 -MULTIPLE_CHOICE                 = 300 -MOVED_PERMANENTLY               = 301 -FOUND                           = 302 -SEE_OTHER                       = 303 -NOT_MODIFIED                    = 304 -USE_PROXY                       = 305 -TEMPORARY_REDIRECT              = 307 +MULTIPLE_CHOICE = 300 +MOVED_PERMANENTLY = 301 +FOUND = 302 +SEE_OTHER = 303 +NOT_MODIFIED = 304 +USE_PROXY = 305 +TEMPORARY_REDIRECT = 307 -BAD_REQUEST                     = 400 -UNAUTHORIZED                    = 401 -PAYMENT_REQUIRED                = 402 -FORBIDDEN                       = 403 -NOT_FOUND                       = 404 -NOT_ALLOWED                     = 405 -NOT_ACCEPTABLE                  = 406 -PROXY_AUTH_REQUIRED             = 407 -REQUEST_TIMEOUT                 = 408 -CONFLICT                        = 409 -GONE                            = 410 -LENGTH_REQUIRED                 = 411 -PRECONDITION_FAILED             = 412 -REQUEST_ENTITY_TOO_LARGE        = 413 -REQUEST_URI_TOO_LONG            = 414 -UNSUPPORTED_MEDIA_TYPE          = 415 +BAD_REQUEST = 400 +UNAUTHORIZED = 401 +PAYMENT_REQUIRED = 402 +FORBIDDEN = 403 +NOT_FOUND = 404 +NOT_ALLOWED = 405 +NOT_ACCEPTABLE = 406 +PROXY_AUTH_REQUIRED = 407 +REQUEST_TIMEOUT = 408 +CONFLICT = 409 +GONE = 410 +LENGTH_REQUIRED = 411 +PRECONDITION_FAILED = 412 +REQUEST_ENTITY_TOO_LARGE = 413 +REQUEST_URI_TOO_LONG = 414 +UNSUPPORTED_MEDIA_TYPE = 415  REQUESTED_RANGE_NOT_SATISFIABLE = 416 -EXPECTATION_FAILED              = 417 +EXPECTATION_FAILED = 417 -INTERNAL_SERVER_ERROR           = 500 -NOT_IMPLEMENTED                 = 501 -BAD_GATEWAY                     = 502 -SERVICE_UNAVAILABLE             = 503 -GATEWAY_TIMEOUT                 = 504 -HTTP_VERSION_NOT_SUPPORTED      = 505 -INSUFFICIENT_STORAGE_SPACE      = 507 -NOT_EXTENDED                    = 510 +INTERNAL_SERVER_ERROR = 500 +NOT_IMPLEMENTED = 501 +BAD_GATEWAY = 502 +SERVICE_UNAVAILABLE = 503 +GATEWAY_TIMEOUT = 504 +HTTP_VERSION_NOT_SUPPORTED = 505 +INSUFFICIENT_STORAGE_SPACE = 507 +NOT_EXTENDED = 510  RESPONSES = {      # 100 diff --git a/netlib/odict.py b/netlib/odict.py index dd738c55..f52acd50 100644 --- a/netlib/odict.py +++ b/netlib/odict.py @@ -1,5 +1,6 @@  from __future__ import (absolute_import, print_function, division) -import re, copy +import re +import copy  def safe_subn(pattern, repl, target, *args, **kwargs): @@ -12,10 +13,12 @@ def safe_subn(pattern, repl, target, *args, **kwargs):  class ODict(object): +      """          A dictionary-like object for managing ordered (key, value) data. Think          about it as a convenient interface to a list of (key, value) tuples.      """ +      def __init__(self, lst=None):          self.lst = lst or [] @@ -157,7 +160,7 @@ class ODict(object):              "key: value"          """          for k, v in self.lst: -            s = "%s: %s"%(k, v) +            s = "%s: %s" % (k, v)              if re.search(expr, s):                  return True          return False @@ -192,11 +195,12 @@ class ODict(object):          return klass([list(i) for i in state]) -  class ODictCaseless(ODict): +      """          A variant of ODict with "caseless" keys. This version _preserves_ key          case, but does not consider case when setting or getting items.      """ +      def _kconv(self, s):          return s.lower() diff --git a/netlib/socks.py b/netlib/socks.py index 6f9f57bd..5a73c61a 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -6,49 +6,50 @@ from . import tcp, utils  class SocksError(Exception): +      def __init__(self, code, message):          super(SocksError, self).__init__(message)          self.code = code  VERSION = utils.BiDi( -    SOCKS4 = 0x04, -    SOCKS5 = 0x05 +    SOCKS4=0x04, +    SOCKS5=0x05  )  CMD = utils.BiDi( -    CONNECT = 0x01, -    BIND = 0x02, -    UDP_ASSOCIATE = 0x03 +    CONNECT=0x01, +    BIND=0x02, +    UDP_ASSOCIATE=0x03  )  ATYP = utils.BiDi( -    IPV4_ADDRESS = 0x01, -    DOMAINNAME = 0x03, -    IPV6_ADDRESS = 0x04 +    IPV4_ADDRESS=0x01, +    DOMAINNAME=0x03, +    IPV6_ADDRESS=0x04  )  REP = utils.BiDi( -    SUCCEEDED = 0x00, -    GENERAL_SOCKS_SERVER_FAILURE = 0x01, -    CONNECTION_NOT_ALLOWED_BY_RULESET = 0x02, -    NETWORK_UNREACHABLE = 0x03, -    HOST_UNREACHABLE = 0x04, -    CONNECTION_REFUSED = 0x05, -    TTL_EXPIRED = 0x06, -    COMMAND_NOT_SUPPORTED = 0x07, -    ADDRESS_TYPE_NOT_SUPPORTED = 0x08, +    SUCCEEDED=0x00, +    GENERAL_SOCKS_SERVER_FAILURE=0x01, +    CONNECTION_NOT_ALLOWED_BY_RULESET=0x02, +    NETWORK_UNREACHABLE=0x03, +    HOST_UNREACHABLE=0x04, +    CONNECTION_REFUSED=0x05, +    TTL_EXPIRED=0x06, +    COMMAND_NOT_SUPPORTED=0x07, +    ADDRESS_TYPE_NOT_SUPPORTED=0x08,  )  METHOD = utils.BiDi( -    NO_AUTHENTICATION_REQUIRED = 0x00, -    GSSAPI = 0x01, -    USERNAME_PASSWORD = 0x02, -    NO_ACCEPTABLE_METHODS = 0xFF +    NO_AUTHENTICATION_REQUIRED=0x00, +    GSSAPI=0x01, +    USERNAME_PASSWORD=0x02, +    NO_ACCEPTABLE_METHODS=0xFF  ) diff --git a/netlib/tcp.py b/netlib/tcp.py index 399203bb..7c115554 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -22,14 +22,28 @@ OP_NO_SSLv2 = SSL.OP_NO_SSLv2  OP_NO_SSLv3 = SSL.OP_NO_SSLv3 -class NetLibError(Exception): pass -class NetLibDisconnect(NetLibError): pass -class NetLibIncomplete(NetLibError): pass -class NetLibTimeout(NetLibError): pass -class NetLibSSLError(NetLibError): pass +class NetLibError(Exception): +    pass + + +class NetLibDisconnect(NetLibError): +    pass + + +class NetLibIncomplete(NetLibError): +    pass + + +class NetLibTimeout(NetLibError): +    pass + + +class NetLibSSLError(NetLibError): +    pass  class SSLKeyLogger(object): +      def __init__(self, filename):          self.filename = filename          self.f = None @@ -67,6 +81,7 @@ log_ssl_key = SSLKeyLogger.create_logfun(os.getenv("MITMPROXY_SSLKEYLOGFILE") or  class _FileLike(object):      BLOCKSIZE = 1024 * 32 +      def __init__(self, o):          self.o = o          self._log = None @@ -112,6 +127,7 @@ class _FileLike(object):  class Writer(_FileLike): +      def flush(self):          """              May raise NetLibDisconnect @@ -119,7 +135,7 @@ class Writer(_FileLike):          if hasattr(self.o, "flush"):              try:                  self.o.flush() -            except (socket.error, IOError), v: +            except (socket.error, IOError) as v:                  raise NetLibDisconnect(str(v))      def write(self, v): @@ -135,11 +151,12 @@ class Writer(_FileLike):                      r = self.o.write(v)                      self.add_log(v[:r])                      return r -            except (SSL.Error, socket.error) as  e: +            except (SSL.Error, socket.error) as e:                  raise NetLibDisconnect(str(e))  class Reader(_FileLike): +      def read(self, length):          """              If length is -1, we read until connection closes. @@ -180,7 +197,7 @@ class Reader(_FileLike):          self.add_log(result)          return result -    def readline(self, size = None): +    def readline(self, size=None):          result = ''          bytes_read = 0          while True: @@ -204,16 +221,18 @@ class Reader(_FileLike):          result = self.read(length)          if length != -1 and len(result) != length:              raise NetLibIncomplete( -                "Expected %s bytes, got %s"%(length, len(result)) +                "Expected %s bytes, got %s" % (length, len(result))              )          return result  class Address(object): +      """          This class wraps an IPv4/IPv6 tuple to provide named attributes and          ipv6 information.      """ +      def __init__(self, address, use_ipv6=False):          self.address = tuple(address)          self.use_ipv6 = use_ipv6 @@ -304,6 +323,7 @@ def close_socket(sock):  class _Connection(object): +      def get_current_cipher(self):          if not self.ssl_established:              return None @@ -319,7 +339,7 @@ class _Connection(object):          # (We call _FileLike.set_descriptor(conn))          # Closing the socket is not our task, therefore we don't call close          # then. -        if type(self.connection) != SSL.Connection: +        if not isinstance(self.connection, SSL.Connection):              if not getattr(self.wfile, "closed", False):                  try:                      self.wfile.flush() @@ -337,6 +357,7 @@ class _Connection(object):      """      Creates an SSL Context.      """ +      def _create_ssl_context(self,                              method=SSLv23_METHOD,                              options=(OP_NO_SSLv2 | OP_NO_SSLv3), @@ -362,8 +383,8 @@ class _Connection(object):          if cipher_list:              try:                  context.set_cipher_list(cipher_list) -            except SSL.Error, v: -                raise NetLibError("SSL cipher specification error: %s"%str(v)) +            except SSL.Error as v: +                raise NetLibError("SSL cipher specification error: %s" % str(v))          # SSLKEYLOGFILE          if log_ssl_key: @@ -380,7 +401,7 @@ class TCPClient(_Connection):          # Make sure to close the real socket, not the SSL proxy.          # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,          # it tries to renegotiate... -        if type(self.connection) == SSL.Connection: +        if isinstance(self.connection, SSL.Connection):              close_socket(self.connection._socket)          else:              close_socket(self.connection) @@ -400,8 +421,8 @@ class TCPClient(_Connection):              try:                  context.use_privatekey_file(cert)                  context.use_certificate_file(cert) -            except SSL.Error, v: -                raise NetLibError("SSL client certificate error: %s"%str(v)) +            except SSL.Error as v: +                raise NetLibError("SSL client certificate error: %s" % str(v))          return context      def convert_to_ssl(self, sni=None, **sslctx_kwargs): @@ -418,8 +439,8 @@ class TCPClient(_Connection):          self.connection.set_connect_state()          try:              self.connection.do_handshake() -        except SSL.Error, v: -            raise NetLibError("SSL handshake error: %s"%repr(v)) +        except SSL.Error as v: +            raise NetLibError("SSL handshake error: %s" % repr(v))          self.ssl_established = True          self.cert = certutils.SSLCert(self.connection.get_peer_certificate())          self.rfile.set_descriptor(self.connection) @@ -435,7 +456,7 @@ class TCPClient(_Connection):                  self.source_address = Address(connection.getsockname())              self.rfile = Reader(connection.makefile('rb', self.rbufsize))              self.wfile = Writer(connection.makefile('wb', self.wbufsize)) -        except (socket.error, IOError), err: +        except (socket.error, IOError) as err:              raise NetLibError('Error connecting to "%s": %s' % (self.address.host, err))          self.connection = connection @@ -447,6 +468,7 @@ class TCPClient(_Connection):  class BaseHandler(_Connection): +      """          The instantiator is expected to call the handle() and finish() methods. @@ -531,8 +553,8 @@ class BaseHandler(_Connection):          self.connection.set_accept_state()          try:              self.connection.do_handshake() -        except SSL.Error, v: -            raise NetLibError("SSL handshake error: %s"%repr(v)) +        except SSL.Error as v: +            raise NetLibError("SSL handshake error: %s" % repr(v))          self.ssl_established = True          self.rfile.set_descriptor(self.connection)          self.wfile.set_descriptor(self.connection) diff --git a/netlib/test.py b/netlib/test.py index db30c0e6..b6f94273 100644 --- a/netlib/test.py +++ b/netlib/test.py @@ -1,9 +1,13 @@  from __future__ import (absolute_import, print_function, division) -import threading, Queue, cStringIO +import threading +import Queue +import cStringIO  import OpenSSL  from . import tcp, certutils +  class ServerThread(threading.Thread): +      def __init__(self, server):          self.server = server          threading.Thread.__init__(self) @@ -19,6 +23,7 @@ class ServerTestBase(object):      ssl = None      handler = None      addr = ("localhost", 0) +      @classmethod      def setupAll(cls):          cls.q = Queue.Queue() @@ -41,10 +46,11 @@ class ServerTestBase(object):  class TServer(tcp.TCPServer): +      def __init__(self, ssl, q, handler_klass, addr):          """              ssl: A dictionary of SSL parameters: -                 +                      cert, key, request_client_cert, cipher_list,                      dhparams, v3_only          """ @@ -70,13 +76,13 @@ class TServer(tcp.TCPServer):                  options = None              h.convert_to_ssl(                  cert, key, -                method = method, -                options = options, -                handle_sni = getattr(h, "handle_sni", None), -                request_client_cert = self.ssl["request_client_cert"], -                cipher_list = self.ssl.get("cipher_list", None), -                dhparams = self.ssl.get("dhparams", None), -                chain_file = self.ssl.get("chain_file", None) +                method=method, +                options=options, +                handle_sni=getattr(h, "handle_sni", None), +                request_client_cert=self.ssl["request_client_cert"], +                cipher_list=self.ssl.get("cipher_list", None), +                dhparams=self.ssl.get("dhparams", None), +                chain_file=self.ssl.get("chain_file", None)              )          h.handle()          h.finish() diff --git a/netlib/utils.py b/netlib/utils.py index 7e539977..9c5404e6 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -68,6 +68,7 @@ def getbit(byte, offset):  class BiDi: +      """          A wee utility class for keeping bi-directional mappings, like field          constants in protocols. Names are attributes on the object, dict-like @@ -77,6 +78,7 @@ class BiDi:          assert CONST.a == 1          assert CONST.get_name(1) == "a"      """ +      def __init__(self, **kwargs):          self.names = kwargs          self.values = {} @@ -96,15 +98,15 @@ class BiDi:  def pretty_size(size):      suffixes = [ -        ("B", 2**10), -        ("kB", 2**20), -        ("MB", 2**30), +        ("B", 2 ** 10), +        ("kB", 2 ** 20), +        ("MB", 2 ** 30),      ]      for suf, lim in suffixes:          if size >= lim:              continue          else: -            x = round(size/float(lim/2**10), 2) +            x = round(size / float(lim / 2 ** 10), 2)              if x == int(x):                  x = int(x)              return str(x) + suf diff --git a/netlib/websockets.py b/netlib/websockets.py index a2d55c19..63dc03f1 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -26,16 +26,17 @@ MAX_64_BIT_INT = (1 << 64)  OPCODE = utils.BiDi( -    CONTINUE = 0x00, -    TEXT = 0x01, -    BINARY = 0x02, -    CLOSE = 0x08, -    PING = 0x09, -    PONG = 0x0a +    CONTINUE=0x00, +    TEXT=0x01, +    BINARY=0x02, +    CLOSE=0x08, +    PING=0x09, +    PONG=0x0a  )  class Masker: +      """          Data sent from the server must be masked to prevent malicious clients          from sending data over the wire in predictable patterns @@ -43,6 +44,7 @@ class Masker:          Servers do not have to mask data they send to the client.          https://tools.ietf.org/html/rfc6455#section-5.3      """ +      def __init__(self, key):          self.key = key          self.masks = [utils.bytes_to_int(byte) for byte in key] @@ -128,17 +130,18 @@ DEFAULT = object()  class FrameHeader: +      def __init__(          self, -        opcode = OPCODE.TEXT, -        payload_length = 0, -        fin = False, -        rsv1 = False, -        rsv2 = False, -        rsv3 = False, -        masking_key = DEFAULT, -        mask = DEFAULT, -        length_code = DEFAULT +        opcode=OPCODE.TEXT, +        payload_length=0, +        fin=False, +        rsv1=False, +        rsv2=False, +        rsv3=False, +        masking_key=DEFAULT, +        mask=DEFAULT, +        length_code=DEFAULT      ):          if not 0 <= opcode < 2 ** 4:              raise ValueError("opcode must be 0-16") @@ -182,9 +185,9 @@ class FrameHeader:          if flags:              vals.extend([":", "|".join(flags)])          if self.masking_key: -            vals.append(":key=%s"%repr(self.masking_key)) +            vals.append(":key=%s" % repr(self.masking_key))          if self.payload_length: -            vals.append(" %s"%utils.pretty_size(self.payload_length)) +            vals.append(" %s" % utils.pretty_size(self.payload_length))          return "".join(vals)      def to_bytes(self): @@ -246,15 +249,15 @@ class FrameHeader:              masking_key = None          return klass( -            fin = fin, -            rsv1 = rsv1, -            rsv2 = rsv2, -            rsv3 = rsv3, -            opcode = opcode, -            mask = mask_bit, -            length_code = length_code, -            payload_length = payload_length, -            masking_key = masking_key, +            fin=fin, +            rsv1=rsv1, +            rsv2=rsv2, +            rsv3=rsv3, +            opcode=opcode, +            mask=mask_bit, +            length_code=length_code, +            payload_length=payload_length, +            masking_key=masking_key,          )      def __eq__(self, other): @@ -262,6 +265,7 @@ class FrameHeader:  class Frame(object): +      """          Represents one websockets frame.          Constructor takes human readable forms of the frame components @@ -287,13 +291,14 @@ class Frame(object):           |                     Payload Data continued ...                |           +---------------------------------------------------------------+      """ -    def __init__(self, payload = "", **kwargs): + +    def __init__(self, payload="", **kwargs):          self.payload = payload          kwargs["payload_length"] = kwargs.get("payload_length", len(payload))          self.header = FrameHeader(**kwargs)      @classmethod -    def default(cls, message, from_client = False): +    def default(cls, message, from_client=False):          """            Construct a basic websocket frame from some default values.            Creates a non-fragmented text frame. @@ -307,10 +312,10 @@ class Frame(object):          return cls(              message, -            fin = 1, # final frame -            opcode = OPCODE.TEXT, # text -            mask = mask_bit, -            masking_key = masking_key, +            fin=1,  # final frame +            opcode=OPCODE.TEXT,  # text +            mask=mask_bit, +            masking_key=masking_key,          )      @classmethod @@ -356,15 +361,15 @@ class Frame(object):          return cls(              payload, -            fin = header.fin, -            opcode = header.opcode, -            mask = header.mask, -            payload_length = header.payload_length, -            masking_key = header.masking_key, -            rsv1 = header.rsv1, -            rsv2 = header.rsv2, -            rsv3 = header.rsv3, -            length_code = header.length_code +            fin=header.fin, +            opcode=header.opcode, +            mask=header.mask, +            payload_length=header.payload_length, +            masking_key=header.masking_key, +            rsv1=header.rsv1, +            rsv2=header.rsv2, +            rsv3=header.rsv3, +            length_code=header.length_code          )      def __eq__(self, other): diff --git a/netlib/wsgi.py b/netlib/wsgi.py index 1b979608..f393039a 100644 --- a/netlib/wsgi.py +++ b/netlib/wsgi.py @@ -7,17 +7,20 @@ from . import odict, tcp  class ClientConn(object): +      def __init__(self, address):          self.address = tcp.Address.wrap(address)  class Flow(object): +      def __init__(self, address, request):          self.client_conn = ClientConn(address)          self.request = request  class Request(object): +      def __init__(self, scheme, method, path, headers, content):          self.scheme, self.method, self.path = scheme, method, path          self.headers, self.content = headers, content @@ -42,6 +45,7 @@ def date_time_string():  class WSGIAdaptor(object): +      def __init__(self, app, domain, port, sversion):          self.app, self.domain, self.port, self.sversion = app, domain, port, sversion @@ -52,24 +56,24 @@ class WSGIAdaptor(object):              path_info = flow.request.path              query = ''          environ = { -            'wsgi.version':         (1, 0), -            'wsgi.url_scheme':      flow.request.scheme, -            'wsgi.input':           cStringIO.StringIO(flow.request.content), -            'wsgi.errors':          errsoc, -            'wsgi.multithread':     True, -            'wsgi.multiprocess':    False, -            'wsgi.run_once':        False, -            'SERVER_SOFTWARE':      self.sversion, -            'REQUEST_METHOD':       flow.request.method, -            'SCRIPT_NAME':          '', -            'PATH_INFO':            urllib.unquote(path_info), -            'QUERY_STRING':         query, -            'CONTENT_TYPE':         flow.request.headers.get('Content-Type', [''])[0], -            'CONTENT_LENGTH':       flow.request.headers.get('Content-Length', [''])[0], -            'SERVER_NAME':          self.domain, -            'SERVER_PORT':          str(self.port), +            'wsgi.version': (1, 0), +            'wsgi.url_scheme': flow.request.scheme, +            'wsgi.input': cStringIO.StringIO(flow.request.content), +            'wsgi.errors': errsoc, +            'wsgi.multithread': True, +            'wsgi.multiprocess': False, +            'wsgi.run_once': False, +            'SERVER_SOFTWARE': self.sversion, +            'REQUEST_METHOD': flow.request.method, +            'SCRIPT_NAME': '', +            'PATH_INFO': urllib.unquote(path_info), +            'QUERY_STRING': query, +            'CONTENT_TYPE': flow.request.headers.get('Content-Type', [''])[0], +            'CONTENT_LENGTH': flow.request.headers.get('Content-Length', [''])[0], +            'SERVER_NAME': self.domain, +            'SERVER_PORT': str(self.port),              # FIXME: We need to pick up the protocol read from the request. -            'SERVER_PROTOCOL':      "HTTP/1.1", +            'SERVER_PROTOCOL': "HTTP/1.1",          }          environ.update(extra)          if flow.client_conn.address: @@ -91,25 +95,25 @@ class WSGIAdaptor(object):                  <h1>Internal Server Error</h1>                  <pre>%s"</pre>              </html> -        """%s +        """ % s          if not headers_sent:              soc.write("HTTP/1.1 500 Internal Server Error\r\n")              soc.write("Content-Type: text/html\r\n") -            soc.write("Content-Length: %s\r\n"%len(c)) +            soc.write("Content-Length: %s\r\n" % len(c))              soc.write("\r\n")          soc.write(c)      def serve(self, request, soc, **env):          state = dict( -            response_started = False, -            headers_sent = False, -            status = None, -            headers = None +            response_started=False, +            headers_sent=False, +            status=None, +            headers=None          )          def write(data):              if not state["headers_sent"]: -                soc.write("HTTP/1.1 %s\r\n"%state["status"]) +                soc.write("HTTP/1.1 %s\r\n" % state["status"])                  h = state["headers"]                  if 'server' not in h:                      h["Server"] = [self.sversion] | 
