aboutsummaryrefslogtreecommitdiffstats
path: root/netlib/websockets/utils.py
blob: 980436622cf5a88bca5e30980ec1c45282295c47 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
Collection of WebSockets Protocol utility functions (RFC6455)
Spec: https://tools.ietf.org/html/rfc6455
"""


import base64
import hashlib
import os

from netlib import http
from mitmproxy.utils import strutils

MAGIC = b'258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
VERSION = "13"


def client_handshake_headers(version=None, key=None, protocol=None, extensions=None):
    """
        Create the headers for a valid HTTP upgrade request. If Key is not
        specified, it is generated, and can be found in sec-websocket-key in
        the returned header set.

        Returns an instance of http.Headers
    """
    if version is None:
        version = VERSION
    if key is None:
        key = base64.b64encode(os.urandom(16)).decode('ascii')
    h = http.Headers(
        connection="upgrade",
        upgrade="websocket",
        sec_websocket_version=version,
        sec_websocket_key=key,
    )
    if protocol is not None:
        h['sec-websocket-protocol'] = protocol
    if extensions is not None:
        h['sec-websocket-extensions'] = extensions
    return h


def server_handshake_headers(client_key, protocol=None, extensions=None):
    """
      The server response is a valid HTTP 101 response.

      Returns an instance of http.Headers
    """
    h = http.Headers(
        connection="upgrade",
        upgrade="websocket",
        sec_websocket_accept=create_server_nonce(client_key),
    )
    if protocol is not None:
        h['sec-websocket-protocol'] = protocol
    if extensions is not None:
        h['sec-websocket-extensions'] = extensions
    return h


def check_handshake(headers):
    return (
        "upgrade" in headers.get("connection", "").lower() and
        headers.get("upgrade", "").lower() == "websocket" and
        (headers.get("sec-websocket-key") is not None or headers.get("sec-websocket-accept") is not None)
    )


def create_server_nonce(client_nonce):
    return base64.b64encode(hashlib.sha1(strutils.always_bytes(client_nonce) + MAGIC).digest())


def check_client_version(headers):
    return headers.get("sec-websocket-version", "") == VERSION


def get_extensions(headers):
    return headers.get("sec-websocket-extensions", None)


def get_protocol(headers):
    return headers.get("sec-websocket-protocol", None)


def get_client_key(headers):
    return headers.get("sec-websocket-key", None)


def get_server_accept(headers):
    return headers.get("sec-websocket-accept", None)