aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2016-01-25 21:14:58 +0100
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2016-02-04 09:52:27 +0100
commit41f4197a0dd73a2b00ea8485608ba9b05a605dd4 (patch)
treed17711f3e4ba15a8149d993487e76f5a0aad438b
parent97c2530f90e8ddf0b36539408372f21f5964e9bb (diff)
downloadmitmproxy-41f4197a0dd73a2b00ea8485608ba9b05a605dd4.tar.gz
mitmproxy-41f4197a0dd73a2b00ea8485608ba9b05a605dd4.tar.bz2
mitmproxy-41f4197a0dd73a2b00ea8485608ba9b05a605dd4.zip
test PushPromise support
-rw-r--r--libmproxy/protocol/http2.py11
-rw-r--r--test/test_protocol_http2.py91
2 files changed, 97 insertions, 5 deletions
diff --git a/libmproxy/protocol/http2.py b/libmproxy/protocol/http2.py
index 54e7572e..71423bf7 100644
--- a/libmproxy/protocol/http2.py
+++ b/libmproxy/protocol/http2.py
@@ -17,7 +17,7 @@ from .base import Layer
from .http import _HttpTransmissionLayer, HttpLayer
from .. import utils
from ..models import HTTPRequest, HTTPResponse
-
+from ..exceptions import HttpProtocolException, ProtocolException
class SafeH2Connection(H2Connection):
def __init__(self, conn, *args, **kwargs):
@@ -207,7 +207,14 @@ class Http2Layer(Layer):
is_server = (conn == self.server_conn.connection)
with source_conn.h2.lock:
- events = source_conn.h2.receive_data(utils.http2_read_frame(source_conn.rfile))
+ try:
+ raw_frame = utils.http2_read_frame(source_conn.rfile)
+ except:
+ for stream in self.streams.values():
+ stream.zombie = time.time()
+ return
+
+ events = source_conn.h2.receive_data(raw_frame)
source_conn.send(source_conn.h2.data_to_send())
for event in events:
diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py
index e72113c4..4fa2c701 100644
--- a/test/test_protocol_http2.py
+++ b/test/test_protocol_http2.py
@@ -28,9 +28,7 @@ requires_alpn = pytest.mark.skipif(
class SimpleHttp2Server(netlib_tservers.ServerTestBase):
- ssl = dict(
- alpn_select=b'h2',
- )
+ ssl = dict(alpn_select=b'h2')
class handler(netlib.tcp.BaseHandler):
def handle(self):
@@ -61,6 +59,59 @@ class SimpleHttp2Server(netlib_tservers.ServerTestBase):
return
+class PushHttp2Server(netlib_tservers.ServerTestBase):
+ ssl = dict(alpn_select=b'h2')
+
+ class handler(netlib.tcp.BaseHandler):
+ def handle(self):
+ h2_conn = h2.connection.H2Connection(client_side=False)
+
+ preamble = self.rfile.read(24)
+ h2_conn.initiate_connection()
+ h2_conn.receive_data(preamble)
+ self.wfile.write(h2_conn.data_to_send())
+ self.wfile.flush()
+
+ while True:
+ events = h2_conn.receive_data(utils.http2_read_frame(self.rfile))
+ self.wfile.write(h2_conn.data_to_send())
+ self.wfile.flush()
+
+ for event in events:
+ if isinstance(event, h2.events.RequestReceived):
+ h2_conn.send_headers(1, [(':status', '200')])
+ h2_conn.push_stream(1, 2, [
+ (':authority', "127.0.0.1:%s" % self.address.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/pushed_stream_foo'),
+ ('foo', 'bar')
+ ])
+ h2_conn.push_stream(1, 4, [
+ (':authority', "127.0.0.1:%s" % self.address.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/pushed_stream_bar'),
+ ('foo', 'bar')
+ ])
+ self.wfile.write(h2_conn.data_to_send())
+ self.wfile.flush()
+
+ h2_conn.send_headers(2, [(':status', '202')])
+ h2_conn.send_headers(4, [(':status', '204')])
+ h2_conn.send_data(1, b'regular_stream')
+ h2_conn.send_data(2, b'pushed_stream_foo')
+ h2_conn.send_data(4, b'pushed_stream_bar')
+ h2_conn.end_stream(1)
+ h2_conn.end_stream(2)
+ h2_conn.end_stream(4)
+ self.wfile.write(h2_conn.data_to_send())
+ self.wfile.flush()
+ print("HERE")
+ elif isinstance(event, h2.events.ConnectionTerminated):
+ return
+
+
@requires_alpn
class TestHttp2(tservers.ProxTestBase):
def _setup_connection(self):
@@ -132,3 +183,37 @@ class TestHttp2(tservers.ProxTestBase):
assert self.master.state.flows[0].response.status_code == 200
assert self.master.state.flows[0].response.headers['foo'] == 'bar'
assert self.master.state.flows[0].response.body == b'foobar'
+
+ def test_pushed_streams(self):
+ self.server = PushHttp2Server()
+ self.server.setup_class()
+
+ client, h2_conn = self._setup_connection()
+
+ self._send_request(client.wfile, h2_conn, headers=[
+ (':authority', "127.0.0.1:%s" % self.server.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/'),
+ ('foo', 'bar')
+ ])
+
+ ended_streams = 0
+ while ended_streams != 3:
+ try:
+ events = h2_conn.receive_data(utils.http2_read_frame(client.rfile))
+ except:
+ break
+ client.wfile.write(h2_conn.data_to_send())
+ client.wfile.flush()
+
+ for event in events:
+ if isinstance(event, h2.events.StreamEnded):
+ ended_streams += 1
+
+ self.server.teardown_class()
+
+ assert len(self.master.state.flows) == 3
+ assert self.master.state.flows[0].response.body == b'regular_stream'
+ assert self.master.state.flows[1].response.body == b'pushed_stream_foo'
+ assert self.master.state.flows[2].response.body == b'pushed_stream_bar'