aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMaximilian Hils <git@maximilianhils.com>2015-08-29 20:53:25 +0200
committerMaximilian Hils <git@maximilianhils.com>2015-08-29 20:53:25 +0200
commita7058e2a3c59cc2b13aaea3d7c767a3ca4a4bc40 (patch)
treecfdfef8dd58f014adc77a81838bbef1c87011621
parent63844df34367bf7147c2d43a9e4061515f6430c9 (diff)
downloadmitmproxy-a7058e2a3c59cc2b13aaea3d7c767a3ca4a4bc40.tar.gz
mitmproxy-a7058e2a3c59cc2b13aaea3d7c767a3ca4a4bc40.tar.bz2
mitmproxy-a7058e2a3c59cc2b13aaea3d7c767a3ca4a4bc40.zip
fix bugs, fix tests
-rw-r--r--libmproxy/console/statusbar.py11
-rw-r--r--libmproxy/protocol2/http.py54
-rw-r--r--test/test_proxy.py9
-rw-r--r--test/test_server.py16
-rw-r--r--test/tservers.py64
5 files changed, 69 insertions, 85 deletions
diff --git a/libmproxy/console/statusbar.py b/libmproxy/console/statusbar.py
index 7eb2131b..ea2dbfa8 100644
--- a/libmproxy/console/statusbar.py
+++ b/libmproxy/console/statusbar.py
@@ -199,11 +199,12 @@ class StatusBar(urwid.WidgetWrap):
r.append("[%s]" % (":".join(opts)))
if self.master.server.config.mode in ["reverse", "upstream"]:
- dst = self.master.server.config.mode.dst
- scheme = "https" if dst[0] else "http"
- if dst[1] != dst[0]:
- scheme += "2https" if dst[1] else "http"
- r.append("[dest:%s]" % utils.unparse_url(scheme, *dst[2:]))
+ dst = self.master.server.config.upstream_server
+ r.append("[dest:%s]" % netlib.utils.unparse_url(
+ dst.scheme,
+ dst.address.host,
+ dst.address.port
+ ))
if self.master.scripts:
r.append("[")
r.append(("heading_key", "s"))
diff --git a/libmproxy/protocol2/http.py b/libmproxy/protocol2/http.py
index 0fde9fb1..a3f32926 100644
--- a/libmproxy/protocol2/http.py
+++ b/libmproxy/protocol2/http.py
@@ -40,6 +40,7 @@ class _HttpLayer(Layer):
def send_response(self, response):
raise NotImplementedError()
+
class _StreamingHttpLayer(_HttpLayer):
supports_streaming = True
@@ -58,7 +59,6 @@ class _StreamingHttpLayer(_HttpLayer):
class Http1Layer(_StreamingHttpLayer):
-
def __init__(self, ctx, mode):
super(Http1Layer, self).__init__(ctx)
self.mode = mode
@@ -105,12 +105,12 @@ class Http1Layer(_StreamingHttpLayer):
def send_response_headers(self, response):
h = self.client_protocol._assemble_response_first_line(response)
- self.client_conn.wfile.write(h+"\r\n")
+ self.client_conn.wfile.write(h + "\r\n")
h = self.client_protocol._assemble_response_headers(
response,
preserve_transfer_encoding=True
)
- self.client_conn.send(h+"\r\n")
+ self.client_conn.send(h + "\r\n")
def send_response_body(self, response, chunks):
if self.client_protocol.has_chunked_encoding(response.headers):
@@ -142,8 +142,10 @@ class Http2Layer(_HttpLayer):
def __init__(self, ctx, mode):
super(Http2Layer, self).__init__(ctx)
self.mode = mode
- self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame)
- self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
+ self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True,
+ unhandled_frame_cb=self.handle_unexpected_frame)
+ self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
+ unhandled_frame_cb=self.handle_unexpected_frame)
def read_request(self):
request = HTTPRequest.from_protocol(
@@ -172,17 +174,20 @@ class Http2Layer(_HttpLayer):
def connect(self):
self.ctx.connect()
- self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
+ self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
+ unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface()
def reconnect(self):
self.ctx.reconnect()
- self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
+ self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
+ unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface()
def set_server(self, *args, **kwargs):
self.ctx.set_server(*args, **kwargs)
- self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame)
+ self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False,
+ unhandled_frame_cb=self.handle_unexpected_frame)
self.server_protocol.perform_connection_preface()
def __call__(self):
@@ -264,7 +269,10 @@ class UpstreamConnectLayer(Layer):
def __init__(self, ctx, connect_request):
super(UpstreamConnectLayer, self).__init__(ctx)
self.connect_request = connect_request
- self.server_conn = ConnectServerConnection((connect_request.host, connect_request.port), self.ctx)
+ self.server_conn = ConnectServerConnection(
+ (connect_request.host, connect_request.port),
+ self.ctx
+ )
def __call__(self):
layer = self.ctx.next_layer(self)
@@ -280,6 +288,9 @@ class UpstreamConnectLayer(Layer):
def reconnect(self):
self.ctx.reconnect()
self.send_request(self.connect_request)
+ resp = self.read_response("CONNECT")
+ if resp.code != 200:
+ raise ProtocolException("Reconnect: Upstream server refuses CONNECT request")
def set_server(self, address, server_tls=None, sni=None, depth=1):
if depth == 1:
@@ -290,7 +301,7 @@ class UpstreamConnectLayer(Layer):
self.connect_request.port = address.port
self.server_conn.address = address
else:
- self.ctx.set_server(address, server_tls, sni, depth-1)
+ self.ctx.set_server(address, server_tls, sni, depth - 1)
class HttpLayer(Layer):
@@ -413,10 +424,10 @@ class HttpLayer(Layer):
# First send the headers and then transfer the response incrementally
self.send_response_headers(flow.response)
chunks = self.read_response_body(
- flow.response.headers,
- flow.request.method,
- flow.response.code,
- max_chunk_size=4096
+ flow.response.headers,
+ flow.request.method,
+ flow.response.code,
+ max_chunk_size=4096
)
if callable(flow.response.stream):
chunks = flow.response.stream(chunks)
@@ -521,7 +532,8 @@ class HttpLayer(Layer):
# If there's not TlsLayer below which could catch the exception,
# TLS will not be established.
if tls and not self.server_conn.tls_established:
- raise ProtocolException("Cannot upgrade to SSL, no TLS layer on the protocol stack.")
+ raise ProtocolException(
+ "Cannot upgrade to SSL, no TLS layer on the protocol stack.")
else:
if not self.server_conn:
self.connect()
@@ -542,7 +554,8 @@ class HttpLayer(Layer):
def validate_request(self, request):
if request.form_in == "absolute" and request.scheme != "http":
- self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme))
+ self.send_response(
+ make_error_response(400, "Invalid request scheme: %s" % request.scheme))
raise HttpException("Invalid request scheme: %s" % request.scheme)
expected_request_forms = {
@@ -570,7 +583,11 @@ class HttpLayer(Layer):
self.send_response(make_error_response(
407,
"Proxy Authentication Required",
- odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()])
+ odict.ODictCaseless(
+ [
+ [k, v] for k, v in
+ self.config.authenticator.auth_challenge_headers().items()
+ ])
))
raise InvalidCredentials("Proxy Authentication Required")
@@ -614,6 +631,9 @@ class RequestReplayThread(threading.Thread):
if r.scheme == "https":
connect_request = make_connect_request((r.host, r.port))
server.send(protocol.assemble(connect_request))
+ resp = protocol.read_response("CONNECT")
+ if resp.code != 200:
+ raise HttpError(502, "Upstream server refuses CONNECT request")
server.establish_ssl(
self.config.clientcerts,
sni=self.flow.server_conn.sni
diff --git a/test/test_proxy.py b/test/test_proxy.py
index 9c01ab63..fac4a4f4 100644
--- a/test/test_proxy.py
+++ b/test/test_proxy.py
@@ -1,9 +1,8 @@
-import argparse
from libmproxy import cmdline
from libmproxy.proxy import ProxyConfig, process_proxy_options
from libmproxy.proxy.connection import ServerConnection
from libmproxy.proxy.primitives import ProxyError
-from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler
+from libmproxy.proxy.server import DummyServer, ProxyServer, ConnectionHandler2
import tutils
from libpathod import test
from netlib import http, tcp
@@ -175,8 +174,10 @@ class TestDummyServer:
class TestConnectionHandler:
def test_fatal_error(self):
config = mock.Mock()
- config.mode.get_upstream_server.side_effect = RuntimeError
- c = ConnectionHandler(
+ root_layer = mock.Mock()
+ root_layer.side_effect = RuntimeError
+ config.mode.return_value = root_layer
+ c = ConnectionHandler2(
config,
mock.MagicMock(),
("127.0.0.1",
diff --git a/test/test_server.py b/test/test_server.py
index 1216a349..7b66c582 100644
--- a/test/test_server.py
+++ b/test/test_server.py
@@ -68,7 +68,7 @@ class CommonMixin:
# SSL with the upstream proxy.
rt = self.master.replay_request(l, block=True)
assert not rt
- if isinstance(self, tservers.HTTPUpstreamProxTest) and not self.ssl:
+ if isinstance(self, tservers.HTTPUpstreamProxTest):
assert l.response.code == 502
else:
assert l.error
@@ -506,7 +506,7 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin, TcpMixin):
p = pathoc.Pathoc(("localhost", self.proxy.port), fp=None)
p.connect()
r = p.request("get:/")
- assert r.status_code == 400
+ assert r.status_code == 502
class TestProxy(tservers.HTTPProxTest):
@@ -724,9 +724,9 @@ class TestStreamRequest(tservers.HTTPProxTest):
assert resp.headers["Transfer-Encoding"][0] == 'chunked'
assert resp.status_code == 200
- chunks = list(
- content for _, content, _ in protocol.read_http_body_chunked(
- resp.headers, None, "GET", 200, False))
+ chunks = list(protocol.read_http_body_chunked(
+ resp.headers, None, "GET", 200, False
+ ))
assert chunks == ["this", "isatest", ""]
connection.close()
@@ -959,6 +959,9 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
p = self.pathoc()
req = p.request("get:'/p/418:b\"content\"'")
+ assert req.content == "content"
+ assert req.status_code == 418
+
assert self.proxy.tmaster.state.flow_count() == 2 # CONNECT and request
# CONNECT, failing request,
assert self.chain[0].tmaster.state.flow_count() == 4
@@ -967,8 +970,7 @@ class TestProxyChainingSSLReconnect(tservers.HTTPUpstreamProxTest):
assert self.chain[1].tmaster.state.flow_count() == 2
# (doesn't store (repeated) CONNECTs from chain[0]
# as it is a regular proxy)
- assert req.content == "content"
- assert req.status_code == 418
+
assert not self.chain[1].tmaster.state.flows[0].response # killed
assert self.chain[1].tmaster.state.flows[1].response
diff --git a/test/tservers.py b/test/tservers.py
index 43ebf2bb..dfd3f627 100644
--- a/test/tservers.py
+++ b/test/tservers.py
@@ -181,22 +181,24 @@ class TResolver:
def original_addr(self, sock):
return ("127.0.0.1", self.port)
-
class TransparentProxTest(ProxTestBase):
ssl = None
resolver = TResolver
@classmethod
- @mock.patch("libmproxy.platform.resolver")
- def setupAll(cls, _):
+ def setupAll(cls):
super(TransparentProxTest, cls).setupAll()
- if cls.ssl:
- ports = [cls.server.port, cls.server2.port]
- else:
- ports = []
- cls.config.mode = TransparentProxyMode(
- cls.resolver(cls.server.port),
- ports)
+
+ cls._resolver = mock.patch(
+ "libmproxy.platform.resolver",
+ new=lambda: cls.resolver(cls.server.port)
+ )
+ cls._resolver.start()
+
+ @classmethod
+ def teardownAll(cls):
+ cls._resolver.stop()
+ super(TransparentProxTest, cls).teardownAll()
@classmethod
def get_proxy_config(cls):
@@ -270,48 +272,6 @@ class SocksModeTest(HTTPProxTest):
d["mode"] = "socks5"
return d
-class SpoofModeTest(ProxTestBase):
- ssl = None
-
- @classmethod
- def get_proxy_config(cls):
- d = ProxTestBase.get_proxy_config()
- d["upstream_server"] = None
- d["mode"] = "spoof"
- return d
-
- def pathoc(self, sni=None):
- """
- Returns a connected Pathoc instance.
- """
- p = libpathod.pathoc.Pathoc(
- ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
- )
- p.connect()
- return p
-
-
-class SSLSpoofModeTest(ProxTestBase):
- ssl = True
-
- @classmethod
- def get_proxy_config(cls):
- d = ProxTestBase.get_proxy_config()
- d["upstream_server"] = None
- d["mode"] = "sslspoof"
- d["spoofed_ssl_port"] = 443
- return d
-
- def pathoc(self, sni=None):
- """
- Returns a connected Pathoc instance.
- """
- p = libpathod.pathoc.Pathoc(
- ("localhost", self.proxy.port), ssl=self.ssl, sni=sni, fp=None
- )
- p.connect()
- return p
-
class ChainProxTest(ProxTestBase):
"""