aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <Kriechi@users.noreply.github.com>2016-02-08 11:41:30 +0100
committerThomas Kriechbaumer <Kriechi@users.noreply.github.com>2016-02-08 11:41:30 +0100
commitec087a1960bdcfff5d8207a8090f35223b02fd49 (patch)
treeea0f24bb4d8d655a9f982ff5e5cd2c36d5c10f4e
parentcd744592f6dfebf9ba00ce8a35828b49fec1af5c (diff)
parentbdb763d9cff75eec4bb44d23bfc2ef6fa4871bcc (diff)
downloadmitmproxy-ec087a1960bdcfff5d8207a8090f35223b02fd49.tar.gz
mitmproxy-ec087a1960bdcfff5d8207a8090f35223b02fd49.tar.bz2
mitmproxy-ec087a1960bdcfff5d8207a8090f35223b02fd49.zip
Merge pull request #921 from mitmproxy/model-cleanup
Model Cleanup
-rw-r--r--libmproxy/flow_format_compat.py14
-rw-r--r--libmproxy/models/connections.py44
-rw-r--r--libmproxy/models/flow.py19
-rw-r--r--libmproxy/models/http.py132
-rw-r--r--libmproxy/stateobject.py55
-rw-r--r--libmproxy/version.py2
-rw-r--r--libmproxy/web/__init__.py5
-rw-r--r--libmproxy/web/app.py35
-rw-r--r--test/test_flow.py6
-rw-r--r--test/tutils.py16
10 files changed, 146 insertions, 182 deletions
diff --git a/libmproxy/flow_format_compat.py b/libmproxy/flow_format_compat.py
index 2b99b805..5af9b762 100644
--- a/libmproxy/flow_format_compat.py
+++ b/libmproxy/flow_format_compat.py
@@ -21,9 +21,23 @@ def convert_014_015(data):
return data
+def convert_015_016(data):
+ for m in ("request", "response"):
+ if "body" in data[m]:
+ data[m]["content"] = data[m].pop("body")
+ if "httpversion" in data[m]:
+ data[m]["http_version"] = data[m].pop("httpversion")
+ if "msg" in data["response"]:
+ data["response"]["reason"] = data["response"].pop("msg")
+ data["request"].pop("form_out", None)
+ data["version"] = (0, 16)
+ return data
+
+
converters = {
(0, 13): convert_013_014,
(0, 14): convert_014_015,
+ (0, 15): convert_015_016,
}
diff --git a/libmproxy/models/connections.py b/libmproxy/models/connections.py
index a45e1629..d5920256 100644
--- a/libmproxy/models/connections.py
+++ b/libmproxy/models/connections.py
@@ -42,28 +42,14 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
return self.ssl_established
_stateobject_attributes = dict(
+ address=tcp.Address,
+ clientcert=certutils.SSLCert,
ssl_established=bool,
timestamp_start=float,
timestamp_end=float,
timestamp_ssl_setup=float
)
- def get_state(self, short=False):
- d = super(ClientConnection, self).get_state(short)
- d.update(
- address=({
- "address": self.address(),
- "use_ipv6": self.address.use_ipv6} if self.address else {}),
- clientcert=self.cert.to_pem() if self.clientcert else None)
- return d
-
- def load_state(self, state):
- super(ClientConnection, self).load_state(state)
- self.address = tcp.Address(
- **state["address"]) if state["address"] else None
- self.clientcert = certutils.SSLCert.from_pem(
- state["clientcert"]) if state["clientcert"] else None
-
def copy(self):
return copy.copy(self)
@@ -76,7 +62,7 @@ class ClientConnection(tcp.BaseHandler, stateobject.StateObject):
@classmethod
def from_state(cls, state):
f = cls(None, tuple(), None)
- f.load_state(state)
+ f.set_state(state)
return f
def convert_to_ssl(self, *args, **kwargs):
@@ -130,33 +116,11 @@ class ServerConnection(tcp.TCPClient, stateobject.StateObject):
ssl_established=bool,
sni=str
)
- _stateobject_long_attributes = {"cert"}
-
- def get_state(self, short=False):
- d = super(ServerConnection, self).get_state(short)
- d.update(
- address=({"address": self.address(),
- "use_ipv6": self.address.use_ipv6} if self.address else {}),
- source_address=({"address": self.source_address(),
- "use_ipv6": self.source_address.use_ipv6} if self.source_address else None),
- cert=self.cert.to_pem() if self.cert else None
- )
- return d
-
- def load_state(self, state):
- super(ServerConnection, self).load_state(state)
-
- self.address = tcp.Address(
- **state["address"]) if state["address"] else None
- self.source_address = tcp.Address(
- **state["source_address"]) if state["source_address"] else None
- self.cert = certutils.SSLCert.from_pem(
- state["cert"]) if state["cert"] else None
@classmethod
def from_state(cls, state):
f = cls(tuple())
- f.load_state(state)
+ f.set_state(state)
return f
def copy(self):
diff --git a/libmproxy/models/flow.py b/libmproxy/models/flow.py
index b4e8cb88..10255dad 100644
--- a/libmproxy/models/flow.py
+++ b/libmproxy/models/flow.py
@@ -45,7 +45,7 @@ class Error(stateobject.StateObject):
# the default implementation assumes an empty constructor. Override
# accordingly.
f = cls(None)
- f.load_state(state)
+ f.set_state(state)
return f
def copy(self):
@@ -86,16 +86,19 @@ class Flow(stateobject.StateObject):
intercepted=bool
)
- def get_state(self, short=False):
- d = super(Flow, self).get_state(short)
+ def get_state(self):
+ d = super(Flow, self).get_state()
d.update(version=version.IVERSION)
if self._backup and self._backup != d:
- if short:
- d.update(modified=True)
- else:
- d.update(backup=self._backup)
+ d.update(backup=self._backup)
return d
+ def set_state(self, state):
+ state.pop("version")
+ if "backup" in state:
+ self._backup = state.pop("backup")
+ super(Flow, self).set_state(state)
+
def __eq__(self, other):
return self is other
@@ -133,7 +136,7 @@ class Flow(stateobject.StateObject):
Revert to the last backed up state.
"""
if self._backup:
- self.load_state(self._backup)
+ self.set_state(self._backup)
self._backup = None
def kill(self, master):
diff --git a/libmproxy/models/http.py b/libmproxy/models/http.py
index d3919adf..3c024e76 100644
--- a/libmproxy/models/http.py
+++ b/libmproxy/models/http.py
@@ -1,41 +1,20 @@
from __future__ import (absolute_import, print_function, division)
import Cookie
import copy
+import warnings
from email.utils import parsedate_tz, formatdate, mktime_tz
import time
from libmproxy import utils
from netlib import encoding
-from netlib.http import status_codes, Headers, Request, Response, CONTENT_MISSING, decoded
+from netlib.http import status_codes, Headers, Request, Response, decoded
from netlib.tcp import Address
-from .. import version, stateobject
+from .. import version
from .flow import Flow
-from collections import OrderedDict
-
-class MessageMixin(stateobject.StateObject):
- # The restoration order is important currently, e.g. because
- # of .content setting .headers["content-length"] automatically.
- # Using OrderedDict is the short term fix, restoring state should
- # be implemented without side-effects again.
- _stateobject_attributes = OrderedDict(
- http_version=bytes,
- headers=Headers,
- timestamp_start=float,
- timestamp_end=float
- )
- _stateobject_long_attributes = {"body"}
-
- def get_state(self, short=False):
- ret = super(MessageMixin, self).get_state(short)
- if short:
- if self.content:
- ret["contentLength"] = len(self.content)
- elif self.content == CONTENT_MISSING:
- ret["contentLength"] = None
- else:
- ret["contentLength"] = 0
- return ret
+
+
+class MessageMixin(object):
def get_decoded_content(self):
"""
@@ -141,6 +120,9 @@ class HTTPRequest(MessageMixin, Request):
timestamp_start=None,
timestamp_end=None,
form_out=None,
+ is_replay=False,
+ stickycookie=False,
+ stickyauth=False,
):
Request.__init__(
self,
@@ -159,51 +141,26 @@ class HTTPRequest(MessageMixin, Request):
self.form_out = form_out or first_line_format # FIXME remove
# Have this request's cookies been modified by sticky cookies or auth?
- self.stickycookie = False
- self.stickyauth = False
+ self.stickycookie = stickycookie
+ self.stickyauth = stickyauth
# Is this request replayed?
- self.is_replay = False
-
- _stateobject_attributes = MessageMixin._stateobject_attributes.copy()
- _stateobject_attributes.update(
- content=bytes,
- first_line_format=str,
- method=bytes,
- scheme=bytes,
- host=bytes,
- port=int,
- path=bytes,
- form_out=str,
- is_replay=bool
- )
-
- @classmethod
- def from_state(cls, state):
- f = cls(
- None,
- b"",
- None,
- None,
- None,
- None,
- None,
- None,
- None,
- None,
- None)
- f.load_state(state)
- return f
+ self.is_replay = is_replay
+
+ def get_state(self):
+ state = super(HTTPRequest, self).get_state()
+ state.update(
+ stickycookie = self.stickycookie,
+ stickyauth = self.stickyauth,
+ is_replay = self.is_replay,
+ )
+ return state
- @classmethod
- def from_protocol(
- self,
- protocol,
- *args,
- **kwargs
- ):
- req = protocol.read_request(*args, **kwargs)
- return self.wrap(req)
+ def set_state(self, state):
+ self.stickycookie = state.pop("stickycookie")
+ self.stickyauth = state.pop("stickyauth")
+ self.is_replay = state.pop("is_replay")
+ super(HTTPRequest, self).set_state(state)
@classmethod
def wrap(self, request):
@@ -223,6 +180,15 @@ class HTTPRequest(MessageMixin, Request):
)
return req
+ @property
+ def form_out(self):
+ warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
+ return self.first_line_format
+
+ @form_out.setter
+ def form_out(self, value):
+ warnings.warn(".form_out is deprecated, use .first_line_format instead.", DeprecationWarning)
+
def __hash__(self):
return id(self)
@@ -275,6 +241,7 @@ class HTTPResponse(MessageMixin, Response):
content,
timestamp_start=None,
timestamp_end=None,
+ is_replay = False
):
Response.__init__(
self,
@@ -288,32 +255,9 @@ class HTTPResponse(MessageMixin, Response):
)
# Is this request replayed?
- self.is_replay = False
+ self.is_replay = is_replay
self.stream = False
- _stateobject_attributes = MessageMixin._stateobject_attributes.copy()
- _stateobject_attributes.update(
- body=bytes,
- status_code=int,
- msg=bytes
- )
-
- @classmethod
- def from_state(cls, state):
- f = cls(None, None, None, None, None)
- f.load_state(state)
- return f
-
- @classmethod
- def from_protocol(
- self,
- protocol,
- *args,
- **kwargs
- ):
- resp = protocol.read_response(*args, **kwargs)
- return self.wrap(resp)
-
@classmethod
def wrap(self, response):
resp = HTTPResponse(
@@ -424,7 +368,7 @@ class HTTPFlow(Flow):
@classmethod
def from_state(cls, state):
f = cls(None, None)
- f.load_state(state)
+ f.set_state(state)
return f
def __repr__(self):
diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py
index 52a8347f..9600ab09 100644
--- a/libmproxy/stateobject.py
+++ b/libmproxy/stateobject.py
@@ -1,52 +1,51 @@
from __future__ import absolute_import
+from netlib.utils import Serializable
-class StateObject(object):
-
+class StateObject(Serializable):
"""
- An object with serializable state.
+ An object with serializable state.
- State attributes can either be serializable types(str, tuple, bool, ...)
- or StateObject instances themselves.
+ State attributes can either be serializable types(str, tuple, bool, ...)
+ or StateObject instances themselves.
"""
- # An attribute-name -> class-or-type dict containing all attributes that
- # should be serialized. If the attribute is a class, it must implement the
- # StateObject protocol.
- _stateobject_attributes = None
- # A set() of attributes that should be ignored for short state
- _stateobject_long_attributes = frozenset([])
- def from_state(self, state):
- raise NotImplementedError()
+ _stateobject_attributes = None
+ """
+ An attribute-name -> class-or-type dict containing all attributes that
+ should be serialized. If the attribute is a class, it must implement the
+ Serializable protocol.
+ """
- def get_state(self, short=False):
+ def get_state(self):
"""
- Retrieve object state. If short is true, return an abbreviated
- format with long data elided.
+ Retrieve object state.
"""
state = {}
for attr, cls in self._stateobject_attributes.iteritems():
- if short and attr in self._stateobject_long_attributes:
- continue
val = getattr(self, attr)
if hasattr(val, "get_state"):
- state[attr] = val.get_state(short)
+ state[attr] = val.get_state()
else:
state[attr] = val
return state
- def load_state(self, state):
+ def set_state(self, state):
"""
- Load object state from data returned by a get_state call.
+ Load object state from data returned by a get_state call.
"""
+ state = state.copy()
for attr, cls in self._stateobject_attributes.iteritems():
- if state.get(attr, None) is None:
- setattr(self, attr, None)
+ if state.get(attr) is None:
+ setattr(self, attr, state.pop(attr))
else:
curr = getattr(self, attr)
- if hasattr(curr, "load_state"):
- curr.load_state(state[attr])
+ if hasattr(curr, "set_state"):
+ curr.set_state(state.pop(attr))
elif hasattr(cls, "from_state"):
- setattr(self, attr, cls.from_state(state[attr]))
- else:
- setattr(self, attr, cls(state[attr]))
+ obj = cls.from_state(state.pop(attr))
+ setattr(self, attr, obj)
+ else: # primitive types such as int, str, ...
+ setattr(self, attr, cls(state.pop(attr)))
+ if state:
+ raise RuntimeWarning("Unexpected State in __setstate__: {}".format(state))
diff --git a/libmproxy/version.py b/libmproxy/version.py
index 03c2f256..25c56706 100644
--- a/libmproxy/version.py
+++ b/libmproxy/version.py
@@ -1,6 +1,6 @@
from __future__ import (absolute_import, print_function, division)
-IVERSION = (0, 15)
+IVERSION = (0, 16)
VERSION = ".".join(str(i) for i in IVERSION)
MINORVERSION = ".".join(str(i) for i in IVERSION[:2])
NAME = "mitmproxy"
diff --git a/libmproxy/web/__init__.py b/libmproxy/web/__init__.py
index 43fc993d..c48b3d09 100644
--- a/libmproxy/web/__init__.py
+++ b/libmproxy/web/__init__.py
@@ -2,6 +2,7 @@ from __future__ import absolute_import, print_function
import collections
import tornado.ioloop
import tornado.httpserver
+
from .. import controller, flow
from . import app
@@ -20,7 +21,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast(
type="flows",
cmd="add",
- data=f.get_state(short=True)
+ data=app._strip_content(f.get_state())
)
def _update(self, f):
@@ -28,7 +29,7 @@ class WebFlowView(flow.FlowView):
app.ClientConnection.broadcast(
type="flows",
cmd="update",
- data=f.get_state(short=True)
+ data=app._strip_content(f.get_state())
)
def _remove(self, f):
diff --git a/libmproxy/web/app.py b/libmproxy/web/app.py
index 79f76013..958b8669 100644
--- a/libmproxy/web/app.py
+++ b/libmproxy/web/app.py
@@ -4,9 +4,38 @@ import tornado.web
import tornado.websocket
import logging
import json
+
+from netlib.http import CONTENT_MISSING
from .. import version, filt
+def _strip_content(flow_state):
+ """
+ Remove flow message content and cert to save transmission space.
+
+ Args:
+ flow_state: The original flow state. Will be left unmodified
+ """
+ for attr in ("request", "response"):
+ if attr in flow_state:
+ message = flow_state[attr]
+ if message["content"]:
+ message["contentLength"] = len(message["content"])
+ elif message["content"] == CONTENT_MISSING:
+ message["contentLength"] = None
+ else:
+ message["contentLength"] = 0
+ del message["content"]
+
+ if "backup" in flow_state:
+ del flow_state["backup"]
+ flow_state["modified"] = True
+
+ flow_state.get("server_conn", {}).pop("cert", None)
+
+ return flow_state
+
+
class APIError(tornado.web.HTTPError):
pass
@@ -100,7 +129,7 @@ class Flows(RequestHandler):
def get(self):
self.write(dict(
- data=[f.get_state(short=True) for f in self.state.flows]
+ data=[_strip_content(f.get_state()) for f in self.state.flows]
))
@@ -141,7 +170,7 @@ class FlowHandler(RequestHandler):
elif k == "port":
request.port = int(v)
elif k == "headers":
- request.headers.load_state(v)
+ request.headers.set_state(v)
else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
@@ -155,7 +184,7 @@ class FlowHandler(RequestHandler):
elif k == "http_version":
response.http_version = str(v)
elif k == "headers":
- response.headers.load_state(v)
+ response.headers.set_state(v)
else:
print "Warning: Unknown update {}.{}: {}".format(a, k, v)
else:
diff --git a/test/test_flow.py b/test/test_flow.py
index 68316f2a..51b88fff 100644
--- a/test/test_flow.py
+++ b/test/test_flow.py
@@ -422,7 +422,7 @@ class TestFlow(object):
assert not f == f2
f2.error = Error("e2")
assert not f == f2
- f.load_state(f2.get_state())
+ f.set_state(f2.get_state())
assert f.get_state() == f2.get_state()
def test_kill(self):
@@ -1204,7 +1204,7 @@ class TestError:
e2 = Error("bar")
assert not e == e2
- e.load_state(e2.get_state())
+ e.set_state(e2.get_state())
assert e.get_state() == e2.get_state()
e3 = e.copy()
@@ -1224,7 +1224,7 @@ class TestClientConnection:
assert not c == c2
c2.timestamp_start = 42
- c.load_state(c2.get_state())
+ c.set_state(c2.get_state())
assert c.timestamp_start == 42
c3 = c.copy()
diff --git a/test/tutils.py b/test/tutils.py
index 5bd91307..2ce0884d 100644
--- a/test/tutils.py
+++ b/test/tutils.py
@@ -76,7 +76,11 @@ def tclient_conn():
"""
c = ClientConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
- clientcert=None
+ clientcert=None,
+ ssl_established=False,
+ timestamp_start=1,
+ timestamp_ssl_setup=2,
+ timestamp_end=3,
))
c.reply = controller.DummyReply()
return c
@@ -88,9 +92,15 @@ def tserver_conn():
"""
c = ServerConnection.from_state(dict(
address=dict(address=("address", 22), use_ipv6=True),
- state=[],
source_address=dict(address=("address", 22), use_ipv6=True),
- cert=None
+ cert=None,
+ timestamp_start=1,
+ timestamp_tcp_setup=2,
+ timestamp_ssl_setup=3,
+ timestamp_end=4,
+ ssl_established=False,
+ sni="address",
+ via=None
))
c.reply = controller.DummyReply()
return c