From 0299bb5b2e4870363ba0c402c6cf15722ca0ee0f Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Thu, 9 Feb 2017 11:56:38 +0100 Subject: eventsequence: coverage++ --- test/mitmproxy/data/addonscripts/recorder.py | 4 +- test/mitmproxy/test_eventsequence.py | 138 +++++++++++---------------- 2 files changed, 59 insertions(+), 83 deletions(-) (limited to 'test') diff --git a/test/mitmproxy/data/addonscripts/recorder.py b/test/mitmproxy/data/addonscripts/recorder.py index 5be88e5c..6b9b6ea8 100644 --- a/test/mitmproxy/data/addonscripts/recorder.py +++ b/test/mitmproxy/data/addonscripts/recorder.py @@ -1,5 +1,5 @@ from mitmproxy import controller -from mitmproxy import events +from mitmproxy import eventsequence from mitmproxy import ctx import sys @@ -11,7 +11,7 @@ class CallLogger: self.name = name def __getattr__(self, attr): - if attr in events.Events: + if attr in eventsequence.Events: def prox(*args, **kwargs): lg = (self.name, attr, args, kwargs) if attr != "log": diff --git a/test/mitmproxy/test_eventsequence.py b/test/mitmproxy/test_eventsequence.py index 262df4b0..6e254225 100644 --- a/test/mitmproxy/test_eventsequence.py +++ b/test/mitmproxy/test_eventsequence.py @@ -1,81 +1,57 @@ -from mitmproxy import events -import contextlib -from . import tservers - - -class Eventer: - def __init__(self, **handlers): - self.failure = None - self.called = [] - self.handlers = handlers - for i in events.Events - {"tick"}: - def mkprox(): - evt = i - - def prox(*args, **kwargs): - self.called.append(evt) - if evt in self.handlers: - try: - handlers[evt](*args, **kwargs) - except AssertionError as e: - self.failure = e - return prox - setattr(self, i, mkprox()) - - def fail(self): - pass - - -class SequenceTester: - @contextlib.contextmanager - def addon(self, addon): - self.master.addons.add(addon) - yield - self.master.addons.remove(addon) - if addon.failure: - raise addon.failure - - -class TestBasic(tservers.HTTPProxyTest, SequenceTester): - ssl = True - - def test_requestheaders(self): - - def hdrs(f): - assert f.request.headers - assert not f.request.content - - def req(f): - assert f.request.headers - assert f.request.content - - with self.addon(Eventer(requestheaders=hdrs, request=req)): - p = self.pathoc() - with p.connect(): - assert p.request("get:'/p/200':b@10").status_code == 200 - - def test_100_continue_fail(self): - e = Eventer() - with self.addon(e): - p = self.pathoc() - with p.connect(): - p.request( - """ - get:'/p/200' - h'expect'='100-continue' - h'content-length'='1000' - da - """ - ) - assert "requestheaders" in e.called - assert "responseheaders" not in e.called - - def test_connect(self): - e = Eventer() - with self.addon(e): - p = self.pathoc() - with p.connect(): - p.request("get:'/p/200:b@1'") - assert "http_connect" in e.called - assert e.called.count("requestheaders") == 1 - assert e.called.count("request") == 1 +import pytest + +from mitmproxy import eventsequence +from mitmproxy.test import tflow + + +@pytest.mark.parametrize("resp, err", [ + (False, False), + (True, False), + (False, True), + (True, True), +]) +def test_http_flow(resp, err): + f = tflow.tflow(resp=resp, err=err) + i = eventsequence.iterate(f) + assert next(i) == ("requestheaders", f) + assert next(i) == ("request", f) + if resp: + assert next(i) == ("responseheaders", f) + assert next(i) == ("response", f) + if err: + assert next(i) == ("error", f) + + +@pytest.mark.parametrize("err", [False, True]) +def test_websocket_flow(err): + f = tflow.twebsocketflow(err=err) + i = eventsequence.iterate(f) + assert next(i) == ("websocket_start", f) + assert len(f.messages) == 0 + assert next(i) == ("websocket_message", f) + assert len(f.messages) == 1 + assert next(i) == ("websocket_message", f) + assert len(f.messages) == 2 + if err: + assert next(i) == ("websocket_error", f) + assert next(i) == ("websocket_end", f) + + +@pytest.mark.parametrize("err", [False, True]) +def test_tcp_flow(err): + f = tflow.ttcpflow(err=err) + i = eventsequence.iterate(f) + assert next(i) == ("tcp_start", f) + assert len(f.messages) == 0 + assert next(i) == ("tcp_message", f) + assert len(f.messages) == 1 + assert next(i) == ("tcp_message", f) + assert len(f.messages) == 2 + if err: + assert next(i) == ("tcp_error", f) + assert next(i) == ("tcp_end", f) + + +def test_invalid(): + with pytest.raises(ValueError): + next(eventsequence.iterate(42)) -- cgit v1.2.3