aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_protocol_http2.py
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2016-01-26 13:15:20 +0100
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2016-02-04 09:52:27 +0100
commit276817e40e99dbb2ddc7638839bd74e944fd704e (patch)
tree9599481a3c51a4602c11f32e55446e499eac8cf2 /test/test_protocol_http2.py
parent187691e65bf4a18de3567d6d801d78aa721b9fa5 (diff)
downloadmitmproxy-276817e40e99dbb2ddc7638839bd74e944fd704e.tar.gz
mitmproxy-276817e40e99dbb2ddc7638839bd74e944fd704e.tar.bz2
mitmproxy-276817e40e99dbb2ddc7638839bd74e944fd704e.zip
refactor http2 tests
Diffstat (limited to 'test/test_protocol_http2.py')
-rw-r--r--test/test_protocol_http2.py234
1 files changed, 149 insertions, 85 deletions
diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py
index 17687b45..b42b86cb 100644
--- a/test/test_protocol_http2.py
+++ b/test/test_protocol_http2.py
@@ -4,8 +4,16 @@ import inspect
import socket
import OpenSSL
import pytest
+import traceback
+import os
+import tempfile
+
from io import BytesIO
+from libmproxy.proxy.config import ProxyConfig
+from libmproxy.proxy.server import ProxyServer
+from libmproxy.cmdline import APP_HOST, APP_PORT
+
import logging
logging.getLogger("hyper.packages.hpack.hpack").setLevel(logging.WARNING)
logging.getLogger("requests.packages.urllib3.connectionpool").setLevel(logging.WARNING)
@@ -18,6 +26,7 @@ import netlib
from netlib import tservers as netlib_tservers
import h2
+from hyperframe.frame import Frame
from libmproxy import utils
from . import tservers
@@ -26,8 +35,7 @@ requires_alpn = pytest.mark.skipif(
not OpenSSL._util.lib.Cryptography_HAS_ALPN,
reason="requires OpenSSL with ALPN support")
-
-class SimpleHttp2Server(netlib_tservers.ServerTestBase):
+class _Http2ServerBase(netlib_tservers.ServerTestBase):
ssl = dict(alpn_select=b'h2')
class handler(netlib.tcp.BaseHandler):
@@ -41,78 +49,56 @@ class SimpleHttp2Server(netlib_tservers.ServerTestBase):
self.wfile.flush()
while True:
- events = h2_conn.receive_data(utils.http2_read_frame(self.rfile))
+ raw_frame = utils.http2_read_frame(self.rfile)
+ events = h2_conn.receive_data(raw_frame)
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
for event in events:
- if isinstance(event, h2.events.RequestReceived):
- h2_conn.send_headers(1, [
- (':status', '200'),
- ('foo', 'bar'),
- ])
- h2_conn.send_data(1, b'foobar')
- h2_conn.end_stream(1)
- self.wfile.write(h2_conn.data_to_send())
- self.wfile.flush()
- elif isinstance(event, h2.events.ConnectionTerminated):
- return
-
-
-class PushHttp2Server(netlib_tservers.ServerTestBase):
- ssl = dict(alpn_select=b'h2')
-
- class handler(netlib.tcp.BaseHandler):
- def handle(self):
- h2_conn = h2.connection.H2Connection(client_side=False)
-
- preamble = self.rfile.read(24)
- h2_conn.initiate_connection()
- h2_conn.receive_data(preamble)
- self.wfile.write(h2_conn.data_to_send())
- self.wfile.flush()
-
- while True:
- events = h2_conn.receive_data(utils.http2_read_frame(self.rfile))
- self.wfile.write(h2_conn.data_to_send())
- self.wfile.flush()
-
- for event in events:
- if isinstance(event, h2.events.RequestReceived):
- h2_conn.send_headers(1, [(':status', '200')])
- h2_conn.push_stream(1, 2, [
- (':authority', "127.0.0.1:%s" % self.address.port),
- (':method', 'GET'),
- (':scheme', 'https'),
- (':path', '/pushed_stream_foo'),
- ('foo', 'bar')
- ])
- h2_conn.push_stream(1, 4, [
- (':authority', "127.0.0.1:%s" % self.address.port),
- (':method', 'GET'),
- (':scheme', 'https'),
- (':path', '/pushed_stream_bar'),
- ('foo', 'bar')
- ])
- self.wfile.write(h2_conn.data_to_send())
- self.wfile.flush()
-
- h2_conn.send_headers(2, [(':status', '202')])
- h2_conn.send_headers(4, [(':status', '204')])
- h2_conn.send_data(1, b'regular_stream')
- h2_conn.send_data(2, b'pushed_stream_foo')
- h2_conn.send_data(4, b'pushed_stream_bar')
- h2_conn.end_stream(1)
- h2_conn.end_stream(2)
- h2_conn.end_stream(4)
- self.wfile.write(h2_conn.data_to_send())
- self.wfile.flush()
- elif isinstance(event, h2.events.ConnectionTerminated):
- return
+ try:
+ if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile):
+ break
+ except Exception as e:
+ print(repr(e))
+ print(traceback.format_exc())
+ break
+
+ def handle_server_event(self, h2_conn, rfile, wfile):
+ raise NotImplementedError()
+
+
+class _Http2TestBase(object):
+ @classmethod
+ def setup_class(self):
+ self.config = ProxyConfig(**self.get_proxy_config())
+
+ tmaster = tservers.TestMaster(self.config)
+ tmaster.start_app(APP_HOST, APP_PORT)
+ self.proxy = tservers.ProxyThread(tmaster)
+ self.proxy.start()
+
+ @classmethod
+ def teardown_class(cls):
+ cls.proxy.shutdown()
+
+ @property
+ def master(self):
+ return self.proxy.tmaster
+
+ @classmethod
+ def get_proxy_config(cls):
+ cls.cadir = os.path.join(tempfile.gettempdir(), "mitmproxy")
+ return dict(
+ no_upstream_cert = False,
+ cadir = cls.cadir,
+ authenticator = None,
+ )
+ def setup(self):
+ self.master.clear_log()
+ self.master.state.clear()
+ self.server.server.handle_server_event = self.handle_server_event
-@requires_alpn
-class TestHttp2(tservers.ProxTestBase):
def _setup_connection(self):
self.config.http2 = True
@@ -123,7 +109,7 @@ class TestHttp2(tservers.ProxTestBase):
client.wfile.write(
b"CONNECT localhost:%d HTTP/1.1\r\n"
b"Host: localhost:%d\r\n"
- b"\r\n" % (self.server.port, self.server.port)
+ b"\r\n" % (self.server.server.address.port, self.server.server.address.port)
)
client.wfile.flush()
@@ -149,14 +135,40 @@ class TestHttp2(tservers.ProxTestBase):
wfile.write(h2_conn.data_to_send())
wfile.flush()
- def test_simple(self):
- self.server = SimpleHttp2Server()
- self.server.setup_class()
+@requires_alpn
+class TestSimple(_Http2TestBase, _Http2ServerBase):
+ @classmethod
+ def setup_class(self):
+ _Http2TestBase.setup_class()
+ _Http2ServerBase.setup_class()
+
+ @classmethod
+ def teardown_class(self):
+ _Http2TestBase.teardown_class()
+ _Http2ServerBase.teardown_class()
+
+ @classmethod
+ def handle_server_event(self, event, h2_conn, rfile, wfile):
+ if isinstance(event, h2.events.ConnectionTerminated):
+ return False
+ elif isinstance(event, h2.events.RequestReceived):
+ h2_conn.send_headers(1, [
+ (':status', '200'),
+ ('foo', 'bar'),
+ ])
+ h2_conn.send_data(1, b'foobar')
+ h2_conn.end_stream(1)
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+
+ return True
+
+ def test_simple(self):
client, h2_conn = self._setup_connection()
self._send_request(client.wfile, h2_conn, headers=[
- (':authority', "127.0.0.1:%s" % self.server.port),
+ (':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -176,21 +188,69 @@ class TestHttp2(tservers.ProxTestBase):
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
- self.server.teardown_class()
-
assert len(self.master.state.flows) == 1
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['foo'] == 'bar'
assert self.master.state.flows[0].response.body == b'foobar'
- def test_pushed_streams(self):
- self.server = PushHttp2Server()
- self.server.setup_class()
+@requires_alpn
+class TestPushPromise(_Http2TestBase, _Http2ServerBase):
+ @classmethod
+ def setup_class(self):
+ _Http2TestBase.setup_class()
+ _Http2ServerBase.setup_class()
+
+ @classmethod
+ def teardown_class(self):
+ _Http2TestBase.teardown_class()
+ _Http2ServerBase.teardown_class()
+
+ @classmethod
+ def handle_server_event(self, event, h2_conn, rfile, wfile):
+ if isinstance(event, h2.events.ConnectionTerminated):
+ return False
+ elif isinstance(event, h2.events.RequestReceived):
+ if event.stream_id != 1:
+ # ignore requests initiated by push promises
+ return True
+
+ h2_conn.send_headers(1, [(':status', '200')])
+ h2_conn.push_stream(1, 2, [
+ (':authority', "127.0.0.1:%s" % self.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/pushed_stream_foo'),
+ ('foo', 'bar')
+ ])
+ h2_conn.push_stream(1, 4, [
+ (':authority', "127.0.0.1:%s" % self.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/pushed_stream_bar'),
+ ('foo', 'bar')
+ ])
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+
+ h2_conn.send_headers(2, [(':status', '202')])
+ h2_conn.send_headers(4, [(':status', '204')])
+ h2_conn.send_data(1, b'regular_stream')
+ h2_conn.send_data(2, b'pushed_stream_foo')
+ h2_conn.send_data(4, b'pushed_stream_bar')
+ h2_conn.end_stream(1)
+ h2_conn.end_stream(2)
+ h2_conn.end_stream(4)
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+
+ return True
+
+ def test_push_promise(self):
client, h2_conn = self._setup_connection()
- self._send_request(client.wfile, h2_conn, headers=[
- (':authority', "127.0.0.1:%s" % self.server.port),
+ self._send_request(client.wfile, h2_conn, stream_id=1, headers=[
+ (':authority', "127.0.0.1:%s" % self.server.server.address.port),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
@@ -198,6 +258,7 @@ class TestHttp2(tservers.ProxTestBase):
])
ended_streams = 0
+ pushed_streams = 0
while ended_streams != 3:
try:
events = h2_conn.receive_data(utils.http2_read_frame(client.rfile))
@@ -209,10 +270,13 @@ class TestHttp2(tservers.ProxTestBase):
for event in events:
if isinstance(event, h2.events.StreamEnded):
ended_streams += 1
+ elif isinstance(event, h2.events.PushedStreamReceived):
+ pushed_streams += 1
- self.server.teardown_class()
+ assert pushed_streams == 2
- assert len(self.master.state.flows) == 3
- assert self.master.state.flows[0].response.body == b'regular_stream'
- assert self.master.state.flows[1].response.body == b'pushed_stream_foo'
- assert self.master.state.flows[2].response.body == b'pushed_stream_bar'
+ bodies = [flow.response.body for flow in self.master.state.flows]
+ assert len(bodies) == 3
+ assert b'regular_stream' in bodies
+ assert b'pushed_stream_foo' in bodies
+ assert b'pushed_stream_bar' in bodies