aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libpathod/language.py46
-rw-r--r--test/test_language.py30
-rw-r--r--test/test_pathod.py3
3 files changed, 52 insertions, 27 deletions
diff --git a/libpathod/language.py b/libpathod/language.py
index 28976a29..e66f987f 100644
--- a/libpathod/language.py
+++ b/libpathod/language.py
@@ -619,7 +619,7 @@ class WS(_Component):
@classmethod
def expr(klass):
- spec = pp.Literal("ws")
+ spec = pp.CaselessLiteral("ws")
spec = spec.setParseAction(lambda x: klass(*x))
return spec
@@ -829,7 +829,6 @@ class InjectAt(_Action):
class _Message(object):
__metaclass__ = abc.ABCMeta
- version = "HTTP/1.1"
logattrs = []
def __init__(self, tokens):
@@ -917,16 +916,6 @@ class _Message(object):
ret["spec"] = self.spec()
return ret
- def values(self, settings):
- vals = self.preamble(settings)
- vals.append("\r\n")
- for h in self.headers:
- vals.extend(h.values(settings))
- vals.append("\r\n")
- if self.body:
- vals.append(self.body.value.get_generator(settings))
- return vals
-
def freeze(self, settings):
r = self.resolve(settings)
return self.__class__([i.freeze(settings) for i in r.tokens])
@@ -938,7 +927,21 @@ class _Message(object):
Sep = pp.Optional(pp.Literal(":")).suppress()
-class Response(_Message):
+class _HTTPMessage(_Message):
+ version = "HTTP/1.1"
+
+ def values(self, settings):
+ vals = self.preamble(settings)
+ vals.append("\r\n")
+ for h in self.headers:
+ vals.extend(h.values(settings))
+ vals.append("\r\n")
+ if self.body:
+ vals.append(self.body.value.get_generator(settings))
+ return vals
+
+
+class Response(_HTTPMessage):
comps = (
Body,
Header,
@@ -966,12 +969,8 @@ class Response(_Message):
def preamble(self, settings):
l = [self.version, " "]
- if self.code:
- l.extend(self.code.values(settings))
- code = int(self.code.code)
- elif self.ws:
- l.extend(Code(101).values(settings))
- code = 101
+ l.extend(self.code.values(settings))
+ code = int(self.code.code)
l.append(" ")
if self.reason:
l.extend(self.reason.values(settings))
@@ -1042,7 +1041,7 @@ class Response(_Message):
return ":".join([i.spec() for i in self.tokens])
-class Request(_Message):
+class Request(_HTTPMessage):
comps = (
Body,
Header,
@@ -1222,7 +1221,12 @@ def parse_requests(s):
try:
parts = pp.OneOrMore(
pp.Group(
- Request.expr()
+ pp.Or(
+ [
+ Request.expr(),
+ WebsocketFrame.expr(),
+ ]
+ )
)
).parseString(s, parseAll=True)
return [Request(i) for i in parts]
diff --git a/test/test_language.py b/test/test_language.py
index 28e26e10..919f5f65 100644
--- a/test/test_language.py
+++ b/test/test_language.py
@@ -14,6 +14,13 @@ def parse_request(s):
return language.parse_requests(s)[0]
+class TestWS:
+ def test_expr(self):
+ v = language.WS("foo")
+ assert v.expr()
+ assert v.values(language.Settings())
+
+
class TestValueNakedLiteral:
def test_expr(self):
v = language.ValueNakedLiteral("foo")
@@ -572,7 +579,6 @@ class TestRequest:
language.Settings(request_host = "foo.com")
)
-
def test_multiline(self):
l = """
GET
@@ -632,10 +638,18 @@ class TestRequest:
+class TestWebsocketFrame:
+
+ def test_spec(self):
+ e = language.WebsocketFrame.expr()
+ assert e.parseString("wf:foo")
+
+
class TestWriteValues:
+
def test_send_chunk(self):
v = "foobarfoobar"
- for bs in range(1, len(v)+2):
+ for bs in range(1, len(v) + 2):
s = cStringIO.StringIO()
language.send_chunk(s, v, bs, 0, len(v))
assert s.getvalue() == v
@@ -662,7 +676,7 @@ class TestWriteValues:
def test_write_values_disconnects(self):
s = cStringIO.StringIO()
- tst = "foo"*100
+ tst = "foo" * 100
language.write_values(s, [tst], [(0, "disconnect")], blocksize=5)
assert not s.getvalue()
@@ -675,14 +689,18 @@ class TestWriteValues:
for bs in range(1, len(tst) + 2):
for off in range(len(tst)):
s = cStringIO.StringIO()
- language.write_values(s, [tst], [(off, "disconnect")], blocksize=bs)
+ language.write_values(
+ s, [tst], [(off, "disconnect")], blocksize=bs
+ )
assert s.getvalue() == tst[:off]
def test_write_values_pauses(self):
tst = "".join(str(i) for i in range(10))
for i in range(2, 10):
s = cStringIO.StringIO()
- language.write_values(s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i)
+ language.write_values(
+ s, [tst], [(2, "pause", 0), (1, "pause", 0)], blocksize=i
+ )
assert s.getvalue() == tst
for i in range(2, 10):
@@ -690,7 +708,7 @@ class TestWriteValues:
language.write_values(s, [tst], [(1, "pause", 0)], blocksize=i)
assert s.getvalue() == tst
- tst = ["".join(str(i) for i in range(10))]*5
+ tst = ["".join(str(i) for i in range(10))] * 5
for i in range(2, 10):
s = cStringIO.StringIO()
language.write_values(s, tst[:], [(1, "pause", 0)], blocksize=i)
diff --git a/test/test_pathod.py b/test/test_pathod.py
index 18b546e4..1a10d2c2 100644
--- a/test/test_pathod.py
+++ b/test/test_pathod.py
@@ -188,6 +188,9 @@ class CommonTests(tutils.DaemonTests):
r = self.pathoc("ws:/p/")
assert r.status_code == 101
+ r = self.pathoc("ws:/p/ws")
+ assert r.status_code == 101
+
class TestDaemon(CommonTests):
ssl = False