aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/debug.py45
-rw-r--r--netlib/http/cookies.py29
-rw-r--r--netlib/http/headers.py9
-rw-r--r--netlib/http/http2/__init__.py2
-rw-r--r--netlib/http/http2/utils.py37
-rw-r--r--netlib/http/message.py2
-rw-r--r--netlib/http/request.py12
-rw-r--r--netlib/http/response.py7
-rw-r--r--netlib/multidict.py24
-rw-r--r--netlib/strutils.py116
-rw-r--r--netlib/tcp.py4
-rw-r--r--netlib/utils.py11
-rw-r--r--netlib/websockets/frame.py2
13 files changed, 205 insertions, 95 deletions
diff --git a/netlib/debug.py b/netlib/debug.py
index a395afcb..29c7f655 100644
--- a/netlib/debug.py
+++ b/netlib/debug.py
@@ -7,8 +7,6 @@ import signal
import platform
import traceback
-import psutil
-
from netlib import version
from OpenSSL import SSL
@@ -19,7 +17,7 @@ def sysinfo():
"Mitmproxy version: %s" % version.VERSION,
"Python version: %s" % platform.python_version(),
"Platform: %s" % platform.platform(),
- "SSL version: %s" % SSL.SSLeay_version(SSL.SSLEAY_VERSION),
+ "SSL version: %s" % SSL.SSLeay_version(SSL.SSLEAY_VERSION).decode(),
]
d = platform.linux_distribution()
t = "Linux distro: %s %s %s" % d
@@ -40,15 +38,32 @@ def sysinfo():
def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
- p = psutil.Process()
-
print("****************************************************", file=file)
print("Summary", file=file)
print("=======", file=file)
- print("num threads: ", p.num_threads(), file=file)
- if hasattr(p, "num_fds"):
- print("num fds: ", p.num_fds(), file=file)
- print("memory: ", p.memory_info(), file=file)
+
+ try:
+ import psutil
+ except:
+ print("(psutil not installed, skipping some debug info)", file=file)
+ else:
+ p = psutil.Process()
+ print("num threads: ", p.num_threads(), file=file)
+ if hasattr(p, "num_fds"):
+ print("num fds: ", p.num_fds(), file=file)
+ print("memory: ", p.memory_info(), file=file)
+
+ print(file=file)
+ print("Files", file=file)
+ print("=====", file=file)
+ for i in p.open_files():
+ print(i, file=file)
+
+ print(file=file)
+ print("Connections", file=file)
+ print("===========", file=file)
+ for i in p.connections():
+ print(i, file=file)
print(file=file)
print("Threads", file=file)
@@ -63,18 +78,6 @@ def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
for i in bthreads:
print(i._threadinfo(), file=file)
- print(file=file)
- print("Files", file=file)
- print("=====", file=file)
- for i in p.open_files():
- print(i, file=file)
-
- print(file=file)
- print("Connections", file=file)
- print("===========", file=file)
- for i in p.connections():
- print(i, file=file)
-
print("****************************************************", file=file)
diff --git a/netlib/http/cookies.py b/netlib/http/cookies.py
index 768a85df..dd0af99c 100644
--- a/netlib/http/cookies.py
+++ b/netlib/http/cookies.py
@@ -1,7 +1,8 @@
import collections
+import email.utils
import re
+import time
-import email.utils
from netlib import multidict
"""
@@ -260,3 +261,29 @@ def refresh_set_cookie_header(c, delta):
if not ret:
raise ValueError("Invalid Cookie")
return ret
+
+
+def is_expired(cookie_attrs):
+ """
+ Determines whether a cookie has expired.
+
+ Returns: boolean
+ """
+
+ # See if 'expires' time is in the past
+ expires = False
+ if 'expires' in cookie_attrs:
+ e = email.utils.parsedate_tz(cookie_attrs["expires"])
+ if e:
+ exp_ts = email.utils.mktime_tz(e)
+ now_ts = time.time()
+ expires = exp_ts < now_ts
+
+ # or if Max-Age is 0
+ max_age = False
+ try:
+ max_age = int(cookie_attrs.get('Max-Age', 1)) == 0
+ except ValueError:
+ pass
+
+ return expires or max_age
diff --git a/netlib/http/headers.py b/netlib/http/headers.py
index 13a8c98f..b8aa212a 100644
--- a/netlib/http/headers.py
+++ b/netlib/http/headers.py
@@ -148,6 +148,15 @@ class Headers(multidict.MultiDict):
value = _always_bytes(value)
super(Headers, self).insert(index, key, value)
+ def items(self, multi=False):
+ if multi:
+ return (
+ (_native(k), _native(v))
+ for k, v in self.fields
+ )
+ else:
+ return super(Headers, self).items()
+
def replace(self, pattern, repl, flags=0):
"""
Replaces a regular expression pattern with repl in each "name: value"
diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py
index 6a979a0d..60064190 100644
--- a/netlib/http/http2/__init__.py
+++ b/netlib/http/http2/__init__.py
@@ -1,6 +1,8 @@
from __future__ import absolute_import, print_function, division
from netlib.http.http2 import framereader
+from netlib.http.http2.utils import parse_headers
__all__ = [
"framereader",
+ "parse_headers",
]
diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py
new file mode 100644
index 00000000..164bacc8
--- /dev/null
+++ b/netlib/http/http2/utils.py
@@ -0,0 +1,37 @@
+from netlib.http import url
+
+
+def parse_headers(headers):
+ authority = headers.get(':authority', '').encode()
+ method = headers.get(':method', 'GET').encode()
+ scheme = headers.get(':scheme', 'https').encode()
+ path = headers.get(':path', '/').encode()
+
+ headers.pop(":method", None)
+ headers.pop(":scheme", None)
+ headers.pop(":path", None)
+
+ host = None
+ port = None
+
+ if path == b'*' or path.startswith(b"/"):
+ first_line_format = "relative"
+ elif method == b'CONNECT': # pragma: no cover
+ raise NotImplementedError("CONNECT over HTTP/2 is not implemented.")
+ else: # pragma: no cover
+ first_line_format = "absolute"
+ # FIXME: verify if path or :host contains what we need
+ scheme, host, port, _ = url.parse(path)
+
+ if authority:
+ host, _, port = authority.partition(b':')
+
+ if not host:
+ host = b'localhost'
+
+ if not port:
+ port = 443 if scheme == b'https' else 80
+
+ port = int(port)
+
+ return first_line_format, method, scheme, host, port, path
diff --git a/netlib/http/message.py b/netlib/http/message.py
index 1252ed25..34709f0a 100644
--- a/netlib/http/message.py
+++ b/netlib/http/message.py
@@ -263,7 +263,7 @@ class Message(basetypes.Serializable):
if strict:
raise
is_strict = False
- decoded = self.content.decode(enc, "replace" if six.PY2 else "surrogateescape")
+ decoded = self.content.decode("utf8", "replace" if six.PY2 else "surrogateescape")
self._text_cache = CachedDecode(content, enc, is_strict, decoded)
return self._text_cache.decoded
diff --git a/netlib/http/request.py b/netlib/http/request.py
index a8ec6238..ecaa9b79 100644
--- a/netlib/http/request.py
+++ b/netlib/http/request.py
@@ -22,8 +22,20 @@ host_header_re = re.compile(r"^(?P<host>[^:]+|\[.+\])(?::(?P<port>\d+))?$")
class RequestData(message.MessageData):
def __init__(self, first_line_format, method, scheme, host, port, path, http_version, headers=(), content=None,
timestamp_start=None, timestamp_end=None):
+ if isinstance(method, six.text_type):
+ method = method.encode("ascii", "strict")
+ if isinstance(scheme, six.text_type):
+ scheme = scheme.encode("ascii", "strict")
+ if isinstance(host, six.text_type):
+ host = host.encode("idna", "strict")
+ if isinstance(path, six.text_type):
+ path = path.encode("ascii", "strict")
+ if isinstance(http_version, six.text_type):
+ http_version = http_version.encode("ascii", "strict")
if not isinstance(headers, nheaders.Headers):
headers = nheaders.Headers(headers)
+ if isinstance(content, six.text_type):
+ raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
self.first_line_format = first_line_format
self.method = method
diff --git a/netlib/http/response.py b/netlib/http/response.py
index d2273edd..85f54940 100644
--- a/netlib/http/response.py
+++ b/netlib/http/response.py
@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function, division
from email.utils import parsedate_tz, formatdate, mktime_tz
import time
+import six
from netlib.http import cookies
from netlib.http import headers as nheaders
@@ -13,8 +14,14 @@ from netlib import human
class ResponseData(message.MessageData):
def __init__(self, http_version, status_code, reason=None, headers=(), content=None,
timestamp_start=None, timestamp_end=None):
+ if isinstance(http_version, six.text_type):
+ http_version = http_version.encode("ascii", "strict")
+ if isinstance(reason, six.text_type):
+ reason = reason.encode("ascii", "strict")
if not isinstance(headers, nheaders.Headers):
headers = nheaders.Headers(headers)
+ if isinstance(content, six.text_type):
+ raise ValueError("Content must be bytes, not {}".format(type(content).__name__))
self.http_version = http_version
self.status_code = status_code
diff --git a/netlib/multidict.py b/netlib/multidict.py
index 50c879d9..51053ff6 100644
--- a/netlib/multidict.py
+++ b/netlib/multidict.py
@@ -170,18 +170,10 @@ class _MultiDict(MutableMapping, basetypes.Serializable):
else:
return super(_MultiDict, self).items()
- def clear(self, key):
- """
- Removes all items with the specified key, and does not raise an
- exception if the key does not exist.
- """
- if key in self:
- del self[key]
-
def collect(self):
"""
Returns a list of (key, value) tuples, where values are either
- singular if threre is only one matching item for a key, or a list
+ singular if there is only one matching item for a key, or a list
if there are more than one. The order of the keys matches the order
in the underlying fields list.
"""
@@ -204,18 +196,16 @@ class _MultiDict(MutableMapping, basetypes.Serializable):
.. code-block:: python
# Simple dict with duplicate values.
- >>> d
- MultiDictView[("name", "value"), ("a", "false"), ("a", "42")]
+ >>> d = MultiDict([("name", "value"), ("a", False), ("a", 42)])
>>> d.to_dict()
{
"name": "value",
- "a": ["false", "42"]
+ "a": [False, 42]
}
"""
- d = {}
- for k, v in self.collect():
- d[k] = v
- return d
+ return {
+ k: v for k, v in self.collect()
+ }
def get_state(self):
return self.fields
@@ -307,4 +297,4 @@ class MultiDictView(_MultiDict):
@fields.setter
def fields(self, value):
- return self._setter(value)
+ self._setter(value)
diff --git a/netlib/strutils.py b/netlib/strutils.py
index 414b2e57..32e77927 100644
--- a/netlib/strutils.py
+++ b/netlib/strutils.py
@@ -1,4 +1,5 @@
-import unicodedata
+from __future__ import absolute_import, print_function, division
+import re
import codecs
import six
@@ -19,60 +20,80 @@ def native(s, *encoding_opts):
"""
if not isinstance(s, (six.binary_type, six.text_type)):
raise TypeError("%r is neither bytes nor unicode" % s)
- if six.PY3:
- if isinstance(s, six.binary_type):
- return s.decode(*encoding_opts)
- else:
+ if six.PY2:
if isinstance(s, six.text_type):
return s.encode(*encoding_opts)
+ else:
+ if isinstance(s, six.binary_type):
+ return s.decode(*encoding_opts)
return s
-def clean_bin(s, keep_spacing=True):
- # type: (Union[bytes, six.text_type], bool) -> six.text_type
+# Translate control characters to "safe" characters. This implementation initially
+# replaced them with the matching control pictures (http://unicode.org/charts/PDF/U2400.pdf),
+# but that turned out to render badly with monospace fonts. We are back to "." therefore.
+_control_char_trans = {
+ x: ord(".") # x + 0x2400 for unicode control group pictures
+ for x in range(32)
+}
+_control_char_trans[127] = ord(".") # 0x2421
+_control_char_trans_newline = _control_char_trans.copy()
+for x in ("\r", "\n", "\t"):
+ del _control_char_trans_newline[ord(x)]
+
+
+if six.PY2:
+ pass
+else:
+ _control_char_trans = str.maketrans(_control_char_trans)
+ _control_char_trans_newline = str.maketrans(_control_char_trans_newline)
+
+
+def escape_control_characters(text, keep_spacing=True):
"""
- Cleans binary data to make it safe to display.
+ Replace all unicode C1 control characters from the given text with their respective control pictures.
+ For example, a null byte is replaced with the unicode character "\u2400".
- Args:
- keep_spacing: If False, tabs and newlines will also be replaced.
+ Args:
+ keep_spacing: If True, tabs and newlines will not be replaced.
"""
- if isinstance(s, six.text_type):
- if keep_spacing:
- keep = u" \n\r\t"
- else:
- keep = u" "
+ # type: (six.string_types) -> six.text_type
+ if not isinstance(text, six.string_types):
+ raise ValueError("text type must be unicode but is {}".format(type(text).__name__))
+
+ trans = _control_char_trans_newline if keep_spacing else _control_char_trans
+ if six.PY2:
return u"".join(
- ch if (unicodedata.category(ch)[0] not in "CZ" or ch in keep) else u"."
- for ch in s
- )
- else:
- if keep_spacing:
- keep = (9, 10, 13) # \t, \n, \r,
- else:
- keep = ()
- return "".join(
- chr(ch) if (31 < ch < 127 or ch in keep) else "."
- for ch in six.iterbytes(s)
+ six.unichr(trans.get(ord(ch), ord(ch)))
+ for ch in text
)
+ return text.translate(trans)
-def bytes_to_escaped_str(data):
+def bytes_to_escaped_str(data, keep_spacing=False):
"""
Take bytes and return a safe string that can be displayed to the user.
Single quotes are always escaped, double quotes are never escaped:
"'" + bytes_to_escaped_str(...) + "'"
gives a valid Python string.
+
+ Args:
+ keep_spacing: If True, tabs and newlines will not be escaped.
"""
- # TODO: We may want to support multi-byte characters without escaping them.
- # One way to do would be calling .decode("utf8", "backslashreplace") first
- # and then escaping UTF8 control chars (see clean_bin).
if not isinstance(data, bytes):
raise ValueError("data must be bytes, but is {}".format(data.__class__.__name__))
# We always insert a double-quote here so that we get a single-quoted string back
# https://stackoverflow.com/questions/29019340/why-does-python-use-different-quotes-for-representing-strings-depending-on-their
- return repr(b'"' + data).lstrip("b")[2:-1]
+ ret = repr(b'"' + data).lstrip("b")[2:-1]
+ if keep_spacing:
+ ret = re.sub(
+ r"(?<!\\)(\\\\)*\\([nrt])",
+ lambda m: (m.group(1) or "") + dict(n="\n", r="\r", t="\t")[m.group(2)],
+ ret
+ )
+ return ret
def escaped_str_to_bytes(data):
@@ -94,24 +115,17 @@ def escaped_str_to_bytes(data):
return codecs.escape_decode(data)[0]
-def isBin(s):
- """
- Does this string have any non-ASCII characters?
- """
- for i in s:
- i = ord(i)
- if i < 9 or 13 < i < 32 or 126 < i:
- return True
- return False
-
-
-def isMostlyBin(s):
- s = s[:100]
- return sum(isBin(ch) for ch in s) / len(s) > 0.3
+def is_mostly_bin(s):
+ # type: (bytes) -> bool
+ return sum(
+ i < 9 or 13 < i < 32 or 126 < i
+ for i in six.iterbytes(s[:100])
+ ) / len(s[:100]) > 0.3
-def isXML(s):
- return s.strip().startswith("<")
+def is_xml(s):
+ # type: (bytes) -> bool
+ return s.strip().startswith(b"<")
def clean_hanging_newline(t):
@@ -132,8 +146,12 @@ def hexdump(s):
A generator of (offset, hex, str) tuples
"""
for i in range(0, len(s), 16):
- offset = "{:0=10x}".format(i).encode()
+ offset = "{:0=10x}".format(i)
part = s[i:i + 16]
x = " ".join("{:0=2x}".format(i) for i in six.iterbytes(part))
x = x.ljust(47) # 16*2 + 15
- yield (offset, x, clean_bin(part, False))
+ part_repr = native(escape_control_characters(
+ part.decode("ascii", "replace").replace(u"\ufffd", u"."),
+ False
+ ))
+ yield (offset, x, part_repr)
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 69dafc1f..cf099edd 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -676,7 +676,7 @@ class TCPClient(_Connection):
self.connection = SSL.Connection(context, self.connection)
if sni:
self.sni = sni
- self.connection.set_tlsext_host_name(sni)
+ self.connection.set_tlsext_host_name(sni.encode("idna"))
self.connection.set_connect_state()
try:
self.connection.do_handshake()
@@ -705,7 +705,7 @@ class TCPClient(_Connection):
if self.cert.cn:
crt["subject"] = [[["commonName", self.cert.cn.decode("ascii", "strict")]]]
if sni:
- hostname = sni.decode("ascii", "strict")
+ hostname = sni
else:
hostname = "no-hostname"
ssl_match_hostname.match_hostname(crt, hostname)
diff --git a/netlib/utils.py b/netlib/utils.py
index 79340cbd..9eebf22c 100644
--- a/netlib/utils.py
+++ b/netlib/utils.py
@@ -56,6 +56,13 @@ class Data(object):
dirname = os.path.dirname(inspect.getsourcefile(m))
self.dirname = os.path.abspath(dirname)
+ def push(self, subpath):
+ """
+ Change the data object to a path relative to the module.
+ """
+ self.dirname = os.path.join(self.dirname, subpath)
+ return self
+
def path(self, path):
"""
Returns a path to the package data housed at 'path' under this
@@ -73,11 +80,9 @@ _label_valid = re.compile(b"(?!-)[A-Z\d-]{1,63}(?<!-)$", re.IGNORECASE)
def is_valid_host(host):
+ # type: (bytes) -> bool
"""
Checks if a hostname is valid.
-
- Args:
- host (bytes): The hostname
"""
try:
host.decode("idna")
diff --git a/netlib/websockets/frame.py b/netlib/websockets/frame.py
index 671e1605..7d355699 100644
--- a/netlib/websockets/frame.py
+++ b/netlib/websockets/frame.py
@@ -255,7 +255,7 @@ class Frame(object):
def __repr__(self):
ret = repr(self.header)
if self.payload:
- ret = ret + "\nPayload:\n" + strutils.clean_bin(self.payload)
+ ret = ret + "\nPayload:\n" + strutils.bytes_to_escaped_str(self.payload)
return ret
def human_readable(self):