From 0cc695407dab288ab3854179a2b87302d40204a2 Mon Sep 17 00:00:00 2001 From: Thomas Kriechbaumer Date: Sat, 13 Aug 2016 20:58:01 +0200 Subject: http2: simplify zombie detection --- mitmproxy/protocol/http2.py | 73 +++++++++++++++++++++++---------------------- 1 file changed, 37 insertions(+), 36 deletions(-) diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index eb5586cb..1b2f5cf5 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -3,6 +3,7 @@ from __future__ import absolute_import, print_function, division import threading import time import traceback +import functools import h2.exceptions import six @@ -54,21 +55,18 @@ class SafeH2Connection(connection.H2Connection): self.update_settings(new_settings) self.conn.send(self.data_to_send()) - def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs): + def safe_send_headers(self, raise_zombie, stream_id, headers, **kwargs): with self.lock: - if is_zombie(): # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + raise_zombie() self.send_headers(stream_id, headers.fields, **kwargs) self.conn.send(self.data_to_send()) - def safe_send_body(self, is_zombie, stream_id, chunks): + def safe_send_body(self, raise_zombie, stream_id, chunks): for chunk in chunks: position = 0 while position < len(chunk): self.lock.acquire() - if is_zombie(): # pragma: no cover - self.lock.release() - raise exceptions.Http2ProtocolException("Zombie Stream") + raise_zombie(self.lock.release) max_outbound_frame_size = self.max_outbound_frame_size frame_chunk = chunk[position:position + max_outbound_frame_size] if self.local_flow_control_window(stream_id) < len(frame_chunk): @@ -84,8 +82,7 @@ class SafeH2Connection(connection.H2Connection): self.lock.release() position += max_outbound_frame_size with self.lock: - if is_zombie(): # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + raise_zombie() self.end_stream(stream_id) self.conn.send(self.data_to_send()) @@ -344,6 +341,17 @@ class Http2Layer(base.Layer): self._kill_all_streams() +def detect_zombie_stream(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + self.raise_zombie() + result = func(self, *args, **kwargs) + self.raise_zombie() + return result + + return wrapper + + class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread): def __init__(self, ctx, stream_id, request_headers): @@ -412,15 +420,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def queued_data_length(self, v): self.request_queued_data_length = v - def is_zombie(self): - return self.zombie is not None + def raise_zombie(self, pre_command=None): + if self.zombie is not None: + if pre_command is not None: + pre_command() + raise exceptions.Http2ProtocolException("Zombie Stream") + @detect_zombie_stream def read_request(self): self.request_data_finished.wait() - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - data = [] while self.request_data_queue.qsize() > 0: data.append(self.request_data_queue.get()) @@ -445,15 +454,14 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def read_request_body(self, request): # pragma: no cover raise NotImplementedError() + @detect_zombie_stream def send_request(self, message): if self.pushed: # nothing to do here return while True: - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - + self.raise_zombie() self.server_conn.h2.lock.acquire() max_streams = self.server_conn.h2.remote_settings.max_concurrent_streams @@ -467,8 +475,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) break # We must not assign a stream id if we are already a zombie. - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() self.server_stream_id = self.server_conn.h2.get_next_available_stream_id() self.server_to_client_stream_ids[self.server_stream_id] = self.client_stream_id @@ -490,7 +497,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) try: self.server_conn.h2.safe_send_headers( - self.is_zombie, + self.raise_zombie, self.server_stream_id, headers, end_stream=self.no_body, @@ -505,19 +512,16 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) if not self.no_body: self.server_conn.h2.safe_send_body( - self.is_zombie, + self.raise_zombie, self.server_stream_id, [message.body] ) - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - + @detect_zombie_stream def read_response_headers(self): self.response_arrived.wait() - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() status_code = int(self.response_headers.get(':status', 502)) headers = self.response_headers.copy() @@ -533,6 +537,7 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) timestamp_end=self.timestamp_end, ) + @detect_zombie_stream def read_response_body(self, request, response): while True: try: @@ -540,14 +545,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) except queue.Empty: # pragma: no cover pass if self.response_data_finished.is_set(): - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() while self.response_data_queue.qsize() > 0: yield self.response_data_queue.get() break - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + self.raise_zombie() + @detect_zombie_stream def send_response_headers(self, response): headers = response.headers.copy() headers.insert(0, ":status", str(response.status_code)) @@ -556,21 +560,18 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) del headers[forbidden_header] with self.client_conn.h2.lock: self.client_conn.h2.safe_send_headers( - self.is_zombie, + self.raise_zombie, self.client_stream_id, headers ) - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") + @detect_zombie_stream def send_response_body(self, _response, chunks): self.client_conn.h2.safe_send_body( - self.is_zombie, + self.raise_zombie, self.client_stream_id, chunks ) - if self.zombie: # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") def run(self): layer = http.HttpLayer(self, self.mode) -- cgit v1.2.3