aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libmproxy/console/__init__.py6
-rw-r--r--libmproxy/console/common.py4
-rw-r--r--libmproxy/controller.py85
-rw-r--r--libmproxy/dump.py22
-rw-r--r--libmproxy/flow.py50
-rw-r--r--libmproxy/proxy.py43
-rw-r--r--test/test_dump.py6
-rw-r--r--test/test_flow.py32
-rw-r--r--test/tservers.py2
-rw-r--r--test/tutils.py15
10 files changed, 158 insertions, 107 deletions
diff --git a/libmproxy/console/__init__.py b/libmproxy/console/__init__.py
index d6c7f5a2..a16cc4dc 100644
--- a/libmproxy/console/__init__.py
+++ b/libmproxy/console/__init__.py
@@ -580,7 +580,7 @@ class ConsoleMaster(flow.FlowMaster):
self.view_flowlist()
- self.server.start_slave(controller.Slave, self.masterq)
+ self.server.start_slave(controller.Slave, controller.Channel(self.masterq))
if self.options.rfile:
ret = self.load_flows(self.options.rfile)
@@ -1002,7 +1002,7 @@ class ConsoleMaster(flow.FlowMaster):
if self.state.intercept and f.match(self.state.intercept) and not f.request.is_replay():
f.intercept()
else:
- r._ack()
+ r.reply()
self.sync_list_view()
self.refresh_flow(f)
@@ -1023,7 +1023,7 @@ class ConsoleMaster(flow.FlowMaster):
# Handlers
def handle_log(self, l):
self.add_event(l.msg)
- l._ack()
+ l.reply()
def handle_error(self, r):
f = flow.FlowMaster.handle_error(self, r)
diff --git a/libmproxy/console/common.py b/libmproxy/console/common.py
index 2da7f802..1cc0b5b9 100644
--- a/libmproxy/console/common.py
+++ b/libmproxy/console/common.py
@@ -184,7 +184,7 @@ def format_flow(f, focus, extended=False, padding=2):
req_timestamp = f.request.timestamp_start,
req_is_replay = f.request.is_replay(),
req_method = f.request.method,
- req_acked = f.request.acked,
+ req_acked = f.request.reply.acked,
req_url = f.request.get_url(),
err_msg = f.error.msg if f.error else None,
@@ -200,7 +200,7 @@ def format_flow(f, focus, extended=False, padding=2):
d.update(dict(
resp_code = f.response.code,
resp_is_replay = f.response.is_replay(),
- resp_acked = f.response.acked,
+ resp_acked = f.response.reply.acked,
resp_clen = contentdesc
))
t = f.response.headers["content-type"]
diff --git a/libmproxy/controller.py b/libmproxy/controller.py
index f38d1edb..c36bb9df 100644
--- a/libmproxy/controller.py
+++ b/libmproxy/controller.py
@@ -17,37 +17,73 @@ import Queue, threading
should_exit = False
-class Msg:
+
+class DummyReply:
+ """
+ A reply object that does nothing. Useful when we need an object to seem
+ like it has a channel, and during testing.
+ """
def __init__(self):
+ self.acked = False
+
+ def __call__(self, msg=False):
+ self.acked = True
+
+
+class Reply:
+ """
+ Messages sent through a channel are decorated with a "reply" attribute.
+ This object is used to respond to the message through the return
+ channel.
+ """
+ def __init__(self, obj):
+ self.obj = obj
self.q = Queue.Queue()
self.acked = False
- def _ack(self, data=False):
+ def __call__(self, msg=False):
if not self.acked:
self.acked = True
- if data is None:
- self.q.put(data)
+ if msg is None:
+ self.q.put(msg)
else:
- self.q.put(data or self)
+ self.q.put(msg or self.obj)
- def _send(self, masterq):
- self.acked = False
- try:
- masterq.put(self, timeout=3)
- while not should_exit: # pragma: no cover
- try:
- g = self.q.get(timeout=0.5)
- except Queue.Empty:
- continue
- return g
- except (Queue.Empty, Queue.Full): # pragma: no cover
- return None
+
+class Channel:
+ def __init__(self, q):
+ self.q = q
+
+ def ask(self, m):
+ """
+ Send a message to the master, and wait for a response.
+ """
+ m.reply = Reply(m)
+ self.q.put(m)
+ while not should_exit:
+ try:
+ # The timeout is here so we can handle a should_exit event.
+ g = m.reply.q.get(timeout=0.5)
+ except Queue.Empty:
+ continue
+ return g
+
+ def tell(self, m):
+ """
+ Send a message to the master, and keep going.
+ """
+ m.reply = None
+ self.q.put(m)
class Slave(threading.Thread):
- def __init__(self, masterq, server):
- self.masterq, self.server = masterq, server
- self.server.set_mqueue(masterq)
+ """
+ Slaves get a channel end-point through which they can send messages to
+ the master.
+ """
+ def __init__(self, channel, server):
+ self.channel, self.server = channel, server
+ self.server.set_channel(channel)
threading.Thread.__init__(self)
def run(self):
@@ -55,6 +91,9 @@ class Slave(threading.Thread):
class Master:
+ """
+ Masters get and respond to messages from slaves.
+ """
def __init__(self, server):
"""
server may be None if no server is needed.
@@ -81,18 +120,18 @@ class Master:
def run(self):
global should_exit
should_exit = False
- self.server.start_slave(Slave, self.masterq)
+ self.server.start_slave(Slave, Channel(self.masterq))
while not should_exit:
self.tick(self.masterq)
self.shutdown()
- def handle(self, msg): # pragma: no cover
+ def handle(self, msg):
c = "handle_" + msg.__class__.__name__.lower()
m = getattr(self, c, None)
if m:
m(msg)
else:
- msg._ack()
+ msg.reply()
def shutdown(self):
global should_exit
diff --git a/libmproxy/dump.py b/libmproxy/dump.py
index 170c701d..3c7eee71 100644
--- a/libmproxy/dump.py
+++ b/libmproxy/dump.py
@@ -150,16 +150,6 @@ class DumpMaster(flow.FlowMaster):
print >> self.outfile, e
self.outfile.flush()
- def handle_log(self, l):
- self.add_event(l.msg)
- l._ack()
-
- def handle_request(self, r):
- f = flow.FlowMaster.handle_request(self, r)
- if f:
- r._ack()
- return f
-
def indent(self, n, t):
l = str(t).strip().split("\n")
return "\n".join(" "*n + i for i in l)
@@ -210,10 +200,20 @@ class DumpMaster(flow.FlowMaster):
self.outfile.flush()
self.state.delete_flow(f)
+ def handle_log(self, l):
+ self.add_event(l.msg)
+ l.reply()
+
+ def handle_request(self, r):
+ f = flow.FlowMaster.handle_request(self, r)
+ if f:
+ r.reply()
+ return f
+
def handle_response(self, msg):
f = flow.FlowMaster.handle_response(self, msg)
if f:
- msg._ack()
+ msg.reply()
self._process_flow(f)
return f
diff --git a/libmproxy/flow.py b/libmproxy/flow.py
index 9238cfbf..0f5fb563 100644
--- a/libmproxy/flow.py
+++ b/libmproxy/flow.py
@@ -196,7 +196,7 @@ class decoded(object):
self.o.encode(self.ce)
-class HTTPMsg(controller.Msg):
+class HTTPMsg:
def get_decoded_content(self):
"""
Returns the decoded content based on the current Content-Encoding header.
@@ -252,6 +252,7 @@ class HTTPMsg(controller.Msg):
return 0
return len(self.content)
+
class Request(HTTPMsg):
"""
An HTTP request.
@@ -289,7 +290,6 @@ class Request(HTTPMsg):
self.timestamp_start = timestamp_start or utils.timestamp()
self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start)
self.close = False
- controller.Msg.__init__(self)
# Have this request's cookies been modified by sticky cookies or auth?
self.stickycookie = False
@@ -396,7 +396,6 @@ class Request(HTTPMsg):
Returns a copy of this object.
"""
c = copy.copy(self)
- c.acked = True
c.headers = self.headers.copy()
return c
@@ -603,7 +602,6 @@ class Response(HTTPMsg):
self.cert = cert
self.timestamp_start = timestamp_start or utils.timestamp()
self.timestamp_end = max(timestamp_end or utils.timestamp(), timestamp_start)
- controller.Msg.__init__(self)
self.replay = False
def _refresh_cookie(self, c, delta):
@@ -708,7 +706,6 @@ class Response(HTTPMsg):
Returns a copy of this object.
"""
c = copy.copy(self)
- c.acked = True
c.headers = self.headers.copy()
return c
@@ -773,7 +770,7 @@ class Response(HTTPMsg):
cookies.append((cookie_name, (cookie_value, cookie_parameters)))
return dict(cookies)
-class ClientDisconnect(controller.Msg):
+class ClientDisconnect:
"""
A client disconnection event.
@@ -782,11 +779,10 @@ class ClientDisconnect(controller.Msg):
client_conn: ClientConnect object.
"""
def __init__(self, client_conn):
- controller.Msg.__init__(self)
self.client_conn = client_conn
-class ClientConnect(controller.Msg):
+class ClientConnect:
"""
A single client connection. Each connection can result in multiple HTTP
Requests.
@@ -807,7 +803,6 @@ class ClientConnect(controller.Msg):
self.close = False
self.requestcount = 0
self.error = None
- controller.Msg.__init__(self)
def __eq__(self, other):
return self._get_state() == other._get_state()
@@ -838,11 +833,10 @@ class ClientConnect(controller.Msg):
Returns a copy of this object.
"""
c = copy.copy(self)
- c.acked = True
return c
-class Error(controller.Msg):
+class Error:
"""
An Error.
@@ -860,7 +854,6 @@ class Error(controller.Msg):
def __init__(self, request, msg, timestamp=None):
self.request, self.msg = request, msg
self.timestamp = timestamp or utils.timestamp()
- controller.Msg.__init__(self)
def _load_state(self, state):
self.msg = state["msg"]
@@ -871,7 +864,6 @@ class Error(controller.Msg):
Returns a copy of this object.
"""
c = copy.copy(self)
- c.acked = True
return c
def _get_state(self):
@@ -1180,10 +1172,11 @@ class Flow:
Kill this request.
"""
self.error = Error(self.request, "Connection killed")
- if self.request and not self.request.acked:
- self.request._ack(None)
- elif self.response and not self.response.acked:
- self.response._ack(None)
+ self.error.reply = controller.DummyReply()
+ if self.request and not self.request.reply.acked:
+ self.request.reply(None)
+ elif self.response and not self.response.reply.acked:
+ self.response.reply(None)
master.handle_error(self.error)
self.intercepting = False
@@ -1199,10 +1192,10 @@ class Flow:
Continue with the flow - called after an intercept().
"""
if self.request:
- if not self.request.acked:
- self.request._ack()
- elif self.response and not self.response.acked:
- self.response._ack()
+ if not self.request.reply.acked:
+ self.request.reply()
+ elif self.response and not self.response.reply.acked:
+ self.response.reply()
self.intercepting = False
def replace(self, pattern, repl, *args, **kwargs):
@@ -1464,7 +1457,7 @@ class FlowMaster(controller.Master):
flow.response = response
if self.refresh_server_playback:
response.refresh()
- flow.request._ack(response)
+ flow.request.reply(response)
if self.server_playback.count() == 0:
self.stop_server_playback()
return True
@@ -1491,10 +1484,13 @@ class FlowMaster(controller.Master):
Loads a flow, and returns a new flow object.
"""
if f.request:
+ f.request.reply = controller.DummyReply()
fr = self.handle_request(f.request)
if f.response:
+ f.response.reply = controller.DummyReply()
self.handle_response(f.response)
if f.error:
+ f.error.reply = controller.DummyReply()
self.handle_error(f.error)
return fr
@@ -1522,7 +1518,7 @@ class FlowMaster(controller.Master):
if self.kill_nonreplay:
f.kill(self)
else:
- f.request._ack()
+ f.request.reply()
def process_new_response(self, f):
if self.stickycookie_state:
@@ -1561,11 +1557,11 @@ class FlowMaster(controller.Master):
def handle_clientconnect(self, cc):
self.run_script_hook("clientconnect", cc)
- cc._ack()
+ cc.reply()
def handle_clientdisconnect(self, r):
self.run_script_hook("clientdisconnect", r)
- r._ack()
+ r.reply()
def handle_error(self, r):
f = self.state.add_error(r)
@@ -1573,7 +1569,7 @@ class FlowMaster(controller.Master):
self.run_script_hook("error", f)
if self.client_playback:
self.client_playback.clear(f)
- r._ack()
+ r.reply()
return f
def handle_request(self, r):
@@ -1596,7 +1592,7 @@ class FlowMaster(controller.Master):
if self.stream:
self.stream.add(f)
else:
- r._ack()
+ r.reply()
return f
def shutdown(self):
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index f2c6db1f..1fbb6d58 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -29,9 +29,8 @@ class ProxyError(Exception):
return "ProxyError(%s, %s)"%(self.code, self.msg)
-class Log(controller.Msg):
+class Log:
def __init__(self, msg):
- controller.Msg.__init__(self)
self.msg = msg
@@ -51,7 +50,7 @@ class ProxyConfig:
class RequestReplayThread(threading.Thread):
def __init__(self, config, flow, masterq):
- self.config, self.flow, self.masterq = config, flow, masterq
+ self.config, self.flow, self.channel = config, flow, controller.Channel(masterq)
threading.Thread.__init__(self)
def run(self):
@@ -66,10 +65,10 @@ class RequestReplayThread(threading.Thread):
response = flow.Response(
self.flow.request, httpversion, code, msg, headers, content, server.cert
)
- response._send(self.masterq)
+ self.channel.ask(response)
except (ProxyError, http.HttpError, tcp.NetLibError), v:
err = flow.Error(self.flow.request, str(v))
- err._send(self.masterq)
+ self.channel.ask(err)
class ServerConnection(tcp.TCPClient):
@@ -128,8 +127,8 @@ class ServerConnectionPool:
class ProxyHandler(tcp.BaseHandler):
- def __init__(self, config, connection, client_address, server, mqueue, server_version):
- self.mqueue, self.server_version = mqueue, server_version
+ def __init__(self, config, connection, client_address, server, channel, server_version):
+ self.channel, self.server_version = channel, server_version
self.config = config
self.server_conn_pool = ServerConnectionPool(config)
self.proxy_connect_state = None
@@ -139,18 +138,18 @@ class ProxyHandler(tcp.BaseHandler):
def handle(self):
cc = flow.ClientConnect(self.client_address)
self.log(cc, "connect")
- cc._send(self.mqueue)
+ self.channel.ask(cc)
while self.handle_request(cc) and not cc.close:
pass
cc.close = True
- cd = flow.ClientDisconnect(cc)
+ cd = flow.ClientDisconnect(cc)
self.log(
cc, "disconnect",
[
"handled %s requests"%cc.requestcount]
)
- cd._send(self.mqueue)
+ self.channel.ask(cd)
def handle_request(self, cc):
try:
@@ -167,14 +166,14 @@ class ProxyHandler(tcp.BaseHandler):
self.log(cc, "Error in wsgi app.", err.split("\n"))
return
else:
- request = request._send(self.mqueue)
+ request = self.channel.ask(request)
if request is None:
return
if isinstance(request, flow.Response):
response = request
request = False
- response = response._send(self.mqueue)
+ response = self.channel.ask(response)
else:
if self.config.reverse_proxy:
scheme, host, port = self.config.reverse_proxy
@@ -192,7 +191,7 @@ class ProxyHandler(tcp.BaseHandler):
request, httpversion, code, msg, headers, content, sc.cert,
sc.rfile.first_byte_timestamp, utils.timestamp()
)
- response = response._send(self.mqueue)
+ response = self.channel.ask(response)
if response is None:
sc.terminate()
if response is None:
@@ -214,7 +213,7 @@ class ProxyHandler(tcp.BaseHandler):
if request:
err = flow.Error(request, cc.error)
- err._send(self.mqueue)
+ self.channel.ask(err)
self.log(
cc, cc.error,
["url: %s"%request.get_url()]
@@ -235,7 +234,7 @@ class ProxyHandler(tcp.BaseHandler):
msg.append(" -> "+i)
msg = "\n".join(msg)
l = Log(msg)
- l._send(self.mqueue)
+ self.channel.ask(l)
def find_cert(self, host, port, sni):
if self.config.certfile:
@@ -438,18 +437,18 @@ class ProxyServer(tcp.TCPServer):
tcp.TCPServer.__init__(self, (address, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
- self.masterq = None
+ self.channel = None
self.apps = AppRegistry()
- def start_slave(self, klass, masterq):
- slave = klass(masterq, self)
+ def start_slave(self, klass, channel):
+ slave = klass(channel, self)
slave.start()
- def set_mqueue(self, q):
- self.masterq = q
+ def set_channel(self, channel):
+ self.channel = channel
def handle_connection(self, request, client_address):
- h = ProxyHandler(self.config, request, client_address, self, self.masterq, self.server_version)
+ h = ProxyHandler(self.config, request, client_address, self, self.channel, self.server_version)
h.handle()
try:
h.finish()
@@ -487,7 +486,7 @@ class DummyServer:
def __init__(self, config):
self.config = config
- def start_slave(self, klass, masterq):
+ def start_slave(self, klass, channel):
pass
def shutdown(self):
diff --git a/test/test_dump.py b/test/test_dump.py
index e1241e29..5d3f9133 100644
--- a/test/test_dump.py
+++ b/test/test_dump.py
@@ -3,6 +3,7 @@ from cStringIO import StringIO
import libpry
from libmproxy import dump, flow, proxy
import tutils
+import mock
def test_strfuncs():
t = tutils.tresp()
@@ -21,6 +22,7 @@ class TestDumpMaster:
req = tutils.treq()
req.content = content
l = proxy.Log("connect")
+ l.reply = mock.MagicMock()
m.handle_log(l)
cc = req.client_conn
cc.connection_error = "error"
@@ -29,7 +31,9 @@ class TestDumpMaster:
m.handle_clientconnect(cc)
m.handle_request(req)
f = m.handle_response(resp)
- m.handle_clientdisconnect(flow.ClientDisconnect(cc))
+ cd = flow.ClientDisconnect(cc)
+ cd.reply = mock.MagicMock()
+ m.handle_clientdisconnect(cd)
return f
def _dummy_cycle(self, n, filt, content, **options):
diff --git a/test/test_flow.py b/test/test_flow.py
index da5b095e..6aa898ad 100644
--- a/test/test_flow.py
+++ b/test/test_flow.py
@@ -223,16 +223,16 @@ class TestFlow:
f = tutils.tflow()
f.request = tutils.treq()
f.intercept()
- assert not f.request.acked
+ assert not f.request.reply.acked
f.kill(fm)
- assert f.request.acked
+ assert f.request.reply.acked
f.intercept()
f.response = tutils.tresp()
f.request = f.response.request
- f.request._ack()
- assert not f.response.acked
+ f.request.reply()
+ assert not f.response.reply.acked
f.kill(fm)
- assert f.response.acked
+ assert f.response.reply.acked
def test_killall(self):
s = flow.State()
@@ -245,25 +245,25 @@ class TestFlow:
fm.handle_request(r)
for i in s.view:
- assert not i.request.acked
+ assert not i.request.reply.acked
s.killall(fm)
for i in s.view:
- assert i.request.acked
+ assert i.request.reply.acked
def test_accept_intercept(self):
f = tutils.tflow()
f.request = tutils.treq()
f.intercept()
- assert not f.request.acked
+ assert not f.request.reply.acked
f.accept_intercept()
- assert f.request.acked
+ assert f.request.reply.acked
f.response = tutils.tresp()
f.request = f.response.request
f.intercept()
- f.request._ack()
- assert not f.response.acked
+ f.request.reply()
+ assert not f.response.reply.acked
f.accept_intercept()
- assert f.response.acked
+ assert f.response.reply.acked
def test_serialization(self):
f = flow.Flow(None)
@@ -562,9 +562,11 @@ class TestFlowMaster:
fm.handle_response(resp)
assert fm.script.ns["log"][-1] == "response"
dc = flow.ClientDisconnect(req.client_conn)
+ dc.reply = controller.DummyReply()
fm.handle_clientdisconnect(dc)
assert fm.script.ns["log"][-1] == "clientdisconnect"
err = flow.Error(f.request, "msg")
+ err.reply = controller.DummyReply()
fm.handle_error(err)
assert fm.script.ns["log"][-1] == "error"
@@ -598,10 +600,12 @@ class TestFlowMaster:
assert not fm.handle_response(rx)
dc = flow.ClientDisconnect(req.client_conn)
+ dc.reply = controller.DummyReply()
req.client_conn.requestcount = 1
fm.handle_clientdisconnect(dc)
err = flow.Error(f.request, "msg")
+ err.reply = controller.DummyReply()
fm.handle_error(err)
fm.load_script(tutils.test_data.path("scripts/a.py"))
@@ -621,7 +625,9 @@ class TestFlowMaster:
fm.tick(q)
assert fm.state.flow_count()
- fm.handle_error(flow.Error(f.request, "error"))
+ err = flow.Error(f.request, "error")
+ err.reply = controller.DummyReply()
+ fm.handle_error(err)
def test_server_playback(self):
controller.should_exit = False
diff --git a/test/tservers.py b/test/tservers.py
index 2966a436..4cbdc942 100644
--- a/test/tservers.py
+++ b/test/tservers.py
@@ -31,7 +31,7 @@ class TestMaster(flow.FlowMaster):
def handle(self, m):
flow.FlowMaster.handle(self, m)
- m._ack()
+ m.reply()
class ProxyThread(threading.Thread):
diff --git a/test/tutils.py b/test/tutils.py
index d5497bae..1a1c8724 100644
--- a/test/tutils.py
+++ b/test/tutils.py
@@ -1,15 +1,18 @@
import os, shutil, tempfile
from contextlib import contextmanager
-from libmproxy import flow, utils
+from libmproxy import flow, utils, controller
from netlib import certutils
-
+import mock
def treq(conn=None):
if not conn:
conn = flow.ClientConnect(("address", 22))
+ conn.reply = controller.DummyReply()
headers = flow.ODictCaseless()
headers["header"] = ["qvalue"]
- return flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content")
+ r = flow.Request(conn, (1, 1), "host", 80, "http", "GET", "/path", headers, "content")
+ r.reply = controller.DummyReply()
+ return r
def tresp(req=None):
@@ -18,7 +21,9 @@ def tresp(req=None):
headers = flow.ODictCaseless()
headers["header_response"] = ["svalue"]
cert = certutils.SSLCert.from_der(file(test_data.path("data/dercert")).read())
- return flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert)
+ resp = flow.Response(req, (1, 1), 200, "message", headers, "content_response", cert)
+ resp.reply = controller.DummyReply()
+ return resp
def tflow():
@@ -37,9 +42,11 @@ def tflow_err():
r = treq()
f = flow.Flow(r)
f.error = flow.Error(r, "error")
+ f.error.reply = controller.DummyReply()
return f
+
@contextmanager
def tmpdir(*args, **kwargs):
orig_workdir = os.getcwd()