aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/__init__.py0
-rw-r--r--netlib/basethread.py14
-rw-r--r--netlib/basetypes.py32
-rw-r--r--netlib/certutils.py481
-rw-r--r--netlib/debug.py120
-rw-r--r--netlib/encoding.py175
-rw-r--r--netlib/exceptions.py59
-rw-r--r--netlib/http/__init__.py15
-rw-r--r--netlib/http/authentication.py176
-rw-r--r--netlib/http/cookies.py384
-rw-r--r--netlib/http/headers.py221
-rw-r--r--netlib/http/http1/__init__.py24
-rw-r--r--netlib/http/http1/assemble.py100
-rw-r--r--netlib/http/http1/read.py377
-rw-r--r--netlib/http/http2/__init__.py8
-rw-r--r--netlib/http/http2/framereader.py25
-rw-r--r--netlib/http/http2/utils.py37
-rw-r--r--netlib/http/message.py298
-rw-r--r--netlib/http/multipart.py32
-rw-r--r--netlib/http/request.py405
-rw-r--r--netlib/http/response.py192
-rw-r--r--netlib/http/status_codes.py104
-rw-r--r--netlib/http/url.py127
-rw-r--r--netlib/http/user_agents.py50
-rw-r--r--netlib/human.py64
-rw-r--r--netlib/multidict.py298
-rw-r--r--netlib/socks.py232
-rw-r--r--netlib/strutils.py142
-rw-r--r--netlib/tcp.py989
-rw-r--r--netlib/tutils.py130
-rw-r--r--netlib/utils.py97
-rw-r--r--netlib/version.py4
-rw-r--r--netlib/version_check.py43
-rw-r--r--netlib/websockets/__init__.py35
-rw-r--r--netlib/websockets/frame.py273
-rw-r--r--netlib/websockets/masker.py25
-rw-r--r--netlib/websockets/utils.py89
-rw-r--r--netlib/wsgi.py164
38 files changed, 0 insertions, 6041 deletions
diff --git a/netlib/__init__.py b/netlib/__init__.py
deleted file mode 100644
index e69de29b..00000000
--- a/netlib/__init__.py
+++ /dev/null
diff --git a/netlib/basethread.py b/netlib/basethread.py
deleted file mode 100644
index a3c81d19..00000000
--- a/netlib/basethread.py
+++ /dev/null
@@ -1,14 +0,0 @@
-import time
-import threading
-
-
-class BaseThread(threading.Thread):
- def __init__(self, name, *args, **kwargs):
- super().__init__(name=name, *args, **kwargs)
- self._thread_started = time.time()
-
- def _threadinfo(self):
- return "%s - age: %is" % (
- self.name,
- int(time.time() - self._thread_started)
- )
diff --git a/netlib/basetypes.py b/netlib/basetypes.py
deleted file mode 100644
index 49892ffc..00000000
--- a/netlib/basetypes.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import abc
-
-
-class Serializable(metaclass=abc.ABCMeta):
- """
- Abstract Base Class that defines an API to save an object's state and restore it later on.
- """
-
- @classmethod
- @abc.abstractmethod
- def from_state(cls, state):
- """
- Create a new object from the given state.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def get_state(self):
- """
- Retrieve object state.
- """
- raise NotImplementedError()
-
- @abc.abstractmethod
- def set_state(self, state):
- """
- Set object state to the given state.
- """
- raise NotImplementedError()
-
- def copy(self):
- return self.from_state(self.get_state())
diff --git a/netlib/certutils.py b/netlib/certutils.py
deleted file mode 100644
index 6a97f99e..00000000
--- a/netlib/certutils.py
+++ /dev/null
@@ -1,481 +0,0 @@
-import os
-import ssl
-import time
-import datetime
-import ipaddress
-
-import sys
-from pyasn1.type import univ, constraint, char, namedtype, tag
-from pyasn1.codec.der.decoder import decode
-from pyasn1.error import PyAsn1Error
-import OpenSSL
-
-from netlib import basetypes
-
-# Default expiry must not be too long: https://github.com/mitmproxy/mitmproxy/issues/815
-
-DEFAULT_EXP = 94608000 # = 24 * 60 * 60 * 365 * 3
-# Generated with "openssl dhparam". It's too slow to generate this on startup.
-DEFAULT_DHPARAM = b"""
------BEGIN DH PARAMETERS-----
-MIICCAKCAgEAyT6LzpwVFS3gryIo29J5icvgxCnCebcdSe/NHMkD8dKJf8suFCg3
-O2+dguLakSVif/t6dhImxInJk230HmfC8q93hdcg/j8rLGJYDKu3ik6H//BAHKIv
-j5O9yjU3rXCfmVJQic2Nne39sg3CreAepEts2TvYHhVv3TEAzEqCtOuTjgDv0ntJ
-Gwpj+BJBRQGG9NvprX1YGJ7WOFBP/hWU7d6tgvE6Xa7T/u9QIKpYHMIkcN/l3ZFB
-chZEqVlyrcngtSXCROTPcDOQ6Q8QzhaBJS+Z6rcsd7X+haiQqvoFcmaJ08Ks6LQC
-ZIL2EtYJw8V8z7C0igVEBIADZBI6OTbuuhDwRw//zU1uq52Oc48CIZlGxTYG/Evq
-o9EWAXUYVzWkDSTeBH1r4z/qLPE2cnhtMxbFxuvK53jGB0emy2y1Ei6IhKshJ5qX
-IB/aE7SSHyQ3MDHHkCmQJCsOd4Mo26YX61NZ+n501XjqpCBQ2+DfZCBh8Va2wDyv
-A2Ryg9SUz8j0AXViRNMJgJrr446yro/FuJZwnQcO3WQnXeqSBnURqKjmqkeFP+d8
-6mk2tqJaY507lRNqtGlLnj7f5RNoBFJDCLBNurVgfvq9TCVWKDIFD4vZRjCrnl6I
-rD693XKIHUCWOjMh1if6omGXKHH40QuME2gNa50+YPn1iYDl88uDbbMCAQI=
------END DH PARAMETERS-----
-"""
-
-
-def create_ca(o, cn, exp):
- key = OpenSSL.crypto.PKey()
- key.generate_key(OpenSSL.crypto.TYPE_RSA, 2048)
- cert = OpenSSL.crypto.X509()
- cert.set_serial_number(int(time.time() * 10000))
- cert.set_version(2)
- cert.get_subject().CN = cn
- cert.get_subject().O = o
- cert.gmtime_adj_notBefore(-3600 * 48)
- cert.gmtime_adj_notAfter(exp)
- cert.set_issuer(cert.get_subject())
- cert.set_pubkey(key)
- cert.add_extensions([
- OpenSSL.crypto.X509Extension(
- b"basicConstraints",
- True,
- b"CA:TRUE"
- ),
- OpenSSL.crypto.X509Extension(
- b"nsCertType",
- False,
- b"sslCA"
- ),
- OpenSSL.crypto.X509Extension(
- b"extendedKeyUsage",
- False,
- b"serverAuth,clientAuth,emailProtection,timeStamping,msCodeInd,msCodeCom,msCTLSign,msSGC,msEFS,nsSGC"
- ),
- OpenSSL.crypto.X509Extension(
- b"keyUsage",
- True,
- b"keyCertSign, cRLSign"
- ),
- OpenSSL.crypto.X509Extension(
- b"subjectKeyIdentifier",
- False,
- b"hash",
- subject=cert
- ),
- ])
- cert.sign(key, "sha256")
- return key, cert
-
-
-def dummy_cert(privkey, cacert, commonname, sans):
- """
- Generates a dummy certificate.
-
- privkey: CA private key
- cacert: CA certificate
- commonname: Common name for the generated certificate.
- sans: A list of Subject Alternate Names.
-
- Returns cert if operation succeeded, None if not.
- """
- ss = []
- for i in sans:
- try:
- ipaddress.ip_address(i.decode("ascii"))
- except ValueError:
- ss.append(b"DNS: %s" % i)
- else:
- ss.append(b"IP: %s" % i)
- ss = b", ".join(ss)
-
- cert = OpenSSL.crypto.X509()
- cert.gmtime_adj_notBefore(-3600 * 48)
- cert.gmtime_adj_notAfter(DEFAULT_EXP)
- cert.set_issuer(cacert.get_subject())
- if commonname is not None:
- cert.get_subject().CN = commonname
- cert.set_serial_number(int(time.time() * 10000))
- if ss:
- cert.set_version(2)
- cert.add_extensions(
- [OpenSSL.crypto.X509Extension(b"subjectAltName", False, ss)])
- cert.set_pubkey(cacert.get_pubkey())
- cert.sign(privkey, "sha256")
- return SSLCert(cert)
-
-
-# DNTree did not pass TestCertStore.test_sans_change and is temporarily replaced by a simple dict.
-#
-# class _Node(UserDict.UserDict):
-# def __init__(self):
-# UserDict.UserDict.__init__(self)
-# self.value = None
-#
-#
-# class DNTree:
-# """
-# Domain store that knows about wildcards. DNS wildcards are very
-# restricted - the only valid variety is an asterisk on the left-most
-# domain component, i.e.:
-#
-# *.foo.com
-# """
-# def __init__(self):
-# self.d = _Node()
-#
-# def add(self, dn, cert):
-# parts = dn.split(".")
-# parts.reverse()
-# current = self.d
-# for i in parts:
-# current = current.setdefault(i, _Node())
-# current.value = cert
-#
-# def get(self, dn):
-# parts = dn.split(".")
-# current = self.d
-# for i in reversed(parts):
-# if i in current:
-# current = current[i]
-# elif "*" in current:
-# return current["*"].value
-# else:
-# return None
-# return current.value
-
-
-class CertStoreEntry:
-
- def __init__(self, cert, privatekey, chain_file):
- self.cert = cert
- self.privatekey = privatekey
- self.chain_file = chain_file
-
-
-class CertStore:
-
- """
- Implements an in-memory certificate store.
- """
- STORE_CAP = 100
-
- def __init__(
- self,
- default_privatekey,
- default_ca,
- default_chain_file,
- dhparams):
- self.default_privatekey = default_privatekey
- self.default_ca = default_ca
- self.default_chain_file = default_chain_file
- self.dhparams = dhparams
- self.certs = dict()
- self.expire_queue = []
-
- def expire(self, entry):
- self.expire_queue.append(entry)
- if len(self.expire_queue) > self.STORE_CAP:
- d = self.expire_queue.pop(0)
- for k, v in list(self.certs.items()):
- if v == d:
- del self.certs[k]
-
- @staticmethod
- def load_dhparam(path):
-
- # netlib<=0.10 doesn't generate a dhparam file.
- # Create it now if neccessary.
- if not os.path.exists(path):
- with open(path, "wb") as f:
- f.write(DEFAULT_DHPARAM)
-
- bio = OpenSSL.SSL._lib.BIO_new_file(path.encode(sys.getfilesystemencoding()), b"r")
- if bio != OpenSSL.SSL._ffi.NULL:
- bio = OpenSSL.SSL._ffi.gc(bio, OpenSSL.SSL._lib.BIO_free)
- dh = OpenSSL.SSL._lib.PEM_read_bio_DHparams(
- bio,
- OpenSSL.SSL._ffi.NULL,
- OpenSSL.SSL._ffi.NULL,
- OpenSSL.SSL._ffi.NULL)
- dh = OpenSSL.SSL._ffi.gc(dh, OpenSSL.SSL._lib.DH_free)
- return dh
-
- @classmethod
- def from_store(cls, path, basename):
- ca_path = os.path.join(path, basename + "-ca.pem")
- if not os.path.exists(ca_path):
- key, ca = cls.create_store(path, basename)
- else:
- with open(ca_path, "rb") as f:
- raw = f.read()
- ca = OpenSSL.crypto.load_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- raw)
- key = OpenSSL.crypto.load_privatekey(
- OpenSSL.crypto.FILETYPE_PEM,
- raw)
- dh_path = os.path.join(path, basename + "-dhparam.pem")
- dh = cls.load_dhparam(dh_path)
- return cls(key, ca, ca_path, dh)
-
- @staticmethod
- def create_store(path, basename, o=None, cn=None, expiry=DEFAULT_EXP):
- if not os.path.exists(path):
- os.makedirs(path)
-
- o = o or basename
- cn = cn or basename
-
- key, ca = create_ca(o=o, cn=cn, exp=expiry)
- # Dump the CA plus private key
- with open(os.path.join(path, basename + "-ca.pem"), "wb") as f:
- f.write(
- OpenSSL.crypto.dump_privatekey(
- OpenSSL.crypto.FILETYPE_PEM,
- key))
- f.write(
- OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- ca))
-
- # Dump the certificate in PEM format
- with open(os.path.join(path, basename + "-ca-cert.pem"), "wb") as f:
- f.write(
- OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- ca))
-
- # Create a .cer file with the same contents for Android
- with open(os.path.join(path, basename + "-ca-cert.cer"), "wb") as f:
- f.write(
- OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- ca))
-
- # Dump the certificate in PKCS12 format for Windows devices
- with open(os.path.join(path, basename + "-ca-cert.p12"), "wb") as f:
- p12 = OpenSSL.crypto.PKCS12()
- p12.set_certificate(ca)
- p12.set_privatekey(key)
- f.write(p12.export())
-
- with open(os.path.join(path, basename + "-dhparam.pem"), "wb") as f:
- f.write(DEFAULT_DHPARAM)
-
- return key, ca
-
- def add_cert_file(self, spec, path):
- with open(path, "rb") as f:
- raw = f.read()
- cert = SSLCert(
- OpenSSL.crypto.load_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- raw))
- try:
- privatekey = OpenSSL.crypto.load_privatekey(
- OpenSSL.crypto.FILETYPE_PEM,
- raw)
- except Exception:
- privatekey = self.default_privatekey
- self.add_cert(
- CertStoreEntry(cert, privatekey, path),
- spec
- )
-
- def add_cert(self, entry, *names):
- """
- Adds a cert to the certstore. We register the CN in the cert plus
- any SANs, and also the list of names provided as an argument.
- """
- if entry.cert.cn:
- self.certs[entry.cert.cn] = entry
- for i in entry.cert.altnames:
- self.certs[i] = entry
- for i in names:
- self.certs[i] = entry
-
- @staticmethod
- def asterisk_forms(dn):
- if dn is None:
- return []
- parts = dn.split(b".")
- parts.reverse()
- curr_dn = b""
- dn_forms = [b"*"]
- for part in parts[:-1]:
- curr_dn = b"." + part + curr_dn # .example.com
- dn_forms.append(b"*" + curr_dn) # *.example.com
- if parts[-1] != b"*":
- dn_forms.append(parts[-1] + curr_dn)
- return dn_forms
-
- def get_cert(self, commonname, sans):
- """
- Returns an (cert, privkey, cert_chain) tuple.
-
- commonname: Common name for the generated certificate. Must be a
- valid, plain-ASCII, IDNA-encoded domain name.
-
- sans: A list of Subject Alternate Names.
- """
-
- potential_keys = self.asterisk_forms(commonname)
- for s in sans:
- potential_keys.extend(self.asterisk_forms(s))
- potential_keys.append((commonname, tuple(sans)))
-
- name = next(
- filter(lambda key: key in self.certs, potential_keys),
- None
- )
- if name:
- entry = self.certs[name]
- else:
- entry = CertStoreEntry(
- cert=dummy_cert(
- self.default_privatekey,
- self.default_ca,
- commonname,
- sans),
- privatekey=self.default_privatekey,
- chain_file=self.default_chain_file)
- self.certs[(commonname, tuple(sans))] = entry
- self.expire(entry)
-
- return entry.cert, entry.privatekey, entry.chain_file
-
-
-class _GeneralName(univ.Choice):
- # We are only interested in dNSNames. We use a default handler to ignore
- # other types.
- # TODO: We should also handle iPAddresses.
- componentType = namedtype.NamedTypes(
- namedtype.NamedType('dNSName', char.IA5String().subtype(
- implicitTag=tag.Tag(tag.tagClassContext, tag.tagFormatSimple, 2)
- )
- ),
- )
-
-
-class _GeneralNames(univ.SequenceOf):
- componentType = _GeneralName()
- sizeSpec = univ.SequenceOf.sizeSpec + \
- constraint.ValueSizeConstraint(1, 1024)
-
-
-class SSLCert(basetypes.Serializable):
-
- def __init__(self, cert):
- """
- Returns a (common name, [subject alternative names]) tuple.
- """
- self.x509 = cert
-
- def __eq__(self, other):
- return self.digest("sha256") == other.digest("sha256")
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def get_state(self):
- return self.to_pem()
-
- def set_state(self, state):
- self.x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, state)
-
- @classmethod
- def from_state(cls, state):
- return cls.from_pem(state)
-
- @classmethod
- def from_pem(cls, txt):
- x509 = OpenSSL.crypto.load_certificate(OpenSSL.crypto.FILETYPE_PEM, txt)
- return cls(x509)
-
- @classmethod
- def from_der(cls, der):
- pem = ssl.DER_cert_to_PEM_cert(der)
- return cls.from_pem(pem)
-
- def to_pem(self):
- return OpenSSL.crypto.dump_certificate(
- OpenSSL.crypto.FILETYPE_PEM,
- self.x509)
-
- def digest(self, name):
- return self.x509.digest(name)
-
- @property
- def issuer(self):
- return self.x509.get_issuer().get_components()
-
- @property
- def notbefore(self):
- t = self.x509.get_notBefore()
- return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
-
- @property
- def notafter(self):
- t = self.x509.get_notAfter()
- return datetime.datetime.strptime(t.decode("ascii"), "%Y%m%d%H%M%SZ")
-
- @property
- def has_expired(self):
- return self.x509.has_expired()
-
- @property
- def subject(self):
- return self.x509.get_subject().get_components()
-
- @property
- def serial(self):
- return self.x509.get_serial_number()
-
- @property
- def keyinfo(self):
- pk = self.x509.get_pubkey()
- types = {
- OpenSSL.crypto.TYPE_RSA: "RSA",
- OpenSSL.crypto.TYPE_DSA: "DSA",
- }
- return (
- types.get(pk.type(), "UNKNOWN"),
- pk.bits()
- )
-
- @property
- def cn(self):
- c = None
- for i in self.subject:
- if i[0] == b"CN":
- c = i[1]
- return c
-
- @property
- def altnames(self):
- """
- Returns:
- All DNS altnames.
- """
- # tcp.TCPClient.convert_to_ssl assumes that this property only contains DNS altnames for hostname verification.
- altnames = []
- for i in range(self.x509.get_extension_count()):
- ext = self.x509.get_extension(i)
- if ext.get_short_name() == b"subjectAltName":
- try:
- dec = decode(ext.get_data(), asn1Spec=_GeneralNames())
- except PyAsn1Error:
- continue
- for i in dec[0]:
- altnames.append(i[0].asOctets())
- return altnames
diff --git a/netlib/debug.py b/netlib/debug.py
deleted file mode 100644
index f1b3d792..00000000
--- a/netlib/debug.py
+++ /dev/null
@@ -1,120 +0,0 @@
-import gc
-import os
-import sys
-import threading
-import signal
-import platform
-import traceback
-
-from netlib import version
-
-from OpenSSL import SSL
-
-
-def sysinfo():
- data = [
- "Mitmproxy version: %s" % version.VERSION,
- "Python version: %s" % platform.python_version(),
- "Platform: %s" % platform.platform(),
- "SSL version: %s" % SSL.SSLeay_version(SSL.SSLEAY_VERSION).decode(),
- ]
- d = platform.linux_distribution()
- t = "Linux distro: %s %s %s" % d
- if d[0]: # pragma: no-cover
- data.append(t)
-
- d = platform.mac_ver()
- t = "Mac version: %s %s %s" % d
- if d[0]: # pragma: no-cover
- data.append(t)
-
- d = platform.win32_ver()
- t = "Windows version: %s %s %s %s" % d
- if d[0]: # pragma: no-cover
- data.append(t)
-
- return "\n".join(data)
-
-
-def dump_info(signal=None, frame=None, file=sys.stdout, testing=False): # pragma: no cover
- print("****************************************************", file=file)
- print("Summary", file=file)
- print("=======", file=file)
-
- try:
- import psutil
- except:
- print("(psutil not installed, skipping some debug info)", file=file)
- else:
- p = psutil.Process()
- print("num threads: ", p.num_threads(), file=file)
- if hasattr(p, "num_fds"):
- print("num fds: ", p.num_fds(), file=file)
- print("memory: ", p.memory_info(), file=file)
-
- print(file=file)
- print("Files", file=file)
- print("=====", file=file)
- for i in p.open_files():
- print(i, file=file)
-
- print(file=file)
- print("Connections", file=file)
- print("===========", file=file)
- for i in p.connections():
- print(i, file=file)
-
- print(file=file)
- print("Threads", file=file)
- print("=======", file=file)
- bthreads = []
- for i in threading.enumerate():
- if hasattr(i, "_threadinfo"):
- bthreads.append(i)
- else:
- print(i.name, file=file)
- bthreads.sort(key=lambda x: x._thread_started)
- for i in bthreads:
- print(i._threadinfo(), file=file)
-
- print(file=file)
- print("Memory", file=file)
- print("=======", file=file)
- gc.collect()
- d = {}
- for i in gc.get_objects():
- t = str(type(i))
- if "mitmproxy" in t or "netlib" in t:
- d[t] = d.setdefault(t, 0) + 1
- itms = list(d.items())
- itms.sort(key=lambda x: x[1])
- for i in itms[-20:]:
- print(i[1], i[0], file=file)
- print("****************************************************", file=file)
-
- if not testing:
- sys.exit(1)
-
-
-def dump_stacks(signal=None, frame=None, file=sys.stdout, testing=False):
- id2name = dict([(th.ident, th.name) for th in threading.enumerate()])
- code = []
- for threadId, stack in sys._current_frames().items():
- code.append(
- "\n# Thread: %s(%d)" % (
- id2name.get(threadId, ""), threadId
- )
- )
- for filename, lineno, name, line in traceback.extract_stack(stack):
- code.append('File: "%s", line %d, in %s' % (filename, lineno, name))
- if line:
- code.append(" %s" % (line.strip()))
- print("\n".join(code), file=file)
- if not testing:
- sys.exit(1)
-
-
-def register_info_dumpers():
- if os.name != "nt":
- signal.signal(signal.SIGUSR1, dump_info)
- signal.signal(signal.SIGUSR2, dump_stacks)
diff --git a/netlib/encoding.py b/netlib/encoding.py
deleted file mode 100644
index e123a033..00000000
--- a/netlib/encoding.py
+++ /dev/null
@@ -1,175 +0,0 @@
-"""
-Utility functions for decoding response bodies.
-"""
-
-import codecs
-import collections
-from io import BytesIO
-
-import gzip
-import zlib
-import brotli
-
-from typing import Union
-
-
-# We have a shared single-element cache for encoding and decoding.
-# This is quite useful in practice, e.g.
-# flow.request.content = flow.request.content.replace(b"foo", b"bar")
-# does not require an .encode() call if content does not contain b"foo"
-CachedDecode = collections.namedtuple("CachedDecode", "encoded encoding errors decoded")
-_cache = CachedDecode(None, None, None, None)
-
-
-def decode(encoded: Union[str, bytes], encoding: str, errors: str='strict') -> Union[str, bytes]:
- """
- Decode the given input object
-
- Returns:
- The decoded value
-
- Raises:
- ValueError, if decoding fails.
- """
- if len(encoded) == 0:
- return encoded
-
- global _cache
- cached = (
- isinstance(encoded, bytes) and
- _cache.encoded == encoded and
- _cache.encoding == encoding and
- _cache.errors == errors
- )
- if cached:
- return _cache.decoded
- try:
- try:
- decoded = custom_decode[encoding](encoded)
- except KeyError:
- decoded = codecs.decode(encoded, encoding, errors)
- if encoding in ("gzip", "deflate", "br"):
- _cache = CachedDecode(encoded, encoding, errors, decoded)
- return decoded
- except TypeError:
- raise
- except Exception as e:
- raise ValueError("{} when decoding {} with {}: {}".format(
- type(e).__name__,
- repr(encoded)[:10],
- repr(encoding),
- repr(e),
- ))
-
-
-def encode(decoded: Union[str, bytes], encoding: str, errors: str='strict') -> Union[str, bytes]:
- """
- Encode the given input object
-
- Returns:
- The encoded value
-
- Raises:
- ValueError, if encoding fails.
- """
- if len(decoded) == 0:
- return decoded
-
- global _cache
- cached = (
- isinstance(decoded, bytes) and
- _cache.decoded == decoded and
- _cache.encoding == encoding and
- _cache.errors == errors
- )
- if cached:
- return _cache.encoded
- try:
- try:
- value = decoded
- if isinstance(value, str):
- value = decoded.encode()
- encoded = custom_encode[encoding](value)
- except KeyError:
- encoded = codecs.encode(decoded, encoding, errors)
- if encoding in ("gzip", "deflate", "br"):
- _cache = CachedDecode(encoded, encoding, errors, decoded)
- return encoded
- except TypeError:
- raise
- except Exception as e:
- raise ValueError("{} when encoding {} with {}: {}".format(
- type(e).__name__,
- repr(decoded)[:10],
- repr(encoding),
- repr(e),
- ))
-
-
-def identity(content):
- """
- Returns content unchanged. Identity is the default value of
- Accept-Encoding headers.
- """
- return content
-
-
-def decode_gzip(content):
- gfile = gzip.GzipFile(fileobj=BytesIO(content))
- return gfile.read()
-
-
-def encode_gzip(content):
- s = BytesIO()
- gf = gzip.GzipFile(fileobj=s, mode='wb')
- gf.write(content)
- gf.close()
- return s.getvalue()
-
-
-def decode_brotli(content):
- return brotli.decompress(content)
-
-
-def encode_brotli(content):
- return brotli.compress(content)
-
-
-def decode_deflate(content):
- """
- Returns decompressed data for DEFLATE. Some servers may respond with
- compressed data without a zlib header or checksum. An undocumented
- feature of zlib permits the lenient decompression of data missing both
- values.
-
- http://bugs.python.org/issue5784
- """
- try:
- return zlib.decompress(content)
- except zlib.error:
- return zlib.decompress(content, -15)
-
-
-def encode_deflate(content):
- """
- Returns compressed content, always including zlib header and checksum.
- """
- return zlib.compress(content)
-
-
-custom_decode = {
- "none": identity,
- "identity": identity,
- "gzip": decode_gzip,
- "deflate": decode_deflate,
- "br": decode_brotli,
-}
-custom_encode = {
- "none": identity,
- "identity": identity,
- "gzip": encode_gzip,
- "deflate": encode_deflate,
- "br": encode_brotli,
-}
-
-__all__ = ["encode", "decode"]
diff --git a/netlib/exceptions.py b/netlib/exceptions.py
deleted file mode 100644
index d0b15d27..00000000
--- a/netlib/exceptions.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""
-We try to be very hygienic regarding the exceptions we throw:
-Every Exception netlib raises shall be a subclass of NetlibException.
-
-
-See also: http://lucumr.pocoo.org/2014/10/16/on-error-handling/
-"""
-
-
-class NetlibException(Exception):
- """
- Base class for all exceptions thrown by netlib.
- """
- def __init__(self, message=None):
- super().__init__(message)
-
-
-class Disconnect:
- """Immediate EOF"""
-
-
-class HttpException(NetlibException):
- pass
-
-
-class HttpReadDisconnect(HttpException, Disconnect):
- pass
-
-
-class HttpSyntaxException(HttpException):
- pass
-
-
-class TcpException(NetlibException):
- pass
-
-
-class TcpDisconnect(TcpException, Disconnect):
- pass
-
-
-class TcpReadIncomplete(TcpException):
- pass
-
-
-class TcpTimeout(TcpException):
- pass
-
-
-class TlsException(NetlibException):
- pass
-
-
-class InvalidCertificateException(TlsException):
- pass
-
-
-class Timeout(TcpException):
- pass
diff --git a/netlib/http/__init__.py b/netlib/http/__init__.py
deleted file mode 100644
index 315f61ac..00000000
--- a/netlib/http/__init__.py
+++ /dev/null
@@ -1,15 +0,0 @@
-from netlib.http.request import Request
-from netlib.http.response import Response
-from netlib.http.message import Message
-from netlib.http.headers import Headers, parse_content_type
-from netlib.http.message import decoded
-from netlib.http import http1, http2, status_codes, multipart
-
-__all__ = [
- "Request",
- "Response",
- "Message",
- "Headers", "parse_content_type",
- "decoded",
- "http1", "http2", "status_codes", "multipart",
-]
diff --git a/netlib/http/authentication.py b/netlib/http/authentication.py
deleted file mode 100644
index a65279e4..00000000
--- a/netlib/http/authentication.py
+++ /dev/null
@@ -1,176 +0,0 @@
-import argparse
-import binascii
-
-
-def parse_http_basic_auth(s):
- words = s.split()
- if len(words) != 2:
- return None
- scheme = words[0]
- try:
- user = binascii.a2b_base64(words[1]).decode("utf8", "replace")
- except binascii.Error:
- return None
- parts = user.split(':')
- if len(parts) != 2:
- return None
- return scheme, parts[0], parts[1]
-
-
-def assemble_http_basic_auth(scheme, username, password):
- v = binascii.b2a_base64((username + ":" + password).encode("utf8")).decode("ascii")
- return scheme + " " + v
-
-
-class NullProxyAuth:
-
- """
- No proxy auth at all (returns empty challange headers)
- """
-
- def __init__(self, password_manager):
- self.password_manager = password_manager
-
- def clean(self, headers_):
- """
- Clean up authentication headers, so they're not passed upstream.
- """
-
- def authenticate(self, headers_):
- """
- Tests that the user is allowed to use the proxy
- """
- return True
-
- def auth_challenge_headers(self):
- """
- Returns a dictionary containing the headers require to challenge the user
- """
- return {}
-
-
-class BasicAuth(NullProxyAuth):
- CHALLENGE_HEADER = None
- AUTH_HEADER = None
-
- def __init__(self, password_manager, realm):
- NullProxyAuth.__init__(self, password_manager)
- self.realm = realm
-
- def clean(self, headers):
- del headers[self.AUTH_HEADER]
-
- def authenticate(self, headers):
- auth_value = headers.get(self.AUTH_HEADER)
- if not auth_value:
- return False
- parts = parse_http_basic_auth(auth_value)
- if not parts:
- return False
- scheme, username, password = parts
- if scheme.lower() != 'basic':
- return False
- if not self.password_manager.test(username, password):
- return False
- self.username = username
- return True
-
- def auth_challenge_headers(self):
- return {self.CHALLENGE_HEADER: 'Basic realm="%s"' % self.realm}
-
-
-class BasicWebsiteAuth(BasicAuth):
- CHALLENGE_HEADER = 'WWW-Authenticate'
- AUTH_HEADER = 'Authorization'
-
-
-class BasicProxyAuth(BasicAuth):
- CHALLENGE_HEADER = 'Proxy-Authenticate'
- AUTH_HEADER = 'Proxy-Authorization'
-
-
-class PassMan:
-
- def test(self, username_, password_token_):
- return False
-
-
-class PassManNonAnon(PassMan):
-
- """
- Ensure the user specifies a username, accept any password.
- """
-
- def test(self, username, password_token_):
- if username:
- return True
- return False
-
-
-class PassManHtpasswd(PassMan):
-
- """
- Read usernames and passwords from an htpasswd file
- """
-
- def __init__(self, path):
- """
- Raises ValueError if htpasswd file is invalid.
- """
- import passlib.apache
- self.htpasswd = passlib.apache.HtpasswdFile(path)
-
- def test(self, username, password_token):
- return bool(self.htpasswd.check_password(username, password_token))
-
-
-class PassManSingleUser(PassMan):
-
- def __init__(self, username, password):
- self.username, self.password = username, password
-
- def test(self, username, password_token):
- return self.username == username and self.password == password_token
-
-
-class AuthAction(argparse.Action):
-
- """
- Helper class to allow seamless integration int argparse. Example usage:
- parser.add_argument(
- "--nonanonymous",
- action=NonanonymousAuthAction, nargs=0,
- help="Allow access to any user long as a credentials are specified."
- )
- """
-
- def __call__(self, parser, namespace, values, option_string=None):
- passman = self.getPasswordManager(values)
- authenticator = BasicProxyAuth(passman, "mitmproxy")
- setattr(namespace, self.dest, authenticator)
-
- def getPasswordManager(self, s): # pragma: no cover
- raise NotImplementedError()
-
-
-class SingleuserAuthAction(AuthAction):
-
- def getPasswordManager(self, s):
- if len(s.split(':')) != 2:
- raise argparse.ArgumentTypeError(
- "Invalid single-user specification. Please use the format username:password"
- )
- username, password = s.split(':')
- return PassManSingleUser(username, password)
-
-
-class NonanonymousAuthAction(AuthAction):
-
- def getPasswordManager(self, s):
- return PassManNonAnon()
-
-
-class HtpasswdAuthAction(AuthAction):
-
- def getPasswordManager(self, s):
- return PassManHtpasswd(s)
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py
deleted file mode 100644
index cb816ca0..00000000
--- a/netlib/http/cookies.py
+++ /dev/null
@@ -1,384 +0,0 @@
-import collections
-import email.utils
-import re
-import time
-
-from netlib import multidict
-
-"""
-A flexible module for cookie parsing and manipulation.
-
-This module differs from usual standards-compliant cookie modules in a number
-of ways. We try to be as permissive as possible, and to retain even mal-formed
-information. Duplicate cookies are preserved in parsing, and can be set in
-formatting. We do attempt to escape and quote values where needed, but will not
-reject data that violate the specs.
-
-Parsing accepts the formats in RFC6265 and partially RFC2109 and RFC2965. We
-also parse the comma-separated variant of Set-Cookie that allows multiple
-cookies to be set in a single header. Serialization follows RFC6265.
-
- http://tools.ietf.org/html/rfc6265
- http://tools.ietf.org/html/rfc2109
- http://tools.ietf.org/html/rfc2965
-"""
-
-_cookie_params = set((
- 'expires', 'path', 'comment', 'max-age',
- 'secure', 'httponly', 'version',
-))
-
-ESCAPE = re.compile(r"([\"\\])")
-
-
-class CookieAttrs(multidict.ImmutableMultiDict):
- @staticmethod
- def _kconv(key):
- return key.lower()
-
- @staticmethod
- def _reduce_values(values):
- # See the StickyCookieTest for a weird cookie that only makes sense
- # if we take the last part.
- return values[-1]
-
-SetCookie = collections.namedtuple("SetCookie", ["value", "attrs"])
-
-
-def _read_until(s, start, term):
- """
- Read until one of the characters in term is reached.
- """
- if start == len(s):
- return "", start + 1
- for i in range(start, len(s)):
- if s[i] in term:
- return s[start:i], i
- return s[start:i + 1], i + 1
-
-
-def _read_quoted_string(s, start):
- """
- start: offset to the first quote of the string to be read
-
- A sort of loose super-set of the various quoted string specifications.
-
- RFC6265 disallows backslashes or double quotes within quoted strings.
- Prior RFCs use backslashes to escape. This leaves us free to apply
- backslash escaping by default and be compatible with everything.
- """
- escaping = False
- ret = []
- # Skip the first quote
- i = start # initialize in case the loop doesn't run.
- for i in range(start + 1, len(s)):
- if escaping:
- ret.append(s[i])
- escaping = False
- elif s[i] == '"':
- break
- elif s[i] == "\\":
- escaping = True
- else:
- ret.append(s[i])
- return "".join(ret), i + 1
-
-
-def _read_key(s, start, delims=";="):
- """
- Read a key - the LHS of a token/value pair in a cookie.
- """
- return _read_until(s, start, delims)
-
-
-def _read_value(s, start, delims):
- """
- Reads a value - the RHS of a token/value pair in a cookie.
- """
- if start >= len(s):
- return "", start
- elif s[start] == '"':
- return _read_quoted_string(s, start)
- else:
- return _read_until(s, start, delims)
-
-
-def _read_cookie_pairs(s, off=0):
- """
- Read pairs of lhs=rhs values from Cookie headers.
-
- off: start offset
- """
- pairs = []
-
- while True:
- lhs, off = _read_key(s, off)
- lhs = lhs.lstrip()
-
- if lhs:
- rhs = None
- if off < len(s) and s[off] == "=":
- rhs, off = _read_value(s, off + 1, ";")
-
- pairs.append([lhs, rhs])
-
- off += 1
-
- if not off < len(s):
- break
-
- return pairs, off
-
-
-def _read_set_cookie_pairs(s, off=0):
- """
- Read pairs of lhs=rhs values from SetCookie headers while handling multiple cookies.
-
- off: start offset
- specials: attributes that are treated specially
- """
- cookies = []
- pairs = []
-
- while True:
- lhs, off = _read_key(s, off, ";=,")
- lhs = lhs.lstrip()
-
- if lhs:
- rhs = None
- if off < len(s) and s[off] == "=":
- rhs, off = _read_value(s, off + 1, ";,")
-
- # Special handliing of attributes
- if lhs.lower() == "expires":
- # 'expires' values can contain commas in them so they need to
- # be handled separately.
-
- # We actually bank on the fact that the expires value WILL
- # contain a comma. Things will fail, if they don't.
-
- # '3' is just a heuristic we use to determine whether we've
- # only read a part of the expires value and we should read more.
- if len(rhs) <= 3:
- trail, off = _read_value(s, off + 1, ";,")
- rhs = rhs + "," + trail
-
- pairs.append([lhs, rhs])
-
- # comma marks the beginning of a new cookie
- if off < len(s) and s[off] == ",":
- cookies.append(pairs)
- pairs = []
-
- off += 1
-
- if not off < len(s):
- break
-
- if pairs or not cookies:
- cookies.append(pairs)
-
- return cookies, off
-
-
-def _has_special(s):
- for i in s:
- if i in '",;\\':
- return True
- o = ord(i)
- if o < 0x21 or o > 0x7e:
- return True
- return False
-
-
-def _format_pairs(pairs, specials=(), sep="; "):
- """
- specials: A lower-cased list of keys that will not be quoted.
- """
- vals = []
- for k, v in pairs:
- if v is None:
- vals.append(k)
- else:
- if k.lower() not in specials and _has_special(v):
- v = ESCAPE.sub(r"\\\1", v)
- v = '"%s"' % v
- vals.append("%s=%s" % (k, v))
- return sep.join(vals)
-
-
-def _format_set_cookie_pairs(lst):
- return _format_pairs(
- lst,
- specials=("expires", "path")
- )
-
-
-def parse_cookie_header(line):
- """
- Parse a Cookie header value.
- Returns a list of (lhs, rhs) tuples.
- """
- pairs, off_ = _read_cookie_pairs(line)
- return pairs
-
-
-def parse_cookie_headers(cookie_headers):
- cookie_list = []
- for header in cookie_headers:
- cookie_list.extend(parse_cookie_header(header))
- return cookie_list
-
-
-def format_cookie_header(lst):
- """
- Formats a Cookie header value.
- """
- return _format_pairs(lst)
-
-
-def parse_set_cookie_header(line):
- """
- Parse a Set-Cookie header value
-
- Returns a list of (name, value, attrs) tuples, where attrs is a
- CookieAttrs dict of attributes. No attempt is made to parse attribute
- values - they are treated purely as strings.
- """
- cookie_pairs, off = _read_set_cookie_pairs(line)
- cookies = [
- (pairs[0][0], pairs[0][1], CookieAttrs(tuple(x) for x in pairs[1:]))
- for pairs in cookie_pairs if pairs
- ]
- return cookies
-
-
-def parse_set_cookie_headers(headers):
- rv = []
- for header in headers:
- cookies = parse_set_cookie_header(header)
- if cookies:
- for name, value, attrs in cookies:
- rv.append((name, SetCookie(value, attrs)))
- return rv
-
-
-def format_set_cookie_header(set_cookies):
- """
- Formats a Set-Cookie header value.
- """
-
- rv = []
-
- for set_cookie in set_cookies:
- name, value, attrs = set_cookie
-
- pairs = [(name, value)]
- pairs.extend(
- attrs.fields if hasattr(attrs, "fields") else attrs
- )
-
- rv.append(_format_set_cookie_pairs(pairs))
-
- return ", ".join(rv)
-
-
-def refresh_set_cookie_header(c, delta):
- """
- Args:
- c: A Set-Cookie string
- delta: Time delta in seconds
- Returns:
- A refreshed Set-Cookie string
- """
-
- name, value, attrs = parse_set_cookie_header(c)[0]
- if not name or not value:
- raise ValueError("Invalid Cookie")
-
- if "expires" in attrs:
- e = email.utils.parsedate_tz(attrs["expires"])
- if e:
- f = email.utils.mktime_tz(e) + delta
- attrs = attrs.with_set_all("expires", [email.utils.formatdate(f)])
- else:
- # This can happen when the expires tag is invalid.
- # reddit.com sends a an expires tag like this: "Thu, 31 Dec
- # 2037 23:59:59 GMT", which is valid RFC 1123, but not
- # strictly correct according to the cookie spec. Browsers
- # appear to parse this tolerantly - maybe we should too.
- # For now, we just ignore this.
- attrs = attrs.with_delitem("expires")
-
- rv = format_set_cookie_header([(name, value, attrs)])
- if not rv:
- raise ValueError("Invalid Cookie")
- return rv
-
-
-def get_expiration_ts(cookie_attrs):
- """
- Determines the time when the cookie will be expired.
-
- Considering both 'expires' and 'max-age' parameters.
-
- Returns: timestamp of when the cookie will expire.
- None, if no expiration time is set.
- """
- if 'expires' in cookie_attrs:
- e = email.utils.parsedate_tz(cookie_attrs["expires"])
- if e:
- return email.utils.mktime_tz(e)
-
- elif 'max-age' in cookie_attrs:
- try:
- max_age = int(cookie_attrs['Max-Age'])
- except ValueError:
- pass
- else:
- now_ts = time.time()
- return now_ts + max_age
-
- return None
-
-
-def is_expired(cookie_attrs):
- """
- Determines whether a cookie has expired.
-
- Returns: boolean
- """
-
- exp_ts = get_expiration_ts(cookie_attrs)
- now_ts = time.time()
-
- # If no expiration information was provided with the cookie
- if exp_ts is None:
- return False
- else:
- return exp_ts <= now_ts
-
-
-def group_cookies(pairs):
- """
- Converts a list of pairs to a (name, value, attrs) for each cookie.
- """
-
- if not pairs:
- return []
-
- cookie_list = []
-
- # First pair is always a new cookie
- name, value = pairs[0]
- attrs = []
-
- for k, v in pairs[1:]:
- if k.lower() in _cookie_params:
- attrs.append((k, v))
- else:
- cookie_list.append((name, value, CookieAttrs(attrs)))
- name, value, attrs = k, v, []
-
- cookie_list.append((name, value, CookieAttrs(attrs)))
- return cookie_list
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
deleted file mode 100644
index 39673f1a..00000000
--- a/netlib/http/headers.py
+++ /dev/null
@@ -1,221 +0,0 @@
-import re
-
-import collections
-from netlib import multidict
-from netlib import strutils
-
-# See also: http://lucumr.pocoo.org/2013/7/2/the-updated-guide-to-unicode/
-
-
-# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
-def _native(x):
- return x.decode("utf-8", "surrogateescape")
-
-
-def _always_bytes(x):
- return strutils.always_bytes(x, "utf-8", "surrogateescape")
-
-
-class Headers(multidict.MultiDict):
- """
- Header class which allows both convenient access to individual headers as well as
- direct access to the underlying raw data. Provides a full dictionary interface.
-
- Example:
-
- .. code-block:: python
-
- # Create headers with keyword arguments
- >>> h = Headers(host="example.com", content_type="application/xml")
-
- # Headers mostly behave like a normal dict.
- >>> h["Host"]
- "example.com"
-
- # HTTP Headers are case insensitive
- >>> h["host"]
- "example.com"
-
- # Headers can also be created from a list of raw (header_name, header_value) byte tuples
- >>> h = Headers([
- (b"Host",b"example.com"),
- (b"Accept",b"text/html"),
- (b"accept",b"application/xml")
- ])
-
- # Multiple headers are folded into a single header as per RFC7230
- >>> h["Accept"]
- "text/html, application/xml"
-
- # Setting a header removes all existing headers with the same name.
- >>> h["Accept"] = "application/text"
- >>> h["Accept"]
- "application/text"
-
- # bytes(h) returns a HTTP1 header block.
- >>> print(bytes(h))
- Host: example.com
- Accept: application/text
-
- # For full control, the raw header fields can be accessed
- >>> h.fields
-
- Caveats:
- For use with the "Set-Cookie" header, see :py:meth:`get_all`.
- """
-
- def __init__(self, fields=(), **headers):
- """
- Args:
- fields: (optional) list of ``(name, value)`` header byte tuples,
- e.g. ``[(b"Host", b"example.com")]``. All names and values must be bytes.
- **headers: Additional headers to set. Will overwrite existing values from `fields`.
- For convenience, underscores in header names will be transformed to dashes -
- this behaviour does not extend to other methods.
- If ``**headers`` contains multiple keys that have equal ``.lower()`` s,
- the behavior is undefined.
- """
- super().__init__(fields)
-
- for key, value in self.fields:
- if not isinstance(key, bytes) or not isinstance(value, bytes):
- raise TypeError("Header fields must be bytes.")
-
- # content_type -> content-type
- headers = {
- _always_bytes(name).replace(b"_", b"-"): _always_bytes(value)
- for name, value in headers.items()
- }
- self.update(headers)
-
- @staticmethod
- def _reduce_values(values):
- # Headers can be folded
- return ", ".join(values)
-
- @staticmethod
- def _kconv(key):
- # Headers are case-insensitive
- return key.lower()
-
- def __bytes__(self):
- if self.fields:
- return b"\r\n".join(b": ".join(field) for field in self.fields) + b"\r\n"
- else:
- return b""
-
- def __delitem__(self, key):
- key = _always_bytes(key)
- super().__delitem__(key)
-
- def __iter__(self):
- for x in super().__iter__():
- yield _native(x)
-
- def get_all(self, name):
- """
- Like :py:meth:`get`, but does not fold multiple headers into a single one.
- This is useful for Set-Cookie headers, which do not support folding.
- See also: https://tools.ietf.org/html/rfc7230#section-3.2.2
- """
- name = _always_bytes(name)
- return [
- _native(x) for x in
- super().get_all(name)
- ]
-
- def set_all(self, name, values):
- """
- Explicitly set multiple headers for the given key.
- See: :py:meth:`get_all`
- """
- name = _always_bytes(name)
- values = [_always_bytes(x) for x in values]
- return super().set_all(name, values)
-
- def insert(self, index, key, value):
- key = _always_bytes(key)
- value = _always_bytes(value)
- super().insert(index, key, value)
-
- def items(self, multi=False):
- if multi:
- return (
- (_native(k), _native(v))
- for k, v in self.fields
- )
- else:
- return super().items()
-
- def replace(self, pattern, repl, flags=0, count=0):
- """
- Replaces a regular expression pattern with repl in each "name: value"
- header line.
-
- Returns:
- The number of replacements made.
- """
- if isinstance(pattern, str):
- pattern = strutils.escaped_str_to_bytes(pattern)
- if isinstance(repl, str):
- repl = strutils.escaped_str_to_bytes(repl)
- pattern = re.compile(pattern, flags)
- replacements = 0
- flag_count = count > 0
- fields = []
- for name, value in self.fields:
- line, n = pattern.subn(repl, name + b": " + value, count=count)
- try:
- name, value = line.split(b": ", 1)
- except ValueError:
- # We get a ValueError if the replacement removed the ": "
- # There's not much we can do about this, so we just keep the header as-is.
- pass
- else:
- replacements += n
- if flag_count:
- count -= n
- if count == 0:
- break
- fields.append((name, value))
- self.fields = tuple(fields)
- return replacements
-
-
-def parse_content_type(c):
- """
- A simple parser for content-type values. Returns a (type, subtype,
- parameters) tuple, where type and subtype are strings, and parameters
- is a dict. If the string could not be parsed, return None.
-
- E.g. the following string:
-
- text/html; charset=UTF-8
-
- Returns:
-
- ("text", "html", {"charset": "UTF-8"})
- """
- parts = c.split(";", 1)
- ts = parts[0].split("/", 1)
- if len(ts) != 2:
- return None
- d = collections.OrderedDict()
- if len(parts) == 2:
- for i in parts[1].split(";"):
- clause = i.split("=", 1)
- if len(clause) == 2:
- d[clause[0].strip()] = clause[1].strip()
- return ts[0].lower(), ts[1].lower(), d
-
-
-def assemble_content_type(type, subtype, parameters):
- if not parameters:
- return "{}/{}".format(type, subtype)
- params = "; ".join(
- "{}={}".format(k, v)
- for k, v in parameters.items()
- )
- return "{}/{}; {}".format(
- type, subtype, params
- )
diff --git a/netlib/http/http1/__init__.py b/netlib/http/http1/__init__.py
deleted file mode 100644
index e4bf01c5..00000000
--- a/netlib/http/http1/__init__.py
+++ /dev/null
@@ -1,24 +0,0 @@
-from .read import (
- read_request, read_request_head,
- read_response, read_response_head,
- read_body,
- connection_close,
- expected_http_body_size,
-)
-from .assemble import (
- assemble_request, assemble_request_head,
- assemble_response, assemble_response_head,
- assemble_body,
-)
-
-
-__all__ = [
- "read_request", "read_request_head",
- "read_response", "read_response_head",
- "read_body",
- "connection_close",
- "expected_http_body_size",
- "assemble_request", "assemble_request_head",
- "assemble_response", "assemble_response_head",
- "assemble_body",
-]
diff --git a/netlib/http/http1/assemble.py b/netlib/http/http1/assemble.py
deleted file mode 100644
index 3d65da34..00000000
--- a/netlib/http/http1/assemble.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import netlib.http.url
-from netlib import exceptions
-
-
-def assemble_request(request):
- if request.data.content is None:
- raise exceptions.HttpException("Cannot assemble flow with missing content")
- head = assemble_request_head(request)
- body = b"".join(assemble_body(request.data.headers, [request.data.content]))
- return head + body
-
-
-def assemble_request_head(request):
- first_line = _assemble_request_line(request.data)
- headers = _assemble_request_headers(request.data)
- return b"%s\r\n%s\r\n" % (first_line, headers)
-
-
-def assemble_response(response):
- if response.data.content is None:
- raise exceptions.HttpException("Cannot assemble flow with missing content")
- head = assemble_response_head(response)
- body = b"".join(assemble_body(response.data.headers, [response.data.content]))
- return head + body
-
-
-def assemble_response_head(response):
- first_line = _assemble_response_line(response.data)
- headers = _assemble_response_headers(response.data)
- return b"%s\r\n%s\r\n" % (first_line, headers)
-
-
-def assemble_body(headers, body_chunks):
- if "chunked" in headers.get("transfer-encoding", "").lower():
- for chunk in body_chunks:
- if chunk:
- yield b"%x\r\n%s\r\n" % (len(chunk), chunk)
- yield b"0\r\n\r\n"
- else:
- for chunk in body_chunks:
- yield chunk
-
-
-def _assemble_request_line(request_data):
- """
- Args:
- request_data (netlib.http.request.RequestData)
- """
- form = request_data.first_line_format
- if form == "relative":
- return b"%s %s %s" % (
- request_data.method,
- request_data.path,
- request_data.http_version
- )
- elif form == "authority":
- return b"%s %s:%d %s" % (
- request_data.method,
- request_data.host,
- request_data.port,
- request_data.http_version
- )
- elif form == "absolute":
- return b"%s %s://%s:%d%s %s" % (
- request_data.method,
- request_data.scheme,
- request_data.host,
- request_data.port,
- request_data.path,
- request_data.http_version
- )
- else:
- raise RuntimeError("Invalid request form")
-
-
-def _assemble_request_headers(request_data):
- """
- Args:
- request_data (netlib.http.request.RequestData)
- """
- headers = request_data.headers.copy()
- if "host" not in headers and request_data.scheme and request_data.host and request_data.port:
- headers["host"] = netlib.http.url.hostport(
- request_data.scheme,
- request_data.host,
- request_data.port
- )
- return bytes(headers)
-
-
-def _assemble_response_line(response_data):
- return b"%s %d %s" % (
- response_data.http_version,
- response_data.status_code,
- response_data.reason,
- )
-
-
-def _assemble_response_headers(response):
- return bytes(response.headers)
diff --git a/netlib/http/http1/read.py b/netlib/http/http1/read.py
deleted file mode 100644
index 4c00a96a..00000000
--- a/netlib/http/http1/read.py
+++ /dev/null
@@ -1,377 +0,0 @@
-import time
-import sys
-import re
-
-from netlib.http import request
-from netlib.http import response
-from netlib.http import headers
-from netlib.http import url
-from netlib import utils
-from netlib import exceptions
-
-
-def get_header_tokens(headers, key):
- """
- Retrieve all tokens for a header key. A number of different headers
- follow a pattern where each header line can containe comma-separated
- tokens, and headers can be set multiple times.
- """
- if key not in headers:
- return []
- tokens = headers[key].split(",")
- return [token.strip() for token in tokens]
-
-
-def read_request(rfile, body_size_limit=None):
- request = read_request_head(rfile)
- expected_body_size = expected_http_body_size(request)
- request.data.content = b"".join(read_body(rfile, expected_body_size, limit=body_size_limit))
- request.timestamp_end = time.time()
- return request
-
-
-def read_request_head(rfile):
- """
- Parse an HTTP request head (request line + headers) from an input stream
-
- Args:
- rfile: The input stream
-
- Returns:
- The HTTP request object (without body)
-
- Raises:
- exceptions.HttpReadDisconnect: No bytes can be read from rfile.
- exceptions.HttpSyntaxException: The input is malformed HTTP.
- exceptions.HttpException: Any other error occured.
- """
- timestamp_start = time.time()
- if hasattr(rfile, "reset_timestamps"):
- rfile.reset_timestamps()
-
- form, method, scheme, host, port, path, http_version = _read_request_line(rfile)
- headers = _read_headers(rfile)
-
- if hasattr(rfile, "first_byte_timestamp"):
- # more accurate timestamp_start
- timestamp_start = rfile.first_byte_timestamp
-
- return request.Request(
- form, method, scheme, host, port, path, http_version, headers, None, timestamp_start
- )
-
-
-def read_response(rfile, request, body_size_limit=None):
- response = read_response_head(rfile)
- expected_body_size = expected_http_body_size(request, response)
- response.data.content = b"".join(read_body(rfile, expected_body_size, body_size_limit))
- response.timestamp_end = time.time()
- return response
-
-
-def read_response_head(rfile):
- """
- Parse an HTTP response head (response line + headers) from an input stream
-
- Args:
- rfile: The input stream
-
- Returns:
- The HTTP request object (without body)
-
- Raises:
- exceptions.HttpReadDisconnect: No bytes can be read from rfile.
- exceptions.HttpSyntaxException: The input is malformed HTTP.
- exceptions.HttpException: Any other error occured.
- """
-
- timestamp_start = time.time()
- if hasattr(rfile, "reset_timestamps"):
- rfile.reset_timestamps()
-
- http_version, status_code, message = _read_response_line(rfile)
- headers = _read_headers(rfile)
-
- if hasattr(rfile, "first_byte_timestamp"):
- # more accurate timestamp_start
- timestamp_start = rfile.first_byte_timestamp
-
- return response.Response(http_version, status_code, message, headers, None, timestamp_start)
-
-
-def read_body(rfile, expected_size, limit=None, max_chunk_size=4096):
- """
- Read an HTTP message body
-
- Args:
- rfile: The input stream
- expected_size: The expected body size (see :py:meth:`expected_body_size`)
- limit: Maximum body size
- max_chunk_size: Maximium chunk size that gets yielded
-
- Returns:
- A generator that yields byte chunks of the content.
-
- Raises:
- exceptions.HttpException, if an error occurs
-
- Caveats:
- max_chunk_size is not considered if the transfer encoding is chunked.
- """
- if not limit or limit < 0:
- limit = sys.maxsize
- if not max_chunk_size:
- max_chunk_size = limit
-
- if expected_size is None:
- for x in _read_chunked(rfile, limit):
- yield x
- elif expected_size >= 0:
- if limit is not None and expected_size > limit:
- raise exceptions.HttpException(
- "HTTP Body too large. "
- "Limit is {}, content length was advertised as {}".format(limit, expected_size)
- )
- bytes_left = expected_size
- while bytes_left:
- chunk_size = min(bytes_left, max_chunk_size)
- content = rfile.read(chunk_size)
- if len(content) < chunk_size:
- raise exceptions.HttpException("Unexpected EOF")
- yield content
- bytes_left -= chunk_size
- else:
- bytes_left = limit
- while bytes_left:
- chunk_size = min(bytes_left, max_chunk_size)
- content = rfile.read(chunk_size)
- if not content:
- return
- yield content
- bytes_left -= chunk_size
- not_done = rfile.read(1)
- if not_done:
- raise exceptions.HttpException("HTTP body too large. Limit is {}.".format(limit))
-
-
-def connection_close(http_version, headers):
- """
- Checks the message to see if the client connection should be closed
- according to RFC 2616 Section 8.1.
- """
- # At first, check if we have an explicit Connection header.
- if "connection" in headers:
- tokens = get_header_tokens(headers, "connection")
- if "close" in tokens:
- return True
- elif "keep-alive" in tokens:
- return False
-
- # If we don't have a Connection header, HTTP 1.1 connections are assumed to
- # be persistent
- return http_version != "HTTP/1.1" and http_version != b"HTTP/1.1" # FIXME: Remove one case.
-
-
-def expected_http_body_size(request, response=None):
- """
- Returns:
- The expected body length:
- - a positive integer, if the size is known in advance
- - None, if the size in unknown in advance (chunked encoding)
- - -1, if all data should be read until end of stream.
-
- Raises:
- exceptions.HttpSyntaxException, if the content length header is invalid
- """
- # Determine response size according to
- # http://tools.ietf.org/html/rfc7230#section-3.3
- if not response:
- headers = request.headers
- response_code = None
- is_request = True
- else:
- headers = response.headers
- response_code = response.status_code
- is_request = False
-
- if is_request:
- if headers.get("expect", "").lower() == "100-continue":
- return 0
- else:
- if request.method.upper() == "HEAD":
- return 0
- if 100 <= response_code <= 199:
- return 0
- if response_code == 200 and request.method.upper() == "CONNECT":
- return 0
- if response_code in (204, 304):
- return 0
-
- if "chunked" in headers.get("transfer-encoding", "").lower():
- return None
- if "content-length" in headers:
- try:
- size = int(headers["content-length"])
- if size < 0:
- raise ValueError()
- return size
- except ValueError:
- raise exceptions.HttpSyntaxException("Unparseable Content Length")
- if is_request:
- return 0
- return -1
-
-
-def _get_first_line(rfile):
- try:
- line = rfile.readline()
- if line == b"\r\n" or line == b"\n":
- # Possible leftover from previous message
- line = rfile.readline()
- except exceptions.TcpDisconnect:
- raise exceptions.HttpReadDisconnect("Remote disconnected")
- if not line:
- raise exceptions.HttpReadDisconnect("Remote disconnected")
- return line.strip()
-
-
-def _read_request_line(rfile):
- try:
- line = _get_first_line(rfile)
- except exceptions.HttpReadDisconnect:
- # We want to provide a better error message.
- raise exceptions.HttpReadDisconnect("Client disconnected")
-
- try:
- method, path, http_version = line.split()
-
- if path == b"*" or path.startswith(b"/"):
- form = "relative"
- scheme, host, port = None, None, None
- elif method == b"CONNECT":
- form = "authority"
- host, port = _parse_authority_form(path)
- scheme, path = None, None
- else:
- form = "absolute"
- scheme, host, port, path = url.parse(path)
-
- _check_http_version(http_version)
- except ValueError:
- raise exceptions.HttpSyntaxException("Bad HTTP request line: {}".format(line))
-
- return form, method, scheme, host, port, path, http_version
-
-
-def _parse_authority_form(hostport):
- """
- Returns (host, port) if hostport is a valid authority-form host specification.
- http://tools.ietf.org/html/draft-luotonen-web-proxy-tunneling-01 section 3.1
-
- Raises:
- ValueError, if the input is malformed
- """
- try:
- host, port = hostport.split(b":")
- port = int(port)
- if not utils.is_valid_host(host) or not utils.is_valid_port(port):
- raise ValueError()
- except ValueError:
- raise exceptions.HttpSyntaxException("Invalid host specification: {}".format(hostport))
-
- return host, port
-
-
-def _read_response_line(rfile):
- try:
- line = _get_first_line(rfile)
- except exceptions.HttpReadDisconnect:
- # We want to provide a better error message.
- raise exceptions.HttpReadDisconnect("Server disconnected")
-
- try:
- parts = line.split(None, 2)
- if len(parts) == 2: # handle missing message gracefully
- parts.append(b"")
-
- http_version, status_code, message = parts
- status_code = int(status_code)
- _check_http_version(http_version)
-
- except ValueError:
- raise exceptions.HttpSyntaxException("Bad HTTP response line: {}".format(line))
-
- return http_version, status_code, message
-
-
-def _check_http_version(http_version):
- if not re.match(br"^HTTP/\d\.\d$", http_version):
- raise exceptions.HttpSyntaxException("Unknown HTTP version: {}".format(http_version))
-
-
-def _read_headers(rfile):
- """
- Read a set of headers.
- Stop once a blank line is reached.
-
- Returns:
- A headers object
-
- Raises:
- exceptions.HttpSyntaxException
- """
- ret = []
- while True:
- line = rfile.readline()
- if not line or line == b"\r\n" or line == b"\n":
- break
- if line[0] in b" \t":
- if not ret:
- raise exceptions.HttpSyntaxException("Invalid headers")
- # continued header
- ret[-1] = (ret[-1][0], ret[-1][1] + b'\r\n ' + line.strip())
- else:
- try:
- name, value = line.split(b":", 1)
- value = value.strip()
- if not name:
- raise ValueError()
- ret.append((name, value))
- except ValueError:
- raise exceptions.HttpSyntaxException(
- "Invalid header line: %s" % repr(line)
- )
- return headers.Headers(ret)
-
-
-def _read_chunked(rfile, limit=sys.maxsize):
- """
- Read a HTTP body with chunked transfer encoding.
-
- Args:
- rfile: the input file
- limit: A positive integer
- """
- total = 0
- while True:
- line = rfile.readline(128)
- if line == b"":
- raise exceptions.HttpException("Connection closed prematurely")
- if line != b"\r\n" and line != b"\n":
- try:
- length = int(line, 16)
- except ValueError:
- raise exceptions.HttpSyntaxException("Invalid chunked encoding length: {}".format(line))
- total += length
- if total > limit:
- raise exceptions.HttpException(
- "HTTP Body too large. Limit is {}, "
- "chunked content longer than {}".format(limit, total)
- )
- chunk = rfile.read(length)
- suffix = rfile.readline(5)
- if suffix != b"\r\n":
- raise exceptions.HttpSyntaxException("Malformed chunked body")
- if length == 0:
- return
- yield chunk
diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py
deleted file mode 100644
index 20cc63a0..00000000
--- a/netlib/http/http2/__init__.py
+++ /dev/null
@@ -1,8 +0,0 @@
-from netlib.http.http2.framereader import read_raw_frame, parse_frame
-from netlib.http.http2.utils import parse_headers
-
-__all__ = [
- "read_raw_frame",
- "parse_frame",
- "parse_headers",
-]
diff --git a/netlib/http/http2/framereader.py b/netlib/http/http2/framereader.py
deleted file mode 100644
index 8b7cfffb..00000000
--- a/netlib/http/http2/framereader.py
+++ /dev/null
@@ -1,25 +0,0 @@
-import codecs
-
-import hyperframe
-from ...exceptions import HttpException
-
-
-def read_raw_frame(rfile):
- header = rfile.safe_read(9)
- length = int(codecs.encode(header[:3], 'hex_codec'), 16)
-
- if length == 4740180:
- raise HttpException("Length field looks more like HTTP/1.1:\n{}".format(rfile.read(-1)))
-
- body = rfile.safe_read(length)
- return [header, body]
-
-
-def parse_frame(header, body=None):
- if body is None:
- body = header[9:]
- header = header[:9]
-
- frame, length = hyperframe.frame.Frame.parse_frame_header(header)
- frame.parse_body(memoryview(body))
- return frame
diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py
deleted file mode 100644
index 164bacc8..00000000
--- a/netlib/http/http2/utils.py
+++ /dev/null
@@ -1,37 +0,0 @@
-from netlib.http import url
-
-
-def parse_headers(headers):
- authority = headers.get(':authority', '').encode()
- method = headers.get(':method', 'GET').encode()
- scheme = headers.get(':scheme', 'https').encode()
- path = headers.get(':path', '/').encode()
-
- headers.pop(":method", None)
- headers.pop(":scheme", None)
- headers.pop(":path", None)
-
- host = None
- port = None
-
- if path == b'*' or path.startswith(b"/"):
- first_line_format = "relative"
- elif method == b'CONNECT': # pragma: no cover
- raise NotImplementedError("CONNECT over HTTP/2 is not implemented.")
- else: # pragma: no cover
- first_line_format = "absolute"
- # FIXME: verify if path or :host contains what we need
- scheme, host, port, _ = url.parse(path)
-
- if authority:
- host, _, port = authority.partition(b':')
-
- if not host:
- host = b'localhost'
-
- if not port:
- port = 443 if scheme == b'https' else 80
-
- port = int(port)
-
- return first_line_format, method, scheme, host, port, path
diff --git a/netlib/http/message.py b/netlib/http/message.py
deleted file mode 100644
index 1980b0ab..00000000
--- a/netlib/http/message.py
+++ /dev/null
@@ -1,298 +0,0 @@
-import re
-import warnings
-from typing import Optional
-
-from netlib import encoding, strutils, basetypes
-from netlib.http import headers
-
-
-# While headers _should_ be ASCII, it's not uncommon for certain headers to be utf-8 encoded.
-def _native(x):
- return x.decode("utf-8", "surrogateescape")
-
-
-def _always_bytes(x):
- return strutils.always_bytes(x, "utf-8", "surrogateescape")
-
-
-class MessageData(basetypes.Serializable):
- def __eq__(self, other):
- if isinstance(other, MessageData):
- return self.__dict__ == other.__dict__
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def set_state(self, state):
- for k, v in state.items():
- if k == "headers":
- v = headers.Headers.from_state(v)
- setattr(self, k, v)
-
- def get_state(self):
- state = vars(self).copy()
- state["headers"] = state["headers"].get_state()
- return state
-
- @classmethod
- def from_state(cls, state):
- state["headers"] = headers.Headers.from_state(state["headers"])
- return cls(**state)
-
-
-class Message(basetypes.Serializable):
- def __eq__(self, other):
- if isinstance(other, Message):
- return self.data == other.data
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def get_state(self):
- return self.data.get_state()
-
- def set_state(self, state):
- self.data.set_state(state)
-
- @classmethod
- def from_state(cls, state):
- state["headers"] = headers.Headers.from_state(state["headers"])
- return cls(**state)
-
- @property
- def headers(self):
- """
- Message headers object
-
- Returns:
- netlib.http.Headers
- """
- return self.data.headers
-
- @headers.setter
- def headers(self, h):
- self.data.headers = h
-
- @property
- def raw_content(self) -> bytes:
- """
- The raw (encoded) HTTP message body
-
- See also: :py:attr:`content`, :py:class:`text`
- """
- return self.data.content
-
- @raw_content.setter
- def raw_content(self, content):
- self.data.content = content
-
- def get_content(self, strict: bool=True) -> bytes:
- """
- The HTTP message body decoded with the content-encoding header (e.g. gzip)
-
- Raises:
- ValueError, when the content-encoding is invalid and strict is True.
-
- See also: :py:class:`raw_content`, :py:attr:`text`
- """
- if self.raw_content is None:
- return None
- ce = self.headers.get("content-encoding")
- if ce:
- try:
- return encoding.decode(self.raw_content, ce)
- except ValueError:
- if strict:
- raise
- return self.raw_content
- else:
- return self.raw_content
-
- def set_content(self, value):
- if value is None:
- self.raw_content = None
- return
- if not isinstance(value, bytes):
- raise TypeError(
- "Message content must be bytes, not {}. "
- "Please use .text if you want to assign a str."
- .format(type(value).__name__)
- )
- ce = self.headers.get("content-encoding")
- try:
- self.raw_content = encoding.encode(value, ce or "identity")
- except ValueError:
- # So we have an invalid content-encoding?
- # Let's remove it!
- del self.headers["content-encoding"]
- self.raw_content = value
- self.headers["content-length"] = str(len(self.raw_content))
-
- content = property(get_content, set_content)
-
- @property
- def http_version(self):
- """
- Version string, e.g. "HTTP/1.1"
- """
- return _native(self.data.http_version)
-
- @http_version.setter
- def http_version(self, http_version):
- self.data.http_version = _always_bytes(http_version)
-
- @property
- def timestamp_start(self):
- """
- First byte timestamp
- """
- return self.data.timestamp_start
-
- @timestamp_start.setter
- def timestamp_start(self, timestamp_start):
- self.data.timestamp_start = timestamp_start
-
- @property
- def timestamp_end(self):
- """
- Last byte timestamp
- """
- return self.data.timestamp_end
-
- @timestamp_end.setter
- def timestamp_end(self, timestamp_end):
- self.data.timestamp_end = timestamp_end
-
- def _get_content_type_charset(self) -> Optional[str]:
- ct = headers.parse_content_type(self.headers.get("content-type", ""))
- if ct:
- return ct[2].get("charset")
-
- def _guess_encoding(self) -> str:
- enc = self._get_content_type_charset()
- if enc:
- return enc
-
- if "json" in self.headers.get("content-type", ""):
- return "utf8"
- else:
- # We may also want to check for HTML meta tags here at some point.
- return "latin-1"
-
- def get_text(self, strict: bool=True) -> str:
- """
- The HTTP message body decoded with both content-encoding header (e.g. gzip)
- and content-type header charset.
-
- Raises:
- ValueError, when either content-encoding or charset is invalid and strict is True.
-
- See also: :py:attr:`content`, :py:class:`raw_content`
- """
- if self.raw_content is None:
- return None
- enc = self._guess_encoding()
-
- content = self.get_content(strict)
- try:
- return encoding.decode(content, enc)
- except ValueError:
- if strict:
- raise
- return content.decode("utf8", "surrogateescape")
-
- def set_text(self, text):
- if text is None:
- self.content = None
- return
- enc = self._guess_encoding()
-
- try:
- self.content = encoding.encode(text, enc)
- except ValueError:
- # Fall back to UTF-8 and update the content-type header.
- ct = headers.parse_content_type(self.headers.get("content-type", "")) or ("text", "plain", {})
- ct[2]["charset"] = "utf-8"
- self.headers["content-type"] = headers.assemble_content_type(*ct)
- enc = "utf8"
- self.content = text.encode(enc, "surrogateescape")
-
- text = property(get_text, set_text)
-
- def decode(self, strict=True):
- """
- Decodes body based on the current Content-Encoding header, then
- removes the header. If there is no Content-Encoding header, no
- action is taken.
-
- Raises:
- ValueError, when the content-encoding is invalid and strict is True.
- """
- self.raw_content = self.get_content(strict)
- self.headers.pop("content-encoding", None)
-
- def encode(self, e):
- """
- Encodes body with the encoding e, where e is "gzip", "deflate", "identity", or "br".
- Any existing content-encodings are overwritten,
- the content is not decoded beforehand.
-
- Raises:
- ValueError, when the specified content-encoding is invalid.
- """
- self.headers["content-encoding"] = e
- self.content = self.raw_content
- if "content-encoding" not in self.headers:
- raise ValueError("Invalid content encoding {}".format(repr(e)))
-
- def replace(self, pattern, repl, flags=0, count=0):
- """
- Replaces a regular expression pattern with repl in both the headers
- and the body of the message. Encoded body will be decoded
- before replacement, and re-encoded afterwards.
-
- Returns:
- The number of replacements made.
- """
- if isinstance(pattern, str):
- pattern = strutils.escaped_str_to_bytes(pattern)
- if isinstance(repl, str):
- repl = strutils.escaped_str_to_bytes(repl)
- replacements = 0
- if self.content:
- self.content, replacements = re.subn(
- pattern, repl, self.content, flags=flags, count=count
- )
- replacements += self.headers.replace(pattern, repl, flags=flags, count=count)
- return replacements
-
- # Legacy
-
- @property
- def body(self): # pragma: no cover
- warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning)
- return self.content
-
- @body.setter
- def body(self, body): # pragma: no cover
- warnings.warn(".body is deprecated, use .content instead.", DeprecationWarning)
- self.content = body
-
-
-class decoded:
- """
- Deprecated: You can now directly use :py:attr:`content`.
- :py:attr:`raw_content` has the encoded content.
- """
-
- def __init__(self, message): # pragma no cover
- warnings.warn("decoded() is deprecated, you can now directly use .content instead. "
- ".raw_content has the encoded content.", DeprecationWarning)
-
- def __enter__(self): # pragma no cover
- pass
-
- def __exit__(self, type, value, tb): # pragma no cover
- pass
diff --git a/netlib/http/multipart.py b/netlib/http/multipart.py
deleted file mode 100644
index 536b2809..00000000
--- a/netlib/http/multipart.py
+++ /dev/null
@@ -1,32 +0,0 @@
-import re
-
-from netlib.http import headers
-
-
-def decode(hdrs, content):
- """
- Takes a multipart boundary encoded string and returns list of (key, value) tuples.
- """
- v = hdrs.get("content-type")
- if v:
- v = headers.parse_content_type(v)
- if not v:
- return []
- try:
- boundary = v[2]["boundary"].encode("ascii")
- except (KeyError, UnicodeError):
- return []
-
- rx = re.compile(br'\bname="([^"]+)"')
- r = []
-
- for i in content.split(b"--" + boundary):
- parts = i.splitlines()
- if len(parts) > 1 and parts[0][0:2] != b"--":
- match = rx.search(parts[1])
- if match:
- key = match.group(1)
- value = b"".join(parts[3 + parts[2:].index(b""):])
- r.append((key, value))
- return r
- return []
diff --git a/netlib/http/request.py b/netlib/http/request.py
deleted file mode 100644
index dd6f4164..00000000
--- a/netlib/http/request.py
+++ /dev/null
@@ -1,405 +0,0 @@
-import re
-import urllib
-
-from netlib import multidict
-from netlib import strutils
-from netlib.http import multipart
-from netlib.http import cookies
-from netlib.http import headers as nheaders
-from netlib.http import message
-import netlib.http.url
-
-# This regex extracts & splits the host header into host and port.
-# Handles the edge case of IPv6 addresses containing colons.
-# https://bugzilla.mozilla.org/show_bug.cgi?id=45891
-host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
-
-
-class RequestData(message.MessageData):
- def __init__(
- self,
- first_line_format,
- method,
- scheme,
- host,
- port,
- path,
- http_version,
- headers=(),
- content=None,
- timestamp_start=None,
- timestamp_end=None
- ):
- if isinstance(method, str):
- method = method.encode("ascii", "strict")
- if isinstance(scheme, str):
- scheme = scheme.encode("ascii", "strict")
- if isinstance(host, str):
- host = host.encode("idna", "strict")
- if isinstance(path, str):
- path = path.encode("ascii", "strict")
- if isinstance(http_version, str):
- http_version = http_version.encode("ascii", "strict")
- if not isinstance(headers, nheaders.Headers):
- headers = nheaders.Headers(headers)
- if isinstance(content, str):
- raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
-
- self.first_line_format = first_line_format
- self.method = method
- self.scheme = scheme
- self.host = host
- self.port = port
- self.path = path
- self.http_version = http_version
- self.headers = headers
- self.content = content
- self.timestamp_start = timestamp_start
- self.timestamp_end = timestamp_end
-
-
-class Request(message.Message):
- """
- An HTTP request.
- """
- def __init__(self, *args, **kwargs):
- super().__init__()
- self.data = RequestData(*args, **kwargs)
-
- def __repr__(self):
- if self.host and self.port:
- hostport = "{}:{}".format(self.host, self.port)
- else:
- hostport = ""
- path = self.path or ""
- return "Request({} {}{})".format(
- self.method, hostport, path
- )
-
- def replace(self, pattern, repl, flags=0, count=0):
- """
- Replaces a regular expression pattern with repl in the headers, the
- request path and the body of the request. Encoded content will be
- decoded before replacement, and re-encoded afterwards.
-
- Returns:
- The number of replacements made.
- """
- if isinstance(pattern, str):
- pattern = strutils.escaped_str_to_bytes(pattern)
- if isinstance(repl, str):
- repl = strutils.escaped_str_to_bytes(repl)
-
- c = super().replace(pattern, repl, flags, count)
- self.path, pc = re.subn(
- pattern, repl, self.data.path, flags=flags, count=count
- )
- c += pc
- return c
-
- @property
- def first_line_format(self):
- """
- HTTP request form as defined in `RFC7230 <https://tools.ietf.org/html/rfc7230#section-5.3>`_.
-
- origin-form and asterisk-form are subsumed as "relative".
- """
- return self.data.first_line_format
-
- @first_line_format.setter
- def first_line_format(self, first_line_format):
- self.data.first_line_format = first_line_format
-
- @property
- def method(self):
- """
- HTTP request method, e.g. "GET".
- """
- return message._native(self.data.method).upper()
-
- @method.setter
- def method(self, method):
- self.data.method = message._always_bytes(method)
-
- @property
- def scheme(self):
- """
- HTTP request scheme, which should be "http" or "https".
- """
- if not self.data.scheme:
- return self.data.scheme
- return message._native(self.data.scheme)
-
- @scheme.setter
- def scheme(self, scheme):
- self.data.scheme = message._always_bytes(scheme)
-
- @property
- def host(self):
- """
- Target host. This may be parsed from the raw request
- (e.g. from a ``GET http://example.com/ HTTP/1.1`` request line)
- or inferred from the proxy mode (e.g. an IP in transparent mode).
-
- Setting the host attribute also updates the host header, if present.
- """
- if not self.data.host:
- return self.data.host
- try:
- return self.data.host.decode("idna")
- except UnicodeError:
- return self.data.host.decode("utf8", "surrogateescape")
-
- @host.setter
- def host(self, host):
- if isinstance(host, str):
- try:
- # There's no non-strict mode for IDNA encoding.
- # We don't want this operation to fail though, so we try
- # utf8 as a last resort.
- host = host.encode("idna", "strict")
- except UnicodeError:
- host = host.encode("utf8", "surrogateescape")
-
- self.data.host = host
-
- # Update host header
- if "host" in self.headers:
- if host:
- self.headers["host"] = host
- else:
- self.headers.pop("host")
-
- @property
- def port(self):
- """
- Target port
- """
- return self.data.port
-
- @port.setter
- def port(self, port):
- self.data.port = port
-
- @property
- def path(self):
- """
- HTTP request path, e.g. "/index.html".
- Guaranteed to start with a slash, except for OPTIONS requests, which may just be "*".
- """
- if self.data.path is None:
- return None
- else:
- return message._native(self.data.path)
-
- @path.setter
- def path(self, path):
- self.data.path = message._always_bytes(path)
-
- @property
- def url(self):
- """
- The URL string, constructed from the request's URL components
- """
- if self.first_line_format == "authority":
- return "%s:%d" % (self.host, self.port)
- return netlib.http.url.unparse(self.scheme, self.host, self.port, self.path)
-
- @url.setter
- def url(self, url):
- self.scheme, self.host, self.port, self.path = netlib.http.url.parse(url)
-
- def _parse_host_header(self):
- """Extract the host and port from Host header"""
- if "host" not in self.headers:
- return None, None
- host, port = self.headers["host"], None
- m = host_header_re.match(host)
- if m:
- host = m.group("host").strip("[]")
- if m.group("port"):
- port = int(m.group("port"))
- return host, port
-
- @property
- def pretty_host(self):
- """
- Similar to :py:attr:`host`, but using the Host headers as an additional preferred data source.
- This is useful in transparent mode where :py:attr:`host` is only an IP address,
- but may not reflect the actual destination as the Host header could be spoofed.
- """
- host, port = self._parse_host_header()
- if not host:
- return self.host
- if not port:
- port = 443 if self.scheme == 'https' else 80
- # Prefer the original address if host header has an unexpected form
- return host if port == self.port else self.host
-
- @property
- def pretty_url(self):
- """
- Like :py:attr:`url`, but using :py:attr:`pretty_host` instead of :py:attr:`host`.
- """
- if self.first_line_format == "authority":
- return "%s:%d" % (self.pretty_host, self.port)
- return netlib.http.url.unparse(self.scheme, self.pretty_host, self.port, self.path)
-
- @property
- def query(self) -> multidict.MultiDictView:
- """
- The request query string as an :py:class:`~netlib.multidict.MultiDictView` object.
- """
- return multidict.MultiDictView(
- self._get_query,
- self._set_query
- )
-
- def _get_query(self):
- query = urllib.parse.urlparse(self.url).query
- return tuple(netlib.http.url.decode(query))
-
- def _set_query(self, query_data):
- query = netlib.http.url.encode(query_data)
- _, _, path, params, _, fragment = urllib.parse.urlparse(self.url)
- self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
-
- @query.setter
- def query(self, value):
- self._set_query(value)
-
- @property
- def cookies(self) -> multidict.MultiDictView:
- """
- The request cookies.
-
- An empty :py:class:`~netlib.multidict.MultiDictView` object if the cookie monster ate them all.
- """
- return multidict.MultiDictView(
- self._get_cookies,
- self._set_cookies
- )
-
- def _get_cookies(self):
- h = self.headers.get_all("Cookie")
- return tuple(cookies.parse_cookie_headers(h))
-
- def _set_cookies(self, value):
- self.headers["cookie"] = cookies.format_cookie_header(value)
-
- @cookies.setter
- def cookies(self, value):
- self._set_cookies(value)
-
- @property
- def path_components(self):
- """
- The URL's path components as a tuple of strings.
- Components are unquoted.
- """
- path = urllib.parse.urlparse(self.url).path
- # This needs to be a tuple so that it's immutable.
- # Otherwise, this would fail silently:
- # request.path_components.append("foo")
- return tuple(netlib.http.url.unquote(i) for i in path.split("/") if i)
-
- @path_components.setter
- def path_components(self, components):
- components = map(lambda x: netlib.http.url.quote(x, safe=""), components)
- path = "/" + "/".join(components)
- _, _, _, params, query, fragment = urllib.parse.urlparse(self.url)
- self.path = urllib.parse.urlunparse(["", "", path, params, query, fragment])
-
- def anticache(self):
- """
- Modifies this request to remove headers that might produce a cached
- response. That is, we remove ETags and If-Modified-Since headers.
- """
- delheaders = [
- "if-modified-since",
- "if-none-match",
- ]
- for i in delheaders:
- self.headers.pop(i, None)
-
- def anticomp(self):
- """
- Modifies this request to remove headers that will compress the
- resource's data.
- """
- self.headers["accept-encoding"] = "identity"
-
- def constrain_encoding(self):
- """
- Limits the permissible Accept-Encoding values, based on what we can
- decode appropriately.
- """
- accept_encoding = self.headers.get("accept-encoding")
- if accept_encoding:
- self.headers["accept-encoding"] = (
- ', '.join(
- e
- for e in {"gzip", "identity", "deflate", "br"}
- if e in accept_encoding
- )
- )
-
- @property
- def urlencoded_form(self):
- """
- The URL-encoded form data as an :py:class:`~netlib.multidict.MultiDictView` object.
- An empty multidict.MultiDictView if the content-type indicates non-form data
- or the content could not be parsed.
- """
- return multidict.MultiDictView(
- self._get_urlencoded_form,
- self._set_urlencoded_form
- )
-
- def _get_urlencoded_form(self):
- is_valid_content_type = "application/x-www-form-urlencoded" in self.headers.get("content-type", "").lower()
- if is_valid_content_type:
- try:
- return tuple(netlib.http.url.decode(self.content))
- except ValueError:
- pass
- return ()
-
- def _set_urlencoded_form(self, form_data):
- """
- Sets the body to the URL-encoded form data, and adds the appropriate content-type header.
- This will overwrite the existing content if there is one.
- """
- self.headers["content-type"] = "application/x-www-form-urlencoded"
- self.content = netlib.http.url.encode(form_data).encode()
-
- @urlencoded_form.setter
- def urlencoded_form(self, value):
- self._set_urlencoded_form(value)
-
- @property
- def multipart_form(self):
- """
- The multipart form data as an :py:class:`~netlib.multidict.MultiDictView` object.
- None if the content-type indicates non-form data.
- """
- return multidict.MultiDictView(
- self._get_multipart_form,
- self._set_multipart_form
- )
-
- def _get_multipart_form(self):
- is_valid_content_type = "multipart/form-data" in self.headers.get("content-type", "").lower()
- if is_valid_content_type:
- try:
- return multipart.decode(self.headers, self.content)
- except ValueError:
- pass
- return ()
-
- def _set_multipart_form(self, value):
- raise NotImplementedError()
-
- @multipart_form.setter
- def multipart_form(self, value):
- self._set_multipart_form(value)
diff --git a/netlib/http/response.py b/netlib/http/response.py
deleted file mode 100644
index a8b48be0..00000000
--- a/netlib/http/response.py
+++ /dev/null
@@ -1,192 +0,0 @@
-import time
-from email.utils import parsedate_tz, formatdate, mktime_tz
-from netlib import human
-from netlib import multidict
-from netlib.http import cookies
-from netlib.http import headers as nheaders
-from netlib.http import message
-from netlib.http import status_codes
-from typing import AnyStr
-from typing import Dict
-from typing import Iterable
-from typing import Tuple
-from typing import Union
-
-
-class ResponseData(message.MessageData):
- def __init__(
- self,
- http_version,
- status_code,
- reason=None,
- headers=(),
- content=None,
- timestamp_start=None,
- timestamp_end=None
- ):
- if isinstance(http_version, str):
- http_version = http_version.encode("ascii", "strict")
- if isinstance(reason, str):
- reason = reason.encode("ascii", "strict")
- if not isinstance(headers, nheaders.Headers):
- headers = nheaders.Headers(headers)
- if isinstance(content, str):
- raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
-
- self.http_version = http_version
- self.status_code = status_code
- self.reason = reason
- self.headers = headers
- self.content = content
- self.timestamp_start = timestamp_start
- self.timestamp_end = timestamp_end
-
-
-class Response(message.Message):
- """
- An HTTP response.
- """
- def __init__(self, *args, **kwargs):
- super().__init__()
- self.data = ResponseData(*args, **kwargs)
-
- def __repr__(self):
- if self.raw_content:
- details = "{}, {}".format(
- self.headers.get("content-type", "unknown content type"),
- human.pretty_size(len(self.raw_content))
- )
- else:
- details = "no content"
- return "Response({status_code} {reason}, {details})".format(
- status_code=self.status_code,
- reason=self.reason,
- details=details
- )
-
- @classmethod
- def make(
- cls,
- status_code: int=200,
- content: AnyStr=b"",
- headers: Union[Dict[AnyStr, AnyStr], Iterable[Tuple[bytes, bytes]]]=()
- ):
- """
- Simplified API for creating response objects.
- """
- resp = cls(
- b"HTTP/1.1",
- status_code,
- status_codes.RESPONSES.get(status_code, "").encode(),
- (),
- None
- )
-
- # Headers can be list or dict, we differentiate here.
- if isinstance(headers, dict):
- resp.headers = nheaders.Headers(**headers)
- elif isinstance(headers, Iterable):
- resp.headers = nheaders.Headers(headers)
- else:
- raise TypeError("Expected headers to be an iterable or dict, but is {}.".format(
- type(headers).__name__
- ))
-
- # Assign this manually to update the content-length header.
- if isinstance(content, bytes):
- resp.content = content
- elif isinstance(content, str):
- resp.text = content
- else:
- raise TypeError("Expected content to be str or bytes, but is {}.".format(
- type(content).__name__
- ))
-
- return resp
-
- @property
- def status_code(self):
- """
- HTTP Status Code, e.g. ``200``.
- """
- return self.data.status_code
-
- @status_code.setter
- def status_code(self, status_code):
- self.data.status_code = status_code
-
- @property
- def reason(self):
- """
- HTTP Reason Phrase, e.g. "Not Found".
- This is always :py:obj:`None` for HTTP2 requests, because HTTP2 responses do not contain a reason phrase.
- """
- return message._native(self.data.reason)
-
- @reason.setter
- def reason(self, reason):
- self.data.reason = message._always_bytes(reason)
-
- @property
- def cookies(self) -> multidict.MultiDictView:
- """
- The response cookies. A possibly empty
- :py:class:`~netlib.multidict.MultiDictView`, where the keys are cookie
- name strings, and values are (value, attr) tuples. Value is a string,
- and attr is an MultiDictView containing cookie attributes. Within
- attrs, unary attributes (e.g. HTTPOnly) are indicated by a Null value.
-
- Caveats:
- Updating the attr
- """
- return multidict.MultiDictView(
- self._get_cookies,
- self._set_cookies
- )
-
- def _get_cookies(self):
- h = self.headers.get_all("set-cookie")
- return tuple(cookies.parse_set_cookie_headers(h))
-
- def _set_cookies(self, value):
- cookie_headers = []
- for k, v in value:
- header = cookies.format_set_cookie_header([(k, v[0], v[1])])
- cookie_headers.append(header)
- self.headers.set_all("set-cookie", cookie_headers)
-
- @cookies.setter
- def cookies(self, value):
- self._set_cookies(value)
-
- def refresh(self, now=None):
- """
- This fairly complex and heuristic function refreshes a server
- response for replay.
-
- - It adjusts date, expires and last-modified headers.
- - It adjusts cookie expiration.
- """
- if not now:
- now = time.time()
- delta = now - self.timestamp_start
- refresh_headers = [
- "date",
- "expires",
- "last-modified",
- ]
- for i in refresh_headers:
- if i in self.headers:
- d = parsedate_tz(self.headers[i])
- if d:
- new = mktime_tz(d) + delta
- self.headers[i] = formatdate(new)
- c = []
- for set_cookie_header in self.headers.get_all("set-cookie"):
- try:
- refreshed = cookies.refresh_set_cookie_header(set_cookie_header, delta)
- except ValueError:
- refreshed = set_cookie_header
- c.append(refreshed)
- if c:
- self.headers.set_all("set-cookie", c)
diff --git a/netlib/http/status_codes.py b/netlib/http/status_codes.py
deleted file mode 100644
index 5a83cd73..00000000
--- a/netlib/http/status_codes.py
+++ /dev/null
@@ -1,104 +0,0 @@
-CONTINUE = 100
-SWITCHING = 101
-OK = 200
-CREATED = 201
-ACCEPTED = 202
-NON_AUTHORITATIVE_INFORMATION = 203
-NO_CONTENT = 204
-RESET_CONTENT = 205
-PARTIAL_CONTENT = 206
-MULTI_STATUS = 207
-
-MULTIPLE_CHOICE = 300
-MOVED_PERMANENTLY = 301
-FOUND = 302
-SEE_OTHER = 303
-NOT_MODIFIED = 304
-USE_PROXY = 305
-TEMPORARY_REDIRECT = 307
-
-BAD_REQUEST = 400
-UNAUTHORIZED = 401
-PAYMENT_REQUIRED = 402
-FORBIDDEN = 403
-NOT_FOUND = 404
-NOT_ALLOWED = 405
-NOT_ACCEPTABLE = 406
-PROXY_AUTH_REQUIRED = 407
-REQUEST_TIMEOUT = 408
-CONFLICT = 409
-GONE = 410
-LENGTH_REQUIRED = 411
-PRECONDITION_FAILED = 412
-REQUEST_ENTITY_TOO_LARGE = 413
-REQUEST_URI_TOO_LONG = 414
-UNSUPPORTED_MEDIA_TYPE = 415
-REQUESTED_RANGE_NOT_SATISFIABLE = 416
-EXPECTATION_FAILED = 417
-IM_A_TEAPOT = 418
-
-INTERNAL_SERVER_ERROR = 500
-NOT_IMPLEMENTED = 501
-BAD_GATEWAY = 502
-SERVICE_UNAVAILABLE = 503
-GATEWAY_TIMEOUT = 504
-HTTP_VERSION_NOT_SUPPORTED = 505
-INSUFFICIENT_STORAGE_SPACE = 507
-NOT_EXTENDED = 510
-
-RESPONSES = {
- # 100
- CONTINUE: "Continue",
- SWITCHING: "Switching Protocols",
-
- # 200
- OK: "OK",
- CREATED: "Created",
- ACCEPTED: "Accepted",
- NON_AUTHORITATIVE_INFORMATION: "Non-Authoritative Information",
- NO_CONTENT: "No Content",
- RESET_CONTENT: "Reset Content.",
- PARTIAL_CONTENT: "Partial Content",
- MULTI_STATUS: "Multi-Status",
-
- # 300
- MULTIPLE_CHOICE: "Multiple Choices",
- MOVED_PERMANENTLY: "Moved Permanently",
- FOUND: "Found",
- SEE_OTHER: "See Other",
- NOT_MODIFIED: "Not Modified",
- USE_PROXY: "Use Proxy",
- # 306 not defined??
- TEMPORARY_REDIRECT: "Temporary Redirect",
-
- # 400
- BAD_REQUEST: "Bad Request",
- UNAUTHORIZED: "Unauthorized",
- PAYMENT_REQUIRED: "Payment Required",
- FORBIDDEN: "Forbidden",
- NOT_FOUND: "Not Found",
- NOT_ALLOWED: "Method Not Allowed",
- NOT_ACCEPTABLE: "Not Acceptable",
- PROXY_AUTH_REQUIRED: "Proxy Authentication Required",
- REQUEST_TIMEOUT: "Request Time-out",
- CONFLICT: "Conflict",
- GONE: "Gone",
- LENGTH_REQUIRED: "Length Required",
- PRECONDITION_FAILED: "Precondition Failed",
- REQUEST_ENTITY_TOO_LARGE: "Request Entity Too Large",
- REQUEST_URI_TOO_LONG: "Request-URI Too Long",
- UNSUPPORTED_MEDIA_TYPE: "Unsupported Media Type",
- REQUESTED_RANGE_NOT_SATISFIABLE: "Requested Range not satisfiable",
- EXPECTATION_FAILED: "Expectation Failed",
- IM_A_TEAPOT: "I'm a teapot",
-
- # 500
- INTERNAL_SERVER_ERROR: "Internal Server Error",
- NOT_IMPLEMENTED: "Not Implemented",
- BAD_GATEWAY: "Bad Gateway",
- SERVICE_UNAVAILABLE: "Service Unavailable",
- GATEWAY_TIMEOUT: "Gateway Time-out",
- HTTP_VERSION_NOT_SUPPORTED: "HTTP Version not supported",
- INSUFFICIENT_STORAGE_SPACE: "Insufficient Storage Space",
- NOT_EXTENDED: "Not Extended"
-}
diff --git a/netlib/http/url.py b/netlib/http/url.py
deleted file mode 100644
index 67e22efa..00000000
--- a/netlib/http/url.py
+++ /dev/null
@@ -1,127 +0,0 @@
-import urllib
-from typing import Sequence
-from typing import Tuple
-
-from netlib import utils
-
-
-# PY2 workaround
-def decode_parse_result(result, enc):
- if hasattr(result, "decode"):
- return result.decode(enc)
- else:
- return urllib.parse.ParseResult(*[x.decode(enc) for x in result])
-
-
-# PY2 workaround
-def encode_parse_result(result, enc):
- if hasattr(result, "encode"):
- return result.encode(enc)
- else:
- return urllib.parse.ParseResult(*[x.encode(enc) for x in result])
-
-
-def parse(url):
- """
- URL-parsing function that checks that
- - port is an integer 0-65535
- - host is a valid IDNA-encoded hostname with no null-bytes
- - path is valid ASCII
-
- Args:
- A URL (as bytes or as unicode)
-
- Returns:
- A (scheme, host, port, path) tuple
-
- Raises:
- ValueError, if the URL is not properly formatted.
- """
- parsed = urllib.parse.urlparse(url)
-
- if not parsed.hostname:
- raise ValueError("No hostname given")
-
- if isinstance(url, bytes):
- host = parsed.hostname
-
- # this should not raise a ValueError,
- # but we try to be very forgiving here and accept just everything.
- # decode_parse_result(parsed, "ascii")
- else:
- host = parsed.hostname.encode("idna")
- parsed = encode_parse_result(parsed, "ascii")
-
- port = parsed.port
- if not port:
- port = 443 if parsed.scheme == b"https" else 80
-
- full_path = urllib.parse.urlunparse(
- (b"", b"", parsed.path, parsed.params, parsed.query, parsed.fragment)
- )
- if not full_path.startswith(b"/"):
- full_path = b"/" + full_path
-
- if not utils.is_valid_host(host):
- raise ValueError("Invalid Host")
- if not utils.is_valid_port(port):
- raise ValueError("Invalid Port")
-
- return parsed.scheme, host, port, full_path
-
-
-def unparse(scheme, host, port, path=""):
- """
- Returns a URL string, constructed from the specified components.
-
- Args:
- All args must be str.
- """
- if path == "*":
- path = ""
- return "%s://%s%s" % (scheme, hostport(scheme, host, port), path)
-
-
-def encode(s: Sequence[Tuple[str, str]]) -> str:
- """
- Takes a list of (key, value) tuples and returns a urlencoded string.
- """
- return urllib.parse.urlencode(s, False, errors="surrogateescape")
-
-
-def decode(s):
- """
- Takes a urlencoded string and returns a list of surrogate-escaped (key, value) tuples.
- """
- return urllib.parse.parse_qsl(s, keep_blank_values=True, errors='surrogateescape')
-
-
-def quote(b: str, safe: str="/") -> str:
- """
- Returns:
- An ascii-encodable str.
- """
- return urllib.parse.quote(b, safe=safe, errors="surrogateescape")
-
-
-def unquote(s: str) -> str:
- """
- Args:
- s: A surrogate-escaped str
- Returns:
- A surrogate-escaped str
- """
- return urllib.parse.unquote(s, errors="surrogateescape")
-
-
-def hostport(scheme, host, port):
- """
- Returns the host component, with a port specifcation if needed.
- """
- if (port, scheme) in [(80, "http"), (443, "https"), (80, b"http"), (443, b"https")]:
- return host
- else:
- if isinstance(host, bytes):
- return b"%s:%d" % (host, port)
- else:
- return "%s:%d" % (host, port)
diff --git a/netlib/http/user_agents.py b/netlib/http/user_agents.py
deleted file mode 100644
index d0ca2f21..00000000
--- a/netlib/http/user_agents.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""
- A small collection of useful user-agent header strings. These should be
- kept reasonably current to reflect common usage.
-"""
-
-# pylint: line-too-long
-
-# A collection of (name, shortcut, string) tuples.
-
-UASTRINGS = [
- ("android",
- "a",
- "Mozilla/5.0 (Linux; U; Android 4.1.1; en-gb; Nexus 7 Build/JRO03D) AFL/01.04.02"), # noqa
- ("blackberry",
- "l",
- "Mozilla/5.0 (BlackBerry; U; BlackBerry 9900; en) AppleWebKit/534.11+ (KHTML, like Gecko) Version/7.1.0.346 Mobile Safari/534.11+"), # noqa
- ("bingbot",
- "b",
- "Mozilla/5.0 (compatible; bingbot/2.0; +http://www.bing.com/bingbot.htm)"), # noqa
- ("chrome",
- "c",
- "Mozilla/5.0 (Windows NT 6.1; WOW64) AppleWebKit/537.1 (KHTML, like Gecko) Chrome/22.0.1207.1 Safari/537.1"), # noqa
- ("firefox",
- "f",
- "Mozilla/5.0 (Windows NT 6.1; Win64; x64; rv:14.0) Gecko/20120405 Firefox/14.0a1"), # noqa
- ("googlebot",
- "g",
- "Googlebot/2.1 (+http://www.googlebot.com/bot.html)"), # noqa
- ("ie9",
- "i",
- "Mozilla/5.0 (Windows; U; MSIE 9.0; WIndows NT 9.0; en-US)"), # noqa
- ("ipad",
- "p",
- "Mozilla/5.0 (iPad; CPU OS 5_1 like Mac OS X) AppleWebKit/534.46 (KHTML, like Gecko) Version/5.1 Mobile/9B176 Safari/7534.48.3"), # noqa
- ("iphone",
- "h",
- "Mozilla/5.0 (iPhone; CPU iPhone OS 4_2_1 like Mac OS X) AppleWebKit/533.17.9 (KHTML, like Gecko) Version/5.0.2 Mobile/8C148a Safari/6533.18.5"), # noqa
- ("safari",
- "s",
- "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_7_3) AppleWebKit/534.55.3 (KHTML, like Gecko) Version/5.1.3 Safari/534.53.10"), # noqa
-]
-
-
-def get_by_shortcut(s):
- """
- Retrieve a user agent entry by shortcut.
- """
- for i in UASTRINGS:
- if s == i[1]:
- return i
diff --git a/netlib/human.py b/netlib/human.py
deleted file mode 100644
index 72e96d30..00000000
--- a/netlib/human.py
+++ /dev/null
@@ -1,64 +0,0 @@
-import datetime
-import time
-
-
-SIZE_TABLE = [
- ("b", 1024 ** 0),
- ("k", 1024 ** 1),
- ("m", 1024 ** 2),
- ("g", 1024 ** 3),
- ("t", 1024 ** 4),
-]
-
-SIZE_UNITS = dict(SIZE_TABLE)
-
-
-def pretty_size(size):
- for bottom, top in zip(SIZE_TABLE, SIZE_TABLE[1:]):
- if bottom[1] <= size < top[1]:
- suf = bottom[0]
- lim = bottom[1]
- x = round(size / lim, 2)
- if x == int(x):
- x = int(x)
- return str(x) + suf
- return "%s%s" % (size, SIZE_TABLE[0][0])
-
-
-def parse_size(s):
- try:
- return int(s)
- except ValueError:
- pass
- for i in SIZE_UNITS.keys():
- if s.endswith(i):
- try:
- return int(s[:-1]) * SIZE_UNITS[i]
- except ValueError:
- break
- raise ValueError("Invalid size specification.")
-
-
-def pretty_duration(secs):
- formatters = [
- (100, "{:.0f}s"),
- (10, "{:2.1f}s"),
- (1, "{:1.2f}s"),
- ]
-
- for limit, formatter in formatters:
- if secs >= limit:
- return formatter.format(secs)
- # less than 1 sec
- return "{:.0f}ms".format(secs * 1000)
-
-
-def format_timestamp(s):
- s = time.localtime(s)
- d = datetime.datetime.fromtimestamp(time.mktime(s))
- return d.strftime("%Y-%m-%d %H:%M:%S")
-
-
-def format_timestamp_with_milli(s):
- d = datetime.datetime.fromtimestamp(s)
- return d.strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
diff --git a/netlib/multidict.py b/netlib/multidict.py
deleted file mode 100644
index 191d1cc6..00000000
--- a/netlib/multidict.py
+++ /dev/null
@@ -1,298 +0,0 @@
-from abc import ABCMeta, abstractmethod
-
-
-try:
- from collections.abc import MutableMapping
-except ImportError: # pragma: no cover
- from collections import MutableMapping # Workaround for Python < 3.3
-
-from netlib import basetypes
-
-
-class _MultiDict(MutableMapping, basetypes.Serializable, metaclass=ABCMeta):
- def __repr__(self):
- fields = (
- repr(field)
- for field in self.fields
- )
- return "{cls}[{fields}]".format(
- cls=type(self).__name__,
- fields=", ".join(fields)
- )
-
- @staticmethod
- @abstractmethod
- def _reduce_values(values):
- """
- If a user accesses multidict["foo"], this method
- reduces all values for "foo" to a single value that is returned.
- For example, HTTP headers are folded, whereas we will just take
- the first cookie we found with that name.
- """
-
- @staticmethod
- @abstractmethod
- def _kconv(key):
- """
- This method converts a key to its canonical representation.
- For example, HTTP headers are case-insensitive, so this method returns key.lower().
- """
-
- def __getitem__(self, key):
- values = self.get_all(key)
- if not values:
- raise KeyError(key)
- return self._reduce_values(values)
-
- def __setitem__(self, key, value):
- self.set_all(key, [value])
-
- def __delitem__(self, key):
- if key not in self:
- raise KeyError(key)
- key = self._kconv(key)
- self.fields = tuple(
- field for field in self.fields
- if key != self._kconv(field[0])
- )
-
- def __iter__(self):
- seen = set()
- for key, _ in self.fields:
- key_kconv = self._kconv(key)
- if key_kconv not in seen:
- seen.add(key_kconv)
- yield key
-
- def __len__(self):
- return len(set(self._kconv(key) for key, _ in self.fields))
-
- def __eq__(self, other):
- if isinstance(other, MultiDict):
- return self.fields == other.fields
- return False
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def get_all(self, key):
- """
- Return the list of all values for a given key.
- If that key is not in the MultiDict, the return value will be an empty list.
- """
- key = self._kconv(key)
- return [
- value
- for k, value in self.fields
- if self._kconv(k) == key
- ]
-
- def set_all(self, key, values):
- """
- Remove the old values for a key and add new ones.
- """
- key_kconv = self._kconv(key)
-
- new_fields = []
- for field in self.fields:
- if self._kconv(field[0]) == key_kconv:
- if values:
- new_fields.append(
- (field[0], values.pop(0))
- )
- else:
- new_fields.append(field)
- while values:
- new_fields.append(
- (key, values.pop(0))
- )
- self.fields = tuple(new_fields)
-
- def add(self, key, value):
- """
- Add an additional value for the given key at the bottom.
- """
- self.insert(len(self.fields), key, value)
-
- def insert(self, index, key, value):
- """
- Insert an additional value for the given key at the specified position.
- """
- item = (key, value)
- self.fields = self.fields[:index] + (item,) + self.fields[index:]
-
- def keys(self, multi=False):
- """
- Get all keys.
-
- Args:
- multi(bool):
- If True, one key per value will be returned.
- If False, duplicate keys will only be returned once.
- """
- return (
- k
- for k, _ in self.items(multi)
- )
-
- def values(self, multi=False):
- """
- Get all values.
-
- Args:
- multi(bool):
- If True, all values will be returned.
- If False, only the first value per key will be returned.
- """
- return (
- v
- for _, v in self.items(multi)
- )
-
- def items(self, multi=False):
- """
- Get all (key, value) tuples.
-
- Args:
- multi(bool):
- If True, all (key, value) pairs will be returned
- If False, only the first (key, value) pair per unique key will be returned.
- """
- if multi:
- return self.fields
- else:
- return super().items()
-
- def collect(self):
- """
- Returns a list of (key, value) tuples, where values are either
- singular if there is only one matching item for a key, or a list
- if there are more than one. The order of the keys matches the order
- in the underlying fields list.
- """
- coll = []
- for key in self:
- values = self.get_all(key)
- if len(values) == 1:
- coll.append([key, values[0]])
- else:
- coll.append([key, values])
- return coll
-
- def to_dict(self):
- """
- Get the MultiDict as a plain Python dict.
- Keys with multiple values are returned as lists.
-
- Example:
-
- .. code-block:: python
-
- # Simple dict with duplicate values.
- >>> d = MultiDict([("name", "value"), ("a", False), ("a", 42)])
- >>> d.to_dict()
- {
- "name": "value",
- "a": [False, 42]
- }
- """
- return {
- k: v for k, v in self.collect()
- }
-
- def get_state(self):
- return self.fields
-
- def set_state(self, state):
- self.fields = tuple(tuple(x) for x in state)
-
- @classmethod
- def from_state(cls, state):
- return cls(state)
-
-
-class MultiDict(_MultiDict):
- def __init__(self, fields=()):
- super().__init__()
- self.fields = tuple(
- tuple(i) for i in fields
- )
-
- @staticmethod
- def _reduce_values(values):
- return values[0]
-
- @staticmethod
- def _kconv(key):
- return key
-
-
-class ImmutableMultiDict(MultiDict, metaclass=ABCMeta):
- def _immutable(self, *_):
- raise TypeError('{} objects are immutable'.format(self.__class__.__name__))
-
- __delitem__ = set_all = insert = _immutable
-
- def __hash__(self):
- return hash(self.fields)
-
- def with_delitem(self, key):
- """
- Returns:
- An updated ImmutableMultiDict. The original object will not be modified.
- """
- ret = self.copy()
- # FIXME: This is filthy...
- super(ImmutableMultiDict, ret).__delitem__(key)
- return ret
-
- def with_set_all(self, key, values):
- """
- Returns:
- An updated ImmutableMultiDict. The original object will not be modified.
- """
- ret = self.copy()
- # FIXME: This is filthy...
- super(ImmutableMultiDict, ret).set_all(key, values)
- return ret
-
- def with_insert(self, index, key, value):
- """
- Returns:
- An updated ImmutableMultiDict. The original object will not be modified.
- """
- ret = self.copy()
- # FIXME: This is filthy...
- super(ImmutableMultiDict, ret).insert(index, key, value)
- return ret
-
-
-class MultiDictView(_MultiDict):
- """
- The MultiDictView provides the MultiDict interface over calculated data.
- The view itself contains no state - data is retrieved from the parent on
- request, and stored back to the parent on change.
- """
- def __init__(self, getter, setter):
- self._getter = getter
- self._setter = setter
- super().__init__()
-
- @staticmethod
- def _kconv(key):
- # All request-attributes are case-sensitive.
- return key
-
- @staticmethod
- def _reduce_values(values):
- # We just return the first element if
- # multiple elements exist with the same key.
- return values[0]
-
- @property
- def fields(self):
- return self._getter()
-
- @fields.setter
- def fields(self, value):
- self._setter(value)
diff --git a/netlib/socks.py b/netlib/socks.py
deleted file mode 100644
index 9f1adb98..00000000
--- a/netlib/socks.py
+++ /dev/null
@@ -1,232 +0,0 @@
-import struct
-import array
-import ipaddress
-
-from netlib import tcp, utils
-
-
-class SocksError(Exception):
- def __init__(self, code, message):
- super().__init__(message)
- self.code = code
-
-VERSION = utils.BiDi(
- SOCKS4=0x04,
- SOCKS5=0x05
-)
-
-CMD = utils.BiDi(
- CONNECT=0x01,
- BIND=0x02,
- UDP_ASSOCIATE=0x03
-)
-
-ATYP = utils.BiDi(
- IPV4_ADDRESS=0x01,
- DOMAINNAME=0x03,
- IPV6_ADDRESS=0x04
-)
-
-REP = utils.BiDi(
- SUCCEEDED=0x00,
- GENERAL_SOCKS_SERVER_FAILURE=0x01,
- CONNECTION_NOT_ALLOWED_BY_RULESET=0x02,
- NETWORK_UNREACHABLE=0x03,
- HOST_UNREACHABLE=0x04,
- CONNECTION_REFUSED=0x05,
- TTL_EXPIRED=0x06,
- COMMAND_NOT_SUPPORTED=0x07,
- ADDRESS_TYPE_NOT_SUPPORTED=0x08,
-)
-
-METHOD = utils.BiDi(
- NO_AUTHENTICATION_REQUIRED=0x00,
- GSSAPI=0x01,
- USERNAME_PASSWORD=0x02,
- NO_ACCEPTABLE_METHODS=0xFF
-)
-
-USERNAME_PASSWORD_VERSION = utils.BiDi(
- DEFAULT=0x01
-)
-
-
-class ClientGreeting:
- __slots__ = ("ver", "methods")
-
- def __init__(self, ver, methods):
- self.ver = ver
- self.methods = array.array("B")
- self.methods.extend(methods)
-
- def assert_socks5(self):
- if self.ver != VERSION.SOCKS5:
- if self.ver == ord("G") and len(self.methods) == ord("E"):
- guess = "Probably not a SOCKS request but a regular HTTP request. "
- else:
- guess = ""
-
- raise SocksError(
- REP.GENERAL_SOCKS_SERVER_FAILURE,
- guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver
- )
-
- @classmethod
- def from_file(cls, f, fail_early=False):
- """
- :param fail_early: If true, a SocksError will be raised if the first byte does not indicate socks5.
- """
- ver, nmethods = struct.unpack("!BB", f.safe_read(2))
- client_greeting = cls(ver, [])
- if fail_early:
- client_greeting.assert_socks5()
- client_greeting.methods.fromstring(f.safe_read(nmethods))
- return client_greeting
-
- def to_file(self, f):
- f.write(struct.pack("!BB", self.ver, len(self.methods)))
- f.write(self.methods.tostring())
-
-
-class ServerGreeting:
- __slots__ = ("ver", "method")
-
- def __init__(self, ver, method):
- self.ver = ver
- self.method = method
-
- def assert_socks5(self):
- if self.ver != VERSION.SOCKS5:
- if self.ver == ord("H") and self.method == ord("T"):
- guess = "Probably not a SOCKS request but a regular HTTP response. "
- else:
- guess = ""
-
- raise SocksError(
- REP.GENERAL_SOCKS_SERVER_FAILURE,
- guess + "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver
- )
-
- @classmethod
- def from_file(cls, f):
- ver, method = struct.unpack("!BB", f.safe_read(2))
- return cls(ver, method)
-
- def to_file(self, f):
- f.write(struct.pack("!BB", self.ver, self.method))
-
-
-class UsernamePasswordAuth:
- __slots__ = ("ver", "username", "password")
-
- def __init__(self, ver, username, password):
- self.ver = ver
- self.username = username
- self.password = password
-
- def assert_authver1(self):
- if self.ver != USERNAME_PASSWORD_VERSION.DEFAULT:
- raise SocksError(
- 0,
- "Invalid auth version. Expected 0x01, got 0x%x" % self.ver
- )
-
- @classmethod
- def from_file(cls, f):
- ver, ulen = struct.unpack("!BB", f.safe_read(2))
- username = f.safe_read(ulen)
- plen, = struct.unpack("!B", f.safe_read(1))
- password = f.safe_read(plen)
- return cls(ver, username.decode(), password.decode())
-
- def to_file(self, f):
- f.write(struct.pack("!BB", self.ver, len(self.username)))
- f.write(self.username.encode())
- f.write(struct.pack("!B", len(self.password)))
- f.write(self.password.encode())
-
-
-class UsernamePasswordAuthResponse:
- __slots__ = ("ver", "status")
-
- def __init__(self, ver, status):
- self.ver = ver
- self.status = status
-
- def assert_authver1(self):
- if self.ver != USERNAME_PASSWORD_VERSION.DEFAULT:
- raise SocksError(
- 0,
- "Invalid auth version. Expected 0x01, got 0x%x" % self.ver
- )
-
- @classmethod
- def from_file(cls, f):
- ver, status = struct.unpack("!BB", f.safe_read(2))
- return cls(ver, status)
-
- def to_file(self, f):
- f.write(struct.pack("!BB", self.ver, self.status))
-
-
-class Message:
- __slots__ = ("ver", "msg", "atyp", "addr")
-
- def __init__(self, ver, msg, atyp, addr):
- self.ver = ver
- self.msg = msg
- self.atyp = atyp
- self.addr = tcp.Address.wrap(addr)
-
- def assert_socks5(self):
- if self.ver != VERSION.SOCKS5:
- raise SocksError(
- REP.GENERAL_SOCKS_SERVER_FAILURE,
- "Invalid SOCKS version. Expected 0x05, got 0x%x" % self.ver
- )
-
- @classmethod
- def from_file(cls, f):
- ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4))
- if rsv != 0x00:
- raise SocksError(
- REP.GENERAL_SOCKS_SERVER_FAILURE,
- "Socks Request: Invalid reserved byte: %s" % rsv
- )
- if atyp == ATYP.IPV4_ADDRESS:
- # We use tnoa here as ntop is not commonly available on Windows.
- host = ipaddress.IPv4Address(f.safe_read(4)).compressed
- use_ipv6 = False
- elif atyp == ATYP.IPV6_ADDRESS:
- host = ipaddress.IPv6Address(f.safe_read(16)).compressed
- use_ipv6 = True
- elif atyp == ATYP.DOMAINNAME:
- length, = struct.unpack("!B", f.safe_read(1))
- host = f.safe_read(length)
- if not utils.is_valid_host(host):
- raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, "Invalid hostname: %s" % host)
- host = host.decode("idna")
- use_ipv6 = False
- else:
- raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,
- "Socks Request: Unknown ATYP: %s" % atyp)
-
- port, = struct.unpack("!H", f.safe_read(2))
- addr = tcp.Address((host, port), use_ipv6=use_ipv6)
- return cls(ver, msg, atyp, addr)
-
- def to_file(self, f):
- f.write(struct.pack("!BBBB", self.ver, self.msg, 0x00, self.atyp))
- if self.atyp == ATYP.IPV4_ADDRESS:
- f.write(ipaddress.IPv4Address(self.addr.host).packed)
- elif self.atyp == ATYP.IPV6_ADDRESS:
- f.write(ipaddress.IPv6Address(self.addr.host).packed)
- elif self.atyp == ATYP.DOMAINNAME:
- f.write(struct.pack("!B", len(self.addr.host)))
- f.write(self.addr.host.encode("idna"))
- else:
- raise SocksError(
- REP.ADDRESS_TYPE_NOT_SUPPORTED,
- "Unknown ATYP: %s" % self.atyp
- )
- f.write(struct.pack("!H", self.addr.port))
diff --git a/netlib/strutils.py b/netlib/strutils.py
deleted file mode 100644
index 57cfbc79..00000000
--- a/netlib/strutils.py
+++ /dev/null
@@ -1,142 +0,0 @@
-import re
-import codecs
-
-
-def always_bytes(unicode_or_bytes, *encode_args):
- if isinstance(unicode_or_bytes, str):
- return unicode_or_bytes.encode(*encode_args)
- elif isinstance(unicode_or_bytes, bytes) or unicode_or_bytes is None:
- return unicode_or_bytes
- else:
- raise TypeError("Expected str or bytes, but got {}.".format(type(unicode_or_bytes).__name__))
-
-
-def native(s, *encoding_opts):
- """
- Convert :py:class:`bytes` or :py:class:`unicode` to the native
- :py:class:`str` type, using latin1 encoding if conversion is necessary.
-
- https://www.python.org/dev/peps/pep-3333/#a-note-on-string-types
- """
- if not isinstance(s, (bytes, str)):
- raise TypeError("%r is neither bytes nor unicode" % s)
- if isinstance(s, bytes):
- return s.decode(*encoding_opts)
- return s
-
-
-# Translate control characters to "safe" characters. This implementation initially
-# replaced them with the matching control pictures (http://unicode.org/charts/PDF/U2400.pdf),
-# but that turned out to render badly with monospace fonts. We are back to "." therefore.
-_control_char_trans = {
- x: ord(".") # x + 0x2400 for unicode control group pictures
- for x in range(32)
-}
-_control_char_trans[127] = ord(".") # 0x2421
-_control_char_trans_newline = _control_char_trans.copy()
-for x in ("\r", "\n", "\t"):
- del _control_char_trans_newline[ord(x)]
-
-
-_control_char_trans = str.maketrans(_control_char_trans)
-_control_char_trans_newline = str.maketrans(_control_char_trans_newline)
-
-
-def escape_control_characters(text: str, keep_spacing=True) -> str:
- """
- Replace all unicode C1 control characters from the given text with a single "."
-
- Args:
- keep_spacing: If True, tabs and newlines will not be replaced.
- """
- if not isinstance(text, str):
- raise ValueError("text type must be unicode but is {}".format(type(text).__name__))
-
- trans = _control_char_trans_newline if keep_spacing else _control_char_trans
- return text.translate(trans)
-
-
-def bytes_to_escaped_str(data, keep_spacing=False, escape_single_quotes=False):
- """
- Take bytes and return a safe string that can be displayed to the user.
-
- Single quotes are always escaped, double quotes are never escaped:
- "'" + bytes_to_escaped_str(...) + "'"
- gives a valid Python string.
-
- Args:
- keep_spacing: If True, tabs and newlines will not be escaped.
- """
-
- if not isinstance(data, bytes):
- raise ValueError("data must be bytes, but is {}".format(data.__class__.__name__))
- # We always insert a double-quote here so that we get a single-quoted string back
- # https://stackoverflow.com/questions/29019340/why-does-python-use-different-quotes-for-representing-strings-depending-on-their
- ret = repr(b'"' + data).lstrip("b")[2:-1]
- if not escape_single_quotes:
- ret = re.sub(r"(?<!\\)(\\\\)*\\'", lambda m: (m.group(1) or "") + "'", ret)
- if keep_spacing:
- ret = re.sub(
- r"(?<!\\)(\\\\)*\\([nrt])",
- lambda m: (m.group(1) or "") + dict(n="\n", r="\r", t="\t")[m.group(2)],
- ret
- )
- return ret
-
-
-def escaped_str_to_bytes(data):
- """
- Take an escaped string and return the unescaped bytes equivalent.
-
- Raises:
- ValueError, if the escape sequence is invalid.
- """
- if not isinstance(data, str):
- raise ValueError("data must be str, but is {}".format(data.__class__.__name__))
-
- # This one is difficult - we use an undocumented Python API here
- # as per http://stackoverflow.com/a/23151714/934719
- return codecs.escape_decode(data)[0]
-
-
-def is_mostly_bin(s: bytes) -> bool:
- if not s or len(s) == 0:
- return False
-
- return sum(
- i < 9 or 13 < i < 32 or 126 < i
- for i in s[:100]
- ) / len(s[:100]) > 0.3
-
-
-def is_xml(s: bytes) -> bool:
- return s.strip().startswith(b"<")
-
-
-def clean_hanging_newline(t):
- """
- Many editors will silently add a newline to the final line of a
- document (I'm looking at you, Vim). This function fixes this common
- problem at the risk of removing a hanging newline in the rare cases
- where the user actually intends it.
- """
- if t and t[-1] == "\n":
- return t[:-1]
- return t
-
-
-def hexdump(s):
- """
- Returns:
- A generator of (offset, hex, str) tuples
- """
- for i in range(0, len(s), 16):
- offset = "{:0=10x}".format(i)
- part = s[i:i + 16]
- x = " ".join("{:0=2x}".format(i) for i in part)
- x = x.ljust(47) # 16*2 + 15
- part_repr = native(escape_control_characters(
- part.decode("ascii", "replace").replace(u"\ufffd", u"."),
- False
- ))
- yield (offset, x, part_repr)
diff --git a/netlib/tcp.py b/netlib/tcp.py
deleted file mode 100644
index aeb1d447..00000000
--- a/netlib/tcp.py
+++ /dev/null
@@ -1,989 +0,0 @@
-import os
-import select
-import socket
-import sys
-import threading
-import time
-import traceback
-
-import binascii
-
-from typing import Optional # noqa
-
-from netlib import strutils
-
-import certifi
-from backports import ssl_match_hostname
-import OpenSSL
-from OpenSSL import SSL
-
-from netlib import certutils
-from netlib import version_check
-from netlib import basetypes
-from netlib import exceptions
-from netlib import basethread
-
-# This is a rather hackish way to make sure that
-# the latest version of pyOpenSSL is actually installed.
-version_check.check_pyopenssl_version()
-
-socket_fileobject = socket.SocketIO
-
-EINTR = 4
-if os.environ.get("NO_ALPN"):
- HAS_ALPN = False
-else:
- HAS_ALPN = SSL._lib.Cryptography_HAS_ALPN
-
-# To enable all SSL methods use: SSLv23
-# then add options to disable certain methods
-# https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
-SSL_BASIC_OPTIONS = (
- SSL.OP_CIPHER_SERVER_PREFERENCE
-)
-if hasattr(SSL, "OP_NO_COMPRESSION"):
- SSL_BASIC_OPTIONS |= SSL.OP_NO_COMPRESSION
-
-SSL_DEFAULT_METHOD = SSL.SSLv23_METHOD
-SSL_DEFAULT_OPTIONS = (
- SSL.OP_NO_SSLv2 |
- SSL.OP_NO_SSLv3 |
- SSL_BASIC_OPTIONS
-)
-if hasattr(SSL, "OP_NO_COMPRESSION"):
- SSL_DEFAULT_OPTIONS |= SSL.OP_NO_COMPRESSION
-
-"""
-Map a reasonable SSL version specification into the format OpenSSL expects.
-Don't ask...
-https://bugs.launchpad.net/pyopenssl/+bug/1020632/comments/3
-"""
-sslversion_choices = {
- "all": (SSL.SSLv23_METHOD, SSL_BASIC_OPTIONS),
- # SSLv23_METHOD + NO_SSLv2 + NO_SSLv3 == TLS 1.0+
- # TLSv1_METHOD would be TLS 1.0 only
- "secure": (SSL.SSLv23_METHOD, (SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3 | SSL_BASIC_OPTIONS)),
- "SSLv2": (SSL.SSLv2_METHOD, SSL_BASIC_OPTIONS),
- "SSLv3": (SSL.SSLv3_METHOD, SSL_BASIC_OPTIONS),
- "TLSv1": (SSL.TLSv1_METHOD, SSL_BASIC_OPTIONS),
- "TLSv1_1": (SSL.TLSv1_1_METHOD, SSL_BASIC_OPTIONS),
- "TLSv1_2": (SSL.TLSv1_2_METHOD, SSL_BASIC_OPTIONS),
-}
-
-
-class SSLKeyLogger:
-
- def __init__(self, filename):
- self.filename = filename
- self.f = None
- self.lock = threading.Lock()
-
- # required for functools.wraps, which pyOpenSSL uses.
- __name__ = "SSLKeyLogger"
-
- def __call__(self, connection, where, ret):
- if where == SSL.SSL_CB_HANDSHAKE_DONE and ret == 1:
- with self.lock:
- if not self.f:
- d = os.path.dirname(self.filename)
- if not os.path.isdir(d):
- os.makedirs(d)
- self.f = open(self.filename, "ab")
- self.f.write(b"\r\n")
- client_random = binascii.hexlify(connection.client_random())
- masterkey = binascii.hexlify(connection.master_key())
- self.f.write(b"CLIENT_RANDOM %s %s\r\n" % (client_random, masterkey))
- self.f.flush()
-
- def close(self):
- with self.lock:
- if self.f:
- self.f.close()
-
- @staticmethod
- def create_logfun(filename):
- if filename:
- return SSLKeyLogger(filename)
- return False
-
-log_ssl_key = SSLKeyLogger.create_logfun(
- os.getenv("MITMPROXY_SSLKEYLOGFILE") or os.getenv("SSLKEYLOGFILE"))
-
-
-class _FileLike:
- BLOCKSIZE = 1024 * 32
-
- def __init__(self, o):
- self.o = o
- self._log = None
- self.first_byte_timestamp = None
-
- def set_descriptor(self, o):
- self.o = o
-
- def __getattr__(self, attr):
- return getattr(self.o, attr)
-
- def start_log(self):
- """
- Starts or resets the log.
-
- This will store all bytes read or written.
- """
- self._log = []
-
- def stop_log(self):
- """
- Stops the log.
- """
- self._log = None
-
- def is_logging(self):
- return self._log is not None
-
- def get_log(self):
- """
- Returns the log as a string.
- """
- if not self.is_logging():
- raise ValueError("Not logging!")
- return b"".join(self._log)
-
- def add_log(self, v):
- if self.is_logging():
- self._log.append(v)
-
- def reset_timestamps(self):
- self.first_byte_timestamp = None
-
-
-class Writer(_FileLike):
-
- def flush(self):
- """
- May raise exceptions.TcpDisconnect
- """
- if hasattr(self.o, "flush"):
- try:
- self.o.flush()
- except (socket.error, IOError) as v:
- raise exceptions.TcpDisconnect(str(v))
-
- def write(self, v):
- """
- May raise exceptions.TcpDisconnect
- """
- if v:
- self.first_byte_timestamp = self.first_byte_timestamp or time.time()
- try:
- if hasattr(self.o, "sendall"):
- self.add_log(v)
- return self.o.sendall(v)
- else:
- r = self.o.write(v)
- self.add_log(v[:r])
- return r
- except (SSL.Error, socket.error) as e:
- raise exceptions.TcpDisconnect(str(e))
-
-
-class Reader(_FileLike):
-
- def read(self, length):
- """
- If length is -1, we read until connection closes.
- """
- result = b''
- start = time.time()
- while length == -1 or length > 0:
- if length == -1 or length > self.BLOCKSIZE:
- rlen = self.BLOCKSIZE
- else:
- rlen = length
- try:
- data = self.o.read(rlen)
- except SSL.ZeroReturnError:
- # TLS connection was shut down cleanly
- break
- except (SSL.WantWriteError, SSL.WantReadError):
- # From the OpenSSL docs:
- # If the underlying BIO is non-blocking, SSL_read() will also return when the
- # underlying BIO could not satisfy the needs of SSL_read() to continue the
- # operation. In this case a call to SSL_get_error with the return value of
- # SSL_read() will yield SSL_ERROR_WANT_READ or SSL_ERROR_WANT_WRITE.
- if (time.time() - start) < self.o.gettimeout():
- time.sleep(0.1)
- continue
- else:
- raise exceptions.TcpTimeout()
- except socket.timeout:
- raise exceptions.TcpTimeout()
- except socket.error as e:
- raise exceptions.TcpDisconnect(str(e))
- except SSL.SysCallError as e:
- if e.args == (-1, 'Unexpected EOF'):
- break
- raise exceptions.TlsException(str(e))
- except SSL.Error as e:
- raise exceptions.TlsException(str(e))
- self.first_byte_timestamp = self.first_byte_timestamp or time.time()
- if not data:
- break
- result += data
- if length != -1:
- length -= len(data)
- self.add_log(result)
- return result
-
- def readline(self, size=None):
- result = b''
- bytes_read = 0
- while True:
- if size is not None and bytes_read >= size:
- break
- ch = self.read(1)
- bytes_read += 1
- if not ch:
- break
- else:
- result += ch
- if ch == b'\n':
- break
- return result
-
- def safe_read(self, length):
- """
- Like .read, but is guaranteed to either return length bytes, or
- raise an exception.
- """
- result = self.read(length)
- if length != -1 and len(result) != length:
- if not result:
- raise exceptions.TcpDisconnect()
- else:
- raise exceptions.TcpReadIncomplete(
- "Expected %s bytes, got %s" % (length, len(result))
- )
- return result
-
- def peek(self, length):
- """
- Tries to peek into the underlying file object.
-
- Returns:
- Up to the next N bytes if peeking is successful.
-
- Raises:
- exceptions.TcpException if there was an error with the socket
- TlsException if there was an error with pyOpenSSL.
- NotImplementedError if the underlying file object is not a [pyOpenSSL] socket
- """
- if isinstance(self.o, socket_fileobject):
- try:
- return self.o._sock.recv(length, socket.MSG_PEEK)
- except socket.error as e:
- raise exceptions.TcpException(repr(e))
- elif isinstance(self.o, SSL.Connection):
- try:
- return self.o.recv(length, socket.MSG_PEEK)
- except SSL.Error as e:
- raise exceptions.TlsException(str(e))
- else:
- raise NotImplementedError("Can only peek into (pyOpenSSL) sockets")
-
-
-class Address(basetypes.Serializable):
-
- """
- This class wraps an IPv4/IPv6 tuple to provide named attributes and
- ipv6 information.
- """
-
- def __init__(self, address, use_ipv6=False):
- self.address = tuple(address)
- self.use_ipv6 = use_ipv6
-
- def get_state(self):
- return {
- "address": self.address,
- "use_ipv6": self.use_ipv6
- }
-
- def set_state(self, state):
- self.address = state["address"]
- self.use_ipv6 = state["use_ipv6"]
-
- @classmethod
- def from_state(cls, state):
- return Address(**state)
-
- @classmethod
- def wrap(cls, t):
- if isinstance(t, cls):
- return t
- else:
- return cls(t)
-
- def __call__(self):
- return self.address
-
- @property
- def host(self):
- return self.address[0]
-
- @property
- def port(self):
- return self.address[1]
-
- @property
- def use_ipv6(self):
- return self.family == socket.AF_INET6
-
- @use_ipv6.setter
- def use_ipv6(self, b):
- self.family = socket.AF_INET6 if b else socket.AF_INET
-
- def __repr__(self):
- return "{}:{}".format(self.host, self.port)
-
- def __eq__(self, other):
- if not other:
- return False
- other = Address.wrap(other)
- return (self.address, self.family) == (other.address, other.family)
-
- def __ne__(self, other):
- return not self.__eq__(other)
-
- def __hash__(self):
- return hash(self.address) ^ 42 # different hash than the tuple alone.
-
-
-def ssl_read_select(rlist, timeout):
- """
- This is a wrapper around select.select() which also works for SSL.Connections
- by taking ssl_connection.pending() into account.
-
- Caveats:
- If .pending() > 0 for any of the connections in rlist, we avoid the select syscall
- and **will not include any other connections which may or may not be ready**.
-
- Args:
- rlist: wait until ready for reading
-
- Returns:
- subset of rlist which is ready for reading.
- """
- return [
- conn for conn in rlist
- if isinstance(conn, SSL.Connection) and conn.pending() > 0
- ] or select.select(rlist, (), (), timeout)[0]
-
-
-def close_socket(sock):
- """
- Does a hard close of a socket, without emitting a RST.
- """
- try:
- # We already indicate that we close our end.
- # may raise "Transport endpoint is not connected" on Linux
- sock.shutdown(socket.SHUT_WR)
-
- # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending
- # readable data could lead to an immediate RST being sent (which is the
- # case on Windows).
- # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html
- #
- # This in turn results in the following issue: If we send an error page
- # to the client and then close the socket, the RST may be received by
- # the client before the error page and the users sees a connection
- # error rather than the error page. Thus, we try to empty the read
- # buffer on Windows first. (see
- # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)
- #
-
- if os.name == "nt": # pragma: no cover
- # We cannot rely on the shutdown()-followed-by-read()-eof technique
- # proposed by the page above: Some remote machines just don't send
- # a TCP FIN, which would leave us in the unfortunate situation that
- # recv() would block infinitely. As a workaround, we set a timeout
- # here even if we are in blocking mode.
- sock.settimeout(sock.gettimeout() or 20)
-
- # limit at a megabyte so that we don't read infinitely
- for _ in range(1024 ** 3 // 4096):
- # may raise a timeout/disconnect exception.
- if not sock.recv(4096):
- break
-
- # Now we can close the other half as well.
- sock.shutdown(socket.SHUT_RD)
-
- except socket.error:
- pass
-
- sock.close()
-
-
-class _Connection:
-
- rbufsize = -1
- wbufsize = -1
-
- def _makefile(self):
- """
- Set up .rfile and .wfile attributes from .connection
- """
- # Ideally, we would use the Buffered IO in Python 3 by default.
- # Unfortunately, the implementation of .peek() is broken for n>1 bytes,
- # as it may just return what's left in the buffer and not all the bytes we want.
- # As a workaround, we just use unbuffered sockets directly.
- # https://mail.python.org/pipermail/python-dev/2009-June/089986.html
- self.rfile = Reader(socket.SocketIO(self.connection, "rb"))
- self.wfile = Writer(socket.SocketIO(self.connection, "wb"))
-
- def __init__(self, connection):
- if connection:
- self.connection = connection
- self.ip_address = Address(connection.getpeername())
- self._makefile()
- else:
- self.connection = None
- self.ip_address = None
- self.rfile = None
- self.wfile = None
-
- self.ssl_established = False
- self.finished = False
-
- def get_current_cipher(self):
- if not self.ssl_established:
- return None
-
- name = self.connection.get_cipher_name()
- bits = self.connection.get_cipher_bits()
- version = self.connection.get_cipher_version()
- return name, bits, version
-
- def finish(self):
- self.finished = True
- # If we have an SSL connection, wfile.close == connection.close
- # (We call _FileLike.set_descriptor(conn))
- # Closing the socket is not our task, therefore we don't call close
- # then.
- if not isinstance(self.connection, SSL.Connection):
- if not getattr(self.wfile, "closed", False):
- try:
- self.wfile.flush()
- self.wfile.close()
- except exceptions.TcpDisconnect:
- pass
-
- self.rfile.close()
- else:
- try:
- self.connection.shutdown()
- except SSL.Error:
- pass
-
- def _create_ssl_context(self,
- method=SSL_DEFAULT_METHOD,
- options=SSL_DEFAULT_OPTIONS,
- verify_options=SSL.VERIFY_NONE,
- ca_path=None,
- ca_pemfile=None,
- cipher_list=None,
- alpn_protos=None,
- alpn_select=None,
- alpn_select_callback=None,
- sni=None,
- ):
- """
- Creates an SSL Context.
-
- :param method: One of SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD, TLSv1_1_METHOD, or TLSv1_2_METHOD
- :param options: A bit field consisting of OpenSSL.SSL.OP_* values
- :param verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values
- :param ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool
- :param ca_pemfile: Path to a PEM formatted trusted CA certificate
- :param cipher_list: A textual OpenSSL cipher list, see https://www.openssl.org/docs/apps/ciphers.html
- :rtype : SSL.Context
- """
- context = SSL.Context(method)
- # Options (NO_SSLv2/3)
- if options is not None:
- context.set_options(options)
-
- # Verify Options (NONE/PEER and trusted CAs)
- if verify_options is not None:
- def verify_cert(conn, x509, errno, err_depth, is_cert_verified):
- if not is_cert_verified:
- self.ssl_verification_error = exceptions.InvalidCertificateException(
- "Certificate Verification Error for {}: {} (errno: {}, depth: {})".format(
- sni,
- strutils.native(SSL._ffi.string(SSL._lib.X509_verify_cert_error_string(errno)), "utf8"),
- errno,
- err_depth
- )
- )
- return is_cert_verified
-
- context.set_verify(verify_options, verify_cert)
- if ca_path is None and ca_pemfile is None:
- ca_pemfile = certifi.where()
- context.load_verify_locations(ca_pemfile, ca_path)
-
- # Workaround for
- # https://github.com/pyca/pyopenssl/issues/190
- # https://github.com/mitmproxy/mitmproxy/issues/472
- # Options already set before are not cleared.
- context.set_mode(SSL._lib.SSL_MODE_AUTO_RETRY)
-
- # Cipher List
- if cipher_list:
- try:
- context.set_cipher_list(cipher_list)
-
- # TODO: maybe change this to with newer pyOpenSSL APIs
- context.set_tmp_ecdh(OpenSSL.crypto.get_elliptic_curve('prime256v1'))
- except SSL.Error as v:
- raise exceptions.TlsException("SSL cipher specification error: %s" % str(v))
-
- # SSLKEYLOGFILE
- if log_ssl_key:
- context.set_info_callback(log_ssl_key)
-
- if HAS_ALPN:
- if alpn_protos is not None:
- # advertise application layer protocols
- context.set_alpn_protos(alpn_protos)
- elif alpn_select is not None and alpn_select_callback is None:
- # select application layer protocol
- def alpn_select_callback(conn_, options):
- if alpn_select in options:
- return bytes(alpn_select)
- else: # pragma no cover
- return options[0]
- context.set_alpn_select_callback(alpn_select_callback)
- elif alpn_select_callback is not None and alpn_select is None:
- context.set_alpn_select_callback(alpn_select_callback)
- elif alpn_select_callback is not None and alpn_select is not None:
- raise exceptions.TlsException("ALPN error: only define alpn_select (string) OR alpn_select_callback (method).")
-
- return context
-
-
-class ConnectionCloser:
- def __init__(self, conn):
- self.conn = conn
- self._canceled = False
-
- def pop(self):
- """
- Cancel the current closer, and return a fresh one.
- """
- self._canceled = True
- return ConnectionCloser(self.conn)
-
- def __enter__(self):
- return self
-
- def __exit__(self, *args):
- if not self._canceled:
- self.conn.close()
-
-
-class TCPClient(_Connection):
-
- def __init__(self, address, source_address=None, spoof_source_address=None):
- super().__init__(None)
- self.address = address
- self.source_address = source_address
- self.cert = None
- self.server_certs = []
- self.ssl_verification_error = None # type: Optional[exceptions.InvalidCertificateException]
- self.sni = None
- self.spoof_source_address = spoof_source_address
-
- @property
- def address(self):
- return self.__address
-
- @address.setter
- def address(self, address):
- if address:
- self.__address = Address.wrap(address)
- else:
- self.__address = None
-
- @property
- def source_address(self):
- return self.__source_address
-
- @source_address.setter
- def source_address(self, source_address):
- if source_address:
- self.__source_address = Address.wrap(source_address)
- else:
- self.__source_address = None
-
- def close(self):
- # Make sure to close the real socket, not the SSL proxy.
- # OpenSSL is really good at screwing up, i.e. when trying to recv from a failed connection,
- # it tries to renegotiate...
- if isinstance(self.connection, SSL.Connection):
- close_socket(self.connection._socket)
- else:
- close_socket(self.connection)
-
- def create_ssl_context(self, cert=None, alpn_protos=None, **sslctx_kwargs):
- context = self._create_ssl_context(
- alpn_protos=alpn_protos,
- **sslctx_kwargs)
- # Client Certs
- if cert:
- try:
- context.use_privatekey_file(cert)
- context.use_certificate_file(cert)
- except SSL.Error as v:
- raise exceptions.TlsException("SSL client certificate error: %s" % str(v))
- return context
-
- def convert_to_ssl(self, sni=None, alpn_protos=None, **sslctx_kwargs):
- """
- cert: Path to a file containing both client cert and private key.
-
- options: A bit field consisting of OpenSSL.SSL.OP_* values
- verify_options: A bit field consisting of OpenSSL.SSL.VERIFY_* values
- ca_path: Path to a directory of trusted CA certificates prepared using the c_rehash tool
- ca_pemfile: Path to a PEM formatted trusted CA certificate
- """
- verification_mode = sslctx_kwargs.get('verify_options', None)
- if verification_mode == SSL.VERIFY_PEER and not sni:
- raise exceptions.TlsException("Cannot validate certificate hostname without SNI")
-
- context = self.create_ssl_context(
- alpn_protos=alpn_protos,
- sni=sni,
- **sslctx_kwargs
- )
- self.connection = SSL.Connection(context, self.connection)
- if sni:
- self.sni = sni
- self.connection.set_tlsext_host_name(sni.encode("idna"))
- self.connection.set_connect_state()
- try:
- self.connection.do_handshake()
- except SSL.Error as v:
- if self.ssl_verification_error:
- raise self.ssl_verification_error
- else:
- raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
- else:
- # Fix for pre v1.0 OpenSSL, which doesn't throw an exception on
- # certificate validation failure
- if verification_mode == SSL.VERIFY_PEER and self.ssl_verification_error:
- raise self.ssl_verification_error
-
- self.cert = certutils.SSLCert(self.connection.get_peer_certificate())
-
- # Keep all server certificates in a list
- for i in self.connection.get_peer_cert_chain():
- self.server_certs.append(certutils.SSLCert(i))
-
- # Validate TLS Hostname
- try:
- crt = dict(
- subjectAltName=[("DNS", x.decode("ascii", "strict")) for x in self.cert.altnames]
- )
- if self.cert.cn:
- crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]]
- if sni:
- hostname = sni
- else:
- hostname = "no-hostname"
- ssl_match_hostname.match_hostname(crt, hostname)
- except (ValueError, ssl_match_hostname.CertificateError) as e:
- self.ssl_verification_error = exceptions.InvalidCertificateException(
- "Certificate Verification Error for {}: {}".format(
- sni or repr(self.address),
- str(e)
- )
- )
- if verification_mode == SSL.VERIFY_PEER:
- raise self.ssl_verification_error
-
- self.ssl_established = True
- self.rfile.set_descriptor(self.connection)
- self.wfile.set_descriptor(self.connection)
-
- def makesocket(self):
- # some parties (cuckoo sandbox) need to hook this
- return socket.socket(self.address.family, socket.SOCK_STREAM)
-
- def connect(self):
- try:
- connection = self.makesocket()
-
- if self.spoof_source_address:
- try:
- # 19 is `IP_TRANSPARENT`, which is only available on Python 3.3+ on some OSes
- if not connection.getsockopt(socket.SOL_IP, 19):
- connection.setsockopt(socket.SOL_IP, 19, 1)
- except socket.error as e:
- raise exceptions.TcpException(
- "Failed to spoof the source address: " + e.strerror
- )
- if self.source_address:
- connection.bind(self.source_address())
- connection.connect(self.address())
- self.source_address = Address(connection.getsockname())
- except (socket.error, IOError) as err:
- raise exceptions.TcpException(
- 'Error connecting to "%s": %s' %
- (self.address.host, err)
- )
- self.connection = connection
- self.ip_address = Address(connection.getpeername())
- self._makefile()
- return ConnectionCloser(self)
-
- def settimeout(self, n):
- self.connection.settimeout(n)
-
- def gettimeout(self):
- return self.connection.gettimeout()
-
- def get_alpn_proto_negotiated(self):
- if HAS_ALPN and self.ssl_established:
- return self.connection.get_alpn_proto_negotiated()
- else:
- return b""
-
-
-class BaseHandler(_Connection):
-
- """
- The instantiator is expected to call the handle() and finish() methods.
- """
-
- def __init__(self, connection, address, server):
- super().__init__(connection)
- self.address = Address.wrap(address)
- self.server = server
- self.clientcert = None
-
- def create_ssl_context(self,
- cert, key,
- handle_sni=None,
- request_client_cert=None,
- chain_file=None,
- dhparams=None,
- extra_chain_certs=None,
- **sslctx_kwargs):
- """
- cert: A certutils.SSLCert object or the path to a certificate
- chain file.
-
- handle_sni: SNI handler, should take a connection object. Server
- name can be retrieved like this:
-
- connection.get_servername()
-
- And you can specify the connection keys as follows:
-
- new_context = Context(TLSv1_METHOD)
- new_context.use_privatekey(key)
- new_context.use_certificate(cert)
- connection.set_context(new_context)
-
- The request_client_cert argument requires some explanation. We're
- supposed to be able to do this with no negative effects - if the
- client has no cert to present, we're notified and proceed as usual.
- Unfortunately, Android seems to have a bug (tested on 4.2.2) - when
- an Android client is asked to present a certificate it does not
- have, it hangs up, which is frankly bogus. Some time down the track
- we may be able to make the proper behaviour the default again, but
- until then we're conservative.
- """
-
- context = self._create_ssl_context(ca_pemfile=chain_file, **sslctx_kwargs)
-
- context.use_privatekey(key)
- if isinstance(cert, certutils.SSLCert):
- context.use_certificate(cert.x509)
- else:
- context.use_certificate_chain_file(cert)
-
- if extra_chain_certs:
- for i in extra_chain_certs:
- context.add_extra_chain_cert(i.x509)
-
- if handle_sni:
- # SNI callback happens during do_handshake()
- context.set_tlsext_servername_callback(handle_sni)
-
- if request_client_cert:
- def save_cert(conn_, cert, errno_, depth_, preverify_ok_):
- self.clientcert = certutils.SSLCert(cert)
- # Return true to prevent cert verification error
- return True
- context.set_verify(SSL.VERIFY_PEER, save_cert)
-
- if dhparams:
- SSL._lib.SSL_CTX_set_tmp_dh(context._context, dhparams)
-
- return context
-
- def convert_to_ssl(self, cert, key, **sslctx_kwargs):
- """
- Convert connection to SSL.
- For a list of parameters, see BaseHandler._create_ssl_context(...)
- """
-
- context = self.create_ssl_context(
- cert,
- key,
- **sslctx_kwargs)
- self.connection = SSL.Connection(context, self.connection)
- self.connection.set_accept_state()
- try:
- self.connection.do_handshake()
- except SSL.Error as v:
- raise exceptions.TlsException("SSL handshake error: %s" % repr(v))
- self.ssl_established = True
- self.rfile.set_descriptor(self.connection)
- self.wfile.set_descriptor(self.connection)
-
- def handle(self): # pragma: no cover
- raise NotImplementedError
-
- def settimeout(self, n):
- self.connection.settimeout(n)
-
- def get_alpn_proto_negotiated(self):
- if HAS_ALPN and self.ssl_established:
- return self.connection.get_alpn_proto_negotiated()
- else:
- return b""
-
-
-class Counter:
- def __init__(self):
- self._count = 0
- self._lock = threading.Lock()
-
- @property
- def count(self):
- with self._lock:
- return self._count
-
- def __enter__(self):
- with self._lock:
- self._count += 1
-
- def __exit__(self, *args):
- with self._lock:
- self._count -= 1
-
-
-class TCPServer:
- request_queue_size = 20
-
- def __init__(self, address):
- self.address = Address.wrap(address)
- self.__is_shut_down = threading.Event()
- self.__shutdown_request = False
- self.socket = socket.socket(self.address.family, socket.SOCK_STREAM)
- self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.socket.bind(self.address())
- self.address = Address.wrap(self.socket.getsockname())
- self.socket.listen(self.request_queue_size)
- self.handler_counter = Counter()
-
- def connection_thread(self, connection, client_address):
- with self.handler_counter:
- client_address = Address(client_address)
- try:
- self.handle_client_connection(connection, client_address)
- except:
- self.handle_error(connection, client_address)
- finally:
- close_socket(connection)
-
- def serve_forever(self, poll_interval=0.1):
- self.__is_shut_down.clear()
- try:
- while not self.__shutdown_request:
- try:
- r, w_, e_ = select.select(
- [self.socket], [], [], poll_interval)
- except select.error as ex: # pragma: no cover
- if ex[0] == EINTR:
- continue
- else:
- raise
- if self.socket in r:
- connection, client_address = self.socket.accept()
- t = basethread.BaseThread(
- "TCPConnectionHandler (%s: %s:%s -> %s:%s)" % (
- self.__class__.__name__,
- client_address[0],
- client_address[1],
- self.address.host,
- self.address.port
- ),
- target=self.connection_thread,
- args=(connection, client_address),
- )
- t.setDaemon(1)
- try:
- t.start()
- except threading.ThreadError:
- self.handle_error(connection, Address(client_address))
- connection.close()
- finally:
- self.__shutdown_request = False
- self.__is_shut_down.set()
-
- def shutdown(self):
- self.__shutdown_request = True
- self.__is_shut_down.wait()
- self.socket.close()
- self.handle_shutdown()
-
- def handle_error(self, connection_, client_address, fp=sys.stderr):
- """
- Called when handle_client_connection raises an exception.
- """
- # If a thread has persisted after interpreter exit, the module might be
- # none.
- if traceback:
- exc = str(traceback.format_exc())
- print(u'-' * 40, file=fp)
- print(
- u"Error in processing of request from %s" % repr(client_address), file=fp)
- print(exc, file=fp)
- print(u'-' * 40, file=fp)
-
- def handle_client_connection(self, conn, client_address): # pragma: no cover
- """
- Called after client connection.
- """
- raise NotImplementedError
-
- def handle_shutdown(self):
- """
- Called after server shutdown.
- """
-
- def wait_for_silence(self, timeout=5):
- start = time.time()
- while 1:
- if time.time() - start >= timeout:
- raise exceptions.Timeout(
- "%s service threads still alive" %
- self.handler_counter.count
- )
- if self.handler_counter.count == 0:
- return
diff --git a/netlib/tutils.py b/netlib/tutils.py
deleted file mode 100644
index d22fdd1c..00000000
--- a/netlib/tutils.py
+++ /dev/null
@@ -1,130 +0,0 @@
-from io import BytesIO
-import tempfile
-import os
-import time
-import shutil
-from contextlib import contextmanager
-import sys
-
-from netlib import utils, tcp, http
-
-
-def treader(bytes):
- """
- Construct a tcp.Read object from bytes.
- """
- fp = BytesIO(bytes)
- return tcp.Reader(fp)
-
-
-@contextmanager
-def tmpdir(*args, **kwargs):
- orig_workdir = os.getcwd()
- temp_workdir = tempfile.mkdtemp(*args, **kwargs)
- os.chdir(temp_workdir)
-
- yield temp_workdir
-
- os.chdir(orig_workdir)
- shutil.rmtree(temp_workdir)
-
-
-def _check_exception(expected, actual, exc_tb):
- if isinstance(expected, str):
- if expected.lower() not in str(actual).lower():
- raise AssertionError(
- "Expected %s, but caught %s" % (
- repr(expected), repr(actual)
- )
- )
- else:
- if not isinstance(actual, expected):
- raise AssertionError(
- "Expected %s, but caught %s %s" % (
- expected.__name__, actual.__class__.__name__, repr(actual)
- )
- )
-
-
-def raises(expected_exception, obj=None, *args, **kwargs):
- """
- Assert that a callable raises a specified exception.
-
- :exc An exception class or a string. If a class, assert that an
- exception of this type is raised. If a string, assert that the string
- occurs in the string representation of the exception, based on a
- case-insenstivie match.
-
- :obj A callable object.
-
- :args Arguments to be passsed to the callable.
-
- :kwargs Arguments to be passed to the callable.
- """
- if obj is None:
- return RaisesContext(expected_exception)
- else:
- try:
- ret = obj(*args, **kwargs)
- except Exception as actual:
- _check_exception(expected_exception, actual, sys.exc_info()[2])
- else:
- raise AssertionError("No exception raised. Return value: {}".format(ret))
-
-
-class RaisesContext:
- def __init__(self, expected_exception):
- self.expected_exception = expected_exception
-
- def __enter__(self):
- return
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- if not exc_type:
- raise AssertionError("No exception raised.")
- else:
- _check_exception(self.expected_exception, exc_val, exc_tb)
- return True
-
-
-test_data = utils.Data(__name__)
-# FIXME: Temporary workaround during repo merge.
-test_data.dirname = os.path.join(test_data.dirname, "..", "test", "netlib")
-
-
-def treq(**kwargs):
- """
- Returns:
- netlib.http.Request
- """
- default = dict(
- first_line_format="relative",
- method=b"GET",
- scheme=b"http",
- host=b"address",
- port=22,
- path=b"/path",
- http_version=b"HTTP/1.1",
- headers=http.Headers(((b"header", b"qvalue"), (b"content-length", b"7"))),
- content=b"content"
- )
- default.update(kwargs)
- return http.Request(**default)
-
-
-def tresp(**kwargs):
- """
- Returns:
- netlib.http.Response
- """
- default = dict(
- http_version=b"HTTP/1.1",
- status_code=200,
- reason=b"OK",
- headers=http.Headers(((b"header-response", b"svalue"), (b"content-length", b"7"))),
- content=b"message",
- timestamp_start=time.time(),
- timestamp_end=time.time(),
- )
- default.update(kwargs)
- return http.Response(**default)
diff --git a/netlib/utils.py b/netlib/utils.py
deleted file mode 100644
index 8cd9ba6e..00000000
--- a/netlib/utils.py
+++ /dev/null
@@ -1,97 +0,0 @@
-import os.path
-import re
-import importlib
-import inspect
-
-
-def setbit(byte, offset, value):
- """
- Set a bit in a byte to 1 if value is truthy, 0 if not.
- """
- if value:
- return byte | (1 << offset)
- else:
- return byte & ~(1 << offset)
-
-
-def getbit(byte, offset):
- mask = 1 << offset
- return bool(byte & mask)
-
-
-class BiDi:
-
- """
- A wee utility class for keeping bi-directional mappings, like field
- constants in protocols. Names are attributes on the object, dict-like
- access maps values to names:
-
- CONST = BiDi(a=1, b=2)
- assert CONST.a == 1
- assert CONST.get_name(1) == "a"
- """
-
- def __init__(self, **kwargs):
- self.names = kwargs
- self.values = {}
- for k, v in kwargs.items():
- self.values[v] = k
- if len(self.names) != len(self.values):
- raise ValueError("Duplicate values not allowed.")
-
- def __getattr__(self, k):
- if k in self.names:
- return self.names[k]
- raise AttributeError("No such attribute: %s", k)
-
- def get_name(self, n, default=None):
- return self.values.get(n, default)
-
-
-class Data:
-
- def __init__(self, name):
- m = importlib.import_module(name)
- dirname = os.path.dirname(inspect.getsourcefile(m))
- self.dirname = os.path.abspath(dirname)
-
- def push(self, subpath):
- """
- Change the data object to a path relative to the module.
- """
- self.dirname = os.path.join(self.dirname, subpath)
- return self
-
- def path(self, path):
- """
- Returns a path to the package data housed at 'path' under this
- module.Path can be a path to a file, or to a directory.
-
- This function will raise ValueError if the path does not exist.
- """
- fullpath = os.path.join(self.dirname, path)
- if not os.path.exists(fullpath):
- raise ValueError("dataPath: %s does not exist." % fullpath)
- return fullpath
-
-
-_label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
-
-
-def is_valid_host(host: bytes) -> bool:
- """
- Checks if a hostname is valid.
- """
- try:
- host.decode("idna")
- except ValueError:
- return False
- if len(host) > 255:
- return False
- if host and host[-1:] == b".":
- host = host[:-1]
- return all(_label_valid.match(x) for x in host.split(b"."))
-
-
-def is_valid_port(port):
- return 0 <= port <= 65535
diff --git a/netlib/version.py b/netlib/version.py
deleted file mode 100644
index cb670642..00000000
--- a/netlib/version.py
+++ /dev/null
@@ -1,4 +0,0 @@
-IVERSION = (0, 19)
-VERSION = ".".join(str(i) for i in IVERSION)
-PATHOD = "pathod " + VERSION
-MITMPROXY = "mitmproxy " + VERSION
diff --git a/netlib/version_check.py b/netlib/version_check.py
deleted file mode 100644
index 547c031c..00000000
--- a/netlib/version_check.py
+++ /dev/null
@@ -1,43 +0,0 @@
-"""
-Having installed a wrong version of pyOpenSSL or netlib is unfortunately a
-very common source of error. Check before every start that both versions
-are somewhat okay.
-"""
-import sys
-import inspect
-import os.path
-
-import OpenSSL
-
-PYOPENSSL_MIN_VERSION = (0, 15)
-
-
-def check_pyopenssl_version(min_version=PYOPENSSL_MIN_VERSION, fp=sys.stderr):
- min_version_str = u".".join(str(x) for x in min_version)
- try:
- v = tuple(int(x) for x in OpenSSL.__version__.split(".")[:2])
- except ValueError:
- print(
- u"Cannot parse pyOpenSSL version: {}"
- u"mitmproxy requires pyOpenSSL {} or greater.".format(
- OpenSSL.__version__, min_version_str
- ),
- file=fp
- )
- return
- if v < min_version:
- print(
- u"You are using an outdated version of pyOpenSSL: "
- u"mitmproxy requires pyOpenSSL {} or greater.".format(min_version_str),
- file=fp
- )
- # Some users apparently have multiple versions of pyOpenSSL installed.
- # Report which one we got.
- pyopenssl_path = os.path.dirname(inspect.getfile(OpenSSL))
- print(
- u"Your pyOpenSSL {} installation is located at {}".format(
- OpenSSL.__version__, pyopenssl_path
- ),
- file=fp
- )
- sys.exit(1)
diff --git a/netlib/websockets/__init__.py b/netlib/websockets/__init__.py
deleted file mode 100644
index 2d6f0a0c..00000000
--- a/netlib/websockets/__init__.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from .frame import FrameHeader
-from .frame import Frame
-from .frame import OPCODE
-from .frame import CLOSE_REASON
-from .masker import Masker
-from .utils import MAGIC
-from .utils import VERSION
-from .utils import client_handshake_headers
-from .utils import server_handshake_headers
-from .utils import check_handshake
-from .utils import check_client_version
-from .utils import create_server_nonce
-from .utils import get_extensions
-from .utils import get_protocol
-from .utils import get_client_key
-from .utils import get_server_accept
-
-__all__ = [
- "FrameHeader",
- "Frame",
- "OPCODE",
- "CLOSE_REASON",
- "Masker",
- "MAGIC",
- "VERSION",
- "client_handshake_headers",
- "server_handshake_headers",
- "check_handshake",
- "check_client_version",
- "create_server_nonce",
- "get_extensions",
- "get_protocol",
- "get_client_key",
- "get_server_accept",
-]
diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py
deleted file mode 100644
index b58fa289..00000000
--- a/netlib/websockets/frame.py
+++ /dev/null
@@ -1,273 +0,0 @@
-import os
-import struct
-import io
-
-from netlib import tcp
-from netlib import strutils
-from netlib import utils
-from netlib import human
-from .masker import Masker
-
-
-MAX_16_BIT_INT = (1 << 16)
-MAX_64_BIT_INT = (1 << 64)
-
-DEFAULT = object()
-
-# RFC 6455, Section 5.2 - Base Framing Protocol
-OPCODE = utils.BiDi(
- CONTINUE=0x00,
- TEXT=0x01,
- BINARY=0x02,
- CLOSE=0x08,
- PING=0x09,
- PONG=0x0a
-)
-
-# RFC 6455, Section 7.4.1 - Defined Status Codes
-CLOSE_REASON = utils.BiDi(
- NORMAL_CLOSURE=1000,
- GOING_AWAY=1001,
- PROTOCOL_ERROR=1002,
- UNSUPPORTED_DATA=1003,
- RESERVED=1004,
- RESERVED_NO_STATUS=1005,
- RESERVED_ABNORMAL_CLOSURE=1006,
- INVALID_PAYLOAD_DATA=1007,
- POLICY_VIOLATION=1008,
- MESSAGE_TOO_BIG=1009,
- MANDATORY_EXTENSION=1010,
- INTERNAL_ERROR=1011,
- RESERVED_TLS_HANDHSAKE_FAILED=1015,
-)
-
-
-class FrameHeader:
-
- def __init__(
- self,
- opcode=OPCODE.TEXT,
- payload_length=0,
- fin=False,
- rsv1=False,
- rsv2=False,
- rsv3=False,
- masking_key=DEFAULT,
- mask=DEFAULT,
- length_code=DEFAULT
- ):
- if not 0 <= opcode < 2 ** 4:
- raise ValueError("opcode must be 0-16")
- self.opcode = opcode
- self.payload_length = payload_length
- self.fin = fin
- self.rsv1 = rsv1
- self.rsv2 = rsv2
- self.rsv3 = rsv3
-
- if length_code is DEFAULT:
- self.length_code = self._make_length_code(self.payload_length)
- else:
- self.length_code = length_code
-
- if mask is DEFAULT and masking_key is DEFAULT:
- self.mask = False
- self.masking_key = b""
- elif mask is DEFAULT:
- self.mask = 1
- self.masking_key = masking_key
- elif masking_key is DEFAULT:
- self.mask = mask
- self.masking_key = os.urandom(4)
- else:
- self.mask = mask
- self.masking_key = masking_key
-
- if self.masking_key and len(self.masking_key) != 4:
- raise ValueError("Masking key must be 4 bytes.")
-
- @classmethod
- def _make_length_code(self, length):
- """
- A websockets frame contains an initial length_code, and an optional
- extended length code to represent the actual length if length code is
- larger than 125
- """
- if length <= 125:
- return length
- elif length >= 126 and length <= 65535:
- return 126
- else:
- return 127
-
- def __repr__(self):
- vals = [
- "ws frame:",
- OPCODE.get_name(self.opcode, hex(self.opcode)).lower()
- ]
- flags = []
- for i in ["fin", "rsv1", "rsv2", "rsv3", "mask"]:
- if getattr(self, i):
- flags.append(i)
- if flags:
- vals.extend([":", "|".join(flags)])
- if self.masking_key:
- vals.append(":key=%s" % repr(self.masking_key))
- if self.payload_length:
- vals.append(" %s" % human.pretty_size(self.payload_length))
- return "".join(vals)
-
- def __bytes__(self):
- first_byte = utils.setbit(0, 7, self.fin)
- first_byte = utils.setbit(first_byte, 6, self.rsv1)
- first_byte = utils.setbit(first_byte, 5, self.rsv2)
- first_byte = utils.setbit(first_byte, 4, self.rsv3)
- first_byte = first_byte | self.opcode
-
- second_byte = utils.setbit(self.length_code, 7, self.mask)
-
- b = bytes([first_byte, second_byte])
-
- if self.payload_length < 126:
- pass
- elif self.payload_length < MAX_16_BIT_INT:
- # '!H' pack as 16 bit unsigned short
- # add 2 byte extended payload length
- b += struct.pack('!H', self.payload_length)
- elif self.payload_length < MAX_64_BIT_INT:
- # '!Q' = pack as 64 bit unsigned long long
- # add 8 bytes extended payload length
- b += struct.pack('!Q', self.payload_length)
- else:
- raise ValueError("Payload length exceeds 64bit integer")
-
- if self.masking_key:
- b += self.masking_key
- return b
-
- @classmethod
- def from_file(cls, fp):
- """
- read a websockets frame header
- """
- first_byte, second_byte = fp.safe_read(2)
- fin = utils.getbit(first_byte, 7)
- rsv1 = utils.getbit(first_byte, 6)
- rsv2 = utils.getbit(first_byte, 5)
- rsv3 = utils.getbit(first_byte, 4)
- opcode = first_byte & 0xF
- mask_bit = utils.getbit(second_byte, 7)
- length_code = second_byte & 0x7F
-
- # payload_length > 125 indicates you need to read more bytes
- # to get the actual payload length
- if length_code <= 125:
- payload_length = length_code
- elif length_code == 126:
- payload_length, = struct.unpack("!H", fp.safe_read(2))
- else: # length_code == 127:
- payload_length, = struct.unpack("!Q", fp.safe_read(8))
-
- # masking key only present if mask bit set
- if mask_bit == 1:
- masking_key = fp.safe_read(4)
- else:
- masking_key = None
-
- return cls(
- fin=fin,
- rsv1=rsv1,
- rsv2=rsv2,
- rsv3=rsv3,
- opcode=opcode,
- mask=mask_bit,
- length_code=length_code,
- payload_length=payload_length,
- masking_key=masking_key,
- )
-
- def __eq__(self, other):
- if isinstance(other, FrameHeader):
- return bytes(self) == bytes(other)
- return False
-
-
-class Frame:
- """
- Represents a single WebSockets frame.
- Constructor takes human readable forms of the frame components.
- from_bytes() reads from a file-like object to create a new Frame.
-
- WebSockets Frame as defined in RFC6455
-
- 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
- +-+-+-+-+-------+-+-------------+-------------------------------+
- |F|R|R|R| opcode|M| Payload len | Extended payload length |
- |I|S|S|S| (4) |A| (7) | (16/64) |
- |N|V|V|V| |S| | (if payload len==126/127) |
- | |1|2|3| |K| | |
- +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - +
- | Extended payload length continued, if payload len == 127 |
- + - - - - - - - - - - - - - - - +-------------------------------+
- | |Masking-key, if MASK set to 1 |
- +-------------------------------+-------------------------------+
- | Masking-key (continued) | Payload Data |
- +-------------------------------- - - - - - - - - - - - - - - - +
- : Payload Data continued ... :
- + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - +
- | Payload Data continued ... |
- +---------------------------------------------------------------+
- """
-
- def __init__(self, payload=b"", **kwargs):
- self.payload = payload
- kwargs["payload_length"] = kwargs.get("payload_length", len(payload))
- self.header = FrameHeader(**kwargs)
-
- @classmethod
- def from_bytes(cls, bytestring):
- """
- Construct a websocket frame from an in-memory bytestring
- to construct a frame from a stream of bytes, use from_file() directly
- """
- return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))
-
- def __repr__(self):
- ret = repr(self.header)
- if self.payload:
- ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
- return ret
-
- def __bytes__(self):
- """
- Serialize the frame to wire format. Returns a string.
- """
- b = bytes(self.header)
- if self.header.masking_key:
- b += Masker(self.header.masking_key)(self.payload)
- else:
- b += self.payload
- return b
-
- @classmethod
- def from_file(cls, fp):
- """
- read a websockets frame sent by a server or client
-
- fp is a "file like" object that could be backed by a network
- stream or a disk or an in memory stream reader
- """
- header = FrameHeader.from_file(fp)
- payload = fp.safe_read(header.payload_length)
-
- if header.mask == 1 and header.masking_key:
- payload = Masker(header.masking_key)(payload)
-
- frame = cls(payload)
- frame.header = header
- return frame
-
- def __eq__(self, other):
- if isinstance(other, Frame):
- return bytes(self) == bytes(other)
- return False
diff --git a/netlib/websockets/masker.py b/netlib/websockets/masker.py
deleted file mode 100644
index 47b1a688..00000000
--- a/netlib/websockets/masker.py
+++ /dev/null
@@ -1,25 +0,0 @@
-class Masker:
- """
- Data sent from the server must be masked to prevent malicious clients
- from sending data over the wire in predictable patterns.
-
- Servers do not have to mask data they send to the client.
- https://tools.ietf.org/html/rfc6455#section-5.3
- """
-
- def __init__(self, key):
- self.key = key
- self.offset = 0
-
- def mask(self, offset, data):
- result = bytearray(data)
- for i in range(len(data)):
- result[i] ^= self.key[offset % 4]
- offset += 1
- result = bytes(result)
- return result
-
- def __call__(self, data):
- ret = self.mask(self.offset, data)
- self.offset += len(ret)
- return ret
diff --git a/netlib/websockets/utils.py b/netlib/websockets/utils.py
deleted file mode 100644
index fdec074e..00000000
--- a/netlib/websockets/utils.py
+++ /dev/null
@@ -1,89 +0,0 @@
-"""
-Collection of WebSockets Protocol utility functions (RFC6455)
-Spec: https://tools.ietf.org/html/rfc6455
-"""
-
-
-import base64
-import hashlib
-import os
-
-from netlib import http, strutils
-
-MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
-VERSION = "13"
-
-
-def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
- """
- Create the headers for a valid HTTP upgrade request. If Key is not
- specified, it is generated, and can be found in sec-websocket-key in
- the returned header set.
-
- Returns an instance of http.Headers
- """
- if version is None:
- version = VERSION
- if key is None:
- key = base64.b64encode(os.urandom(16)).decode('ascii')
- h = http.Headers(
- connection="upgrade",
- upgrade="websocket",
- sec_websocket_version=version,
- sec_websocket_key=key,
- )
- if protocol is not None:
- h['sec-websocket-protocol'] = protocol
- if extensions is not None:
- h['sec-websocket-extensions'] = extensions
- return h
-
-
-def server_handshake_headers(client_key, protocol=None, extensions=None):
- """
- The server response is a valid HTTP 101 response.
-
- Returns an instance of http.Headers
- """
- h = http.Headers(
- connection="upgrade",
- upgrade="websocket",
- sec_websocket_accept=create_server_nonce(client_key),
- )
- if protocol is not None:
- h['sec-websocket-protocol'] = protocol
- if extensions is not None:
- h['sec-websocket-extensions'] = extensions
- return h
-
-
-def check_handshake(headers):
- return (
- "upgrade" in headers.get("connection", "").lower() and
- headers.get("upgrade", "").lower() == "websocket" and
- (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
- )
-
-
-def create_server_nonce(client_nonce):
- return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest())
-
-
-def check_client_version(headers):
- return headers.get("sec-websocket-version", "") == VERSION
-
-
-def get_extensions(headers):
- return headers.get("sec-websocket-extensions", None)
-
-
-def get_protocol(headers):
- return headers.get("sec-websocket-protocol", None)
-
-
-def get_client_key(headers):
- return headers.get("sec-websocket-key", None)
-
-
-def get_server_accept(headers):
- return headers.get("sec-websocket-accept", None)
diff --git a/netlib/wsgi.py b/netlib/wsgi.py
deleted file mode 100644
index 11e4aba9..00000000
--- a/netlib/wsgi.py
+++ /dev/null
@@ -1,164 +0,0 @@
-import time
-import traceback
-import urllib
-import io
-
-from netlib import http, tcp, strutils
-
-
-class ClientConn:
-
- def __init__(self, address):
- self.address = tcp.Address.wrap(address)
-
-
-class Flow:
-
- def __init__(self, address, request):
- self.client_conn = ClientConn(address)
- self.request = request
-
-
-class Request:
-
- def __init__(self, scheme, method, path, http_version, headers, content):
- self.scheme, self.method, self.path = scheme, method, path
- self.headers, self.content = headers, content
- self.http_version = http_version
-
-
-def date_time_string():
- """Return the current date and time formatted for a message header."""
- WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
- MONTHS = [
- None,
- 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
- 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'
- ]
- now = time.time()
- year, month, day, hh, mm, ss, wd, y_, z_ = time.gmtime(now)
- s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
- WEEKS[wd],
- day, MONTHS[month], year,
- hh, mm, ss
- )
- return s
-
-
-class WSGIAdaptor:
-
- def __init__(self, app, domain, port, sversion):
- self.app, self.domain, self.port, self.sversion = app, domain, port, sversion
-
- def make_environ(self, flow, errsoc, **extra):
- """
- Raises:
- ValueError, if the content-encoding is invalid.
- """
- path = strutils.native(flow.request.path, "latin-1")
- if '?' in path:
- path_info, query = strutils.native(path, "latin-1").split('?', 1)
- else:
- path_info = path
- query = ''
- environ = {
- 'wsgi.version': (1, 0),
- 'wsgi.url_scheme': strutils.native(flow.request.scheme, "latin-1"),
- 'wsgi.input': io.BytesIO(flow.request.content or b""),
- 'wsgi.errors': errsoc,
- 'wsgi.multithread': True,
- 'wsgi.multiprocess': False,
- 'wsgi.run_once': False,
- 'SERVER_SOFTWARE': self.sversion,
- 'REQUEST_METHOD': strutils.native(flow.request.method, "latin-1"),
- 'SCRIPT_NAME': '',
- 'PATH_INFO': urllib.parse.unquote(path_info),
- 'QUERY_STRING': query,
- 'CONTENT_TYPE': strutils.native(flow.request.headers.get('Content-Type', ''), "latin-1"),
- 'CONTENT_LENGTH': strutils.native(flow.request.headers.get('Content-Length', ''), "latin-1"),
- 'SERVER_NAME': self.domain,
- 'SERVER_PORT': str(self.port),
- 'SERVER_PROTOCOL': strutils.native(flow.request.http_version, "latin-1"),
- }
- environ.update(extra)
- if flow.client_conn.address:
- environ["REMOTE_ADDR"] = strutils.native(flow.client_conn.address.host, "latin-1")
- environ["REMOTE_PORT"] = flow.client_conn.address.port
-
- for key, value in flow.request.headers.items():
- key = 'HTTP_' + strutils.native(key, "latin-1").upper().replace('-', '_')
- if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'):
- environ[key] = value
- return environ
-
- def error_page(self, soc, headers_sent, s):
- """
- Make a best-effort attempt to write an error page. If headers are
- already sent, we just bung the error into the page.
- """
- c = """
- <html>
- <h1>Internal Server Error</h1>
- <pre>{err}"</pre>
- </html>
- """.format(err=s).strip().encode()
-
- if not headers_sent:
- soc.write(b"HTTP/1.1 500 Internal Server Error\r\n")
- soc.write(b"Content-Type: text/html\r\n")
- soc.write("Content-Length: {length}\r\n".format(length=len(c)).encode())
- soc.write(b"\r\n")
- soc.write(c)
-
- def serve(self, request, soc, **env):
- state = dict(
- response_started=False,
- headers_sent=False,
- status=None,
- headers=None
- )
-
- def write(data):
- if not state["headers_sent"]:
- soc.write("HTTP/1.1 {status}\r\n".format(status=state["status"]).encode())
- headers = state["headers"]
- if 'server' not in headers:
- headers["Server"] = self.sversion
- if 'date' not in headers:
- headers["Date"] = date_time_string()
- soc.write(bytes(headers))
- soc.write(b"\r\n")
- state["headers_sent"] = True
- if data:
- soc.write(data)
- soc.flush()
-
- def start_response(status, headers, exc_info=None):
- if exc_info:
- if state["headers_sent"]:
- raise exc_info[1]
- elif state["status"]:
- raise AssertionError('Response already started')
- state["status"] = status
- state["headers"] = http.Headers([[strutils.always_bytes(k), strutils.always_bytes(v)] for k, v in headers])
- if exc_info:
- self.error_page(soc, state["headers_sent"], traceback.format_tb(exc_info[2]))
- state["headers_sent"] = True
-
- errs = io.BytesIO()
- try:
- dataiter = self.app(
- self.make_environ(request, errs, **env), start_response
- )
- for i in dataiter:
- write(i)
- if not state["headers_sent"]:
- write(b"")
- except Exception:
- try:
- s = traceback.format_exc()
- errs.write(s.encode("utf-8", "replace"))
- self.error_page(soc, state["headers_sent"], s)
- except Exception: # pragma: no cover
- pass
- return errs.getvalue()