aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--mitmproxy/controller.py2
-rw-r--r--mitmproxy/flow/master.py4
-rw-r--r--mitmproxy/protocol/__init__.py4
-rw-r--r--mitmproxy/protocol/http.py21
-rw-r--r--mitmproxy/protocol/websockets.py48
-rw-r--r--mitmproxy/proxy/root_context.py19
6 files changed, 39 insertions, 59 deletions
diff --git a/mitmproxy/controller.py b/mitmproxy/controller.py
index c262b192..d886af97 100644
--- a/mitmproxy/controller.py
+++ b/mitmproxy/controller.py
@@ -28,6 +28,8 @@ Events = frozenset([
"response",
"responseheaders",
+ "websockets_handshake",
+
"next_layer",
"error",
diff --git a/mitmproxy/flow/master.py b/mitmproxy/flow/master.py
index 0475ef4e..9cdcc8dd 100644
--- a/mitmproxy/flow/master.py
+++ b/mitmproxy/flow/master.py
@@ -334,6 +334,10 @@ class FlowMaster(controller.Master):
self.client_playback.clear(f)
return f
+ @controller.handler
+ def websockets_handshake(self, f):
+ return f
+
def handle_intercept(self, f):
self.state.update_flow(f)
diff --git a/mitmproxy/protocol/__init__.py b/mitmproxy/protocol/__init__.py
index 510cd195..b99b55bd 100644
--- a/mitmproxy/protocol/__init__.py
+++ b/mitmproxy/protocol/__init__.py
@@ -29,8 +29,10 @@ from __future__ import absolute_import, print_function, division
from .base import Layer, ServerConnectionMixin
from .http import UpstreamConnectLayer
+from .http import HttpLayer
from .http1 import Http1Layer
from .http2 import Http2Layer
+from .websockets import WebSocketsLayer
from .rawtcp import RawTCPLayer
from .tls import TlsClientHello
from .tls import TlsLayer
@@ -40,7 +42,9 @@ __all__ = [
"Layer", "ServerConnectionMixin",
"TlsLayer", "is_tls_record_magic", "TlsClientHello",
"UpstreamConnectLayer",
+ "HttpLayer",
"Http1Layer",
"Http2Layer",
+ "WebSocketsLayer",
"RawTCPLayer",
]
diff --git a/mitmproxy/protocol/http.py b/mitmproxy/protocol/http.py
index fbb52c92..1418d6e9 100644
--- a/mitmproxy/protocol/http.py
+++ b/mitmproxy/protocol/http.py
@@ -10,7 +10,6 @@ import six
from mitmproxy import exceptions
from mitmproxy import models
from mitmproxy.protocol import base
-from .websockets import WebSocketsLayer
import netlib.exceptions
from netlib import http
@@ -192,20 +191,10 @@ class HttpLayer(base.Layer):
self.process_request_hook(flow)
try:
- # WebSockets
- if websockets.check_handshake(request.headers):
- if websockets.check_client_version(request.headers):
- layer = WebSocketsLayer(self, request)
- layer()
- return
- else:
- # we only support RFC6455 with WebSockets version 13
- self.send_response(models.make_error_response(
- 400,
- http.status_codes.RESPONSES.get(400),
- http.Headers(sec_websocket_version="13")
- ))
- return
+ if websockets.check_handshake(request.headers) and websockets.check_client_version(request.headers):
+ # we only support RFC6455 with WebSockets version 13
+ # allow inline scripts to manupulate the client handshake
+ self.channel.ask("websockets_handshake", flow)
if not flow.response:
self.establish_server_connection(
@@ -230,7 +219,7 @@ class HttpLayer(base.Layer):
# It may be useful to pass additional args (such as the upgrade header)
# to next_layer in the future
if flow.response.status_code == 101:
- layer = self.ctx.next_layer(self)
+ layer = self.ctx.next_layer(self, flow)
layer()
return
diff --git a/mitmproxy/protocol/websockets.py b/mitmproxy/protocol/websockets.py
index 05eaa537..f15a38ef 100644
--- a/mitmproxy/protocol/websockets.py
+++ b/mitmproxy/protocol/websockets.py
@@ -6,12 +6,10 @@ import struct
from OpenSSL import SSL
from mitmproxy import exceptions
-from mitmproxy import models
from mitmproxy.protocol import base
import netlib.exceptions
from netlib import tcp
-from netlib import http
from netlib import websockets
@@ -38,44 +36,17 @@ class WebSocketsLayer(base.Layer):
Only raw frames are forwarded to the other endpoint.
"""
- def __init__(self, ctx, request):
+ def __init__(self, ctx, flow):
super(WebSocketsLayer, self).__init__(ctx)
- self._request = request
+ self._flow = flow
- self.client_key = websockets.get_client_key(self._request.headers)
- self.client_protocol = websockets.get_protocol(self._request.headers)
- self.client_extensions = websockets.get_extensions(self._request.headers)
+ self.client_key = websockets.get_client_key(self._flow.request.headers)
+ self.client_protocol = websockets.get_protocol(self._flow.request.headers)
+ self.client_extensions = websockets.get_extensions(self._flow.request.headers)
- self.server_accept = None
- self.server_protocol = None
- self.server_extensions = None
-
- def _initiate_server_conn(self):
- self.establish_server_connection(
- self._request.host,
- self._request.port,
- self._request.scheme,
- )
-
- self.server_conn.send(netlib.http.http1.assemble_request(self._request))
- response = netlib.http.http1.read_response(self.server_conn.rfile, self._request, body_size_limit=None)
-
- if not websockets.check_handshake(response.headers):
- raise exceptions.ProtocolException("Establishing WebSockets connection with server failed: {}".format(response.headers))
-
- self.server_accept = websockets.get_server_accept(response.headers)
- self.server_protocol = websockets.get_protocol(response.headers)
- self.server_extensions = websockets.get_extensions(response.headers)
-
- def _complete_handshake(self):
- headers = websockets.server_handshake_headers(self.client_key, self.server_protocol, self.server_extensions)
- self.send_response(models.HTTPResponse(
- self._request.http_version,
- 101,
- http.status_codes.RESPONSES.get(101),
- headers,
- b"",
- ))
+ self.server_accept = websockets.get_server_accept(self._flow.response.headers)
+ self.server_protocol = websockets.get_protocol(self._flow.response.headers)
+ self.server_extensions = websockets.get_extensions(self._flow.response.headers)
def _handle_frame(self, frame, source_conn, other_conn, is_server):
self.log(
@@ -114,9 +85,6 @@ class WebSocketsLayer(base.Layer):
return True
def __call__(self):
- self._initiate_server_conn()
- self._complete_handshake()
-
client = self.client_conn.connection
server = self.server_conn.connection
conns = [client, server]
diff --git a/mitmproxy/proxy/root_context.py b/mitmproxy/proxy/root_context.py
index 81dd625c..95611362 100644
--- a/mitmproxy/proxy/root_context.py
+++ b/mitmproxy/proxy/root_context.py
@@ -4,6 +4,7 @@ import sys
import six
+from netlib import websockets
import netlib.exceptions
from mitmproxy import exceptions
from mitmproxy import protocol
@@ -32,7 +33,7 @@ class RootContext(object):
self.channel = channel
self.config = config
- def next_layer(self, top_layer):
+ def next_layer(self, top_layer, flow=None):
"""
This function determines the next layer in the protocol stack.
@@ -42,10 +43,22 @@ class RootContext(object):
Returns:
The next layer
"""
- layer = self._next_layer(top_layer)
+ layer = self._next_layer(top_layer, flow)
return self.channel.ask("next_layer", layer)
- def _next_layer(self, top_layer):
+ def _next_layer(self, top_layer, flow):
+ if flow is not None:
+ # We already have a flow, try to derive the next information from it
+
+ # Check for WebSockets handshake
+ is_websockets = (
+ flow and
+ websockets.check_handshake(flow.request.headers) and
+ websockets.check_handshake(flow.response.headers)
+ )
+ if isinstance(top_layer, protocol.HttpLayer) and is_websockets:
+ return protocol.WebSocketsLayer(top_layer, flow)
+
try:
d = top_layer.client_conn.rfile.peek(3)
except netlib.exceptions.TcpException as e: