From bdb763d9cff75eec4bb44d23bfc2ef6fa4871bcc Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Mon, 8 Feb 2016 04:19:25 +0100 Subject: make stateobject simpler and stricter --- libmproxy/models/connections.py | 43 +++---------------------- libmproxy/models/flow.py | 10 ++++-- libmproxy/models/http.py | 71 ++++++++++++++++++----------------------- 3 files changed, 43 insertions(+), 81 deletions(-) (limited to 'libmproxy/models') diff --git a/libmproxy/models/connections.py b/libmproxy/models/connections.py index 1d7c980e..d5920256 100644 --- a/libmproxy/models/connections.py +++ b/libmproxy/models/connections.py @@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): return self.ssl_established _stateobject_attributes = dict( + address=tcp.Address, + clientcert=certutils.SSLCert, ssl_established=bool, timestamp_start=float, timestamp_end=float, timestamp_ssl_setup=float ) - def get_state(self): - d = super(ClientConnection, self).get_state() - d.update( - address=({ - "address": self.address(), - "use_ipv6": self.address.use_ipv6} if self.address else {}), - clientcert=self.cert.to_pem() if self.clientcert else None) - return d - - def load_state(self, state): - super(ClientConnection, self).load_state(state) - self.address = tcp.Address( - **state["address"]) if state["address"] else None - self.clientcert = certutils.SSLCert.from_pem( - state["clientcert"]) if state["clientcert"] else None - def copy(self): return copy.copy(self) @@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject): @classmethod def from_state(cls, state): f = cls(None, tuple(), None) - f.load_state(state) + f.set_state(state) return f def convert_to_ssl(self, *args, **kwargs): @@ -131,31 +117,10 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject): sni=str ) - def get_state(self): - d = super(ServerConnection, self).get_state() - d.update( - address=({"address": self.address(), - "use_ipv6": self.address.use_ipv6} if self.address else {}), - source_address=({"address": self.source_address(), - "use_ipv6": self.source_address.use_ipv6} if self.source_address else None), - cert=self.cert.to_pem() if self.cert else None - ) - return d - - def load_state(self, state): - super(ServerConnection, self).load_state(state) - - self.address = tcp.Address( - **state["address"]) if state["address"] else None - self.source_address = tcp.Address( - **state["source_address"]) if state["source_address"] else None - self.cert = certutils.SSLCert.from_pem( - state["cert"]) if state["cert"] else None - @classmethod def from_state(cls, state): f = cls(tuple()) - f.load_state(state) + f.set_state(state) return f def copy(self): diff --git a/libmproxy/models/flow.py b/libmproxy/models/flow.py index 0ba374c9..10255dad 100644 --- a/libmproxy/models/flow.py +++ b/libmproxy/models/flow.py @@ -45,7 +45,7 @@ class Error(stateobject.StateObject): # the default implementation assumes an empty constructor. Override # accordingly. f = cls(None) - f.load_state(state) + f.set_state(state) return f def copy(self): @@ -93,6 +93,12 @@ class Flow(stateobject.StateObject): d.update(backup=self._backup) return d + def set_state(self, state): + state.pop("version") + if "backup" in state: + self._backup = state.pop("backup") + super(Flow, self).set_state(state) + def __eq__(self, other): return self is other @@ -130,7 +136,7 @@ class Flow(stateobject.StateObject): Revert to the last backed up state. """ if self._backup: - self.load_state(self._backup) + self.set_state(self._backup) self._backup = None def kill(self, master): diff --git a/libmproxy/models/http.py b/libmproxy/models/http.py index 730b007d..3c024e76 100644 --- a/libmproxy/models/http.py +++ b/libmproxy/models/http.py @@ -1,6 +1,7 @@ from __future__ import (absolute_import, print_function, division) import Cookie import copy +import warnings from email.utils import parsedate_tz, formatdate, mktime_tz import time @@ -8,28 +9,12 @@ from libmproxy import utils from netlib import encoding from netlib.http import status_codes, Headers, Request, Response, decoded from netlib.tcp import Address -from .. import version, stateobject +from .. import version from .flow import Flow -class MessageMixin(stateobject.StateObject): - - def get_state(self): - state = vars(self.data).copy() - state["headers"] = state["headers"].get_state() - return state - - def load_state(self, state): - for k, v in state.items(): - if k == "headers": - v = Headers.from_state(v) - setattr(self.data, k, v) - - @classmethod - def from_state(cls, state): - state["headers"] = Headers.from_state(state["headers"]) - return cls(**state) +class MessageMixin(object): def get_decoded_content(self): """ @@ -136,6 +121,8 @@ class HTTPRequest(MessageMixin, Request): timestamp_end=None, form_out=None, is_replay=False, + stickycookie=False, + stickyauth=False, ): Request.__init__( self, @@ -154,21 +141,26 @@ class HTTPRequest(MessageMixin, Request): self.form_out = form_out or first_line_format # FIXME remove # Have this request's cookies been modified by sticky cookies or auth? - self.stickycookie = False - self.stickyauth = False + self.stickycookie = stickycookie + self.stickyauth = stickyauth # Is this request replayed? self.is_replay = is_replay - @classmethod - def from_protocol( - self, - protocol, - *args, - **kwargs - ): - req = protocol.read_request(*args, **kwargs) - return self.wrap(req) + def get_state(self): + state = super(HTTPRequest, self).get_state() + state.update( + stickycookie = self.stickycookie, + stickyauth = self.stickyauth, + is_replay = self.is_replay, + ) + return state + + def set_state(self, state): + self.stickycookie = state.pop("stickycookie") + self.stickyauth = state.pop("stickyauth") + self.is_replay = state.pop("is_replay") + super(HTTPRequest, self).set_state(state) @classmethod def wrap(self, request): @@ -188,6 +180,15 @@ class HTTPRequest(MessageMixin, Request): ) return req + @property + def form_out(self): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + return self.first_line_format + + @form_out.setter + def form_out(self, value): + warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning) + def __hash__(self): return id(self) @@ -257,16 +258,6 @@ class HTTPResponse(MessageMixin, Response): self.is_replay = is_replay self.stream = False - @classmethod - def from_protocol( - self, - protocol, - *args, - **kwargs - ): - resp = protocol.read_response(*args, **kwargs) - return self.wrap(resp) - @classmethod def wrap(self, response): resp = HTTPResponse( @@ -377,7 +368,7 @@ class HTTPFlow(Flow): @classmethod def from_state(cls, state): f = cls(None, None) - f.load_state(state) + f.set_state(state) return f def __repr__(self): -- cgit v1.2.3