aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy/proxy/protocol/websocket.py
blob: 15d9a288d2f22d17bc26f9dbd6db70f999169b11 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import socket
import struct
from OpenSSL import SSL

from mitmproxy import exceptions
from mitmproxy import flow
from mitmproxy.proxy.protocol import base
from mitmproxy.net import tcp
from mitmproxy.net import websockets
from mitmproxy.websocket import WebSocketFlow, WebSocketBinaryMessage, WebSocketTextMessage


class WebSocketLayer(base.Layer):
    """
        WebSocket layer to intercept, modify, and forward WebSocket messages.

        Only version 13 is supported (as specified in RFC6455).
        Only HTTP/1.1-initiated connections are supported.

        The client starts by sending an Upgrade-request.
        In order to determine the handshake and negotiate the correct protocol
        and extensions, the Upgrade-request is forwarded to the server.
        The response from the server is then parsed and negotiated settings are extracted.
        Finally the handshake is completed by forwarding the server-response to the client.
        After that, only WebSocket frames are exchanged.

        PING/PONG frames pass through and must be answered by the other endpoint.

        CLOSE frames are forwarded before this WebSocketLayer terminates.

        This layer is transparent to any negotiated extensions.
        This layer is transparent to any negotiated subprotocols.
        Only raw frames are forwarded to the other endpoint.

        WebSocket messages are stored in a WebSocketFlow.
    """

    def __init__(self, ctx, handshake_flow):
        super().__init__(ctx)
        self.handshake_flow = handshake_flow
        self.flow = None  # type: WebSocketFlow

        self.client_frame_buffer = []
        self.server_frame_buffer = []

    def _handle_frame(self, frame, source_conn, other_conn, is_server):
        if frame.header.opcode & 0x8 == 0:
            return self._handle_data_frame(frame, source_conn, other_conn, is_server)
        elif frame.header.opcode in (websockets.OPCODE.PING, websockets.OPCODE.PONG):
            return self._handle_ping_pong(frame, source_conn, other_conn, is_server)
        elif frame.header.opcode == websockets.OPCODE.CLOSE:
            return self._handle_close(frame, source_conn, other_conn, is_server)
        else:
            return self._handle_unknown_frame(frame, source_conn, other_conn, is_server)

    def _handle_data_frame(self, frame, source_conn, other_conn, is_server):
        fb = self.server_frame_buffer if is_server else self.client_frame_buffer
        fb.append(frame)

        if frame.header.fin:
            if frame.header.opcode == websockets.OPCODE.TEXT:
                t = WebSocketTextMessage
            else:
                t = WebSocketBinaryMessage

            payload = b''.join(f.payload for f in fb)
            fb.clear()

            websocket_message = t(self.flow, not is_server, payload)
            self.flow.messages.append(websocket_message)
            self.channel.ask("websocket_message", self.flow)

            # chunk payload into multiple 10kB frames, and send them
            payload = websocket_message.content
            chunk_size = 10240  # 10kB
            chunks = range(0, len(payload), chunk_size)
            frms = [
                websockets.Frame(
                    payload=payload[i:i + chunk_size],
                    opcode=frame.header.opcode,
                    mask=(False if is_server else 1),
                    masking_key=(b'' if is_server else os.urandom(4))) for i in chunks
            ]
            frms[-1].header.fin = 1

            for frm in frms:
                other_conn.send(bytes(frm))

        return True

    def _handle_ping_pong(self, frame, source_conn, other_conn, is_server):
        # just forward the ping/pong to the other side
        other_conn.send(bytes(frame))
        return True

    def _handle_close(self, frame, source_conn, other_conn, is_server):
        self.flow.close_sender = "server" if is_server else "client"
        if len(frame.payload) >= 2:
            code, = struct.unpack('!H', frame.payload[:2])
            self.flow.close_code = code
            self.flow.close_message = websockets.CLOSE_REASON.get_name(code, default='unknown status code')
        if len(frame.payload) > 2:
            self.flow.close_reason = frame.payload[2:]

        other_conn.send(bytes(frame))

        # close the connection
        return False

    def _handle_unknown_frame(self, frame, source_conn, other_conn, is_server):
        # unknown frame - just forward it
        other_conn.send(bytes(frame))

        sender = "server" if is_server else "client"
        self.log("Unknown WebSocket frame received from {}".format(sender), "info", [repr(frame)])

        return True

    def __call__(self):
        self.flow = WebSocketFlow(self.client_conn, self.server_conn, self.handshake_flow, self)
        self.flow.metadata['websocket_handshake'] = self.handshake_flow
        self.handshake_flow.metadata['websocket_flow'] = self.flow
        self.channel.ask("websocket_start", self.flow)

        client = self.client_conn.connection
        server = self.server_conn.connection
        conns = [client, server]

        try:
            while not self.channel.should_exit.is_set():
                r = tcp.ssl_read_select(conns, 1)
                for conn in r:
                    source_conn = self.client_conn if conn == client else self.server_conn
                    other_conn = self.server_conn if conn == client else self.client_conn
                    is_server = (conn == self.server_conn.connection)

                    frame = websockets.Frame.from_file(source_conn.rfile)

                    if not self._handle_frame(frame, source_conn, other_conn, is_server):
                        return
        except (socket.error, exceptions.TcpException, SSL.Error) as e:
            self.flow.error = flow.Error("WebSocket connection closed unexpectedly: {}".format(repr(e)))
            self.channel.tell("websocket_error", self.flow)
        finally:
            self.channel.tell("websocket_end", self.flow)