From 6ce1470631e843bef926d9fee5c2ad1f359dc0ac Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Fri, 31 Jan 2014 01:06:35 +0100 Subject: move StateObject back into libmproxy --- libmproxy/protocol/__init__.py | 2 +- libmproxy/protocol/http.py | 4 +-- libmproxy/proxy.py | 40 ++++++++++++++++++---- libmproxy/stateobject.py | 75 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 112 insertions(+), 9 deletions(-) create mode 100644 libmproxy/stateobject.py (limited to 'libmproxy') diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index 5e11e750..ae0d99a6 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -26,7 +26,7 @@ class ProtocolHandler(object): This method gets called should there be an uncaught exception during the connection. This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode. """ - raise NotImplementedError + raise error from . import http, tcp diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index a33962a6..11735ec0 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -1,10 +1,10 @@ import Cookie, urllib, urlparse, time, copy from email.utils import parsedate_tz, formatdate, mktime_tz import netlib.utils -from netlib import http, tcp, http_status, stateobject, odict +from netlib import http, tcp, http_status, odict from netlib.odict import ODict, ODictCaseless from . import ProtocolHandler, ConnectionTypeChange, KILL -from .. import encoding, utils, version, filt, controller +from .. import encoding, utils, version, filt, controller, stateobject from ..proxy import ProxyError, ClientConnection, ServerConnection HDR_FORM_URLENCODED = "application/x-www-form-urlencoded" diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index 6d5dd236..82d7ecef 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -1,7 +1,7 @@ import os, socket, time, threading from OpenSSL import SSL -from netlib import tcp, http, certutils, http_auth, stateobject -import utils, version, platform, controller +from netlib import tcp, http, certutils, http_auth +import utils, version, platform, controller, stateobject TRANSPARENT_SSL_PORTS = [443, 8443] @@ -36,7 +36,7 @@ class ProxyConfig: class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): def __init__(self, client_connection, address, server): - if client_connection: # Eventually, this object is restored from state + if client_connection: # Eventually, this object is restored from state. We don't have a connection then. tcp.BaseHandler.__init__(self, client_connection, address, server) else: self.address = None @@ -49,11 +49,22 @@ class ClientConnection(tcp.BaseHandler, stateobject.SimpleStateObject): _stateobject_attributes = dict( timestamp_start=float, timestamp_end=float, - timestamp_ssl_setup=float, - address=tcp.Address, - clientcert=certutils.SSLCert + timestamp_ssl_setup=float ) + def _get_state(self): + d = super(ClientConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + clientcert=self.cert.to_pem() if self.clientcert else None + ) + return d + + def _load_state(self, state): + super(ClientConnection, self)._load_state(state) + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.clientcert = certutils.SSLCert.from_pem(state["clientcert"]) if state["clientcert"] else None + @classmethod def _from_state(cls, state): f = cls(None, None, None) @@ -90,6 +101,23 @@ class ServerConnection(tcp.TCPClient, stateobject.SimpleStateObject): cert=certutils.SSLCert ) + def _get_state(self): + d = super(ServerConnection, self)._get_state() + d.update( + address={"address": self.address(), "use_ipv6": self.address.use_ipv6}, + source_address= {"address": self.source_address(), + "use_ipv6": self.source_address.use_ipv6} if self.source_address else None, + cert=self.cert.to_pem() if self.cert else None + ) + return d + + def _load_state(self, state): + super(ServerConnection, self)._load_state(state) + + self.address = tcp.Address(**state["address"]) if state["address"] else None + self.source_address = tcp.Address(**state["source_address"]) if state["source_address"] else None + self.cert = certutils.SSLCert.from_pem(state["cert"]) if state["cert"] else None + @classmethod def _from_state(cls, state): f = cls(None) diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py new file mode 100644 index 00000000..ef8879b8 --- /dev/null +++ b/libmproxy/stateobject.py @@ -0,0 +1,75 @@ +class StateObject: + def _get_state(self): + raise NotImplementedError + + def _load_state(self, state): + raise NotImplementedError + + @classmethod + def _from_state(cls, state): + raise NotImplementedError + + def __eq__(self, other): + try: + return self._get_state() == other._get_state() + except AttributeError: # we may compare with something that's not a StateObject + return False + + +class SimpleStateObject(StateObject): + """ + A StateObject with opionated conventions that tries to keep everything DRY. + + Simply put, you agree on a list of attributes and their type. + Attributes can either be primitive types(str, tuple, bool, ...) or StateObject instances themselves. + SimpleStateObject uses this information for the default _get_state(), _from_state(s) and _load_state(s) methods. + Overriding _get_state or _load_state to add custom adjustments is always possible. + """ + + _stateobject_attributes = None # none by default to raise an exception if definition was forgotten + """ + An attribute-name -> class-or-type dict containing all attributes that should be serialized + If the attribute is a class, this class must be a subclass of StateObject. + """ + + def _get_state(self): + return {attr: self._get_state_attr(attr, cls) + for attr, cls in self._stateobject_attributes.iteritems()} + + def _get_state_attr(self, attr, cls): + """ + helper for _get_state. + returns the value of the given attribute + """ + val = getattr(self, attr) + if hasattr(val, "_get_state"): + return val._get_state() + else: + return val + + def _load_state(self, state): + for attr, cls in self._stateobject_attributes.iteritems(): + self._load_state_attr(attr, cls, state) + + def _load_state_attr(self, attr, cls, state): + """ + helper for _load_state. + loads the given attribute from the state. + """ + if state[attr] is None: + setattr(self, attr, None) + return + + curr = getattr(self, attr) + if hasattr(curr, "_load_state"): + curr._load_state(state[attr]) + elif hasattr(cls, "_from_state"): + setattr(self, attr, cls._from_state(state[attr])) + else: + setattr(self, attr, cls(state[attr])) + + @classmethod + def _from_state(cls, state): + f = cls() # the default implementation assumes an empty constructor. Override accordingly. + f._load_state(state) + return f \ No newline at end of file -- cgit v1.2.3