diff options
-rw-r--r-- | libpathod/language/base.py | 2 | ||||
-rw-r--r-- | libpathod/language/http.py | 2 | ||||
-rw-r--r-- | libpathod/language/websockets.py | 49 | ||||
-rw-r--r-- | libpathod/pathoc.py | 3 | ||||
-rw-r--r-- | libpathod/templates/docs_lang_websockets.html | 16 | ||||
-rw-r--r-- | test/test_language_websocket.py | 53 | ||||
-rw-r--r-- | test/tutils.py | 1 |
7 files changed, 96 insertions, 30 deletions
diff --git a/libpathod/language/base.py b/libpathod/language/base.py index 41ad639a..3773fde1 100644 --- a/libpathod/language/base.py +++ b/libpathod/language/base.py @@ -10,6 +10,7 @@ from . import generators, exceptions class Settings: def __init__( self, + is_client = False, staticdir = None, unconstrained_file_access = False, request_host = None, @@ -19,6 +20,7 @@ class Settings: self.unconstrained_file_access = unconstrained_file_access self.request_host = request_host self.websocket_key = websocket_key + self.is_client = is_client diff --git a/libpathod/language/http.py b/libpathod/language/http.py index 94de7237..070cc5f4 100644 --- a/libpathod/language/http.py +++ b/libpathod/language/http.py @@ -5,7 +5,7 @@ import pyparsing as pp import netlib.websockets from netlib import http_status, http_uastrings -from . import base, generators, exceptions, actions, message +from . import base, exceptions, actions, message class WS(base.CaselessLiteral): diff --git a/libpathod/language/websockets.py b/libpathod/language/websockets.py index 3abdf9d8..599cdb88 100644 --- a/libpathod/language/websockets.py +++ b/libpathod/language/websockets.py @@ -1,4 +1,4 @@ - +import os import netlib.websockets import pyparsing as pp from . import base, generators, actions, message @@ -115,6 +115,10 @@ class WebsocketFrame(message.Message): def mask(self): return self.tok(Mask) + @property + def key(self): + return self.tok(Key) + @classmethod def expr(klass): parts = [i.expr() for i in klass.comps] @@ -129,8 +133,21 @@ class WebsocketFrame(message.Message): resp = resp.setParseAction(klass) return resp + def resolve(self, settings, msg=None): + tokens = self.tokens[:] + if not self.mask and settings.is_client: + tokens.append( + Mask(True) + ) + if self.mask and self.mask.value and not self.key: + tokens.append( + Key(base.TokValueLiteral(os.urandom(4))) + ) + return self.__class__( + [i.resolve(settings, self) for i in tokens] + ) + def values(self, settings): - vals = [] if self.body: bodygen = self.body.value.get_generator(settings) length = len(self.body.value.get_generator(settings)) @@ -138,29 +155,31 @@ class WebsocketFrame(message.Message): bodygen = None length = 0 frameparts = dict( - mask = True, payload_length = length ) + if self.mask and self.mask.value: + frameparts["mask"] = True + if self.key: + key = self.key.values(settings)[0][:] + frameparts["masking_key"] = key for i in ["opcode", "fin", "rsv1", "rsv2", "rsv3", "mask"]: v = getattr(self, i, None) if v is not None: frameparts[i] = v.value frame = netlib.websockets.FrameHeader(**frameparts) vals = [frame.to_bytes()] - if self.body: - masker = netlib.websockets.Masker(frame.masking_key) - vals.append( - generators.TransformGenerator( - bodygen, - masker.mask + if bodygen: + if frame.masking_key: + masker = netlib.websockets.Masker(frame.masking_key) + vals.append( + generators.TransformGenerator( + bodygen, + masker.mask + ) ) - ) + else: + vals.append(bodygen) return vals - def resolve(self, settings, msg=None): - return self.__class__( - [i.resolve(settings, self) for i in self.tokens] - ) - def spec(self): return ":".join([i.spec() for i in self.tokens]) diff --git a/libpathod/pathoc.py b/libpathod/pathoc.py index 2574da6c..3d61c9e7 100644 --- a/libpathod/pathoc.py +++ b/libpathod/pathoc.py @@ -212,7 +212,8 @@ class Pathoc(tcp.TCPClient): self.settings = language.Settings( staticdir = os.getcwd(), unconstrained_file_access = True, - request_host = self.address.host + request_host = self.address.host, + is_client = True ) self.ssl, self.sni = ssl, sni self.clientcert = clientcert diff --git a/libpathod/templates/docs_lang_websockets.html b/libpathod/templates/docs_lang_websockets.html index c50d081f..9eb1ec25 100644 --- a/libpathod/templates/docs_lang_websockets.html +++ b/libpathod/templates/docs_lang_websockets.html @@ -45,6 +45,14 @@ </tr> <tr> + <td> k<a href="#valuespec">VALUE</a> </td> + <td> + Set the masking key. The resulting value must be exactly 4 + bytes long. + </td> + </tr> + + <tr> <td> [-]mask </td> <td> Set or un-set the <b>mask</b> bit. @@ -62,8 +70,12 @@ <tr> <td> r </td> <td> - Create a "raw" frame - disables auto-generation of the masking - key if the mask bit is on. + Create a "raw" frame: + <ul> + <li> Don't auto-generate the masking key if the mask flag is + set </li> + + <li> Don't set the mask flag if masking key is set. </li> </td> </tr> diff --git a/test/test_language_websocket.py b/test/test_language_websocket.py index 4b384f61..5e2ccb88 100644 --- a/test/test_language_websocket.py +++ b/test/test_language_websocket.py @@ -54,17 +54,48 @@ class TestWebsocketFrame: assert not frm.header.rsv2 assert not frm.header.rsv3 - def test_construction(self): - wf = parse_request("wf:c1") - frm = netlib.websockets.Frame.from_bytes(tutils.render(wf)) - assert wf.opcode.value == 1 == frm.header.opcode + def fr(self, spec, **kwargs): + settings = language.base.Settings(**kwargs) + wf = parse_request(spec) + return netlib.websockets.Frame.from_bytes(tutils.render(wf, settings)) - wf = parse_request("wf:cbinary") - frm = netlib.websockets.Frame.from_bytes(tutils.render(wf)) - assert wf.opcode.value == frm.header.opcode - assert wf.opcode.value == netlib.websockets.OPCODE.BINARY + def test_construction(self): + assert self.fr("wf:c1").header.opcode == 1 + assert self.fr("wf:c0").header.opcode == 0 + assert self.fr("wf:cbinary").header.opcode ==\ + netlib.websockets.OPCODE.BINARY + assert self.fr("wf:ctext").header.opcode ==\ + netlib.websockets.OPCODE.TEXT def test_auto_raw(self): - wf = parse_request("wf:b'foo':mask") - frm = netlib.websockets.Frame.from_bytes(tutils.render(wf)) - print frm.human_readable() + # Simple server frame + frm = self.fr("wf:b'foo'") + assert not frm.header.mask + assert not frm.header.masking_key + + # Simple client frame + frm = self.fr("wf:b'foo'", is_client=True) + assert frm.header.mask + assert frm.header.masking_key + frm = self.fr("wf:b'foo':k'abcd'", is_client=True) + assert frm.header.mask + assert frm.header.masking_key == 'abcd' + + # Server frame, mask explicitly set + frm = self.fr("wf:b'foo':mask") + assert frm.header.mask + assert frm.header.masking_key + frm = self.fr("wf:b'foo':k'abcd'") + assert frm.header.mask + assert frm.header.masking_key == 'abcd' + + # Client frame, mask explicitly unset + frm = self.fr("wf:b'foo':-mask", is_client=True) + assert not frm.header.mask + assert not frm.header.masking_key + + frm = self.fr("wf:b'foo':-mask:k'abcd'", is_client=True) + assert not frm.header.mask + # We're reading back a corrupted frame - the first 3 characters of the + # mask is mis-interpreted as the payload + assert frm.payload == "abc" diff --git a/test/tutils.py b/test/tutils.py index 2387e752..07252b53 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -141,6 +141,7 @@ test_data = utils.Data(__name__) def render(r, settings=language.Settings()): + r = r.resolve(settings) s = cStringIO.StringIO() assert language.serve(r, s, settings) return s.getvalue() |