From d998790c2f12594e6d0edc5a98e908677b60b31f Mon Sep 17 00:00:00 2001 From: Aldo Cortesi Date: Wed, 17 Sep 2014 11:35:14 +1200 Subject: Clean up and clarify StateObject - Flatten the class hierarchy - get_state, load_state, from_state are public - Simplify code - Remove __eq__ and __neq__. This fundamentally changes the semantics of inherited objects in a way that's not part of the core function of the class --- libmproxy/flow.py | 8 ++-- libmproxy/protocol/http.py | 16 +++---- libmproxy/protocol/primitives.py | 18 ++++---- libmproxy/protocol/tcp.py | 33 +++++++++++--- libmproxy/proxy/connection.py | 30 ++++++------- libmproxy/stateobject.py | 96 ++++++++++++---------------------------- libmproxy/web/__init__.py | 4 +- 7 files changed, 92 insertions(+), 113 deletions(-) (limited to 'libmproxy') diff --git a/libmproxy/flow.py b/libmproxy/flow.py index b9095a02..4dc3a272 100644 --- a/libmproxy/flow.py +++ b/libmproxy/flow.py @@ -561,7 +561,7 @@ class FlowMaster(controller.Master): rflow = self.server_playback.next_flow(flow) if not rflow: return None - response = http.HTTPResponse._from_state(rflow.response._get_state()) + response = http.HTTPResponse.from_state(rflow.response.get_state()) response.is_replay = True if self.refresh_server_playback: response.refresh() @@ -740,7 +740,7 @@ class FlowWriter: self.fo = fo def add(self, flow): - d = flow._get_state() + d = flow.get_state() tnetstring.dump(d, self.fo) @@ -766,7 +766,7 @@ class FlowReader: v = ".".join(str(i) for i in data["version"]) raise FlowReadError("Incompatible serialized data version: %s"%v) off = self.fo.tell() - yield handle.protocols[data["conntype"]]["flow"]._from_state(data) + yield handle.protocols[data["conntype"]]["flow"].from_state(data) except ValueError, v: # Error is due to EOF if self.fo.tell() == off and self.fo.read() == '': @@ -782,5 +782,5 @@ class FilteredFlowWriter: def add(self, f): if self.filt and not f.match(self.filt): return - d = f._get_state() + d = f.get_state() tnetstring.dump(d, self.fo) diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 46da7a2b..1f3d6fdf 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -85,7 +85,7 @@ class decoded(object): self.o.encode(self.ce) -class HTTPMessage(stateobject.SimpleStateObject): +class HTTPMessage(stateobject.StateObject): """ Base class for HTTPRequest and HTTPResponse """ @@ -275,9 +275,9 @@ class HTTPRequest(HTTPMessage): ) @classmethod - def _from_state(cls, state): + def from_state(cls, state): f = cls(None, None, None, None, None, None, None, None, None, None, None) - f._load_state(state) + f.load_state(state) return f def __repr__(self): @@ -626,9 +626,9 @@ class HTTPResponse(HTTPMessage): ) @classmethod - def _from_state(cls, state): + def from_state(cls, state): f = cls(None, None, None, None, None) - f._load_state(state) + f.load_state(state) return f def __repr__(self): @@ -814,9 +814,9 @@ class HTTPFlow(Flow): ) @classmethod - def _from_state(cls, state): + def from_state(cls, state): f = cls(None, None) - f._load_state(state) + f.load_state(state) return f def __repr__(self): @@ -1311,4 +1311,4 @@ class RequestReplayThread(threading.Thread): self.flow.error = Error(repr(v)) self.channel.ask("error", self.flow) finally: - r.form_out = form_out_backup \ No newline at end of file + r.form_out = form_out_backup diff --git a/libmproxy/protocol/primitives.py b/libmproxy/protocol/primitives.py index a8c5856c..77dc936d 100644 --- a/libmproxy/protocol/primitives.py +++ b/libmproxy/protocol/primitives.py @@ -8,7 +8,7 @@ from ..proxy.connection import ClientConnection, ServerConnection KILL = 0 # const for killed requests -class Error(stateobject.SimpleStateObject): +class Error(stateobject.StateObject): """ An Error. @@ -41,11 +41,11 @@ class Error(stateobject.SimpleStateObject): return self.msg @classmethod - def _from_state(cls, state): + def from_state(cls, state): # the default implementation assumes an empty constructor. Override # accordingly. f = cls(None) - f._load_state(state) + f.load_state(state) return f def copy(self): @@ -53,7 +53,7 @@ class Error(stateobject.SimpleStateObject): return c -class Flow(stateobject.SimpleStateObject): +class Flow(stateobject.StateObject): """ A Flow is a collection of objects representing a single transaction. This class is usually subclassed for each protocol, e.g. HTTPFlow. @@ -78,8 +78,8 @@ class Flow(stateobject.SimpleStateObject): conntype=str ) - def _get_state(self): - d = super(Flow, self)._get_state() + def get_state(self): + d = super(Flow, self).get_state() d.update(version=version.IVERSION) return d @@ -101,7 +101,7 @@ class Flow(stateobject.SimpleStateObject): Has this Flow been modified? """ if self._backup: - return self._backup != self._get_state() + return self._backup != self.get_state() else: return False @@ -111,14 +111,14 @@ class Flow(stateobject.SimpleStateObject): call to .revert(). """ if not self._backup: - self._backup = self._get_state() + self._backup = self.get_state() def revert(self): """ Revert to the last backed up state. """ if self._backup: - self._load_state(self._backup) + self.load_state(self._backup) self._backup = None diff --git a/libmproxy/protocol/tcp.py b/libmproxy/protocol/tcp.py index 990c502a..a56bf07b 100644 --- a/libmproxy/protocol/tcp.py +++ b/libmproxy/protocol/tcp.py @@ -1,8 +1,10 @@ from __future__ import absolute_import -import select, socket +import select +import socket from .primitives import ProtocolHandler from netlib.utils import cleanBin + class TCPHandler(ProtocolHandler): """ TCPHandler acts as a generic TCP forwarder. @@ -34,7 +36,9 @@ class TCPHandler(ProtocolHandler): closed = False if src.ssl_established: # Unfortunately, pyOpenSSL lacks a recv_into function. - contents = src.rfile.read(1) # We need to read a single byte before .pending() becomes usable + # We need to read a single byte before .pending() + # becomes usable + contents = src.rfile.read(1) contents += src.rfile.read(src.connection.pending()) if not contents: closed = True @@ -56,15 +60,30 @@ class TCPHandler(ProtocolHandler): continue if src.ssl_established or dst.ssl_established: - # if one of the peers is over SSL, we need to send bytes/strings - if not src.ssl_established: # only ssl to dst, i.e. we revc'd into buf but need bytes/string now. + # if one of the peers is over SSL, we need to send + # bytes/strings + if not src.ssl_established: + # only ssl to dst, i.e. we revc'd into buf but need + # bytes/string now. contents = buf[:size].tobytes() - self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(contents)), "debug") + self.c.log( + "%s %s\r\n%s" % ( + direction, dst_str, cleanBin(contents) + ), + "debug" + ) dst.connection.send(contents) else: # socket.socket.send supports raw bytearrays/memoryviews - self.c.log("%s %s\r\n%s" % (direction, dst_str, cleanBin(buf.tobytes())), "debug") + self.c.log( + "%s %s\r\n%s" % ( + direction, + dst_str, + cleanBin(buf.tobytes()) + ), + "debug" + ) dst.connection.send(buf[:size]) except socket.error as e: self.c.log("TCP connection closed unexpectedly.", "debug") - return \ No newline at end of file + return diff --git a/libmproxy/proxy/connection.py b/libmproxy/proxy/connection.py index de8e20d8..e0e94a2b 100644 --- a/libmproxy/proxy/connection.py +++ b/libmproxy/proxy/connection.py @@ -5,7 +5,7 @@ from netlib import tcp, certutils from .. import stateobject, utils -class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): +class ClientConnection(tcp.BaseHandler, stateobject.StateObject): def __init__(self, client_connection, address, server): if client_connection: # Eventually, this object is restored from state. We don't have a connection then. tcp.BaseHandler.__init__(self, client_connection, address, server) @@ -36,16 +36,16 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): timestamp_ssl_setup=float ) - def _get_state(self): - d = super(ClientConnection, self)._get_state() + def get_state(self): + d = super(ClientConnection, self).get_state() d.update( address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, clientcert=self.cert.to_pem() if self.clientcert else None ) return d - def _load_state(self, state): - super(ClientConnection, self)._load_state(state) + def load_state(self, state): + super(ClientConnection, self).load_state(state) self.address = tcp.Address(**state["address"]) if state["address"] else None self.clientcert = certutils.SSLCert.from_pem(state["clientcert"]) if state["clientcert"] else None @@ -57,9 +57,9 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): self.wfile.flush() @classmethod - def _from_state(cls, state): + def from_state(cls, state): f = cls(None, tuple(), None) - f._load_state(state) + f.load_state(state) return f def convert_to_ssl(self, *args, **kwargs): @@ -71,7 +71,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): self.timestamp_end = utils.timestamp() -class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): +class ServerConnection(tcp.TCPClient, stateobject.StateObject): def __init__(self, address): tcp.TCPClient.__init__(self, address) @@ -107,8 +107,8 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): sni=str ) - def _get_state(self): - d = super(ServerConnection, self)._get_state() + def get_state(self): + d = super(ServerConnection, self).get_state() d.update( address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, @@ -118,17 +118,17 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): ) return d - def _load_state(self, state): - super(ServerConnection, self)._load_state(state) + def load_state(self, state): + super(ServerConnection, self).load_state(state) self.address = tcp.Address(**state["address"]) if state["address"] else None self.source_address = tcp.Address(**state["source_address"]) if state["source_address"] else None self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None @classmethod - def _from_state(cls, state): + def from_state(cls, state): f = cls(tuple()) - f._load_state(state) + f.load_state(state) return f def copy(self): @@ -154,4 +154,4 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): def finish(self): tcp.TCPClient.finish(self) - self.timestamp_end = utils.timestamp() \ No newline at end of file + self.timestamp_end = utils.timestamp() diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py index 9e9d6088..37b72c7e 100644 --- a/libmproxy/stateobject.py +++ b/libmproxy/stateobject.py @@ -2,82 +2,42 @@ from __future__ import absolute_import class StateObject(object): - def _get_state(self): - raise NotImplementedError # pragma: nocover - - def _load_state(self, state): - raise NotImplementedError # pragma: nocover - - @classmethod - def _from_state(cls, state): - raise NotImplementedError # pragma: nocover - # Usually, this function roughly equals to the following code: - # f = cls() - # f._load_state(state) - # return f - - def __eq__(self, other): - try: - return self._get_state() == other._get_state() - except AttributeError: - # we may compare with something that's not a StateObject - return False - - def __ne__(self, other): - return not self.__eq__(other) - - -class SimpleStateObject(StateObject): """ - A StateObject with opionated conventions that tries to keep everything DRY. + An object with serializable state. - Simply put, you agree on a list of attributes and their type. Attributes can - either be primitive types(str, tuple, bool, ...) or StateObject instances - themselves. SimpleStateObject uses this information for the default - _get_state(), _from_state(s) and _load_state(s) methods. Overriding - _get_state or _load_state to add custom adjustments is always possible. + State attributes can either be serializable types(str, tuple, bool, ...) + or StateObject instances themselves. """ - - _stateobject_attributes = None # none by default to raise an exception if definition was forgotten - """ - An attribute-name -> class-or-type dict containing all attributes that - should be serialized If the attribute is a class, this class must be a - subclass of StateObject. - """ - - def _get_state(self): - return {attr: self._get_state_attr(attr, cls) - for attr, cls in self._stateobject_attributes.iteritems()} + # An attribute-name -> class-or-type dict containing all attributes that + # should be serialized. If the attribute is a class, it must be a subclass + # of StateObject. + _stateobject_attributes = None def _get_state_attr(self, attr, cls): - """ - helper for _get_state. - returns the value of the given attribute - """ val = getattr(self, attr) - if hasattr(val, "_get_state"): - return val._get_state() + if hasattr(val, "get_state"): + return val.get_state() else: return val - def _load_state(self, state): - for attr, cls in self._stateobject_attributes.iteritems(): - self._load_state_attr(attr, cls, state) - - def _load_state_attr(self, attr, cls, state): - """ - helper for _load_state. - loads the given attribute from the state. - """ - if state.get(attr, None) is None: - setattr(self, attr, None) - return + def from_state(self): + raise NotImplementedError - curr = getattr(self, attr) - if hasattr(curr, "_load_state"): - curr._load_state(state[attr]) - elif hasattr(cls, "_from_state"): - setattr(self, attr, cls._from_state(state[attr])) - else: - setattr(self, attr, cls(state[attr])) + def get_state(self): + state = {} + for attr, cls in self._stateobject_attributes.iteritems(): + state[attr] = self._get_state_attr(attr, cls) + return state + def load_state(self, state): + for attr, cls in self._stateobject_attributes.iteritems(): + if state.get(attr, None) is None: + setattr(self, attr, None) + else: + curr = getattr(self, attr) + if hasattr(curr, "load_state"): + curr.load_state(state[attr]) + elif hasattr(cls, "from_state"): + setattr(self, attr, cls.from_state(state[attr])) + else: + setattr(self, attr, cls(state[attr])) diff --git a/libmproxy/web/__init__.py b/libmproxy/web/__init__.py index 044cb0cd..c2597861 100644 --- a/libmproxy/web/__init__.py +++ b/libmproxy/web/__init__.py @@ -1,8 +1,8 @@ - import tornado.ioloop import tornado.httpserver from .. import controller, utils, flow, script, proxy import app +import pprint class Stop(Exception): @@ -81,7 +81,7 @@ class WebMaster(flow.FlowMaster): self.shutdown() def handle_request(self, f): - print f + pprint.pprint(f.get_state()) flow.FlowMaster.handle_request(self, f) if f: f.reply() -- cgit v1.2.3