diff options
author | Ujjwal Verma <ujjwalverma1111@gmail.com> | 2017-06-05 01:52:36 +0530 |
---|---|---|
committer | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2017-07-04 10:52:50 +0200 |
commit | d4f35d7a4a601c11639d2478cae1b00d6c003c98 (patch) | |
tree | 14b019665178a863b733befaa4805d6956d4aecc | |
parent | 47c9604aed6049d99b8605419a7edc90935a8006 (diff) | |
download | mitmproxy-d4f35d7a4a601c11639d2478cae1b00d6c003c98.tar.gz mitmproxy-d4f35d7a4a601c11639d2478cae1b00d6c003c98.tar.bz2 mitmproxy-d4f35d7a4a601c11639d2478cae1b00d6c003c98.zip |
request streaming for HTTP/2
-rw-r--r-- | mitmproxy/proxy/protocol/http.py | 2 | ||||
-rw-r--r-- | mitmproxy/proxy/protocol/http2.py | 42 | ||||
-rw-r--r-- | test/mitmproxy/proxy/protocol/test_http2.py | 124 |
3 files changed, 155 insertions, 13 deletions
diff --git a/mitmproxy/proxy/protocol/http.py b/mitmproxy/proxy/protocol/http.py index 9f120767..93865c74 100644 --- a/mitmproxy/proxy/protocol/http.py +++ b/mitmproxy/proxy/protocol/http.py @@ -333,6 +333,8 @@ class HttpLayer(base.Layer): if f.request.stream: self.send_request_headers(f.request) chunks = self.read_request_body(f.request) + if callable(f.request.stream): + chunks = f.request.stream(chunks) self.send_request_body(f.request, chunks) else: self.send_request(f.request) diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index ace7ecde..eab5292f 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -487,14 +487,23 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr @detect_zombie_stream def read_request_body(self, request): - self.request_data_finished.wait() - data = [] - while self.request_data_queue.qsize() > 0: - data.append(self.request_data_queue.get()) - return data + if not request.stream: + self.request_data_finished.wait() + + while True: + try: + yield self.request_data_queue.get(timeout=0.1) + except queue.Empty: # pragma: no cover + pass + if self.request_data_finished.is_set(): + self.raise_zombie() + while self.request_data_queue.qsize() > 0: + yield self.request_data_queue.get() + break + self.raise_zombie() @detect_zombie_stream - def send_request(self, message): + def send_request_headers(self, request): if self.pushed: # nothing to do here return @@ -519,10 +528,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.server_stream_id = self.connections[self.server_conn].get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id - headers = message.headers.copy() - headers.insert(0, ":path", message.path) - headers.insert(0, ":method", message.method) - headers.insert(0, ":scheme", message.scheme) + headers = request.headers.copy() + headers.insert(0, ":path", request.path) + headers.insert(0, ":method", request.method) + headers.insert(0, ":scheme", request.scheme) priority_exclusive = None priority_depends_on = None @@ -553,14 +562,25 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr self.raise_zombie() self.connections[self.server_conn].lock.release() + @detect_zombie_stream + def send_request_body(self, request, chunks): + if self.pushed: + # nothing to do here + return + if not self.no_body: self.connections[self.server_conn].safe_send_body( self.raise_zombie, self.server_stream_id, - [message.content] + chunks ) @detect_zombie_stream + def send_request(self, message): + self.send_request_headers(message) + self.send_request_body(message, [message.content]) + + @detect_zombie_stream def read_response_headers(self): self.response_arrived.wait() diff --git a/test/mitmproxy/proxy/protocol/test_http2.py b/test/mitmproxy/proxy/protocol/test_http2.py index 261f8415..487d8890 100644 --- a/test/mitmproxy/proxy/protocol/test_http2.py +++ b/test/mitmproxy/proxy/protocol/test_http2.py @@ -14,6 +14,7 @@ import mitmproxy.net from ...net import tservers as net_tservers from mitmproxy import exceptions from mitmproxy.net.http import http1, http2 +from pathod.language import generators from ... import tservers from ....conftest import requires_alpn @@ -166,7 +167,8 @@ class _Http2TestBase: end_stream=None, priority_exclusive=None, priority_depends_on=None, - priority_weight=None): + priority_weight=None, + streaming=False): if headers is None: headers = [] if end_stream is None: @@ -182,7 +184,8 @@ class _Http2TestBase: ) if body: h2_conn.send_data(stream_id, body) - h2_conn.end_stream(stream_id) + if not streaming: + h2_conn.end_stream(stream_id) wfile.write(h2_conn.data_to_send()) wfile.flush() @@ -862,3 +865,120 @@ class TestConnectionTerminated(_Http2Test): assert connection_terminated_event.error_code == 5 assert connection_terminated_event.last_stream_id == 42 assert connection_terminated_event.additional_data == b'foobar' + + +@requires_alpn +class TestRequestStreaming(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.DataReceived): + data = event.data + assert data + h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data=data) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + + return True + + @pytest.mark.parametrize('streaming', [True, False]) + def test_request_streaming(self, streaming): + class Stream: + def requestheaders(self, f): + f.request.stream = streaming + + self.master.addons.add(Stream()) + h2_conn = self.setup_connection() + body = generators.RandomGenerator("bytes", 100)[:] + self._send_request( + self.client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + + ], + body=body, + streaming=True + ) + done = False + connection_terminated_event = None + self.client.rfile.o.settimeout(2) + while not done: + try: + raw = b''.join(http2.read_raw_frame(self.client.rfile)) + events = h2_conn.receive_data(raw) + + for event in events: + if isinstance(event, h2.events.ConnectionTerminated): + connection_terminated_event = event + done = True + except: + break + + if streaming: + assert connection_terminated_event.additional_data == body + else: + assert connection_terminated_event is None + + +@requires_alpn +class TestResponseStreaming(_Http2Test): + + @classmethod + def handle_server_event(cls, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + data = generators.RandomGenerator("bytes", 100)[:] + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('content-length', '100') + ]) + h2_conn.send_data(event.stream_id, data) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + @pytest.mark.parametrize('streaming', [True, False]) + def test_response_streaming(self, streaming): + class Stream: + def responseheaders(self, f): + f.response.stream = streaming + + self.master.addons.add(Stream()) + h2_conn = self.setup_connection() + self._send_request( + self.client.wfile, + h2_conn, + headers=[ + (':authority', "127.0.0.1:{}".format(self.server.server.address[1])), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + + ] + ) + done = False + self.client.rfile.o.settimeout(2) + data = None + while not done: + try: + raw = b''.join(http2.read_raw_frame(self.client.rfile)) + events = h2_conn.receive_data(raw) + + for event in events: + if isinstance(event, h2.events.DataReceived): + data = event.data + done = True + except: + break + + if streaming: + assert data + else: + assert data is None |