aboutsummaryrefslogtreecommitdiffstats
path: root/test
diff options
context:
space:
mode:
Diffstat (limited to 'test')
-rw-r--r--test/mitmproxy/test_protocol_http2.py74
-rw-r--r--test/netlib/tservers.py12
2 files changed, 80 insertions, 6 deletions
diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py
index 932c8df2..89bb16c6 100644
--- a/test/mitmproxy/test_protocol_http2.py
+++ b/test/mitmproxy/test_protocol_http2.py
@@ -3,9 +3,10 @@
from __future__ import (absolute_import, print_function, division)
import pytest
-import traceback
import os
+import traceback
import tempfile
+
import h2
from mitmproxy.proxy.config import ProxyConfig
@@ -46,6 +47,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
+ if 'h2_server_settings' in self.kwargs:
+ h2_conn.update_settings(self.kwargs['h2_server_settings'])
+ self.wfile.write(h2_conn.data_to_send())
+ self.wfile.flush()
+
done = False
while not done:
try:
@@ -508,3 +514,69 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase):
if len(self.master.state.flows) == 1:
assert self.master.state.flows[0].response is None
+
+
+@requires_alpn
+class TestMaxConcurrentStreams(_Http2TestBase, _Http2ServerBase):
+
+ @classmethod
+ def setup_class(self):
+ _Http2TestBase.setup_class()
+ _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2})
+
+ @classmethod
+ def teardown_class(self):
+ _Http2TestBase.teardown_class()
+ _Http2ServerBase.teardown_class()
+
+ @classmethod
+ def handle_server_event(self, event, h2_conn, rfile, wfile):
+ if isinstance(event, h2.events.ConnectionTerminated):
+ return False
+ elif isinstance(event, h2.events.RequestReceived):
+ h2_conn.send_headers(event.stream_id, [
+ (':status', '200'),
+ ('X-Stream-ID', str(event.stream_id)),
+ ])
+ h2_conn.send_data(event.stream_id, b'Stream-ID {}'.format(event.stream_id))
+ h2_conn.end_stream(event.stream_id)
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+ return True
+
+ def test_max_concurrent_streams(self):
+ client, h2_conn = self._setup_connection()
+ new_streams = [1, 3, 5, 7, 9, 11]
+ for id in new_streams:
+ # this will exceed MAX_CONCURRENT_STREAMS on the server connection
+ # and cause mitmproxy to throttle stream creation to the server
+ self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
+ (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/'),
+ ('X-Stream-ID', str(id)),
+ ])
+
+ ended_streams = 0
+ while ended_streams != len(new_streams):
+ try:
+ header, body = framereader.http2_read_raw_frame(client.rfile)
+ events = h2_conn.receive_data(b''.join([header, body]))
+ except:
+ break
+ client.wfile.write(h2_conn.data_to_send())
+ client.wfile.flush()
+
+ for event in events:
+ if isinstance(event, h2.events.StreamEnded):
+ ended_streams += 1
+
+ h2_conn.close_connection()
+ client.wfile.write(h2_conn.data_to_send())
+ client.wfile.flush()
+
+ assert len(self.master.state.flows) == len(new_streams)
+ for flow in self.master.state.flows:
+ assert flow.response.status_code == 200
+ assert "Stream-ID" in flow.response.body
diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py
index 803aaa72..666f97ac 100644
--- a/test/netlib/tservers.py
+++ b/test/netlib/tservers.py
@@ -24,7 +24,7 @@ class _ServerThread(threading.Thread):
class _TServer(tcp.TCPServer):
- def __init__(self, ssl, q, handler_klass, addr):
+ def __init__(self, ssl, q, handler_klass, addr, **kwargs):
"""
ssl: A dictionary of SSL parameters:
@@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer):
self.q = q
self.handler_klass = handler_klass
+ if self.handler_klass is not None:
+ self.handler_klass.kwargs = kwargs
self.last_handler = None
def handle_client_connection(self, request, client_address):
@@ -89,16 +91,16 @@ class ServerTestBase(object):
addr = ("localhost", 0)
@classmethod
- def setup_class(cls):
+ def setup_class(cls, **kwargs):
cls.q = queue.Queue()
- s = cls.makeserver()
+ s = cls.makeserver(**kwargs)
cls.port = s.address.port
cls.server = _ServerThread(s)
cls.server.start()
@classmethod
- def makeserver(cls):
- return _TServer(cls.ssl, cls.q, cls.handler, cls.addr)
+ def makeserver(cls, **kwargs):
+ return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs)
@classmethod
def teardown_class(cls):