aboutsummaryrefslogtreecommitdiffstats
path: root/netlib
diff options
context:
space:
mode:
Diffstat (limited to 'netlib')
-rw-r--r--netlib/websockets/implementations.py10
-rw-r--r--netlib/websockets/websockets.py35
2 files changed, 22 insertions, 23 deletions
diff --git a/netlib/websockets/implementations.py b/netlib/websockets/implementations.py
index 78ae5be6..ff42ff65 100644
--- a/netlib/websockets/implementations.py
+++ b/netlib/websockets/implementations.py
@@ -26,8 +26,8 @@ class WebSocketsEchoHandler(tcp.BaseHandler):
self.on_message(decoded)
def send_message(self, message):
- frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = False)
- self.wfile.write(frame.to_bytes())
+ frame = ws.WebSocketsFrame.default(message, from_client = False)
+ self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
def handshake(self):
@@ -47,7 +47,7 @@ class WebSocketsClient(tcp.TCPClient):
def __init__(self, address, source_address=None):
super(WebSocketsClient, self).__init__(address, source_address)
self.version = "13"
- self.key = b64encode(os.urandom(16)).decode('utf-8')
+ self.key = ws.generate_client_nounce()
self.resource = "/"
def connect(self):
@@ -76,6 +76,6 @@ class WebSocketsClient(tcp.TCPClient):
self.close()
def send_message(self, message):
- frame = ws.WebSocketsFrame.default_frame_from_message(message, from_client = True)
- self.wfile.write(frame.to_bytes())
+ frame = ws.WebSocketsFrame.default(message, from_client = True)
+ self.wfile.write(frame.safe_to_bytes())
self.wfile.flush()
diff --git a/netlib/websockets/websockets.py b/netlib/websockets/websockets.py
index b796ce39..527d55d6 100644
--- a/netlib/websockets/websockets.py
+++ b/netlib/websockets/websockets.py
@@ -65,7 +65,6 @@ class WebSocketsFrame(object):
payload = None, # bytestring
masking_key = None, # 32 bit byte string
actual_payload_length = None, # any decimal integer
- use_validation = True # indicates whether or not you care if this frame adheres to the spec
):
self.fin = fin
self.rsv1 = rsv1
@@ -78,21 +77,18 @@ class WebSocketsFrame(object):
self.payload = payload
self.decoded_payload = decoded_payload
self.actual_payload_length = actual_payload_length
- self.use_validation = use_validation
-
- if self.use_validation:
- self.validate_frame()
@classmethod
def from_bytes(cls, bytestring):
"""
Construct a websocket frame from an in-memory bytestring
- to construct a frame from a stream of bytes, use read_frame() directly
+ to construct a frame from a stream of bytes, use from_byte_stream() directly
"""
self.from_byte_stream(io.BytesIO(bytestring).read)
+
@classmethod
- def default_frame_from_message(cls, message, from_client = False):
+ def default(cls, message, from_client = False):
"""
Construct a basic websocket frame from some default values.
Creates a non-fragmented text frame.
@@ -119,7 +115,7 @@ class WebSocketsFrame(object):
actual_payload_length = actual_length
)
- def validate_frame(self):
+ def frame_is_valid(self):
"""
Validate websocket frame invariants, call at anytime to ensure the WebSocketsFrame
has not been corrupted.
@@ -141,10 +137,11 @@ class WebSocketsFrame(object):
assert self.actual_payload_length == len(self.payload)
if self.payload is not None and self.masking_key is not None:
- apply_mask(self.payload, self.masking_key) == self.decoded_payload
+ assert apply_mask(self.payload, self.masking_key) == self.decoded_payload
+ return True
except AssertionError:
- raise WebSocketFrameValidationException()
+ return False
def human_readable(self):
return "\n".join([
@@ -161,15 +158,19 @@ class WebSocketsFrame(object):
("actual_payload_length - " + str(self.actual_payload_length)),
("use_validation - " + str(self.use_validation))])
+ def safe_to_bytes(self):
+ try:
+ assert self.frame_is_valid()
+ return self.to_bytes()
+ except:
+ raise WebSocketFrameValidationException()
+
def to_bytes(self):
"""
Serialize the frame back into the wire format, returns a bytestring
+ If you haven't checked is_valid_frame() then there's no guarentees that the
+ serialized bytes will be correct. see safe_to_bytes()
"""
- # validate enforces all the assumptions made by this serializer
- # in the spritit of mitmproxy, it's possible to create and serialize invalid frames
- # by skipping validation.
- if self.use_validation:
- self.validate_frame()
max_16_bit_int = (1 << 16)
max_64_bit_int = (1 << 63)
@@ -198,6 +199,7 @@ class WebSocketsFrame(object):
pass
elif self.actual_payload_length < max_16_bit_int:
+
# '!H' pack as 16 bit unsigned short
bytes += struct.pack('!H', self.actual_payload_length) # add 2 byte extended payload length
@@ -284,9 +286,6 @@ def apply_mask(message, masking_key):
def random_masking_key():
return os.urandom(4)
-def masking_key_list(masking_key):
- return [utils.bytes_to_int(byte) for byte in masking_key]
-
def create_client_handshake(host, port, key, version, resource):
"""
WebSockets connections are intiated by the client with a valid HTTP upgrade request