diff options
Diffstat (limited to 'mitmproxy/protocol/http2.py')
-rw-r--r-- | mitmproxy/protocol/http2.py | 99 |
1 files changed, 60 insertions, 39 deletions
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index 0e42d619..1595fb61 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -96,15 +96,17 @@ class Http2Layer(base.Layer): self.server_to_client_stream_ids = dict([(0, 0)]) self.client_conn.h2 = SafeH2Connection(self.client_conn, client_side=False, header_encoding=False) - # make sure that we only pass actual SSL.Connection objects in here, - # because otherwise ssl_read_select fails! - self.active_conns = [self.client_conn.connection] - def _initiate_server_conn(self): - self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False) - self.server_conn.h2.initiate_connection() - self.server_conn.send(self.server_conn.h2.data_to_send()) - self.active_conns.append(self.server_conn.connection) + if self.server_conn: + self.server_conn.h2 = SafeH2Connection(self.server_conn, client_side=True, header_encoding=False) + self.server_conn.h2.initiate_connection() + self.server_conn.send(self.server_conn.h2.data_to_send()) + + def _complete_handshake(self): + preamble = self.client_conn.rfile.read(24) + self.client_conn.h2.initiate_connection() + self.client_conn.h2.receive_data(preamble) + self.client_conn.send(self.client_conn.h2.data_to_send()) def next_layer(self): # pragma: no cover # WebSockets over HTTP/2? @@ -126,7 +128,7 @@ class Http2Layer(base.Layer): eid = event.stream_id if isinstance(event, events.RequestReceived): - return self._handle_request_received(eid, event) + return self._handle_request_received(eid, event, source_conn.h2) elif isinstance(event, events.ResponseReceived): return self._handle_response_received(eid, event) elif isinstance(event, events.DataReceived): @@ -138,9 +140,9 @@ class Http2Layer(base.Layer): elif isinstance(event, events.RemoteSettingsChanged): return self._handle_remote_settings_changed(event, other_conn) elif isinstance(event, events.ConnectionTerminated): - return self._handle_connection_terminated(event) + return self._handle_connection_terminated(event, is_server) elif isinstance(event, events.PushedStreamReceived): - return self._handle_pushed_stream_received(event) + return self._handle_pushed_stream_received(event, source_conn.h2) elif isinstance(event, events.PriorityUpdated): return self._handle_priority_updated(eid, event) elif isinstance(event, events.TrailersReceived): @@ -149,9 +151,9 @@ class Http2Layer(base.Layer): # fail-safe for unhandled events return True - def _handle_request_received(self, eid, event): + def _handle_request_received(self, eid, event, h2_connection): headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) + self.streams[eid] = Http2SingleStreamLayer(self, h2_connection, eid, headers) self.streams[eid].timestamp_start = time.time() self.streams[eid].no_body = (event.stream_ended is not None) if event.priority_updated is not None: @@ -173,7 +175,7 @@ class Http2Layer(base.Layer): def _handle_data_received(self, eid, event, source_conn): bsl = self.config.options.body_size_limit if bsl and self.streams[eid].queued_data_length > bsl: - self.streams[eid].zombie = time.time() + self.streams[eid].kill() source_conn.h2.safe_reset_stream( event.stream_id, h2.errors.REFUSED_STREAM @@ -194,7 +196,7 @@ class Http2Layer(base.Layer): return True def _handle_stream_reset(self, eid, event, is_server, other_conn): - self.streams[eid].zombie = time.time() + self.streams[eid].kill() if eid in self.streams and event.error_code == h2.errors.CANCEL: if is_server: other_stream_id = self.streams[eid].client_stream_id @@ -209,7 +211,13 @@ class Http2Layer(base.Layer): other_conn.h2.safe_update_settings(new_settings) return True - def _handle_connection_terminated(self, event): + def _handle_connection_terminated(self, event, is_server): + self.log("HTTP/2 connection terminated by {}: error code: {}, last stream id: {}, additional data: {}".format( + "server" if is_server else "client", + event.error_code, + event.last_stream_id, + event.additional_data), "info") + if event.error_code != h2.errors.NO_ERROR: # Something terrible has happened - kill everything! self.client_conn.h2.close_connection( @@ -226,7 +234,7 @@ class Http2Layer(base.Layer): """ return False - def _handle_pushed_stream_received(self, event): + def _handle_pushed_stream_received(self, event, h2_connection): # pushed stream ids should be unique and not dependent on race conditions # only the parent stream id must be looked up first parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] @@ -235,7 +243,7 @@ class Http2Layer(base.Layer): self.client_conn.send(self.client_conn.h2.data_to_send()) headers = netlib.http.Headers([[k, v] for k, v in event.headers]) - self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, event.pushed_stream_id, headers) + self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, h2_connection, event.pushed_stream_id, headers) self.streams[event.pushed_stream_id].timestamp_start = time.time() self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].parent_stream_id = parent_eid @@ -253,7 +261,7 @@ class Http2Layer(base.Layer): with self.server_conn.h2.lock: mapped_stream_id = event.stream_id if mapped_stream_id in self.streams and self.streams[mapped_stream_id].server_stream_id: - # if the stream is already up and running and was sent to the server + # if the stream is already up and running and was sent to the server, # use the mapped server stream id to update priority information mapped_stream_id = self.streams[mapped_stream_id].server_stream_id @@ -294,37 +302,36 @@ class Http2Layer(base.Layer): def _kill_all_streams(self): for stream in self.streams.values(): - if not stream.zombie: - stream.zombie = time.time() - stream.request_data_finished.set() - stream.response_arrived.set() - stream.data_finished.set() + stream.kill() def __call__(self): - if self.server_conn: - self._initiate_server_conn() + self._initiate_server_conn() + self._complete_handshake() - preamble = self.client_conn.rfile.read(24) - self.client_conn.h2.initiate_connection() - self.client_conn.h2.receive_data(preamble) - self.client_conn.send(self.client_conn.h2.data_to_send()) + client = self.client_conn.connection + server = self.server_conn.connection + conns = [client, server] try: while True: - r = tcp.ssl_read_select(self.active_conns, 1) + r = tcp.ssl_read_select(conns, 1) for conn in r: - source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn - other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + source_conn = self.client_conn if conn == client else self.server_conn + other_conn = self.server_conn if conn == client else self.client_conn is_server = (conn == self.server_conn.connection) with source_conn.h2.lock: try: - raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile)) + raw_frame = b''.join(http2.read_raw_frame(source_conn.rfile)) except: # read frame failed: connection closed self._kill_all_streams() return + if source_conn.h2.state_machine.state == h2.connection.ConnectionState.CLOSED: + self.log("HTTP/2 connection entered closed state already", "debug") + return + incoming_events = source_conn.h2.receive_data(raw_frame) source_conn.send(source_conn.h2.data_to_send()) @@ -354,10 +361,11 @@ def detect_zombie_stream(func): class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread): - def __init__(self, ctx, stream_id, request_headers): + def __init__(self, ctx, h2_connection, stream_id, request_headers): super(Http2SingleStreamLayer, self).__init__( ctx, name="Http2SingleStreamLayer-{}".format(stream_id) ) + self.h2_connection = h2_connection self.zombie = None self.client_stream_id = stream_id self.server_stream_id = None @@ -365,6 +373,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.response_headers = None self.pushed = False + self.timestamp_start = None + self.timestamp_end = None + self.request_data_queue = queue.Queue() self.request_queued_data_length = 0 self.request_data_finished = threading.Event() @@ -381,6 +392,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.priority_weight = None self.handled_priority_event = None + def kill(self): + if not self.zombie: + self.zombie = time.time() + self.request_data_finished.set() + self.response_arrived.set() + self.response_data_finished.set() + def connect(self): # pragma: no cover raise exceptions.Http2ProtocolException("HTTP2 layer should already have a connection.") @@ -421,10 +439,11 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.request_queued_data_length = v def raise_zombie(self, pre_command=None): - if self.zombie is not None: + connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED + if self.zombie is not None or not hasattr(self.server_conn, 'h2') or connection_closed: if pre_command is not None: pre_command() - raise exceptions.Http2ProtocolException("Zombie Stream") + raise exceptions.Http2ZombieException("Connection already dead") @detect_zombie_stream def read_request(self): @@ -508,6 +527,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except Exception as e: # pragma: no cover raise e finally: + self.raise_zombie() self.server_conn.h2.lock.release() if not self.no_body: @@ -581,6 +601,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) try: layer() + except exceptions.Http2ZombieException as e: # pragma: no cover + pass except exceptions.ProtocolException as e: # pragma: no cover self.log(repr(e), "info") self.log(traceback.format_exc(), "debug") @@ -589,5 +611,4 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except exceptions.Kill: self.log("Connection killed", "info") - if not self.zombie: - self.zombie = time.time() + self.kill() |