aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/http.py8
-rw-r--r--netlib/websockets.py12
-rw-r--r--test/test_http.py6
-rw-r--r--test/test_websockets.py38
4 files changed, 24 insertions, 40 deletions
diff --git a/netlib/http.py b/netlib/http.py
index fe27240a..43155486 100644
--- a/netlib/http.py
+++ b/netlib/http.py
@@ -33,7 +33,7 @@ def _is_valid_host(host):
return True
-def get_line(fp):
+def get_request_line(fp):
"""
Get a line, possibly preceded by a blank.
"""
@@ -41,8 +41,6 @@ def get_line(fp):
if line == "\r\n" or line == "\n":
# Possible leftover from previous message
line = fp.readline()
- if line == "":
- raise tcp.NetLibDisconnect()
return line
@@ -457,7 +455,9 @@ def read_request(rfile, include_body=True, body_size_limit=None, wfile=None):
httpversion, host, port, scheme, method, path, headers, content = (
None, None, None, None, None, None, None, None)
- request_line = get_line(rfile)
+ request_line = get_request_line(rfile)
+ if not request_line:
+ raise tcp.NetLibDisconnect()
request_line_parts = parse_init(request_line)
if not request_line_parts:
diff --git a/netlib/websockets.py b/netlib/websockets.py
index d5c5c2fe..da03768d 100644
--- a/netlib/websockets.py
+++ b/netlib/websockets.py
@@ -350,16 +350,16 @@ def get_payload_length_pair(payload_bytestring):
return (length_code, actual_length)
-def check_client_handshake(req):
- if req.headers.get_first("upgrade", None) != "websocket":
+def check_client_handshake(headers):
+ if headers.get_first("upgrade", None) != "websocket":
return
- return req.headers.get_first('sec-websocket-key')
+ return headers.get_first('sec-websocket-key')
-def check_server_handshake(resp):
- if resp.headers.get_first("upgrade", None) != "websocket":
+def check_server_handshake(headers):
+ if headers.get_first("upgrade", None) != "websocket":
return
- return resp.headers.get_first('sec-websocket-accept')
+ return headers.get_first('sec-websocket-accept')
def create_server_nonce(client_nonce):
diff --git a/test/test_http.py b/test/test_http.py
index 8b99c769..962eb9cb 100644
--- a/test/test_http.py
+++ b/test/test_http.py
@@ -412,10 +412,10 @@ def test_parse_http_basic_auth():
assert not http.parse_http_basic_auth(v)
-def test_get_line():
+def test_get_request_line():
r = cStringIO.StringIO("\nfoo")
- assert http.get_line(r) == "foo"
- tutils.raises(tcp.NetLibDisconnect, http.get_line, r)
+ assert http.get_request_line(r) == "foo"
+ assert not http.get_request_line(r)
class TestReadRequest():
diff --git a/test/test_websockets.py b/test/test_websockets.py
index 9e205e70..6f3b429d 100644
--- a/test/test_websockets.py
+++ b/test/test_websockets.py
@@ -27,7 +27,7 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
def handshake(self):
req = http.read_request(self.rfile)
- key = websockets.check_client_handshake(req)
+ key = websockets.check_client_handshake(req.headers)
self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers(key)
@@ -56,7 +56,7 @@ class WebSocketsClient(tcp.TCPClient):
self.wfile.flush()
resp = http.read_response(self.rfile, "get", None)
- server_nonce = websockets.check_server_handshake(resp)
+ server_nonce = websockets.check_server_handshake(resp.headers)
if not server_nonce == websockets.create_server_nonce(self.client_nonce):
self.close()
@@ -153,38 +153,22 @@ class TestWebSockets(test.ServerTestBase):
assert websockets.Frame.from_bytes(bytes).to_bytes() == bytes
def test_check_server_handshake(self):
- resp = http.Response(
- (1, 1),
- 101,
- "Switching Protocols",
- websockets.server_handshake_headers("key"),
- ""
- )
- assert websockets.check_server_handshake(resp)
- resp.headers["Upgrade"] = ["not_websocket"]
- assert not websockets.check_server_handshake(resp)
+ headers = websockets.server_handshake_headers("key")
+ assert websockets.check_server_handshake(headers)
+ headers["Upgrade"] = ["not_websocket"]
+ assert not websockets.check_server_handshake(headers)
def test_check_client_handshake(self):
- resp = http.Request(
- "relative",
- "get",
- "http",
- "host",
- 22,
- "/",
- (1, 1),
- websockets.client_handshake_headers("key"),
- ""
- )
- assert websockets.check_client_handshake(resp) == "key"
- resp.headers["Upgrade"] = ["not_websocket"]
- assert not websockets.check_client_handshake(resp)
+ headers = websockets.client_handshake_headers("key")
+ assert websockets.check_client_handshake(headers) == "key"
+ headers["Upgrade"] = ["not_websocket"]
+ assert not websockets.check_client_handshake(headers)
class BadHandshakeHandler(WebSocketsEchoHandler):
def handshake(self):
client_hs = http.read_request(self.rfile)
- websockets.check_client_handshake(client_hs)
+ websockets.check_client_handshake(client_hs.headers)
self.wfile.write(http.response_preamble(101) + "\r\n")
headers = websockets.server_handshake_headers("malformed key")