diff options
author | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2016-02-02 23:54:35 +0100 |
---|---|---|
committer | Thomas Kriechbaumer <thomas@kriechbaumer.name> | 2016-02-04 09:52:29 +0100 |
commit | cf8c063773b70ad37ab0a2125f5ed03c35e17336 (patch) | |
tree | 211dd9089b83c66d76d33de7b3c6008767ccd5a9 /libmproxy/protocol | |
parent | ca5cc34d0b70f3306f62004be7ceb3f0c2053da7 (diff) | |
download | mitmproxy-cf8c063773b70ad37ab0a2125f5ed03c35e17336.tar.gz mitmproxy-cf8c063773b70ad37ab0a2125f5ed03c35e17336.tar.bz2 mitmproxy-cf8c063773b70ad37ab0a2125f5ed03c35e17336.zip |
fix http2 race condition
Diffstat (limited to 'libmproxy/protocol')
-rw-r--r-- | libmproxy/protocol/http2.py | 67 |
1 files changed, 53 insertions, 14 deletions
diff --git a/libmproxy/protocol/http2.py b/libmproxy/protocol/http2.py index e617f77c..de068836 100644 --- a/libmproxy/protocol/http2.py +++ b/libmproxy/protocol/http2.py @@ -167,7 +167,8 @@ class Http2Layer(Layer): new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()]) other_conn.h2.safe_update_settings(new_settings) elif isinstance(event, ConnectionTerminated): - other_conn.h2.safe_close_connection(event.error_code) + # Do not immediately terminate the other connection. + # Some streams might be still sending data to the client. return False elif isinstance(event, PushedStreamReceived): # pushed stream ids should be uniq and not dependent on race conditions @@ -183,7 +184,7 @@ class Http2Layer(Layer): self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].parent_stream_id = parent_eid self.streams[event.pushed_stream_id].timestamp_end = time.time() - self.streams[event.pushed_stream_id].data_finished.set() + self.streams[event.pushed_stream_id].request_data_finished.set() self.streams[event.pushed_stream_id].start() elif isinstance(event, TrailersReceived): raise NotImplementedError() @@ -240,18 +241,50 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): self.server_stream_id = None self.request_headers = request_headers self.response_headers = None - self.data_queue = Queue.Queue() - self.queued_data_length = 0 + self.pushed = False + + self.request_data_queue = Queue.Queue() + self.request_queued_data_length = 0 + self.request_data_finished = threading.Event() self.response_arrived = threading.Event() - self.data_finished = threading.Event() + self.response_data_queue = Queue.Queue() + self.response_queued_data_length = 0 + self.response_data_finished = threading.Event() + + @property + def data_queue(self): + if self.response_arrived.is_set(): + return self.response_data_queue + else: + return self.request_data_queue + + @property + def queued_data_length(self): + if self.response_arrived.is_set(): + return self.response_queued_data_length + else: + return self.request_queued_data_length + + @property + def data_finished(self): + if self.response_arrived.is_set(): + return self.response_data_finished + else: + return self.request_data_finished + + @queued_data_length.setter + def queued_data_length(self, v): + if self.response_arrived.is_set(): + return self.response_queued_data_length + else: + return self.request_queued_data_length def is_zombie(self): return self.zombie is not None def read_request(self): - self.data_finished.wait() - self.data_finished.clear() + self.request_data_finished.wait() authority = self.request_headers.get(':authority', '') method = self.request_headers.get(':method', 'GET') @@ -279,8 +312,8 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): port = int(port) data = [] - while self.data_queue.qsize() > 0: - data.append(self.data_queue.get()) + while self.request_data_queue.qsize() > 0: + data.append(self.request_data_queue.get()) data = b"".join(data) return HTTPRequest( @@ -298,9 +331,15 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): ) def send_request(self, message): - if self.zombie: + if self.pushed: + # nothing to do here return + with self.server_conn.h2.lock: + # We must not assign a stream id if we are already a zombie. + if self.zombie: + return + self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id @@ -333,12 +372,12 @@ class Http2SingleStreamLayer(_HttpTransmissionLayer, threading.Thread): def read_response_body(self, request, response): while True: try: - yield self.data_queue.get(timeout=1) + yield self.response_data_queue.get(timeout=1) except Queue.Empty: pass - if self.data_finished.is_set(): - while self.data_queue.qsize() > 0: - yield self.data_queue.get() + if self.response_data_finished.is_set(): + while self.response_data_queue.qsize() > 0: + yield self.response_data_queue.get() return if self.zombie: return |