From 947f79eb6c173d445c57203d8f301a033176272b Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 16 Jan 2016 11:31:43 +0100 Subject: improved zombie detection --- libmproxy/protocol/http.py | 132 ++++++++++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 61 deletions(-) (limited to 'libmproxy') diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 75ca520d..ea069bcb 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -157,7 +157,7 @@ class SafeH2Connection(H2Connection): with self.lock: try: self.reset_stream(stream_id, error_code) - except StreamClosedError: + except h2.exceptions.ProtocolError: # stream is already closed - good pass self.conn.send(self.data_to_send()) @@ -172,30 +172,33 @@ class SafeH2Connection(H2Connection): self.update_settings(new_settings) self.conn.send(self.data_to_send()) - def safe_send_headers(self, stream_id, headers): + def safe_send_headers(self, is_zombie, stream_id, headers): with self.lock: + if is_zombie(self, stream_id): + return self.send_headers(stream_id, headers) self.conn.send(self.data_to_send()) - def safe_send_body(self, stream_id, chunks): - # TODO: this assumes the MAX_FRAME_SIZE does not change in the middle - # of a transfer - it could though. Then we need to re-chunk everything. + def safe_send_body(self, is_zombie, stream_id, chunks): for chunk in chunks: - max_outbound_frame_size = self.max_outbound_frame_size - for i in xrange(0, len(chunk), max_outbound_frame_size): - frame_chunk = chunk[i:i+max_outbound_frame_size] - + position = 0 + while position < len(chunk): self.lock.acquire() - while True: - if self.local_flow_control_window(stream_id) < len(frame_chunk): - self.lock.release() - time.sleep(0) - else: - break + max_outbound_frame_size = self.max_outbound_frame_size + frame_chunk = chunk[position:position+max_outbound_frame_size] + if self.local_flow_control_window(stream_id) < len(frame_chunk): + self.lock.release() + time.sleep(0) + continue + if is_zombie(self, stream_id): + return self.send_data(stream_id, frame_chunk) self.conn.send(self.data_to_send()) self.lock.release() + position += max_outbound_frame_size with self.lock: + if is_zombie(self, stream_id): + return self.end_stream(stream_id) self.conn.send(self.data_to_send()) @@ -254,45 +257,45 @@ class Http2Layer(Layer): events = source_conn.h2.receive_data(raw_frame) source_conn.send(source_conn.h2.data_to_send()) - for event in events: - if hasattr(event, 'stream_id'): - if is_server: - eid = self.server_to_client_stream_ids[event.stream_id] - else: - eid = event.stream_id - - if isinstance(event, RequestReceived): - headers = Headers([[str(k), str(v)] for k, v in event.headers]) - self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) - self.streams[eid].start() - elif isinstance(event, ResponseReceived): - headers = Headers([[str(k), str(v)] for k, v in event.headers]) - self.streams[eid].response_headers = headers - self.streams[eid].response_arrived.set() - elif isinstance(event, DataReceived): - self.streams[eid].data_queue.put(event.data) - source_conn.h2.safe_increment_flow_control(event.stream_id, len(event.data)) - elif isinstance(event, StreamEnded): - self.streams[eid].data_finished.set() - elif isinstance(event, StreamReset): - self.streams[eid].zombie = time.time() - if eid in self.streams and event.error_code == 0x8: + for event in events: + if hasattr(event, 'stream_id'): if is_server: - other_stream_id = self.streams[eid].client_stream_id + eid = self.server_to_client_stream_ids[event.stream_id] else: - other_stream_id = self.streams[eid].server_stream_id - other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) - elif isinstance(event, RemoteSettingsChanged): - source_conn.h2.safe_acknowledge_settings(event) - new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()]) - other_conn.h2.safe_update_settings(new_settings) - elif isinstance(event, ConnectionTerminated): - other_conn.h2.safe_close_connection(event.error_code) - return - elif isinstance(event, TrailersReceived): - raise NotImplementedError() - elif isinstance(event, PushedStreamReceived): - raise NotImplementedError() + eid = event.stream_id + + if isinstance(event, RequestReceived): + headers = Headers([[str(k), str(v)] for k, v in event.headers]) + self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) + self.streams[eid].start() + elif isinstance(event, ResponseReceived): + headers = Headers([[str(k), str(v)] for k, v in event.headers]) + self.streams[eid].response_headers = headers + self.streams[eid].response_arrived.set() + elif isinstance(event, DataReceived): + self.streams[eid].data_queue.put(event.data) + source_conn.h2.safe_increment_flow_control(event.stream_id, len(event.data)) + elif isinstance(event, StreamEnded): + self.streams[eid].data_finished.set() + elif isinstance(event, StreamReset): + self.streams[eid].zombie = time.time() + if eid in self.streams and event.error_code == 0x8: + if is_server: + other_stream_id = self.streams[eid].client_stream_id + else: + other_stream_id = self.streams[eid].server_stream_id + other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) + elif isinstance(event, RemoteSettingsChanged): + source_conn.h2.safe_acknowledge_settings(event) + new_settings = dict([(id, cs.new_value) for (id, cs) in event.changed_settings.iteritems()]) + other_conn.h2.safe_update_settings(new_settings) + elif isinstance(event, ConnectionTerminated): + other_conn.h2.safe_close_connection(event.error_code) + return + elif isinstance(event, TrailersReceived): + raise NotImplementedError() + elif isinstance(event, PushedStreamReceived): + raise NotImplementedError() death_time = time.time() - 10 for stream_id in self.streams.keys(): @@ -314,6 +317,18 @@ class Http2SingleStreamLayer(_HttpLayer, threading.Thread): self.response_arrived = threading.Event() self.data_finished = threading.Event() + def is_zombie(self, h2_conn, stream_id): + if self.zombie: + return True + + try: + h2_conn._get_stream_by_id(stream_id) + except Exception as e: + if isinstance(e, h2.exceptions.StreamClosedError): + return true + + return False + def read_request(self): self.data_finished.wait() self.data_finished.clear() @@ -364,18 +379,17 @@ class Http2SingleStreamLayer(_HttpLayer, threading.Thread): ) def send_request(self, message): - if self.zombie: - return - with self.server_conn.h2.lock: self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id self.server_conn.h2.safe_send_headers( + self.is_zombie, self.server_stream_id, message.headers ) self.server_conn.h2.safe_send_body( + self.is_zombie, self.server_stream_id, message.body ) @@ -409,19 +423,15 @@ class Http2SingleStreamLayer(_HttpLayer, threading.Thread): return def send_response_headers(self, response): - if self.zombie: - return - self.client_conn.h2.safe_send_headers( + self.is_zombie, self.client_stream_id, response.headers ) def send_response_body(self, _response, chunks): - if self.zombie: - return - self.client_conn.h2.safe_send_body( + self.is_zombie, self.client_stream_id, chunks ) -- cgit v1.2.3