diff options
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r-- | test/test_websockets.py | 31 |
1 files changed, 18 insertions, 13 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py index 9956543b..ae0a5e33 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -12,6 +12,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler): super(WebSocketsEchoHandler, self).__init__( connection, address, server ) + self.protocol = websockets.WebsocketsProtocol() self.handshake_done = False def handle(self): @@ -31,10 +32,10 @@ class WebSocketsEchoHandler(tcp.BaseHandler): def handshake(self): req = http.read_request(self.rfile) - key = websockets.check_client_handshake(req.headers) + key = self.protocol.check_client_handshake(req.headers) self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers(key) + headers = self.protocol.server_handshake_headers(key) self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True @@ -48,6 +49,7 @@ class WebSocketsClient(tcp.TCPClient): def __init__(self, address, source_address=None): super(WebSocketsClient, self).__init__(address, source_address) + self.protocol = websockets.WebsocketsProtocol() self.client_nonce = None def connect(self): @@ -55,15 +57,15 @@ class WebSocketsClient(tcp.TCPClient): preamble = http.request_preamble("GET", "/") self.wfile.write(preamble + "\r\n") - headers = websockets.client_handshake_headers() + headers = self.protocol.client_handshake_headers() self.client_nonce = headers.get_first("sec-websocket-key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() resp = http.read_response(self.rfile, "get", None) - server_nonce = websockets.check_server_handshake(resp.headers) + server_nonce = self.protocol.check_server_handshake(resp.headers) - if not server_nonce == websockets.create_server_nonce( + if not server_nonce == self.protocol.create_server_nonce( self.client_nonce): self.close() @@ -78,6 +80,9 @@ class WebSocketsClient(tcp.TCPClient): class TestWebSockets(tservers.ServerTestBase): handler = WebSocketsEchoHandler + def __init__(self): + self.protocol = websockets.WebsocketsProtocol() + def random_bytes(self, n=100): return os.urandom(n) @@ -130,26 +135,26 @@ class TestWebSockets(tservers.ServerTestBase): assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes def test_check_server_handshake(self): - headers = websockets.server_handshake_headers("key") - assert websockets.check_server_handshake(headers) + headers = self.protocol.server_handshake_headers("key") + assert self.protocol.check_server_handshake(headers) headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_server_handshake(headers) + assert not self.protocol.check_server_handshake(headers) def test_check_client_handshake(self): - headers = websockets.client_handshake_headers("key") - assert websockets.check_client_handshake(headers) == "key" + headers = self.protocol.client_handshake_headers("key") + assert self.protocol.check_client_handshake(headers) == "key" headers["Upgrade"] = ["not_websocket"] - assert not websockets.check_client_handshake(headers) + assert not self.protocol.check_client_handshake(headers) class BadHandshakeHandler(WebSocketsEchoHandler): def handshake(self): client_hs = http.read_request(self.rfile) - websockets.check_client_handshake(client_hs.headers) + self.protocol.check_client_handshake(client_hs.headers) self.wfile.write(http.response_preamble(101) + "\r\n") - headers = websockets.server_handshake_headers("malformed key") + headers = self.protocol.server_handshake_headers("malformed key") self.wfile.write(headers.format() + "\r\n") self.wfile.flush() self.handshake_done = True |