aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorUjjwal Verma <ujjwalverma1111@gmail.com>2017-06-05 01:52:36 +0530
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2017-07-04 10:52:50 +0200
commitd4f35d7a4a601c11639d2478cae1b00d6c003c98 (patch)
tree14b019665178a863b733befaa4805d6956d4aecc
parent47c9604aed6049d99b8605419a7edc90935a8006 (diff)
downloadmitmproxy-d4f35d7a4a601c11639d2478cae1b00d6c003c98.tar.gz
mitmproxy-d4f35d7a4a601c11639d2478cae1b00d6c003c98.tar.bz2
mitmproxy-d4f35d7a4a601c11639d2478cae1b00d6c003c98.zip
request streaming for HTTP/2
-rw-r--r--mitmproxy/proxy/protocol/http.py2
-rw-r--r--mitmproxy/proxy/protocol/http2.py42
-rw-r--r--test/mitmproxy/proxy/protocol/test_http2.py124
3 files changed, 155 insertions, 13 deletions
diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py
index 9f120767..93865c74 100644
--- a/mitmproxy/proxy/protocol/http.py
+++ b/mitmproxy/proxy/protocol/http.py
@@ -333,6 +333,8 @@ class HttpLayer(base.Layer):
if f.request.stream:
self.send_request_headers(f.request)
chunks = self.read_request_body(f.request)
+ if callable(f.request.stream):
+ chunks = f.request.stream(chunks)
self.send_request_body(f.request, chunks)
else:
self.send_request(f.request)
diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py
index ace7ecde..eab5292f 100644
--- a/mitmproxy/proxy/protocol/http2.py
+++ b/mitmproxy/proxy/protocol/http2.py
@@ -487,14 +487,23 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
@detect_zombie_stream
def read_request_body(self, request):
- self.request_data_finished.wait()
- data = []
- while self.request_data_queue.qsize() > 0:
- data.append(self.request_data_queue.get())
- return data
+ if not request.stream:
+ self.request_data_finished.wait()
+
+ while True:
+ try:
+ yield self.request_data_queue.get(timeout=0.1)
+ except queue.Empty: # pragma: no cover
+ pass
+ if self.request_data_finished.is_set():
+ self.raise_zombie()
+ while self.request_data_queue.qsize() > 0:
+ yield self.request_data_queue.get()
+ break
+ self.raise_zombie()
@detect_zombie_stream
- def send_request(self, message):
+ def send_request_headers(self, request):
if self.pushed:
# nothing to do here
return
@@ -519,10 +528,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.server_stream_id = self.connections[self.server_conn].get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
- headers = message.headers.copy()
- headers.insert(0, ":path", message.path)
- headers.insert(0, ":method", message.method)
- headers.insert(0, ":scheme", message.scheme)
+ headers = request.headers.copy()
+ headers.insert(0, ":path", request.path)
+ headers.insert(0, ":method", request.method)
+ headers.insert(0, ":scheme", request.scheme)
priority_exclusive = None
priority_depends_on = None
@@ -553,14 +562,25 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr
self.raise_zombie()
self.connections[self.server_conn].lock.release()
+ @detect_zombie_stream
+ def send_request_body(self, request, chunks):
+ if self.pushed:
+ # nothing to do here
+ return
+
if not self.no_body:
self.connections[self.server_conn].safe_send_body(
self.raise_zombie,
self.server_stream_id,
- [message.content]
+ chunks
)
@detect_zombie_stream
+ def send_request(self, message):
+ self.send_request_headers(message)
+ self.send_request_body(message, [message.content])
+
+ @detect_zombie_stream
def read_response_headers(self):
self.response_arrived.wait()
diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py
index 261f8415..487d8890 100644
--- a/test/mitmproxy/proxy/protocol/test_http2.py
+++ b/test/mitmproxy/proxy/protocol/test_http2.py
@@ -14,6 +14,7 @@ import mitmproxy.net
from ...net import tservers as net_tservers
from mitmproxy import exceptions
from mitmproxy.net.http import http1, http2
+from pathod.language import generators
from ... import tservers
from ....conftest import requires_alpn
@@ -166,7 +167,8 @@ class _Http2TestBase:
end_stream=None,
priority_exclusive=None,
priority_depends_on=None,
- priority_weight=None):
+ priority_weight=None,
+ streaming=False):
if headers is None:
headers = []
if end_stream is None:
@@ -182,7 +184,8 @@ class _Http2TestBase:
)
if body:
h2_conn.send_data(stream_id, body)
- h2_conn.end_stream(stream_id)
+ if not streaming:
+ h2_conn.end_stream(stream_id)
wfile.write(h2_conn.data_to_send())
wfile.flush()
@@ -862,3 +865,120 @@ class TestConnectionTerminated(_Http2Test):
assert connection_terminated_event.error_code == 5
assert connection_terminated_event.last_stream_id == 42
assert connection_terminated_event.additional_data == b'foobar'
+
+
+@requires_alpn
+class TestRequestStreaming(_Http2Test):
+
+ @classmethod
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
+ if isinstance(event, h2.events.ConnectionTerminated):
+ return False
+ elif isinstance(event, h2.events.DataReceived):
+ data = event.data
+ assert data
+ h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=data)
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+
+ return True
+
+ @pytest.mark.parametrize('streaming', [True, False])
+ def test_request_streaming(self, streaming):
+ class Stream:
+ def requestheaders(self, f):
+ f.request.stream = streaming
+
+ self.master.addons.add(Stream())
+ h2_conn = self.setup_connection()
+ body = generators.RandomGenerator("bytes", 100)[:]
+ self._send_request(
+ self.client.wfile,
+ h2_conn,
+ headers=[
+ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/'),
+
+ ],
+ body=body,
+ streaming=True
+ )
+ done = False
+ connection_terminated_event = None
+ self.client.rfile.o.settimeout(2)
+ while not done:
+ try:
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
+ events = h2_conn.receive_data(raw)
+
+ for event in events:
+ if isinstance(event, h2.events.ConnectionTerminated):
+ connection_terminated_event = event
+ done = True
+ except:
+ break
+
+ if streaming:
+ assert connection_terminated_event.additional_data == body
+ else:
+ assert connection_terminated_event is None
+
+
+@requires_alpn
+class TestResponseStreaming(_Http2Test):
+
+ @classmethod
+ def handle_server_event(cls, event, h2_conn, rfile, wfile):
+ if isinstance(event, h2.events.ConnectionTerminated):
+ return False
+ elif isinstance(event, h2.events.RequestReceived):
+ data = generators.RandomGenerator("bytes", 100)[:]
+ h2_conn.send_headers(event.stream_id, [
+ (':status', '200'),
+ ('content-length', '100')
+ ])
+ h2_conn.send_data(event.stream_id, data)
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+ return True
+
+ @pytest.mark.parametrize('streaming', [True, False])
+ def test_response_streaming(self, streaming):
+ class Stream:
+ def responseheaders(self, f):
+ f.response.stream = streaming
+
+ self.master.addons.add(Stream())
+ h2_conn = self.setup_connection()
+ self._send_request(
+ self.client.wfile,
+ h2_conn,
+ headers=[
+ (':authority', "127.0.0.1:{}".format(self.server.server.address[1])),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/'),
+
+ ]
+ )
+ done = False
+ self.client.rfile.o.settimeout(2)
+ data = None
+ while not done:
+ try:
+ raw = b''.join(http2.read_raw_frame(self.client.rfile))
+ events = h2_conn.receive_data(raw)
+
+ for event in events:
+ if isinstance(event, h2.events.DataReceived):
+ data = event.data
+ done = True
+ except:
+ break
+
+ if streaming:
+ assert data
+ else:
+ assert data is None