aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2016-06-11 19:52:24 +1200
committerAldo Cortesi <aldo@nullcube.com>2016-06-11 19:52:24 +1200
commit09edbd9492e59c0c8dcae69b4b1f4b745867abe4 (patch)
treee9cf29c394334c02d908058c2c5e159715d3e3c3
parent5b9f07c81c0dcc8c7b3d7afdeae8f6229ebf8622 (diff)
downloadmitmproxy-09edbd9492e59c0c8dcae69b4b1f4b745867abe4.tar.gz
mitmproxy-09edbd9492e59c0c8dcae69b4b1f4b745867abe4.tar.bz2
mitmproxy-09edbd9492e59c0c8dcae69b4b1f4b745867abe4.zip
Improve debugging of thread and other leaks
- Add basethread.BaseThread that all threads outside of test suites should use - Add a signal handler to mitmproxy, mitmdump and mitmweb that dumps resource information to screen when SIGUSR1 is received. - Improve thread naming throughout to make thread dumps understandable
-rw-r--r--mitmproxy/controller.py9
-rw-r--r--mitmproxy/main.py1
-rw-r--r--mitmproxy/protocol/http2.py7
-rw-r--r--mitmproxy/protocol/http_replay.py8
-rw-r--r--mitmproxy/script/concurrent.py9
-rw-r--r--netlib/basethread.py14
-rw-r--r--netlib/debug.py71
-rw-r--r--netlib/tcp.py18
-rw-r--r--pathod/pathoc.py34
-rw-r--r--pathod/test.py7
-rw-r--r--setup.py2
-rw-r--r--test/netlib/test_debug.py8
12 files changed, 138 insertions, 50 deletions
diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py
index 084702a6..898be3bc 100644
--- a/mitmproxy/controller.py
+++ b/mitmproxy/controller.py
@@ -5,8 +5,10 @@ import threading
from six.moves import queue
+from netlib import basethread
from mitmproxy import exceptions
+
Events = frozenset([
"clientconnect",
"clientdisconnect",
@@ -95,12 +97,13 @@ class Master(object):
self.should_exit.set()
-class ServerThread(threading.Thread):
+class ServerThread(basethread.BaseThread):
def __init__(self, server):
self.server = server
- super(ServerThread, self).__init__()
address = getattr(self.server, "address", None)
- self.name = "ServerThread ({})".format(repr(address))
+ super(ServerThread, self).__init__(
+ "ServerThread ({})".format(repr(address))
+ )
def run(self):
self.server.serve_forever()
diff --git a/mitmproxy/main.py b/mitmproxy/main.py
index 34d4aa6b..53417fe8 100644
--- a/mitmproxy/main.py
+++ b/mitmproxy/main.py
@@ -47,6 +47,7 @@ def process_options(parser, options):
sys.exit(0)
if options.quiet:
options.verbose = 0
+ debug.register_info_dumper()
return config.process_proxy_options(parser, options)
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py
index 9247e657..957b8d64 100644
--- a/mitmproxy/protocol/http2.py
+++ b/mitmproxy/protocol/http2.py
@@ -18,6 +18,7 @@ from mitmproxy.protocol import base
from mitmproxy.protocol import http
import netlib.http
from netlib import tcp
+from netlib import basethread
from netlib.http import http2
@@ -261,10 +262,12 @@ class Http2Layer(base.Layer):
self._cleanup_streams()
-class Http2SingleStreamLayer(http._HttpTransmissionLayer, threading.Thread):
+class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
def __init__(self, ctx, stream_id, request_headers):
- super(Http2SingleStreamLayer, self).__init__(ctx, name="Thread-Http2SingleStreamLayer-{}".format(stream_id))
+ super(Http2SingleStreamLayer, self).__init__(
+ ctx, name="Http2SingleStreamLayer-{}".format(stream_id)
+ )
self.zombie = None
self.client_stream_id = stream_id
self.server_stream_id = None
diff --git a/mitmproxy/protocol/http_replay.py b/mitmproxy/protocol/http_replay.py
index 5928c0af..e804eba9 100644
--- a/mitmproxy/protocol/http_replay.py
+++ b/mitmproxy/protocol/http_replay.py
@@ -1,6 +1,5 @@
from __future__ import absolute_import, print_function, division
-import threading
import traceback
import netlib.exceptions
@@ -8,12 +7,13 @@ from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import models
from netlib.http import http1
+from netlib import basethread
# TODO: Doesn't really belong into mitmproxy.protocol...
-class RequestReplayThread(threading.Thread):
+class RequestReplayThread(basethread.BaseThread):
name = "RequestReplayThread"
def __init__(self, config, flow, event_queue, should_exit):
@@ -26,7 +26,9 @@ class RequestReplayThread(threading.Thread):
self.channel = controller.Channel(event_queue, should_exit)
else:
self.channel = None
- super(RequestReplayThread, self).__init__()
+ super(RequestReplayThread, self).__init__(
+ "RequestReplay (%s)" % flow.request.url
+ )
def run(self):
r = self.flow.request
diff --git a/mitmproxy/script/concurrent.py b/mitmproxy/script/concurrent.py
index 89c835f6..56d39d0b 100644
--- a/mitmproxy/script/concurrent.py
+++ b/mitmproxy/script/concurrent.py
@@ -5,10 +5,10 @@ offload computations from mitmproxy's main master thread.
from __future__ import absolute_import, print_function, division
from mitmproxy import controller
-import threading
+from netlib import basethread
-class ScriptThread(threading.Thread):
+class ScriptThread(basethread.BaseThread):
name = "ScriptThread"
@@ -24,5 +24,8 @@ def concurrent(fn):
if not obj.reply.acked:
obj.reply.ack()
obj.reply.take()
- ScriptThread(target=run).start()
+ ScriptThread(
+ "script.concurrent (%s)" % fn.__name__,
+ target=run
+ ).start()
return _concurrent
diff --git a/netlib/basethread.py b/netlib/basethread.py
new file mode 100644
index 00000000..7963eb7e
--- /dev/null
+++ b/netlib/basethread.py
@@ -0,0 +1,14 @@
+import time
+import threading
+
+
+class BaseThread(threading.Thread):
+ def __init__(self, name, *args, **kwargs):
+ super(BaseThread, self).__init__(name=name, *args, **kwargs)
+ self._thread_started = time.time()
+
+ def _threadinfo(self):
+ return "%s - age: %is" % (
+ self.name,
+ int(time.time() - self._thread_started)
+ )
diff --git a/netlib/debug.py b/netlib/debug.py
index bf446eb0..b48cb122 100644
--- a/netlib/debug.py
+++ b/netlib/debug.py
@@ -1,29 +1,76 @@
+from __future__ import (absolute_import, print_function, division)
+
+import sys
+import threading
+import signal
import platform
+
+import psutil
+
from netlib import version
-"""
- Some utilities to help with debugging.
-"""
def sysinfo():
data = [
- "Mitmproxy verison: %s"%version.VERSION,
- "Python version: %s"%platform.python_version(),
- "Platform: %s"%platform.platform(),
+ "Mitmproxy verison: %s" % version.VERSION,
+ "Python version: %s" % platform.python_version(),
+ "Platform: %s" % platform.platform(),
]
d = platform.linux_distribution()
- t = "Linux distro: %s %s %s"%d
- if d[0]: # pragma: no-cover
+ t = "Linux distro: %s %s %s" % d
+ if d[0]: # pragma: no-cover
data.append(t)
d = platform.mac_ver()
- t = "Mac version: %s %s %s"%d
- if d[0]: # pragma: no-cover
+ t = "Mac version: %s %s %s" % d
+ if d[0]: # pragma: no-cover
data.append(t)
d = platform.win32_ver()
- t = "Windows version: %s %s %s %s"%d
- if d[0]: # pragma: no-cover
+ t = "Windows version: %s %s %s %s" % d
+ if d[0]: # pragma: no-cover
data.append(t)
return "\n".join(data)
+
+
+def dump_info(sig, frm, file=sys.stdout): # pragma: no cover
+ p = psutil.Process()
+
+ print("****************************************************", file=file)
+ print("Summary", file=file)
+ print("=======", file=file)
+ print("num threads: ", p.num_threads(), file=file)
+ print("num fds: ", p.num_fds(), file=file)
+ print("memory: ", p.memory_info(), file=file)
+
+ print(file=file)
+ print("Threads", file=file)
+ print("=======", file=file)
+ bthreads = []
+ for i in threading.enumerate():
+ if hasattr(i, "_threadinfo"):
+ bthreads.append(i)
+ else:
+ print(i.name, file=file)
+ bthreads.sort(key=lambda x: x._thread_started)
+ for i in bthreads:
+ print(i._threadinfo(), file=file)
+
+ print(file=file)
+ print("Files", file=file)
+ print("=====", file=file)
+ for i in p.open_files():
+ print(i, file=file)
+
+ print(file=file)
+ print("Connections", file=file)
+ print("===========", file=file)
+ for i in p.connections():
+ print(i, file=file)
+
+ print("****************************************************", file=file)
+
+
+def register_info_dumper(): # pragma: no cover
+ signal.signal(signal.SIGUSR1, dump_info)
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 0eec326b..acd67cad 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -17,7 +17,11 @@ import six
import OpenSSL
from OpenSSL import SSL
-from netlib import certutils, version_check, basetypes, exceptions
+from netlib import certutils
+from netlib import version_check
+from netlib import basetypes
+from netlib import exceptions
+from netlib import basethread
# This is a rather hackish way to make sure that
# the latest version of pyOpenSSL is actually installed.
@@ -900,12 +904,16 @@ class TCPServer(object):
raise
if self.socket in r:
connection, client_address = self.socket.accept()
- t = threading.Thread(
+ t = basethread.BaseThread(
+ "TCPConnectionHandler (%s: %s:%s -> %s:%s)" % (
+ self.__class__.__name__,
+ client_address[0],
+ client_address[1],
+ self.address.host,
+ self.address.port
+ ),
target=self.connection_thread,
args=(connection, client_address),
- name="ConnectionThread (%s:%s -> %s:%s)" %
- (client_address[0], client_address[1],
- self.address.host, self.address.port)
)
t.setDaemon(1)
try:
diff --git a/pathod/pathoc.py b/pathod/pathoc.py
index def6cfcf..b2563988 100644
--- a/pathod/pathoc.py
+++ b/pathod/pathoc.py
@@ -8,15 +8,15 @@ from six.moves import queue
import random
import select
import time
-import threading
import OpenSSL.crypto
import six
from netlib import tcp, certutils, websockets, socks
-from netlib.exceptions import HttpException, TcpDisconnect, TcpTimeout, TlsException, TcpException, \
- NetlibException
-from netlib.http import http1, http2
+from netlib import exceptions
+from netlib.http import http1
+from netlib.http import http2
+from netlib import basethread
from pathod import log, language
@@ -77,7 +77,7 @@ class SSLInfo(object):
return "\n".join(parts)
-class WebsocketFrameReader(threading.Thread):
+class WebsocketFrameReader(basethread.BaseThread):
def __init__(
self,
@@ -88,7 +88,7 @@ class WebsocketFrameReader(threading.Thread):
ws_read_limit,
timeout
):
- threading.Thread.__init__(self)
+ basethread.BaseThread.__init__(self, "WebsocketFrameReader")
self.timeout = timeout
self.ws_read_limit = ws_read_limit
self.logfp = logfp
@@ -129,7 +129,7 @@ class WebsocketFrameReader(threading.Thread):
with self.logger.ctx() as log:
try:
frm = websockets.Frame.from_file(self.rfile)
- except TcpDisconnect:
+ except exceptions.TcpDisconnect:
return
self.frames_queue.put(frm)
log("<< %s" % frm.header.human_readable())
@@ -241,8 +241,8 @@ class Pathoc(tcp.TCPClient):
try:
resp = self.protocol.read_response(self.rfile, treq(method="CONNECT"))
if resp.status_code != 200:
- raise HttpException("Unexpected status code: %s" % resp.status_code)
- except HttpException as e:
+ raise exceptions.HttpException("Unexpected status code: %s" % resp.status_code)
+ except exceptions.HttpException as e:
six.reraise(PathocError, PathocError(
"Proxy CONNECT failed: %s" % repr(e)
))
@@ -280,7 +280,7 @@ class Pathoc(tcp.TCPClient):
connect_reply.msg,
"SOCKS server error"
)
- except (socks.SocksError, TcpDisconnect) as e:
+ except (socks.SocksError, exceptions.TcpDisconnect) as e:
raise PathocError(str(e))
def connect(self, connect_to=None, showssl=False, fp=sys.stdout):
@@ -310,7 +310,7 @@ class Pathoc(tcp.TCPClient):
cipher_list=self.ciphers,
alpn_protos=alpn_protos
)
- except TlsException as v:
+ except exceptions.TlsException as v:
raise PathocError(str(v))
self.sslinfo = SSLInfo(
@@ -406,7 +406,7 @@ class Pathoc(tcp.TCPClient):
Returns Response if we have a non-ignored response.
- May raise a NetlibException
+ May raise a exceptions.NetlibException
"""
logger = log.ConnectionLogger(
self.fp,
@@ -424,10 +424,10 @@ class Pathoc(tcp.TCPClient):
resp = self.protocol.read_response(self.rfile, treq(method=req["method"].encode()))
resp.sslinfo = self.sslinfo
- except HttpException as v:
+ except exceptions.HttpException as v:
lg("Invalid server response: %s" % v)
raise
- except TcpTimeout:
+ except exceptions.TcpTimeout:
if self.ignoretimeout:
lg("Timeout (ignored)")
return None
@@ -451,7 +451,7 @@ class Pathoc(tcp.TCPClient):
Returns Response if we have a non-ignored response.
- May raise a NetlibException
+ May raise a exceptions.NetlibException
"""
if isinstance(r, basestring):
r = language.parse_pathoc(r, self.use_http2).next()
@@ -530,11 +530,11 @@ def main(args): # pragma: no cover
# We consume the queue when we can, so it doesn't build up.
for i_ in p.wait(timeout=0, finish=False):
pass
- except NetlibException:
+ except exceptions.NetlibException:
break
for i_ in p.wait(timeout=0.01, finish=True):
pass
- except TcpException as v:
+ except exceptions.TcpException as v:
print(str(v), file=sys.stderr)
continue
except PathocError as v:
diff --git a/pathod/test.py b/pathod/test.py
index 11462729..3ba541b1 100644
--- a/pathod/test.py
+++ b/pathod/test.py
@@ -1,10 +1,10 @@
from six.moves import cStringIO as StringIO
-import threading
import time
from six.moves import queue
from . import pathod
+from netlib import basethread
class TimeoutError(Exception):
@@ -95,11 +95,10 @@ class Daemon:
self.thread.join()
-class _PaThread(threading.Thread):
+class _PaThread(basethread.BaseThread):
def __init__(self, iface, q, ssl, daemonargs):
- threading.Thread.__init__(self)
- self.name = "PathodThread"
+ basethread.BaseThread.__init__(self, "PathodThread")
self.iface, self.q, self.ssl = iface, q, ssl
self.daemonargs = daemonargs
self.server = None
diff --git a/setup.py b/setup.py
index 050043b3..cd123044 100644
--- a/setup.py
+++ b/setup.py
@@ -1,7 +1,6 @@
from setuptools import setup, find_packages
from codecs import open
import os
-import sys
# Based on https://github.com/pypa/sampleproject/blob/master/setup.py
# and https://python-packaging-user-guide.readthedocs.org/
@@ -73,6 +72,7 @@ setup(
"lxml>=3.5.0, <3.7",
"Pillow>=3.2, <3.3",
"passlib>=1.6.5, <1.7",
+ "psutil>=4.2.0, <4.3",
"pyasn1>=0.1.9, <0.2",
"pyOpenSSL>=16.0, <17.0",
"pyparsing>=2.1.3, <2.2",
diff --git a/test/netlib/test_debug.py b/test/netlib/test_debug.py
index d174bb5f..c39d3752 100644
--- a/test/netlib/test_debug.py
+++ b/test/netlib/test_debug.py
@@ -1,6 +1,14 @@
+from __future__ import (absolute_import, print_function, division)
+from six.moves import cStringIO as StringIO
from netlib import debug
+def test_dump_info():
+ cs = StringIO()
+ debug.dump_info(None, None, file=cs)
+ assert cs.getvalue()
+
+
def test_sysinfo():
assert debug.sysinfo()