diff options
| -rw-r--r-- | mitmproxy/protocol/http2.py | 42 | 
1 files changed, 30 insertions, 12 deletions
| diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index f6261b6b..d848affa 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -5,7 +5,6 @@ import time  import traceback  import h2.exceptions -import hyperframe  import six  from h2 import connection  from h2 import events @@ -55,12 +54,12 @@ class SafeH2Connection(connection.H2Connection):              self.update_settings(new_settings)              self.conn.send(self.data_to_send()) -    def safe_send_headers(self, is_zombie, stream_id, headers): -        # make sure to have a lock -        if is_zombie():  # pragma: no cover -            raise exceptions.Http2ProtocolException("Zombie Stream") -        self.send_headers(stream_id, headers.fields) -        self.conn.send(self.data_to_send()) +    def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs): +        with self.lock: +            if is_zombie():  # pragma: no cover +                raise exceptions.Http2ProtocolException("Zombie Stream") +            self.send_headers(stream_id, headers.fields, **kwargs) +            self.conn.send(self.data_to_send())      def safe_send_body(self, is_zombie, stream_id, chunks):          for chunk in chunks: @@ -141,6 +140,10 @@ class Http2Layer(base.Layer):              headers = netlib.http.Headers([[k, v] for k, v in event.headers])              self.streams[eid] = Http2SingleStreamLayer(self, eid, headers)              self.streams[eid].timestamp_start = time.time() +            if event.priority_updated is not None: +                self.streams[eid].priority_weight = event.priority_updated.weight +                self.streams[eid].priority_depends_on = event.priority_updated.depends_on +                self.streams[eid].priority_exclusive = event.priority_updated.exclusive              self.streams[eid].start()          elif isinstance(event, events.ResponseReceived):              headers = netlib.http.Headers([[k, v] for k, v in event.headers]) @@ -184,7 +187,6 @@ class Http2Layer(base.Layer):                  self.client_conn.send(self.client_conn.h2.data_to_send())                  self._kill_all_streams()                  return False -          elif isinstance(event, events.PushedStreamReceived):              # pushed stream ids should be unique and not dependent on race conditions              # only the parent stream id must be looked up first @@ -210,9 +212,18 @@ class Http2Layer(base.Layer):              if depends_on in self.streams.keys() and self.streams[depends_on].server_stream_id:                  depends_on = self.streams[depends_on].server_stream_id -            # weight is between 1 and 256 (inclusive), but represented as uint8 (0 to 255) -            frame = hyperframe.frame.PriorityFrame(stream_id, depends_on, event.weight - 1, event.exclusive) -            self.server_conn.send(frame.serialize()) +            self.streams[eid].priority_weight = event.weight +            self.streams[eid].priority_depends_on = event.depends_on +            self.streams[eid].priority_exclusive = event.exclusive + +            with self.server_conn.h2.lock: +                self.server_conn.h2.prioritize( +                    stream_id, +                    weight=event.weight, +                    depends_on=depends_on, +                    exclusive=event.exclusive +                ) +                self.server_conn.send(self.server_conn.h2.data_to_send())          elif isinstance(event, events.TrailersReceived):              raise NotImplementedError() @@ -267,7 +278,7 @@ class Http2Layer(base.Layer):                                  self._kill_all_streams()                                  return -                self._cleanup_streams() +                    self._cleanup_streams()          except Exception as e:              self.log(repr(e), "info")              self.log(traceback.format_exc(), "debug") @@ -296,6 +307,10 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)          self.response_queued_data_length = 0          self.response_data_finished = threading.Event() +        self.priority_weight = None +        self.priority_depends_on = None +        self.priority_exclusive = None +      @property      def data_queue(self):          if self.response_arrived.is_set(): @@ -394,6 +409,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)                  self.is_zombie,                  self.server_stream_id,                  headers, +                priority_weight=self.priority_weight, +                priority_depends_on=self.priority_depends_on, +                priority_exclusive=self.priority_exclusive,              )          except Exception as e:              raise e | 
