diff options
-rw-r--r-- | libmproxy/controller.py | 10 | ||||
-rw-r--r-- | libmproxy/flow.py | 45 | ||||
-rw-r--r-- | libmproxy/proxy.py | 46 |
3 files changed, 46 insertions, 55 deletions
diff --git a/libmproxy/controller.py b/libmproxy/controller.py index c36bb9df..849d998b 100644 --- a/libmproxy/controller.py +++ b/libmproxy/controller.py @@ -56,13 +56,14 @@ class Channel: def ask(self, m): """ - Send a message to the master, and wait for a response. + Decorate a message with a reply attribute, and send it to the + master. then wait for a response. """ m.reply = Reply(m) self.q.put(m) while not should_exit: try: - # The timeout is here so we can handle a should_exit event. + # The timeout is here so we can handle a should_exit event. g = m.reply.q.get(timeout=0.5) except Queue.Empty: continue @@ -70,9 +71,10 @@ class Channel: def tell(self, m): """ - Send a message to the master, and keep going. + Decorate a message with a dummy reply attribute, send it to the + master, then return immediately. """ - m.reply = None + m.reply = DummyReply() self.q.put(m) diff --git a/libmproxy/flow.py b/libmproxy/flow.py index 0f5fb563..883c7235 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -196,7 +196,15 @@ class decoded(object): self.o.encode(self.ce) -class HTTPMsg: +class StateObject: + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: + return False + + +class HTTPMsg(StateObject): def get_decoded_content(self): """ Returns the decoded content based on the current Content-Encoding header. @@ -388,13 +396,7 @@ class Request(HTTPMsg): def __hash__(self): return id(self) - def __eq__(self, other): - return self._get_state() == other._get_state() - def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) c.headers = self.headers.copy() return c @@ -698,13 +700,7 @@ class Response(HTTPMsg): state["timestamp_end"], ) - def __eq__(self, other): - return self._get_state() == other._get_state() - def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) c.headers = self.headers.copy() return c @@ -782,7 +778,7 @@ class ClientDisconnect: self.client_conn = client_conn -class ClientConnect: +class ClientConnect(StateObject): """ A single client connection. Each connection can result in multiple HTTP Requests. @@ -804,9 +800,6 @@ class ClientConnect: self.requestcount = 0 self.error = None - def __eq__(self, other): - return self._get_state() == other._get_state() - def _load_state(self, state): self.close = True self.error = state["error"] @@ -829,14 +822,10 @@ class ClientConnect: return None def copy(self): - """ - Returns a copy of this object. - """ - c = copy.copy(self) - return c + return copy.copy(self) -class Error: +class Error(StateObject): """ An Error. @@ -860,9 +849,6 @@ class Error: self.timestamp = state["timestamp"] def copy(self): - """ - Returns a copy of this object. - """ c = copy.copy(self) return c @@ -880,9 +866,6 @@ class Error: state["timestamp"], ) - def __eq__(self, other): - return self._get_state() == other._get_state() - def replace(self, pattern, repl, *args, **kwargs): """ Replaces a regular expression pattern with repl in both the headers @@ -1174,9 +1157,9 @@ class Flow: self.error = Error(self.request, "Connection killed") self.error.reply = controller.DummyReply() if self.request and not self.request.reply.acked: - self.request.reply(None) + self.request.reply(proxy.KILL) elif self.response and not self.response.reply.acked: - self.response.reply(None) + self.response.reply(proxy.KILL) master.handle_error(self.error) self.intercepting = False diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 1fbb6d58..6d476c7b 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -20,6 +20,8 @@ from netlib import odict, tcp, http, wsgi, certutils, http_status import utils, flow, version, platform, controller import authentication +KILL = 0 + class ProxyError(Exception): def __init__(self, code, msg, headers=None): @@ -149,7 +151,7 @@ class ProxyHandler(tcp.BaseHandler): [ "handled %s requests"%cc.requestcount] ) - self.channel.ask(cd) + self.channel.tell(cd) def handle_request(self, cc): try: @@ -166,15 +168,15 @@ class ProxyHandler(tcp.BaseHandler): self.log(cc, "Error in wsgi app.", err.split("\n")) return else: - request = self.channel.ask(request) - if request is None: + request_reply = self.channel.ask(request) + if request_reply == KILL: return - - if isinstance(request, flow.Response): - response = request + elif isinstance(request_reply, flow.Response): request = False - response = self.channel.ask(response) + response = request_reply + response_reply = self.channel.ask(response) else: + request = request_reply if self.config.reverse_proxy: scheme, host, port = self.config.reverse_proxy else: @@ -191,20 +193,24 @@ class ProxyHandler(tcp.BaseHandler): request, httpversion, code, msg, headers, content, sc.cert, sc.rfile.first_byte_timestamp, utils.timestamp() ) - response = self.channel.ask(response) - if response is None: + response_reply = self.channel.ask(response) + # Not replying to the server invalidates the server connection, so we terminate. + if response_reply == KILL: sc.terminate() - if response is None: - return - self.send_response(response) - if request and http.request_connection_close(request.httpversion, request.headers): - return - # We could keep the client connection when the server - # connection needs to go away. However, we want to mimic - # behaviour as closely as possible to the client, so we - # disconnect. - if http.response_connection_close(response.httpversion, response.headers): + + if response_reply == KILL: return + else: + response = response_reply + self.send_response(response) + if request and http.request_connection_close(request.httpversion, request.headers): + return + # We could keep the client connection when the server + # connection needs to go away. However, we want to mimic + # behaviour as closely as possible to the client, so we + # disconnect. + if http.response_connection_close(response.httpversion, response.headers): + return except (IOError, ProxyError, http.HttpError, tcp.NetLibDisconnect), e: if hasattr(e, "code"): cc.error = "%s: %s"%(e.code, e.msg) @@ -234,7 +240,7 @@ class ProxyHandler(tcp.BaseHandler): msg.append(" -> "+i) msg = "\n".join(msg) l = Log(msg) - self.channel.ask(l) + self.channel.tell(l) def find_cert(self, host, port, sni): if self.config.certfile: |