From f140b1d84fc27d4156ca4b901023b0c81b8851bc Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 5 Nov 2016 18:23:25 +0100 Subject: http2: move h2 connection object --- mitmproxy/proxy/protocol/http2.py | 95 ++++++++++++++++++++------------------- 1 file changed, 48 insertions(+), 47 deletions(-) diff --git a/mitmproxy/proxy/protocol/http2.py b/mitmproxy/proxy/protocol/http2.py index b204f6e8..5ab503f5 100644 --- a/mitmproxy/proxy/protocol/http2.py +++ b/mitmproxy/proxy/protocol/http2.py @@ -87,13 +87,15 @@ class Http2Layer(base.Layer): self.mode = mode self.streams = dict() # type: Dict[int, Http2SingleStreamLayer] self.server_to_client_stream_ids = dict([(0, 0)]) # type: Dict[int, int] + self.connections = {} # type: Dict[object, SafeH2Connection] + config = h2.config.H2Configuration( client_side=False, header_encoding=False, validate_outbound_headers=False, normalize_outbound_headers=False, validate_inbound_headers=False) - self.client_conn.h2 = SafeH2Connection(self.client_conn, config=config) + self.connections[self.client_conn] = SafeH2Connection(self.client_conn, config=config) def _initiate_server_conn(self): if self.server_conn.connected(): @@ -103,15 +105,15 @@ class Http2Layer(base.Layer): validate_outbound_headers=False, normalize_outbound_headers=False, validate_inbound_headers=False) - self.server_conn.h2 = SafeH2Connection(self.server_conn, config=config) - self.server_conn.h2.initiate_connection() - self.server_conn.send(self.server_conn.h2.data_to_send()) + self.connections[self.server_conn] = SafeH2Connection(self.server_conn, config=config) + self.connections[self.server_conn].initiate_connection() + self.server_conn.send(self.connections[self.server_conn].data_to_send()) def _complete_handshake(self): preamble = self.client_conn.rfile.read(24) - self.client_conn.h2.initiate_connection() - self.client_conn.h2.receive_data(preamble) - self.client_conn.send(self.client_conn.h2.data_to_send()) + self.connections[self.client_conn].initiate_connection() + self.connections[self.client_conn].receive_data(preamble) + self.client_conn.send(self.connections[self.client_conn].data_to_send()) def next_layer(self): # pragma: no cover # WebSockets over HTTP/2? @@ -133,7 +135,7 @@ class Http2Layer(base.Layer): eid = event.stream_id if isinstance(event, events.RequestReceived): - return self._handle_request_received(eid, event, source_conn.h2) + return self._handle_request_received(eid, event) elif isinstance(event, events.ResponseReceived): return self._handle_response_received(eid, event) elif isinstance(event, events.DataReceived): @@ -147,7 +149,7 @@ class Http2Layer(base.Layer): elif isinstance(event, events.ConnectionTerminated): return self._handle_connection_terminated(event, is_server) elif isinstance(event, events.PushedStreamReceived): - return self._handle_pushed_stream_received(event, source_conn.h2) + return self._handle_pushed_stream_received(event) elif isinstance(event, events.PriorityUpdated): return self._handle_priority_updated(eid, event) elif isinstance(event, events.TrailersReceived): @@ -156,9 +158,9 @@ class Http2Layer(base.Layer): # fail-safe for unhandled events return True - def _handle_request_received(self, eid, event, h2_connection): + def _handle_request_received(self, eid, event): headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) - self.streams[eid] = Http2SingleStreamLayer(self, h2_connection, eid, headers) + self.streams[eid] = Http2SingleStreamLayer(self, self.connections[self.client_conn], eid, headers) self.streams[eid].timestamp_start = time.time() self.streams[eid].no_body = (event.stream_ended is not None) if event.priority_updated is not None: @@ -182,7 +184,7 @@ class Http2Layer(base.Layer): bsl = self.config.options.body_size_limit if bsl and self.streams[eid].queued_data_length > bsl: self.streams[eid].kill() - source_conn.h2.safe_reset_stream( + self.connections[source_conn].safe_reset_stream( event.stream_id, h2.errors.REFUSED_STREAM ) @@ -190,7 +192,7 @@ class Http2Layer(base.Layer): else: self.streams[eid].data_queue.put(event.data) self.streams[eid].queued_data_length += len(event.data) - source_conn.h2.safe_acknowledge_received_data( + self.connections[source_conn].safe_acknowledge_received_data( event.flow_controlled_length, event.stream_id ) @@ -209,12 +211,12 @@ class Http2Layer(base.Layer): else: other_stream_id = self.streams[eid].server_stream_id if other_stream_id is not None: - other_conn.h2.safe_reset_stream(other_stream_id, event.error_code) + self.connections[other_conn].safe_reset_stream(other_stream_id, event.error_code) return True def _handle_remote_settings_changed(self, event, other_conn): new_settings = dict([(key, cs.new_value) for (key, cs) in event.changed_settings.items()]) - other_conn.h2.safe_update_settings(new_settings) + self.connections[other_conn].safe_update_settings(new_settings) return True def _handle_connection_terminated(self, event, is_server): @@ -226,12 +228,12 @@ class Http2Layer(base.Layer): if event.error_code != h2.errors.NO_ERROR: # Something terrible has happened - kill everything! - self.client_conn.h2.close_connection( + self.connections[self.client_conn].close_connection( error_code=event.error_code, last_stream_id=event.last_stream_id, additional_data=event.additional_data ) - self.client_conn.send(self.client_conn.h2.data_to_send()) + self.client_conn.send(self.connections[self.client_conn].data_to_send()) self._kill_all_streams() else: """ @@ -240,17 +242,18 @@ class Http2Layer(base.Layer): """ return False - def _handle_pushed_stream_received(self, event, h2_connection): + def _handle_pushed_stream_received(self, event): # pushed stream ids should be unique and not dependent on race conditions # only the parent stream id must be looked up first parent_eid = self.server_to_client_stream_ids[event.parent_stream_id] - with self.client_conn.h2.lock: - self.client_conn.h2.push_stream(parent_eid, event.pushed_stream_id, event.headers) - self.client_conn.send(self.client_conn.h2.data_to_send()) + with self.connections[self.client_conn].lock: + self.connections[self.client_conn].push_stream(parent_eid, event.pushed_stream_id, event.headers) + self.client_conn.send(self.connections[self.client_conn].data_to_send()) headers = mitmproxy.net.http.Headers([[k, v] for k, v in event.headers]) - self.streams[event.pushed_stream_id] = Http2SingleStreamLayer(self, h2_connection, event.pushed_stream_id, headers) + layer = Http2SingleStreamLayer(self, self.connections[self.client_conn], event.pushed_stream_id, headers) + self.streams[event.pushed_stream_id] = layer self.streams[event.pushed_stream_id].timestamp_start = time.time() self.streams[event.pushed_stream_id].pushed = True self.streams[event.pushed_stream_id].parent_stream_id = parent_eid @@ -266,7 +269,7 @@ class Http2Layer(base.Layer): # HeadersFrame + Priority information as RequestReceived return True - with self.server_conn.h2.lock: + with self.connections[self.server_conn].lock: mapped_stream_id = event.stream_id if mapped_stream_id in self.streams and self.streams[mapped_stream_id].server_stream_id: # if the stream is already up and running and was sent to the server, @@ -278,13 +281,13 @@ class Http2Layer(base.Layer): self.streams[eid].priority_depends_on = event.depends_on self.streams[eid].priority_weight = event.weight - self.server_conn.h2.prioritize( + self.connections[self.server_conn].prioritize( mapped_stream_id, weight=event.weight, depends_on=self._map_depends_on_stream_id(mapped_stream_id, event.depends_on), exclusive=event.exclusive ) - self.server_conn.send(self.server_conn.h2.data_to_send()) + self.server_conn.send(self.connections[self.server_conn].data_to_send()) return True def _map_depends_on_stream_id(self, stream_id, depends_on): @@ -316,19 +319,17 @@ class Http2Layer(base.Layer): self._initiate_server_conn() self._complete_handshake() - client = self.client_conn.connection - server = self.server_conn.connection - conns = [client, server] + conns = [c.connection for c in self.connections.keys()] try: while True: r = tcp.ssl_read_select(conns, 1) for conn in r: - source_conn = self.client_conn if conn == client else self.server_conn - other_conn = self.server_conn if conn == client else self.client_conn - is_server = (conn == self.server_conn.connection) + source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn + other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn + is_server = (source_conn == self.server_conn) - with source_conn.h2.lock: + with self.connections[source_conn].lock: try: raw_frame = b''.join(http2.read_raw_frame(source_conn.rfile)) except: @@ -336,12 +337,12 @@ class Http2Layer(base.Layer): self._kill_all_streams() return - if source_conn.h2.state_machine.state == h2.connection.ConnectionState.CLOSED: + if self.connections[source_conn].state_machine.state == h2.connection.ConnectionState.CLOSED: self.log("HTTP/2 connection entered closed state already", "debug") return - incoming_events = source_conn.h2.receive_data(raw_frame) - source_conn.send(source_conn.h2.data_to_send()) + incoming_events = self.connections[source_conn].receive_data(raw_frame) + source_conn.send(self.connections[source_conn].data_to_send()) for event in incoming_events: if not self._handle_event(event, source_conn, other_conn, is_server): @@ -450,7 +451,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr def raise_zombie(self, pre_command=None): connection_closed = self.h2_connection.state_machine.state == h2.connection.ConnectionState.CLOSED - if self.zombie is not None or not hasattr(self.server_conn, 'h2') or connection_closed: + if self.zombie is not None or connection_closed: if pre_command is not None: pre_command() raise exceptions.Http2ZombieException("Connection already dead") @@ -494,12 +495,12 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr while True: self.raise_zombie() - self.server_conn.h2.lock.acquire() + self.connections[self.server_conn].lock.acquire() - max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams - if self.server_conn.h2.open_outbound_streams + 1 >= max_streams: + max_streams = self.connections[self.server_conn].remote_settings.max_concurrent_streams + if self.connections[self.server_conn].open_outbound_streams + 1 >= max_streams: # wait until we get a free slot for a new outgoing stream - self.server_conn.h2.lock.release() + self.connections[self.server_conn].lock.release() time.sleep(0.1) continue @@ -509,7 +510,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr # We must not assign a stream id if we are already a zombie. self.raise_zombie() - self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() + self.server_stream_id = self.connections[self.server_conn].get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id headers = message.headers.copy() @@ -528,7 +529,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr priority_weight = self.priority_weight try: - self.server_conn.h2.safe_send_headers( + self.connections[self.server_conn].safe_send_headers( self.raise_zombie, self.server_stream_id, headers, @@ -541,10 +542,10 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr raise e finally: self.raise_zombie() - self.server_conn.h2.lock.release() + self.connections[self.server_conn].lock.release() if not self.no_body: - self.server_conn.h2.safe_send_body( + self.connections[self.server_conn].safe_send_body( self.raise_zombie, self.server_stream_id, [message.content] @@ -591,8 +592,8 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr for forbidden_header in h2.utilities.CONNECTION_HEADERS: if forbidden_header in headers: del headers[forbidden_header] - with self.client_conn.h2.lock: - self.client_conn.h2.safe_send_headers( + with self.connections[self.client_conn].lock: + self.connections[self.client_conn].safe_send_headers( self.raise_zombie, self.client_stream_id, headers @@ -600,7 +601,7 @@ class Http2SingleStreamLayer(httpbase._HttpTransmissionLayer, basethread.BaseThr @detect_zombie_stream def send_response_body(self, _response, chunks): - self.client_conn.h2.safe_send_body( + self.connections[self.client_conn].safe_send_body( self.raise_zombie, self.client_stream_id, chunks -- cgit v1.2.3