aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.appveyor.yml2
-rw-r--r--mitmproxy/protocol/http2.py99
-rw-r--r--netlib/http/http2/__init__.py2
-rw-r--r--netlib/http/http2/utils.py37
-rw-r--r--pathod/protocols/http2.py56
-rw-r--r--setup.py2
-rw-r--r--test/mitmproxy/test_protocol_http2.py125
-rw-r--r--test/netlib/tservers.py12
-rw-r--r--test/pathod/test_protocols_http2.py31
9 files changed, 238 insertions, 128 deletions
diff --git a/.appveyor.yml b/.appveyor.yml
index 339342ae..8a83b478 100644
--- a/.appveyor.yml
+++ b/.appveyor.yml
@@ -25,7 +25,7 @@ install:
- "pip install -U tox"
test_script:
- - ps: "tox -- --cov netlib --cov mitmproxy --cov pathod | Select-String -NotMatch Cryptography_locking_cb"
+ - ps: "tox --recreate -- --cov netlib --cov mitmproxy --cov pathod | Select-String -NotMatch Cryptography_locking_cb"
deploy_script:
ps: |
diff --git a/mitmproxy/protocol/http2.py b/mitmproxy/protocol/http2.py
index b9a30c7e..9515eef9 100644
--- a/mitmproxy/protocol/http2.py
+++ b/mitmproxy/protocol/http2.py
@@ -5,7 +5,6 @@ import time
import traceback
import h2.exceptions
-import hyperframe
import six
from h2 import connection
from h2 import events
@@ -55,12 +54,12 @@ class SafeH2Connection(connection.H2Connection):
self.update_settings(new_settings)
self.conn.send(self.data_to_send())
- def safe_send_headers(self, is_zombie, stream_id, headers):
- # make sure to have a lock
- if is_zombie(): # pragma: no cover
- raise exceptions.Http2ProtocolException("Zombie Stream")
- self.send_headers(stream_id, headers.fields)
- self.conn.send(self.data_to_send())
+ def safe_send_headers(self, is_zombie, stream_id, headers, **kwargs):
+ with self.lock:
+ if is_zombie(): # pragma: no cover
+ raise exceptions.Http2ProtocolException("Zombie Stream")
+ self.send_headers(stream_id, headers.fields, **kwargs)
+ self.conn.send(self.data_to_send())
def safe_send_body(self, is_zombie, stream_id, chunks):
for chunk in chunks:
@@ -141,6 +140,12 @@ class Http2Layer(base.Layer):
headers = netlib.http.Headers([[k, v] for k, v in event.headers])
self.streams[eid] = Http2SingleStreamLayer(self, eid, headers)
self.streams[eid].timestamp_start = time.time()
+ self.streams[eid].no_body = (event.stream_ended is not None)
+ if event.priority_updated is not None:
+ self.streams[eid].priority_weight = event.priority_updated.weight
+ self.streams[eid].priority_depends_on = event.priority_updated.depends_on
+ self.streams[eid].priority_exclusive = event.priority_updated.exclusive
+ self.streams[eid].handled_priority_event = event.priority_updated
self.streams[eid].start()
elif isinstance(event, events.ResponseReceived):
headers = netlib.http.Headers([[k, v] for k, v in event.headers])
@@ -184,7 +189,6 @@ class Http2Layer(base.Layer):
self.client_conn.send(self.client_conn.h2.data_to_send())
self._kill_all_streams()
return False
-
elif isinstance(event, events.PushedStreamReceived):
# pushed stream ids should be unique and not dependent on race conditions
# only the parent stream id must be looked up first
@@ -202,6 +206,16 @@ class Http2Layer(base.Layer):
self.streams[event.pushed_stream_id].request_data_finished.set()
self.streams[event.pushed_stream_id].start()
elif isinstance(event, events.PriorityUpdated):
+ if eid in self.streams:
+ if self.streams[eid].handled_priority_event is event:
+ # This event was already handled during stream creation
+ # HeadersFrame + Priority information as RequestReceived
+ return True
+ if eid in self.streams:
+ self.streams[eid].priority_weight = event.weight
+ self.streams[eid].priority_depends_on = event.depends_on
+ self.streams[eid].priority_exclusive = event.exclusive
+
stream_id = event.stream_id
if stream_id in self.streams.keys() and self.streams[stream_id].server_stream_id:
stream_id = self.streams[stream_id].server_stream_id
@@ -210,9 +224,14 @@ class Http2Layer(base.Layer):
if depends_on in self.streams.keys() and self.streams[depends_on].server_stream_id:
depends_on = self.streams[depends_on].server_stream_id
- # weight is between 1 and 256 (inclusive), but represented as uint8 (0 to 255)
- frame = hyperframe.frame.PriorityFrame(stream_id, depends_on, event.weight - 1, event.exclusive)
- self.server_conn.send(frame.serialize())
+ with self.server_conn.h2.lock:
+ self.server_conn.h2.prioritize(
+ stream_id,
+ weight=event.weight,
+ depends_on=depends_on,
+ exclusive=event.exclusive
+ )
+ self.server_conn.send(self.server_conn.h2.data_to_send())
elif isinstance(event, events.TrailersReceived):
raise NotImplementedError()
@@ -267,7 +286,7 @@ class Http2Layer(base.Layer):
self._kill_all_streams()
return
- self._cleanup_streams()
+ self._cleanup_streams()
except Exception as e:
self.log(repr(e), "info")
self.log(traceback.format_exc(), "debug")
@@ -296,6 +315,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.response_queued_data_length = 0
self.response_data_finished = threading.Event()
+ self.no_body = False
+
+ self.priority_weight = None
+ self.priority_depends_on = None
+ self.priority_exclusive = None
+ self.handled_priority_event = None
+
@property
def data_queue(self):
if self.response_arrived.is_set():
@@ -330,39 +356,13 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
- authority = self.request_headers.get(':authority', '')
- method = self.request_headers.get(':method', 'GET')
- scheme = self.request_headers.get(':scheme', 'https')
- path = self.request_headers.get(':path', '/')
- self.request_headers.clear(":method")
- self.request_headers.clear(":scheme")
- self.request_headers.clear(":path")
- host = None
- port = None
-
- if path == '*' or path.startswith("/"):
- first_line_format = "relative"
- elif method == 'CONNECT': # pragma: no cover
- raise NotImplementedError("CONNECT over HTTP/2 is not implemented.")
- else: # pragma: no cover
- first_line_format = "absolute"
- # FIXME: verify if path or :host contains what we need
- scheme, host, port, _ = netlib.http.url.parse(path)
-
- if authority:
- host, _, port = authority.partition(':')
-
- if not host:
- host = 'localhost'
- if not port:
- port = 443 if scheme == 'https' else 80
- port = int(port)
-
data = []
while self.request_data_queue.qsize() > 0:
data.append(self.request_data_queue.get())
data = b"".join(data)
+ first_line_format, method, scheme, host, port, path = http2.parse_headers(self.request_headers)
+
return models.HTTPRequest(
first_line_format,
method,
@@ -420,17 +420,23 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
self.is_zombie,
self.server_stream_id,
headers,
+ end_stream=self.no_body,
+ priority_weight=self.priority_weight,
+ priority_depends_on=self.priority_depends_on,
+ priority_exclusive=self.priority_exclusive,
)
except Exception as e:
raise e
finally:
self.server_conn.h2.lock.release()
- self.server_conn.h2.safe_send_body(
- self.is_zombie,
- self.server_stream_id,
- message.body
- )
+ if not self.no_body:
+ self.server_conn.h2.safe_send_body(
+ self.is_zombie,
+ self.server_stream_id,
+ message.body
+ )
+
if self.zombie: # pragma: no cover
raise exceptions.Http2ProtocolException("Zombie Stream")
@@ -472,6 +478,9 @@ class Http2SingleStreamLayer(http._HttpTransmissionLayer, basethread.BaseThread)
def send_response_headers(self, response):
headers = response.headers.copy()
headers.insert(0, ":status", str(response.status_code))
+ for forbidden_header in h2.utilities.CONNECTION_HEADERS:
+ if forbidden_header in headers:
+ del headers[forbidden_header]
with self.client_conn.h2.lock:
self.client_conn.h2.safe_send_headers(
self.is_zombie,
diff --git a/netlib/http/http2/__init__.py b/netlib/http/http2/__init__.py
index 6a979a0d..60064190 100644
--- a/netlib/http/http2/__init__.py
+++ b/netlib/http/http2/__init__.py
@@ -1,6 +1,8 @@
from __future__ import absolute_import, print_function, division
from netlib.http.http2 import framereader
+from netlib.http.http2.utils import parse_headers
__all__ = [
"framereader",
+ "parse_headers",
]
diff --git a/netlib/http/http2/utils.py b/netlib/http/http2/utils.py
new file mode 100644
index 00000000..4c01952d
--- /dev/null
+++ b/netlib/http/http2/utils.py
@@ -0,0 +1,37 @@
+from netlib.http import url
+
+
+def parse_headers(headers):
+ authority = headers.get(':authority', '').encode()
+ method = headers.get(':method', 'GET').encode()
+ scheme = headers.get(':scheme', 'https').encode()
+ path = headers.get(':path', '/').encode()
+
+ headers.clear(":method")
+ headers.clear(":scheme")
+ headers.clear(":path")
+
+ host = None
+ port = None
+
+ if path == b'*' or path.startswith(b"/"):
+ first_line_format = "relative"
+ elif method == b'CONNECT': # pragma: no cover
+ raise NotImplementedError("CONNECT over HTTP/2 is not implemented.")
+ else: # pragma: no cover
+ first_line_format = "absolute"
+ # FIXME: verify if path or :host contains what we need
+ scheme, host, port, _ = url.parse(path)
+
+ if authority:
+ host, _, port = authority.partition(b':')
+
+ if not host:
+ host = b'localhost'
+
+ if not port:
+ port = 443 if scheme == b'https' else 80
+
+ port = int(port)
+
+ return first_line_format, method, scheme, host, port, path
diff --git a/pathod/protocols/http2.py b/pathod/protocols/http2.py
index c8728940..5ad120de 100644
--- a/pathod/protocols/http2.py
+++ b/pathod/protocols/http2.py
@@ -7,8 +7,7 @@ import hyperframe.frame
from hpack.hpack import Encoder, Decoder
from netlib import utils, strutils
-from netlib.http import url
-from netlib.http.http2 import framereader
+from netlib.http import http2
import netlib.http.headers
import netlib.http.response
import netlib.http.request
@@ -101,46 +100,15 @@ class HTTP2StateProtocol(object):
timestamp_end = time.time()
- authority = headers.get(':authority', b'')
- method = headers.get(':method', 'GET')
- scheme = headers.get(':scheme', 'https')
- path = headers.get(':path', '/')
-
- headers.clear(":method")
- headers.clear(":scheme")
- headers.clear(":path")
-
- host = None
- port = None
-
- if path == '*' or path.startswith("/"):
- first_line_format = "relative"
- elif method == 'CONNECT':
- first_line_format = "authority"
- if ":" in authority:
- host, port = authority.split(":", 1)
- else:
- host = authority
- else:
- first_line_format = "absolute"
- # FIXME: verify if path or :host contains what we need
- scheme, host, port, _ = url.parse(path)
- scheme = scheme.decode('ascii')
- host = host.decode('ascii')
-
- if host is None:
- host = 'localhost'
- if port is None:
- port = 80 if scheme == 'http' else 443
- port = int(port)
+ first_line_format, method, scheme, host, port, path = http2.parse_headers(headers)
request = netlib.http.request.Request(
first_line_format,
- method.encode('ascii'),
- scheme.encode('ascii'),
- host.encode('ascii'),
+ method,
+ scheme,
+ host,
port,
- path.encode('ascii'),
+ path,
b"HTTP/2.0",
headers,
body,
@@ -213,10 +181,10 @@ class HTTP2StateProtocol(object):
headers = request.headers.copy()
if ':authority' not in headers:
- headers.insert(0, b':authority', authority.encode('ascii'))
- headers.insert(0, b':scheme', request.scheme.encode('ascii'))
- headers.insert(0, b':path', request.path.encode('ascii'))
- headers.insert(0, b':method', request.method.encode('ascii'))
+ headers.insert(0, ':authority', authority)
+ headers.insert(0, ':scheme', request.scheme)
+ headers.insert(0, ':path', request.path)
+ headers.insert(0, ':method', request.method)
if hasattr(request, 'stream_id'):
stream_id = request.stream_id
@@ -286,7 +254,7 @@ class HTTP2StateProtocol(object):
def read_frame(self, hide=False):
while True:
- frm = framereader.http2_read_frame(self.tcp_handler.rfile)
+ frm = http2.framereader.http2_read_frame(self.tcp_handler.rfile)
if not hide and self.dump_frames: # pragma no cover
print(frm.human_readable("<<"))
@@ -429,7 +397,7 @@ class HTTP2StateProtocol(object):
self._handle_unexpected_frame(frm)
headers = netlib.http.headers.Headers(
- (k.encode('ascii'), v.encode('ascii')) for k, v in self.decoder.decode(header_blocks)
+ [[k, v] for k, v in self.decoder.decode(header_blocks, raw=True)]
)
return stream_id, headers, body
diff --git a/setup.py b/setup.py
index 564eb4d7..0de4ba32 100644
--- a/setup.py
+++ b/setup.py
@@ -66,7 +66,7 @@ setup(
"construct>=2.5.2, <2.6",
"cryptography>=1.3, <1.5",
"Flask>=0.10.1, <0.12",
- "h2>=2.3.1, <3",
+ "h2>=2.4.0, <3",
"html2text>=2016.1.8, <=2016.5.29",
"hyperframe>=4.0.1, <5",
"lxml>=3.5.0, <3.7",
diff --git a/test/mitmproxy/test_protocol_http2.py b/test/mitmproxy/test_protocol_http2.py
index 932c8df2..2eb0b120 100644
--- a/test/mitmproxy/test_protocol_http2.py
+++ b/test/mitmproxy/test_protocol_http2.py
@@ -3,9 +3,10 @@
from __future__ import (absolute_import, print_function, division)
import pytest
-import traceback
import os
import tempfile
+import traceback
+
import h2
from mitmproxy.proxy.config import ProxyConfig
@@ -46,6 +47,11 @@ class _Http2ServerBase(netlib_tservers.ServerTestBase):
self.wfile.write(h2_conn.data_to_send())
self.wfile.flush()
+ if 'h2_server_settings' in self.kwargs:
+ h2_conn.update_settings(self.kwargs['h2_server_settings'])
+ self.wfile.write(h2_conn.data_to_send())
+ self.wfile.flush()
+
done = False
while not done:
try:
@@ -508,3 +514,120 @@ class TestConnectionLost(_Http2TestBase, _Http2ServerBase):
if len(self.master.state.flows) == 1:
assert self.master.state.flows[0].response is None
+
+
+@requires_alpn
+class TestMaxConcurrentStreams(_Http2TestBase, _Http2ServerBase):
+
+ @classmethod
+ def setup_class(self):
+ _Http2TestBase.setup_class()
+ _Http2ServerBase.setup_class(h2_server_settings={h2.settings.MAX_CONCURRENT_STREAMS: 2})
+
+ @classmethod
+ def teardown_class(self):
+ _Http2TestBase.teardown_class()
+ _Http2ServerBase.teardown_class()
+
+ @classmethod
+ def handle_server_event(self, event, h2_conn, rfile, wfile):
+ if isinstance(event, h2.events.ConnectionTerminated):
+ return False
+ elif isinstance(event, h2.events.RequestReceived):
+ h2_conn.send_headers(event.stream_id, [
+ (':status', '200'),
+ ('X-Stream-ID', str(event.stream_id)),
+ ])
+ h2_conn.send_data(event.stream_id, b'Stream-ID {}'.format(event.stream_id))
+ h2_conn.end_stream(event.stream_id)
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+ return True
+
+ def test_max_concurrent_streams(self):
+ client, h2_conn = self._setup_connection()
+ new_streams = [1, 3, 5, 7, 9, 11]
+ for id in new_streams:
+ # this will exceed MAX_CONCURRENT_STREAMS on the server connection
+ # and cause mitmproxy to throttle stream creation to the server
+ self._send_request(client.wfile, h2_conn, stream_id=id, headers=[
+ (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/'),
+ ('X-Stream-ID', str(id)),
+ ])
+
+ ended_streams = 0
+ while ended_streams != len(new_streams):
+ try:
+ header, body = framereader.http2_read_raw_frame(client.rfile)
+ events = h2_conn.receive_data(b''.join([header, body]))
+ except:
+ break
+ client.wfile.write(h2_conn.data_to_send())
+ client.wfile.flush()
+
+ for event in events:
+ if isinstance(event, h2.events.StreamEnded):
+ ended_streams += 1
+
+ h2_conn.close_connection()
+ client.wfile.write(h2_conn.data_to_send())
+ client.wfile.flush()
+
+ assert len(self.master.state.flows) == len(new_streams)
+ for flow in self.master.state.flows:
+ assert flow.response.status_code == 200
+ assert "Stream-ID" in flow.response.body
+
+
+@requires_alpn
+class TestConnectionTerminated(_Http2TestBase, _Http2ServerBase):
+
+ @classmethod
+ def setup_class(self):
+ _Http2TestBase.setup_class()
+ _Http2ServerBase.setup_class()
+
+ @classmethod
+ def teardown_class(self):
+ _Http2TestBase.teardown_class()
+ _Http2ServerBase.teardown_class()
+
+ @classmethod
+ def handle_server_event(self, event, h2_conn, rfile, wfile):
+ if isinstance(event, h2.events.RequestReceived):
+ h2_conn.close_connection(error_code=5, last_stream_id=42, additional_data='foobar')
+ wfile.write(h2_conn.data_to_send())
+ wfile.flush()
+ return True
+
+ def test_connection_terminated(self):
+ client, h2_conn = self._setup_connection()
+
+ self._send_request(client.wfile, h2_conn, headers=[
+ (':authority', "127.0.0.1:%s" % self.server.server.address.port),
+ (':method', 'GET'),
+ (':scheme', 'https'),
+ (':path', '/'),
+ ])
+
+ done = False
+ connection_terminated_event = None
+ while not done:
+ try:
+ raw = b''.join(framereader.http2_read_raw_frame(client.rfile))
+ events = h2_conn.receive_data(raw)
+ for event in events:
+ if isinstance(event, h2.events.ConnectionTerminated):
+ connection_terminated_event = event
+ done = True
+ except:
+ break
+
+ assert len(self.master.state.flows) == 1
+ assert connection_terminated_event is not None
+ assert connection_terminated_event.error_code == 5
+ assert connection_terminated_event.last_stream_id == 42
+ assert connection_terminated_event.additional_data == 'foobar'
diff --git a/test/netlib/tservers.py b/test/netlib/tservers.py
index 803aaa72..666f97ac 100644
--- a/test/netlib/tservers.py
+++ b/test/netlib/tservers.py
@@ -24,7 +24,7 @@ class _ServerThread(threading.Thread):
class _TServer(tcp.TCPServer):
- def __init__(self, ssl, q, handler_klass, addr):
+ def __init__(self, ssl, q, handler_klass, addr, **kwargs):
"""
ssl: A dictionary of SSL parameters:
@@ -42,6 +42,8 @@ class _TServer(tcp.TCPServer):
self.q = q
self.handler_klass = handler_klass
+ if self.handler_klass is not None:
+ self.handler_klass.kwargs = kwargs
self.last_handler = None
def handle_client_connection(self, request, client_address):
@@ -89,16 +91,16 @@ class ServerTestBase(object):
addr = ("localhost", 0)
@classmethod
- def setup_class(cls):
+ def setup_class(cls, **kwargs):
cls.q = queue.Queue()
- s = cls.makeserver()
+ s = cls.makeserver(**kwargs)
cls.port = s.address.port
cls.server = _ServerThread(s)
cls.server.start()
@classmethod
- def makeserver(cls):
- return _TServer(cls.ssl, cls.q, cls.handler, cls.addr)
+ def makeserver(cls, **kwargs):
+ return _TServer(cls.ssl, cls.q, cls.handler, cls.addr, **kwargs)
@classmethod
def teardown_class(cls):
diff --git a/test/pathod/test_protocols_http2.py b/test/pathod/test_protocols_http2.py
index e42c2858..8d7efc82 100644
--- a/test/pathod/test_protocols_http2.py
+++ b/test/pathod/test_protocols_http2.py
@@ -367,37 +367,6 @@ class TestReadRequestAbsolute(netlib_tservers.ServerTestBase):
assert req.port == 22
-class TestReadRequestConnect(netlib_tservers.ServerTestBase):
- class handler(tcp.BaseHandler):
- def handle(self):
- self.wfile.write(
- codecs.decode('00001b0105000000014287bdab4e9c17b7ff44871c92585422e08541871c92585422e085', 'hex_codec'))
- self.wfile.write(
- codecs.decode('00001d0105000000014287bdab4e9c17b7ff44882f91d35d055c87a741882f91d35d055c87a7', 'hex_codec'))
- self.wfile.flush()
-
- ssl = True
-
- def test_connect(self):
- c = tcp.TCPClient(("127.0.0.1", self.port))
- with c.connect():
- c.convert_to_ssl()
- protocol = HTTP2StateProtocol(c, is_server=True)
- protocol.connection_preface_performed = True
-
- req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "authority"
- assert req.method == "CONNECT"
- assert req.host == "address"
- assert req.port == 22
-
- req = protocol.read_request(NotImplemented)
- assert req.first_line_format == "authority"
- assert req.method == "CONNECT"
- assert req.host == "example.com"
- assert req.port == 443
-
-
class TestReadResponse(netlib_tservers.ServerTestBase):
class handler(tcp.BaseHandler):
def handle(self):