aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorThomas Kriechbaumer <thomas@kriechbaumer.name>2016-08-13 20:58:01 +0200
committerThomas Kriechbaumer <thomas@kriechbaumer.name>2016-08-14 12:43:56 +0200
commit0cc695407dab288ab3854179a2b87302d40204a2 (patch)
treeb0ac4c2958fdd61e8936697dc68e2790ededcc81
parent65677bdd284ef71184185671f4ae0b3713b5a3de (diff)
downloadmitmproxy-0cc695407dab288ab3854179a2b87302d40204a2.tar.gz
mitmproxy-0cc695407dab288ab3854179a2b87302d40204a2.tar.bz2
mitmproxy-0cc695407dab288ab3854179a2b87302d40204a2.zip
http2: simplify zombie detection
-rw-r--r--mitmproxy/protocol/http2.py73
1 files changed, 37 insertions, 36 deletions
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py
index eb5586cb..1b2f5cf5 100644
--- a/mitmproxy/protocol/http2.py
+++ b/mitmproxy/protocol/http2.py
@@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division
import threading
import time
import traceback
+import functools
import h2.exceptions
import six
@@ -54,21 +55,18 @@ class SafeH2Connection(connection.H2Connection):
self.update_settings(new_settings)
self.conn.send(self.data_to_send())
- def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs):
+ def safe_send_headers(self, raise_zombie, stream_id, headers, **kwargs):
with self.lock:
- if is_zombie(): # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ raise_zombie()
self.send_headers(stream_id, headers.fields, **kwargs)
self.conn.send(self.data_to_send())
- def safe_send_body(self, is_zombie, stream_id, chunks):
+ def safe_send_body(self, raise_zombie, stream_id, chunks):
for chunk in chunks:
position = 0
while position < len(chunk):
self.lock.acquire()
- if is_zombie(): # pragma: no cover
- self.lock.release()
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ raise_zombie(self.lock.release)
max_outbound_frame_size = self.max_outbound_frame_size
frame_chunk = chunk[position:position + max_outbound_frame_size]
if self.local_flow_control_window(stream_id) < len(frame_chunk):
@@ -84,8 +82,7 @@ class SafeH2Connection(connection.H2Connection):
self.lock.release()
position += max_outbound_frame_size
with self.lock:
- if is_zombie(): # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ raise_zombie()
self.end_stream(stream_id)
self.conn.send(self.data_to_send())
@@ -344,6 +341,17 @@ class Http2Layer(base.Layer):
self._kill_all_streams()
+def detect_zombie_stream(func):
+ @functools.wraps(func)
+ def wrapper(self, *args, **kwargs):
+ self.raise_zombie()
+ result = func(self, *args, **kwargs)
+ self.raise_zombie()
+ return result
+
+ return wrapper
+
+
class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread):
def __init__(self, ctx, stream_id, request_headers):
@@ -412,15 +420,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def queued_data_length(self, v):
self.request_queued_data_length = v
- def is_zombie(self):
- return self.zombie is not None
+ def raise_zombie(self, pre_command=None):
+ if self.zombie is not None:
+ if pre_command is not None:
+ pre_command()
+ raise exceptions.Http2ProtocolException("Zombie Stream")
+ @detect_zombie_stream
def read_request(self):
self.request_data_finished.wait()
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
-
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
@@ -445,15 +454,14 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def read_request_body(self, request): # pragma: no cover
raise NotImplementedError()
+ @detect_zombie_stream
def send_request(self, message):
if self.pushed:
# nothing to do here
return
while True:
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
-
+ self.raise_zombie()
self.server_conn.h2.lock.acquire()
max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams
@@ -467,8 +475,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
break
# We must not assign a stream id if we are already a zombie.
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
self.server_stream_id = self.server_conn.h2.get_next_available_stream_id()
self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id
@@ -490,7 +497,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
try:
self.server_conn.h2.safe_send_headers(
- self.is_zombie,
+ self.raise_zombie,
self.server_stream_id,
headers,
end_stream=self.no_body,
@@ -505,19 +512,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
if not self.no_body:
self.server_conn.h2.safe_send_body(
- self.is_zombie,
+ self.raise_zombie,
self.server_stream_id,
[message.body]
)
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
-
+ @detect_zombie_stream
def read_response_headers(self):
self.response_arrived.wait()
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
status_code = int(self.response_headers.get(':status', 502))
headers = self.response_headers.copy()
@@ -533,6 +537,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
timestamp_end=self.timestamp_end,
)
+ @detect_zombie_stream
def read_response_body(self, request, response):
while True:
try:
@@ -540,14 +545,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
except queue.Empty: # pragma: no cover
pass
if self.response_data_finished.is_set():
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
while self.response_data_queue.qsize() > 0:
yield self.response_data_queue.get()
break
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.raise_zombie()
+ @detect_zombie_stream
def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
@@ -556,21 +560,18 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
del headers[forbidden_header]
with self.client_conn.h2.lock:
self.client_conn.h2.safe_send_headers(
- self.is_zombie,
+ self.raise_zombie,
self.client_stream_id,
headers
)
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
+ @detect_zombie_stream
def send_response_body(self, _response, chunks):
self.client_conn.h2.safe_send_body(
- self.is_zombie,
+ self.raise_zombie,
self.client_stream_id,
chunks
)
- if self.zombie: # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
def run(self):
layer = http.HttpLayer(self, self.mode)