aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/protocol/http2.py76
-rw-r--r--test/mitmproxy/test_protocol_http2.py6
2 files changed, 43 insertions, 39 deletions
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py
index eb5586cb..0e42d619 100644
--- a/mitmproxy/protocol/http2.py
+++ b/mitmproxy/protocol/http2.py
@@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division
import threading
import time
import traceback
+import functools
import h2.exceptions
import six
@@ -54,21 +55,18 @@ class SafeH2Connection(connection.H2Connection):
self.update_settings(new_settings)
self.conn.send(self.data_to_send())
- def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs):
+ def safe_send_headers(self, raise_zombie, stream_id, headers, **kwargs):
with self.lock:
- if is_zombie(): # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ raise_zombie()
self.send_headers(stream_id, headers.fields, **kwargs)
self.conn.send(self.data_to_send())
- def safe_send_body(self, is_zombie, stream_id, chunks):
+ def safe_send_body(self, raise_zombie, stream_id, chunks):
for chunk in chunks:
position = 0
while position < len(chunk):
self.lock.acquire()
- if is_zombie(): # pragma: no cover
- self.lock.release()
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ raise_zombie(self.lock.release)
max_outbound_frame_size = self.max_outbound_frame_size
frame_chunk = chunk[position:position + max_outbound_frame_size]
if self.local_flow_control_window(stream_id) < len(frame_chunk):
@@ -84,8 +82,7 @@ class SafeH2Connection(connection.H2Connection):
self.lock.release()
position += max_outbound_frame_size
with self.lock:
- if is_zombie(): # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ raise_zombie()
self.end_stream(stream_id)
self.conn.send(self.data_to_send())
@@ -344,6 +341,17 @@ class Http2Layer(base.Layer):
self._kill_all_streams()
+def detect_zombie_stream(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ self.raise_zombie()
+ result = func(self, *args, **kwargs)
+ self.raise_zombie()
+ return result
+
+ return wrapper
+
+
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
def __init__(self, ctx, stream_id, request_headers):
@@ -412,15 +420,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def queued_data_length(self, v):
self.request_queued_data_length = v
- def is_zombie(self):
- return self.zombie is not None
+ def raise_zombie(self, pre_command=None):
+ if self.zombie is not None:
+ if pre_command is not None:
+ pre_command()
+ raise exceptions.Http2ProtocolException("Zombie Stream")
+ @detect_zombie_stream
def read_request(self):
self.request_data_finished.wait()
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
-
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
@@ -445,15 +454,14 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def read_request_body(self, request): # pragma: no cover
raise NotImplementedError()
+ @detect_zombie_stream
def send_request(self, message):
if self.pushed:
# nothing to do here
return
while True:
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
-
+ self.raise_zombie()
self.server_conn.h2.lock.acquire()
max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams
@@ -467,8 +475,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
break
# We must not assign a stream id if we are already a zombie.
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
@@ -490,7 +497,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
try:
self.server_conn.h2.safe_send_headers(
- self.is_zombie,
+ self.raise_zombie,
self.server_stream_id,
headers,
end_stream=self.no_body,
@@ -505,19 +512,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
if not self.no_body:
self.server_conn.h2.safe_send_body(
- self.is_zombie,
+ self.raise_zombie,
self.server_stream_id,
[message.body]
)
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
-
+ @detect_zombie_stream
def read_response_headers(self):
self.response_arrived.wait()
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
@@ -533,6 +537,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
timestamp_end=self.timestamp_end,
)
+ @detect_zombie_stream
def read_response_body(self, request, response):
while True:
try:
@@ -540,14 +545,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except queue.Empty: # pragma: no cover
pass
if self.response_data_finished.is_set():
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
break
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
+ @detect_zombie_stream
def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
@@ -556,21 +560,21 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
del headers[forbidden_header]
with self.client_conn.h2.lock:
self.client_conn.h2.safe_send_headers(
- self.is_zombie,
+ self.raise_zombie,
self.client_stream_id,
headers
)
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ @detect_zombie_stream
def send_response_body(self, _response, chunks):
self.client_conn.h2.safe_send_body(
- self.is_zombie,
+ self.raise_zombie,
self.client_stream_id,
chunks
)
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+
+ def __call__(self):
+ raise EnvironmentError('Http2SingleStreamLayer must be run as thread')
def run(self):
layer = http.HttpLayer(self, self.mode)
diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py
index f0fa9a40..873c89c3 100644
--- a/test/mitmproxy/test_protocol_http2.py
+++ b/test/mitmproxy/test_protocol_http2.py
@@ -849,15 +849,15 @@ class TestMaxConcurrentStreams(_Http2Test):
def test_max_concurrent_streams(self):
client, h2_conn = self._setup_connection()
new_streams = [1, 3, 5, 7, 9, 11]
- for id in new_streams:
+ for stream_id in new_streams:
# this will exceed MAX_CONCURRENT_STREAMS on the server connection
# and cause mitmproxy to throttle stream creation to the server
- self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
+ self._send_request(client.wfile, h2_conn, stream_id=stream_id, headers=[
(':authority', "127.0.0.1:{}".format(self.server.server.address.port)),
(':method', 'GET'),
(':scheme', 'https'),
(':path', '/'),
- ('X-Stream-ID', str(id)),
+ ('X-Stream-ID', str(stream_id)),
])
ended_streams = 0