aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/protocol/http2.py138
1 files changed, 83 insertions, 55 deletions
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py
index 957b8d64..b9a30c7e 100644
--- a/mitmproxy/protocol/http2.py
+++ b/mitmproxy/protocol/http2.py
@@ -56,11 +56,11 @@ class SafeH2Connection(connection.H2Connection):
self.conn.send(self.data_to_send())
def safe_send_headers(self, is_zombie, stream_id, headers):
- with self.lock:
- if is_zombie(): # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
- self.send_headers(stream_id, headers.fields)
- self.conn.send(self.data_to_send())
+ # make sure to have a lock
+ if is_zombie(): # pragma: no cover
+ raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.send_headers(stream_id, headers.fields)
+ self.conn.send(self.data_to_send())
def safe_send_body(self, is_zombie, stream_id, chunks):
for chunk in chunks:
@@ -77,8 +77,12 @@ class SafeH2Connection(connection.H2Connection):
time.sleep(0.1)
continue
self.send_data(stream_id, frame_chunk)
- self.conn.send(self.data_to_send())
- self.lock.release()
+ try:
+ self.conn.send(self.data_to_send())
+ except Exception as e:
+ raise e
+ finally:
+ self.lock.release()
position += max_outbound_frame_size
with self.lock:
if is_zombie(): # pragma: no cover
@@ -225,6 +229,9 @@ class Http2Layer(base.Layer):
for stream in self.streams.values():
if not stream.zombie:
stream.zombie = time.time()
+ stream.request_data_finished.set()
+ stream.response_arrived.set()
+ stream.data_finished.set()
def __call__(self):
if self.server_conn:
@@ -235,31 +242,36 @@ class Http2Layer(base.Layer):
self.client_conn.h2.receive_data(preamble)
self.client_conn.send(self.client_conn.h2.data_to_send())
- while True:
- r = tcp.ssl_read_select(self.active_conns, 1)
- for conn in r:
- source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
- other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
- is_server = (conn == self.server_conn.connection)
-
- with source_conn.h2.lock:
- try:
- raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile))
- except:
- # read frame failed: connection closed
- self._kill_all_streams()
- return
-
- incoming_events = source_conn.h2.receive_data(raw_frame)
- source_conn.send(source_conn.h2.data_to_send())
-
- for event in incoming_events:
- if not self._handle_event(event, source_conn, other_conn, is_server):
- # connection terminated: GoAway
+ try:
+ while True:
+ r = tcp.ssl_read_select(self.active_conns, 1)
+ for conn in r:
+ source_conn = self.client_conn if conn == self.client_conn.connection else self.server_conn
+ other_conn = self.server_conn if conn == self.client_conn.connection else self.client_conn
+ is_server = (conn == self.server_conn.connection)
+
+ with source_conn.h2.lock:
+ try:
+ raw_frame = b''.join(http2.framereader.http2_read_raw_frame(source_conn.rfile))
+ except:
+ # read frame failed: connection closed
self._kill_all_streams()
return
- self._cleanup_streams()
+ incoming_events = source_conn.h2.receive_data(raw_frame)
+ source_conn.send(source_conn.h2.data_to_send())
+
+ for event in incoming_events:
+ if not self._handle_event(event, source_conn, other_conn, is_server):
+ # connection terminated: GoAway
+ self._kill_all_streams()
+ return
+
+ self._cleanup_streams()
+ except Exception as e:
+ self.log(repr(e), "info")
+ self.log(traceback.format_exc(), "debug")
+ self._kill_all_streams()
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
@@ -315,6 +327,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def read_request(self):
self.request_data_finished.wait()
+ if self.zombie: # pragma: no cover
+ raise exceptions.Http2ProtocolException("Zombie Stream")
+
authority = self.request_headers.get(':authority', '')
method = self.request_headers.get(':method', 'GET')
scheme = self.request_headers.get(':scheme', 'https')
@@ -366,31 +381,32 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
raise NotImplementedError()
def send_request(self, message):
- if not hasattr(self.server_conn, 'h2'):
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ if self.pushed:
+ # nothing to do here
+ return
while True:
+ if self.zombie: # pragma: no cover
+ raise exceptions.Http2ProtocolException("Zombie Stream")
+
self.server_conn.h2.lock.acquire()
+
max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams
if self.server_conn.h2.open_outbound_streams + 1 >= max_streams:
# wait until we get a free slot for a new outgoing stream
self.server_conn.h2.lock.release()
time.sleep(0.1)
- else:
- break
+ continue
- if self.pushed:
- # nothing to do here
- self.server_conn.h2.lock.release()
- return
+ # keep the lock
+ break
- with self.server_conn.h2.lock:
- # We must not assign a stream id if we are already a zombie.
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ # We must not assign a stream id if we are already a zombie.
+ if self.zombie: # pragma: no cover
+ raise exceptions.Http2ProtocolException("Zombie Stream")
- self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
- self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
+ self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
+ self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
headers = message.headers.copy()
headers.insert(0, ":path", message.path)
@@ -398,12 +414,17 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
headers.insert(0, ":scheme", message.scheme)
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
- self.server_conn.h2.safe_send_headers(
- self.is_zombie,
- self.server_stream_id,
- headers,
- )
- self.server_conn.h2.lock.release()
+
+ try:
+ self.server_conn.h2.safe_send_headers(
+ self.is_zombie,
+ self.server_stream_id,
+ headers,
+ )
+ except Exception as e:
+ raise e
+ finally:
+ self.server_conn.h2.lock.release()
self.server_conn.h2.safe_send_body(
self.is_zombie,
@@ -416,6 +437,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def read_response_headers(self):
self.response_arrived.wait()
+ if self.zombie: # pragma: no cover
+ raise exceptions.Http2ProtocolException("Zombie Stream")
+
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
headers.clear(":status")
@@ -437,6 +461,8 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except queue.Empty:
pass
if self.response_data_finished.is_set():
+ if self.zombie: # pragma: no cover
+ raise exceptions.Http2ProtocolException("Zombie Stream")
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
break
@@ -446,11 +472,12 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
- self.client_conn.h2.safe_send_headers(
- self.is_zombie,
- self.client_stream_id,
- headers
- )
+ with self.client_conn.h2.lock:
+ self.client_conn.h2.safe_send_headers(
+ self.is_zombie,
+ self.client_stream_id,
+ headers
+ )
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
@@ -484,4 +511,5 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.log(repr(e), "info")
self.log(traceback.format_exc(), "debug")
- self.zombie = time.time()
+ if not self.zombie:
+ self.zombie = time.time()