aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2016-02-02 23:54:35 +0100
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2016-02-04 09:52:29 +0100
commitcf8c063773b70ad37ab0a2125f5ed03c35e17336 (patch)
tree211dd9089b83c66d76d33de7b3c6008767ccd5a9
parentca5cc34d0b70f3306f62004be7ceb3f0c2053da7 (diff)
downloadmitmproxy-cf8c063773b70ad37ab0a2125f5ed03c35e17336.tar.gz
mitmproxy-cf8c063773b70ad37ab0a2125f5ed03c35e17336.tar.bz2
mitmproxy-cf8c063773b70ad37ab0a2125f5ed03c35e17336.zip
fix http2 race condition
-rw-r--r--libmproxy/protocol/http2.py67
-rw-r--r--test/test_protocol_http2.py66
2 files changed, 107 insertions, 26 deletions
diff --git a/libmproxy/protocol/http2.py b/libmproxy/protocol/http2.py
index e617f77c..de068836 100644
--- a/libmproxy/protocol/http2.py
+++ b/libmproxy/protocol/http2.py
@@ -167,7 +167,8 @@ class Http2Layer(Layer):
new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()])
other_conn.h2.safe_update_settings(new_settings)
elif isinstance(event, ConnectionTerminated):
- other_conn.h2.safe_close_connection(event.error_code)
+ # Do not immediately terminate the other connection.
+ # Some streams might be still sending data to the client.
return False
elif isinstance(event, PushedStreamReceived):
# pushed stream ids should be uniq and not dependent on race conditions
@@ -183,7 +184,7 @@ class Http2Layer(Layer):
self.streams[event.pushed_stream_id].pushed = True
self.streams[event.pushed_stream_id].parent_stream_id = parent_eid
self.streams[event.pushed_stream_id].timestamp_end = time.time()
- self.streams[event.pushed_stream_id].data_finished.set()
+ self.streams[event.pushed_stream_id].request_data_finished.set()
self.streams[event.pushed_stream_id].start()
elif isinstance(event, TrailersReceived):
raise NotImplementedError()
@@ -240,18 +241,50 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
self.server_stream_id = None
self.request_headers = request_headers
self.response_headers = None
- self.data_queue = Queue.Queue()
- self.queued_data_length = 0
+ self.pushed = False
+
+ self.request_data_queue = Queue.Queue()
+ self.request_queued_data_length = 0
+ self.request_data_finished = threading.Event()
self.response_arrived = threading.Event()
- self.data_finished = threading.Event()
+ self.response_data_queue = Queue.Queue()
+ self.response_queued_data_length = 0
+ self.response_data_finished = threading.Event()
+
+ @property
+ def data_queue(self):
+ if self.response_arrived.is_set():
+ return self.response_data_queue
+ else:
+ return self.request_data_queue
+
+ @property
+ def queued_data_length(self):
+ if self.response_arrived.is_set():
+ return self.response_queued_data_length
+ else:
+ return self.request_queued_data_length
+
+ @property
+ def data_finished(self):
+ if self.response_arrived.is_set():
+ return self.response_data_finished
+ else:
+ return self.request_data_finished
+
+ @queued_data_length.setter
+ def queued_data_length(self, v):
+ if self.response_arrived.is_set():
+ return self.response_queued_data_length
+ else:
+ return self.request_queued_data_length
def is_zombie(self):
return self.zombie is not None
def read_request(self):
- self.data_finished.wait()
- self.data_finished.clear()
+ self.request_data_finished.wait()
authority = self.request_headers.get(':authority', '')
method = self.request_headers.get(':method', 'GET')
@@ -279,8 +312,8 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
port = int(port)
data = []
- while self.data_queue.qsize() > 0:
- data.append(self.data_queue.get())
+ while self.request_data_queue.qsize() > 0:
+ data.append(self.request_data_queue.get())
data = b"".join(data)
return HTTPRequest(
@@ -298,9 +331,15 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
)
def send_request(self, message):
- if self.zombie:
+ if self.pushed:
+ # nothing to do here
return
+
with self.server_conn.h2.lock:
+ # We must not assign a stream id if we are already a zombie.
+ if self.zombie:
+ return
+
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
@@ -333,12 +372,12 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread):
def read_response_body(self, request, response):
while True:
try:
- yield self.data_queue.get(timeout=1)
+ yield self.response_data_queue.get(timeout=1)
except Queue.Empty:
pass
- if self.data_finished.is_set():
- while self.data_queue.qsize() > 0:
- yield self.data_queue.get()
+ if self.response_data_finished.is_set():
+ while self.response_data_queue.qsize() > 0:
+ yield self.response_data_queue.get()
return
if self.zombie:
return
diff --git a/test/test_protocol_http2.py b/test/test_protocol_http2.py
index cc62f734..38cfdfc3 100644
--- a/test/test_protocol_http2.py
+++ b/test/test_protocol_http2.py
@@ -5,6 +5,7 @@ import pytest
import traceback
import os
import tempfile
+import sys
from libmproxy.proxy.config import ProxyConfig
from libmproxy.proxy.server import ProxyServer
@@ -47,9 +48,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
- while True:
+ done = False
+ while not done:
try:
- events = h2_conn.receive_data(b''.join(http2_read_raw_frame(self.rfile)))
+ raw = b''.join(http2_read_raw_frame(self.rfile))
+ events = h2_conn.receive_data(raw)
except:
break
self.wfile.write(h2_conn.data_to_send())
@@ -58,10 +61,12 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
for event in events:
try:
if not self.server.handle_server_event(event, h2_conn, self.rfile, self.wfile):
+ done = True
break
except Exception as e:
print(repr(e))
print(traceback.format_exc())
+ done = True
break
def handle_server_event(self, h2_conn, rfile, wfile):
@@ -182,7 +187,10 @@ class TestSimple(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
- events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
+ try:
+ events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
+ except:
+ break
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
@@ -248,7 +256,10 @@ class TestWithBodies(_Http2TestBase, _Http2ServerBase):
done = False
while not done:
- events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
+ try:
+ events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
+ except:
+ break
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
@@ -303,14 +314,16 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
wfile.write(h2_conn.data_to_send())
wfile.flush()
- h2_conn.send_headers(2, [(':status', '202')])
- h2_conn.send_headers(4, [(':status', '204')])
+ h2_conn.send_headers(2, [(':status', '200')])
+ h2_conn.send_headers(4, [(':status', '200')])
wfile.write(h2_conn.data_to_send())
wfile.flush()
h2_conn.send_data(1, b'regular_stream')
h2_conn.send_data(2, b'pushed_stream_foo')
h2_conn.send_data(4, b'pushed_stream_bar')
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
h2_conn.end_stream(1)
h2_conn.end_stream(2)
h2_conn.end_stream(4)
@@ -330,11 +343,14 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
('foo', 'bar')
])
+ done = False
ended_streams = 0
pushed_streams = 0
- while ended_streams != 3:
+ responses = 0
+ while not done:
try:
- events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
+ raw = b''.join(http2_read_raw_frame(client.rfile))
+ events = h2_conn.receive_data(raw)
except:
break
client.wfile.write(h2_conn.data_to_send())
@@ -345,7 +361,19 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
ended_streams += 1
elif isinstance(event, h2.events.PushedStreamReceived):
pushed_streams += 1
+ elif isinstance(event, h2.events.ResponseReceived):
+ responses += 1
+ if isinstance(event, h2.events.ConnectionTerminated):
+ done = True
+ if responses == 3 and ended_streams == 3 and pushed_streams == 2:
+ done = True
+
+ h2_conn.close_connection()
+ client.wfile.write(h2_conn.data_to_send())
+ client.wfile.flush()
+
+ assert ended_streams == 3
assert pushed_streams == 2
bodies = [flow.response.body for flow in self.master.state.flows]
@@ -365,8 +393,11 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
('foo', 'bar')
])
- streams = 0
- while streams != 3:
+ done = False
+ ended_streams = 0
+ pushed_streams = 0
+ responses = 0
+ while not done:
try:
events = h2_conn.receive_data(b''.join(http2_read_raw_frame(client.rfile)))
except:
@@ -376,12 +407,23 @@ class TestPushPromise(_Http2TestBase, _Http2ServerBase):
for event in events:
if isinstance(event, h2.events.StreamEnded) and event.stream_id == 1:
- streams += 1
+ ended_streams += 1
elif isinstance(event, h2.events.PushedStreamReceived):
- streams += 1
+ pushed_streams += 1
h2_conn.reset_stream(event.pushed_stream_id, error_code=0x8)
client.wfile.write(h2_conn.data_to_send())
client.wfile.flush()
+ elif isinstance(event, h2.events.ResponseReceived):
+ responses += 1
+ if isinstance(event, h2.events.ConnectionTerminated):
+ done = True
+
+ if responses >= 1 and ended_streams >= 1 and pushed_streams == 2:
+ done = True
+
+ h2_conn.close_connection()
+ client.wfile.write(h2_conn.data_to_send())
+ client.wfile.flush()
bodies = [flow.response.body for flow in self.master.state.flows if flow.response]
assert len(bodies) >= 1