aboutsummaryrefslogtreecommitdiffstats
path: root/test/test_websockets.py
diff options
context:
space:
mode:
Diffstat (limited to 'test/test_websockets.py')
-rw-r--r--test/test_websockets.py31
1 files changed, 18 insertions, 13 deletions
diff --git a/test/test_websockets.py b/test/test_websockets.py
index 9956543b..ae0a5e33 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -12,6 +12,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
super(WebSocketsEchoHandler, self).__init__(
connection, address, server
)
+ self.protocol = websockets.WebsocketsProtocol()
self.handshake_done = False
def handle(self):
@@ -31,10 +32,10 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
def handshake(self):
req = http.read_request(self.rfile)
- key = websockets.check_client_handshake(req.headers)
+ key = self.protocol.check_client_handshake(req.headers)
self.wfile.write(http.response_preamble(101) + "\r\n")
- headers = websockets.server_handshake_headers(key)
+ headers = self.protocol.server_handshake_headers(key)
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
self.handshake_done = True
@@ -48,6 +49,7 @@ class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
+ self.protocol = websockets.WebsocketsProtocol()
self.client_nonce = None
def connect(self):
@@ -55,15 +57,15 @@ class WebSocketsClient(tcp.TCPClient):
preamble = http.request_preamble("GET", "/")
self.wfile.write(preamble + "\r\n")
- headers = websockets.client_handshake_headers()
+ headers = self.protocol.client_handshake_headers()
self.client_nonce = headers.get_first("sec-websocket-key")
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
resp = http.read_response(self.rfile, "get", None)
- server_nonce = websockets.check_server_handshake(resp.headers)
+ server_nonce = self.protocol.check_server_handshake(resp.headers)
- if not server_nonce == websockets.create_server_nonce(
+ if not server_nonce == self.protocol.create_server_nonce(
self.client_nonce):
self.close()
@@ -78,6 +80,9 @@ class WebSocketsClient(tcp.TCPClient):
class TestWebSockets(tservers.ServerTestBase):
handler = WebSocketsEchoHandler
+ def __init__(self):
+ self.protocol = websockets.WebsocketsProtocol()
+
def random_bytes(self, n=100):
return os.urandom(n)
@@ -130,26 +135,26 @@ class TestWebSockets(tservers.ServerTestBase):
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
def test_check_server_handshake(self):
- headers = websockets.server_handshake_headers("key")
- assert websockets.check_server_handshake(headers)
+ headers = self.protocol.server_handshake_headers("key")
+ assert self.protocol.check_server_handshake(headers)
headers["Upgrade"] = ["not_websocket"]
- assert not websockets.check_server_handshake(headers)
+ assert not self.protocol.check_server_handshake(headers)
def test_check_client_handshake(self):
- headers = websockets.client_handshake_headers("key")
- assert websockets.check_client_handshake(headers) == "key"
+ headers = self.protocol.client_handshake_headers("key")
+ assert self.protocol.check_client_handshake(headers) == "key"
headers["Upgrade"] = ["not_websocket"]
- assert not websockets.check_client_handshake(headers)
+ assert not self.protocol.check_client_handshake(headers)
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
client_hs = http.read_request(self.rfile)
- websockets.check_client_handshake(client_hs.headers)
+ self.protocol.check_client_handshake(client_hs.headers)
self.wfile.write(http.response_preamble(101) + "\r\n")
- headers = websockets.server_handshake_headers("malformed key")
+ headers = self.protocol.server_handshake_headers("malformed key")
self.wfile.write(headers.format() + "\r\n")
self.wfile.flush()
self.handshake_done = True