diff options
| -rw-r--r-- | .appveyor.yml | 2 | ||||
| -rw-r--r-- | mitmproxy/protocol/http2.py | 99 | ||||
| -rw-r--r-- | netlib/http/http2/__init__.py | 2 | ||||
| -rw-r--r-- | netlib/http/http2/utils.py | 37 | ||||
| -rw-r--r-- | pathod/protocols/http2.py | 56 | ||||
| -rw-r--r-- | setup.py | 2 | ||||
| -rw-r--r-- | test/mitmproxy/test_protocol_http2.py | 125 | ||||
| -rw-r--r-- | test/netlib/tservers.py | 12 | ||||
| -rw-r--r-- | test/pathod/test_protocols_http2.py | 31 |
9 files changed, 238 insertions, 128 deletions
diff --git a/.appveyor.yml b/.appveyor.yml index 339342ae..8a83b478 100644 --- a/.appveyor.yml +++ b/.appveyor.yml @@ -25,7 +25,7 @@ install: - "pip install -U tox" test_script: - - ps: "tox -- --cov netlib --cov mitmproxy --cov pathod | Select-String -NotMatch Cryptography_locking_cb" + - ps: "tox --recreate -- --cov netlib --cov mitmproxy --cov pathod | Select-String -NotMatch Cryptography_locking_cb" deploy_script: ps: | diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py index b9a30c7e..9515eef9 100644 --- a/mitmproxy/protocol/http2.py +++ b/mitmproxy/protocol/http2.py @@ -5,7 +5,6 @@ import time import traceback import h2.exceptions -import hyperframe import six from h2 import connection from h2 import events @@ -55,12 +54,12 @@ class SafeH2Connection(connection.H2Connection): self.update_settings(new_settings) self.conn.send(self.data_to_send()) - def safe_send_headers(self, is_zombie, stream_id, headers): - # make sure to have a lock - if is_zombie(): # pragma: no cover - raise exceptions.Http2ProtocolException("Zombie Stream") - self.send_headers(stream_id, headers.fields) - self.conn.send(self.data_to_send()) + def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs): + with self.lock: + if is_zombie(): # pragma: no cover + raise exceptions.Http2ProtocolException("Zombie Stream") + self.send_headers(stream_id, headers.fields, **kwargs) + self.conn.send(self.data_to_send()) def safe_send_body(self, is_zombie, stream_id, chunks): for chunk in chunks: @@ -141,6 +140,12 @@ class Http2Layer(base.Layer): headers = netlib.http.Headers([[k, v] for k, v in event.headers]) self.streams[eid] = Http2SingleStreamLayer(self, eid, headers) self.streams[eid].timestamp_start = time.time() + self.streams[eid].no_body = (event.stream_ended is not None) + if event.priority_updated is not None: + self.streams[eid].priority_weight = event.priority_updated.weight + self.streams[eid].priority_depends_on = event.priority_updated.depends_on + self.streams[eid].priority_exclusive = event.priority_updated.exclusive + self.streams[eid].handled_priority_event = event.priority_updated self.streams[eid].start() elif isinstance(event, events.ResponseReceived): headers = netlib.http.Headers([[k, v] for k, v in event.headers]) @@ -184,7 +189,6 @@ class Http2Layer(base.Layer): self.client_conn.send(self.client_conn.h2.data_to_send()) self._kill_all_streams() return False - elif isinstance(event, events.PushedStreamReceived): # pushed stream ids should be unique and not dependent on race conditions # only the parent stream id must be looked up first @@ -202,6 +206,16 @@ class Http2Layer(base.Layer): self.streams[event.pushed_stream_id].request_data_finished.set() self.streams[event.pushed_stream_id].start() elif isinstance(event, events.PriorityUpdated): + if eid in self.streams: + if self.streams[eid].handled_priority_event is event: + # This event was already handled during stream creation + # HeadersFrame + Priority information as RequestReceived + return True + if eid in self.streams: + self.streams[eid].priority_weight = event.weight + self.streams[eid].priority_depends_on = event.depends_on + self.streams[eid].priority_exclusive = event.exclusive + stream_id = event.stream_id if stream_id in self.streams.keys() and self.streams[stream_id].server_stream_id: stream_id = self.streams[stream_id].server_stream_id @@ -210,9 +224,14 @@ class Http2Layer(base.Layer): if depends_on in self.streams.keys() and self.streams[depends_on].server_stream_id: depends_on = self.streams[depends_on].server_stream_id - # weight is between 1 and 256 (inclusive), but represented as uint8 (0 to 255) - frame = hyperframe.frame.PriorityFrame(stream_id, depends_on, event.weight - 1, event.exclusive) - self.server_conn.send(frame.serialize()) + with self.server_conn.h2.lock: + self.server_conn.h2.prioritize( + stream_id, + weight=event.weight, + depends_on=depends_on, + exclusive=event.exclusive + ) + self.server_conn.send(self.server_conn.h2.data_to_send()) elif isinstance(event, events.TrailersReceived): raise NotImplementedError() @@ -267,7 +286,7 @@ class Http2Layer(base.Layer): self._kill_all_streams() return - self._cleanup_streams() + self._cleanup_streams() except Exception as e: self.log(repr(e), "info") self.log(traceback.format_exc(), "debug") @@ -296,6 +315,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.response_queued_data_length = 0 self.response_data_finished = threading.Event() + self.no_body = False + + self.priority_weight = None + self.priority_depends_on = None + self.priority_exclusive = None + self.handled_priority_event = None + @property def data_queue(self): if self.response_arrived.is_set(): @@ -330,39 +356,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) if self.zombie: # pragma: no cover raise exceptions.Http2ProtocolException("Zombie Stream") - authority = self.request_headers.get(':authority', '') - method = self.request_headers.get(':method', 'GET') - scheme = self.request_headers.get(':scheme', 'https') - path = self.request_headers.get(':path', '/') - self.request_headers.clear(":method") - self.request_headers.clear(":scheme") - self.request_headers.clear(":path") - host = None - port = None - - if path == '*' or path.startswith("/"): - first_line_format = "relative" - elif method == 'CONNECT': # pragma: no cover - raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") - else: # pragma: no cover - first_line_format = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = netlib.http.url.parse(path) - - if authority: - host, _, port = authority.partition(':') - - if not host: - host = 'localhost' - if not port: - port = 443 if scheme == 'https' else 80 - port = int(port) - data = [] while self.request_data_queue.qsize() > 0: data.append(self.request_data_queue.get()) data = b"".join(data) + first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers) + return models.HTTPRequest( first_line_format, method, @@ -420,17 +420,23 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) self.is_zombie, self.server_stream_id, headers, + end_stream=self.no_body, + priority_weight=self.priority_weight, + priority_depends_on=self.priority_depends_on, + priority_exclusive=self.priority_exclusive, ) except Exception as e: raise e finally: self.server_conn.h2.lock.release() - self.server_conn.h2.safe_send_body( - self.is_zombie, - self.server_stream_id, - message.body - ) + if not self.no_body: + self.server_conn.h2.safe_send_body( + self.is_zombie, + self.server_stream_id, + message.body + ) + if self.zombie: # pragma: no cover raise exceptions.Http2ProtocolException("Zombie Stream") @@ -472,6 +478,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread) def send_response_headers(self, response): headers = response.headers.copy() headers.insert(0, ":status", str(response.status_code)) + for forbidden_header in h2.utilities.CONNECTION_HEADERS: + if forbidden_header in headers: + del headers[forbidden_header] with self.client_conn.h2.lock: self.client_conn.h2.safe_send_headers( self.is_zombie, diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py index 6a979a0d..60064190 100644 --- a/netlib/http/http2/__init__.py +++ b/netlib/http/http2/__init__.py @@ -1,6 +1,8 @@ from __future__ import absolute_import, print_function, division from netlib.http.http2 import framereader +from netlib.http.http2.utils import parse_headers __all__ = [ "framereader", + "parse_headers", ] diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py new file mode 100644 index 00000000..4c01952d --- /dev/null +++ b/netlib/http/http2/utils.py @@ -0,0 +1,37 @@ +from netlib.http import url + + +def parse_headers(headers): + authority = headers.get(':authority', '').encode() + method = headers.get(':method', 'GET').encode() + scheme = headers.get(':scheme', 'https').encode() + path = headers.get(':path', '/').encode() + + headers.clear(":method") + headers.clear(":scheme") + headers.clear(":path") + + host = None + port = None + + if path == b'*' or path.startswith(b"/"): + first_line_format = "relative" + elif method == b'CONNECT': # pragma: no cover + raise NotImplementedError("CONNECT over HTTP/2 is not implemented.") + else: # pragma: no cover + first_line_format = "absolute" + # FIXME: verify if path or :host contains what we need + scheme, host, port, _ = url.parse(path) + + if authority: + host, _, port = authority.partition(b':') + + if not host: + host = b'localhost' + + if not port: + port = 443 if scheme == b'https' else 80 + + port = int(port) + + return first_line_format, method, scheme, host, port, path diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py index c8728940..5ad120de 100644 --- a/pathod/protocols/http2.py +++ b/pathod/protocols/http2.py @@ -7,8 +7,7 @@ import hyperframe.frame from hpack.hpack import Encoder, Decoder from netlib import utils, strutils -from netlib.http import url -from netlib.http.http2 import framereader +from netlib.http import http2 import netlib.http.headers import netlib.http.response import netlib.http.request @@ -101,46 +100,15 @@ class HTTP2StateProtocol(object): timestamp_end = time.time() - authority = headers.get(':authority', b'') - method = headers.get(':method', 'GET') - scheme = headers.get(':scheme', 'https') - path = headers.get(':path', '/') - - headers.clear(":method") - headers.clear(":scheme") - headers.clear(":path") - - host = None - port = None - - if path == '*' or path.startswith("/"): - first_line_format = "relative" - elif method == 'CONNECT': - first_line_format = "authority" - if ":" in authority: - host, port = authority.split(":", 1) - else: - host = authority - else: - first_line_format = "absolute" - # FIXME: verify if path or :host contains what we need - scheme, host, port, _ = url.parse(path) - scheme = scheme.decode('ascii') - host = host.decode('ascii') - - if host is None: - host = 'localhost' - if port is None: - port = 80 if scheme == 'http' else 443 - port = int(port) + first_line_format, method, scheme, host, port, path = http2.parse_headers(headers) request = netlib.http.request.Request( first_line_format, - method.encode('ascii'), - scheme.encode('ascii'), - host.encode('ascii'), + method, + scheme, + host, port, - path.encode('ascii'), + path, b"HTTP/2.0", headers, body, @@ -213,10 +181,10 @@ class HTTP2StateProtocol(object): headers = request.headers.copy() if ':authority' not in headers: - headers.insert(0, b':authority', authority.encode('ascii')) - headers.insert(0, b':scheme', request.scheme.encode('ascii')) - headers.insert(0, b':path', request.path.encode('ascii')) - headers.insert(0, b':method', request.method.encode('ascii')) + headers.insert(0, ':authority', authority) + headers.insert(0, ':scheme', request.scheme) + headers.insert(0, ':path', request.path) + headers.insert(0, ':method', request.method) if hasattr(request, 'stream_id'): stream_id = request.stream_id @@ -286,7 +254,7 @@ class HTTP2StateProtocol(object): def read_frame(self, hide=False): while True: - frm = framereader.http2_read_frame(self.tcp_handler.rfile) + frm = http2.framereader.http2_read_frame(self.tcp_handler.rfile) if not hide and self.dump_frames: # pragma no cover print(frm.human_readable("<<")) @@ -429,7 +397,7 @@ class HTTP2StateProtocol(object): self._handle_unexpected_frame(frm) headers = netlib.http.headers.Headers( - (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks) + [[k, v] for k, v in self.decoder.decode(header_blocks, raw=True)] ) return stream_id, headers, body @@ -66,7 +66,7 @@ setup( "construct>=2.5.2, <2.6", "cryptography>=1.3, <1.5", "Flask>=0.10.1, <0.12", - "h2>=2.3.1, <3", + "h2>=2.4.0, <3", "html2text>=2016.1.8, <=2016.5.29", "hyperframe>=4.0.1, <5", "lxml>=3.5.0, <3.7", diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py index 932c8df2..2eb0b120 100644 --- a/test/mitmproxy/test_protocol_http2.py +++ b/test/mitmproxy/test_protocol_http2.py @@ -3,9 +3,10 @@ from __future__ import (absolute_import, print_function, division) import pytest -import traceback import os import tempfile +import traceback + import h2 from mitmproxy.proxy.config import ProxyConfig @@ -46,6 +47,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase): self.wfile.write(h2_conn.data_to_send()) self.wfile.flush() + if 'h2_server_settings' in self.kwargs: + h2_conn.update_settings(self.kwargs['h2_server_settings']) + self.wfile.write(h2_conn.data_to_send()) + self.wfile.flush() + done = False while not done: try: @@ -508,3 +514,120 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase): if len(self.master.state.flows) == 1: assert self.master.state.flows[0].response is None + + +@requires_alpn +class TestMaxConcurrentStreams(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2}) + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.ConnectionTerminated): + return False + elif isinstance(event, h2.events.RequestReceived): + h2_conn.send_headers(event.stream_id, [ + (':status', '200'), + ('X-Stream-ID', str(event.stream_id)), + ]) + h2_conn.send_data(event.stream_id, b'Stream-ID {}'.format(event.stream_id)) + h2_conn.end_stream(event.stream_id) + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_max_concurrent_streams(self): + client, h2_conn = self._setup_connection() + new_streams = [1, 3, 5, 7, 9, 11] + for id in new_streams: + # this will exceed MAX_CONCURRENT_STREAMS on the server connection + # and cause mitmproxy to throttle stream creation to the server + self._send_request(client.wfile, h2_conn, stream_id=id, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ('X-Stream-ID', str(id)), + ]) + + ended_streams = 0 + while ended_streams != len(new_streams): + try: + header, body = framereader.http2_read_raw_frame(client.rfile) + events = h2_conn.receive_data(b''.join([header, body])) + except: + break + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + for event in events: + if isinstance(event, h2.events.StreamEnded): + ended_streams += 1 + + h2_conn.close_connection() + client.wfile.write(h2_conn.data_to_send()) + client.wfile.flush() + + assert len(self.master.state.flows) == len(new_streams) + for flow in self.master.state.flows: + assert flow.response.status_code == 200 + assert "Stream-ID" in flow.response.body + + +@requires_alpn +class TestConnectionTerminated(_Http2TestBase, _Http2ServerBase): + + @classmethod + def setup_class(self): + _Http2TestBase.setup_class() + _Http2ServerBase.setup_class() + + @classmethod + def teardown_class(self): + _Http2TestBase.teardown_class() + _Http2ServerBase.teardown_class() + + @classmethod + def handle_server_event(self, event, h2_conn, rfile, wfile): + if isinstance(event, h2.events.RequestReceived): + h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data='foobar') + wfile.write(h2_conn.data_to_send()) + wfile.flush() + return True + + def test_connection_terminated(self): + client, h2_conn = self._setup_connection() + + self._send_request(client.wfile, h2_conn, headers=[ + (':authority', "127.0.0.1:%s" % self.server.server.address.port), + (':method', 'GET'), + (':scheme', 'https'), + (':path', '/'), + ]) + + done = False + connection_terminated_event = None + while not done: + try: + raw = b''.join(framereader.http2_read_raw_frame(client.rfile)) + events = h2_conn.receive_data(raw) + for event in events: + if isinstance(event, h2.events.ConnectionTerminated): + connection_terminated_event = event + done = True + except: + break + + assert len(self.master.state.flows) == 1 + assert connection_terminated_event is not None + assert connection_terminated_event.error_code == 5 + assert connection_terminated_event.last_stream_id == 42 + assert connection_terminated_event.additional_data == 'foobar' diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py index 803aaa72..666f97ac 100644 --- a/test/netlib/tservers.py +++ b/test/netlib/tservers.py @@ -24,7 +24,7 @@ class _ServerThread(threading.Thread): class _TServer(tcp.TCPServer): - def __init__(self, ssl, q, handler_klass, addr): + def __init__(self, ssl, q, handler_klass, addr, **kwargs): """ ssl: A dictionary of SSL parameters: @@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer): self.q = q self.handler_klass = handler_klass + if self.handler_klass is not None: + self.handler_klass.kwargs = kwargs self.last_handler = None def handle_client_connection(self, request, client_address): @@ -89,16 +91,16 @@ class ServerTestBase(object): addr = ("localhost", 0) @classmethod - def setup_class(cls): + def setup_class(cls, **kwargs): cls.q = queue.Queue() - s = cls.makeserver() + s = cls.makeserver(**kwargs) cls.port = s.address.port cls.server = _ServerThread(s) cls.server.start() @classmethod - def makeserver(cls): - return _TServer(cls.ssl, cls.q, cls.handler, cls.addr) + def makeserver(cls, **kwargs): + return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs) @classmethod def teardown_class(cls): diff --git a/test/pathod/test_protocols_http2.py b/test/pathod/test_protocols_http2.py index e42c2858..8d7efc82 100644 --- a/test/pathod/test_protocols_http2.py +++ b/test/pathod/test_protocols_http2.py @@ -367,37 +367,6 @@ class TestReadRequestAbsolute(netlib_tservers.ServerTestBase): assert req.port == 22 -class TestReadRequestConnect(netlib_tservers.ServerTestBase): - class handler(tcp.BaseHandler): - def handle(self): - self.wfile.write( - codecs.decode('00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085', 'hex_codec')) - self.wfile.write( - codecs.decode('00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7', 'hex_codec')) - self.wfile.flush() - - ssl = True - - def test_connect(self): - c = tcp.TCPClient(("127.0.0.1", self.port)) - with c.connect(): - c.convert_to_ssl() - protocol = HTTP2StateProtocol(c, is_server=True) - protocol.connection_preface_performed = True - - req = protocol.read_request(NotImplemented) - assert req.first_line_format == "authority" - assert req.method == "CONNECT" - assert req.host == "address" - assert req.port == 22 - - req = protocol.read_request(NotImplemented) - assert req.first_line_format == "authority" - assert req.method == "CONNECT" - assert req.host == "example.com" - assert req.port == 443 - - class TestReadResponse(netlib_tservers.ServerTestBase): class handler(tcp.BaseHandler): def handle(self): |
