diff options
| -rw-r--r-- | libpathod/app.py | 15 | ||||
| -rw-r--r-- | libpathod/language.py | 154 | ||||
| -rw-r--r-- | libpathod/pathoc.py | 18 | ||||
| -rw-r--r-- | libpathod/pathod.py | 20 | ||||
| -rw-r--r-- | libpathod/templates/docs_lang.html | 38 | ||||
| -rw-r--r-- | libpathod/templates/request_previewform.html | 6 | ||||
| -rw-r--r-- | libpathod/templates/response_previewform.html | 8 | ||||
| -rw-r--r-- | test/test_language.py | 60 | ||||
| -rw-r--r-- | test/test_pathoc.py | 2 | ||||
| -rw-r--r-- | test/test_pathod.py | 2 | 
10 files changed, 243 insertions, 80 deletions
| diff --git a/libpathod/app.py b/libpathod/app.py index 858c1d86..c3ce9991 100644 --- a/libpathod/app.py +++ b/libpathod/app.py @@ -1,6 +1,7 @@  import logging  import pprint  import cStringIO +import copy  from flask import Flask, jsonify, render_template, request, abort, make_response  import version, language, utils  from netlib import http_uastrings @@ -10,6 +11,7 @@ logging.basicConfig(level="DEBUG")  def make_app(noapi):      app = Flask(__name__) +    # app.debug = True      if not noapi:          @app.route('/api/info') @@ -144,20 +146,17 @@ def make_app(noapi):          c = app.config["pathod"].check_policy(              safe, -            app.config["pathod"].request_settings +            app.config["pathod"].settings          )          if c:              args["error"] = c              return render(template, False, **args)          if is_request: -            language.serve( -                safe, -                s, -                app.config["pathod"].request_settings, -                request_host = "example.com" -            ) +            set = copy.copy(app.config["pathod"].settings) +            set.request_host = "example.com" +            language.serve(safe, s, set)          else: -            language.serve(safe, s, app.config["pathod"].request_settings) +            language.serve(safe, s, app.config["pathod"].settings)          args["output"] = utils.escape_unprintables(s.getvalue())          return render(template, False, **args) diff --git a/libpathod/language.py b/libpathod/language.py index 29d2ade8..28976a29 100644 --- a/libpathod/language.py +++ b/libpathod/language.py @@ -15,6 +15,20 @@ BLOCKSIZE = 1024  TRUNCATE = 1024 +class Settings: +    def __init__( +        self, +        staticdir = None, +        unconstrained_file_access = False, +        request_host = None, +        websocket_key = None +    ): +        self.staticdir = staticdir +        self.unconstrained_file_access = unconstrained_file_access +        self.request_host = request_host +        self.websocket_key = websocket_key + +  def quote(s):      quotechar = s[0]      s = s[1:-1] @@ -22,7 +36,11 @@ def quote(s):      return quotechar + s + quotechar -class FileAccessDenied(Exception): +class RenderError(Exception): +    pass + + +class FileAccessDenied(RenderError):      pass @@ -97,7 +115,7 @@ def write_values(fp, vals, actions, sofar=0, blocksize=BLOCKSIZE):          return True -def serve(msg, fp, settings, **kwargs): +def serve(msg, fp, settings):      """          fp: The file pointer to write to. @@ -107,7 +125,7 @@ def serve(msg, fp, settings, **kwargs):          Calling this function may modify the object.      """ -    msg = msg.resolve(settings, **kwargs) +    msg = msg.resolve(settings)      started = time.time()      vals = msg.values(settings) @@ -351,15 +369,16 @@ class ValueFile(_Token):          return self      def get_generator(self, settings): -        uf = settings.get("unconstrained_file_access") -        sd = settings.get("staticdir") -        if not sd: +        if not settings.staticdir:              raise FileAccessDenied("File access disabled.") -        sd = os.path.normpath(os.path.abspath(sd)) +        sd = os.path.normpath(os.path.abspath(settings.staticdir))          s = os.path.expanduser(self.path) -        s = os.path.normpath(os.path.abspath(os.path.join(sd, s))) -        if not uf and not s.startswith(sd): +        s = os.path.normpath( +            os.path.abspath(os.path.join(settings.staticdir, s)) +        ) +        uf = settings.unconstrained_file_access +        if not uf and not s.startswith(settings.staticdir):              raise FileAccessDenied(                  "File access outside of configured directory"              ) @@ -594,9 +613,28 @@ class Path(_Component):          return Path(self.value.freeze(settings)) +class WS(_Component): +    def __init__(self, value): +        self.value = value + +    @classmethod +    def expr(klass): +        spec = pp.Literal("ws") +        spec = spec.setParseAction(lambda x: klass(*x)) +        return spec + +    def values(self, settings): +        return "ws" + +    def spec(self): +        return "ws" + +    def freeze(self, settings): +        return self + +  class Method(_Component):      methods = [ -        "ws",          "get",          "head",          "post", @@ -797,29 +835,35 @@ class _Message(object):      def __init__(self, tokens):          self.tokens = tokens -    def _get_tokens(self, klass): +    def toks(self, klass): +        """ +            Fetch all tokens that are instances of klass +        """          return [i for i in self.tokens if isinstance(i, klass)] -    def _get_token(self, klass): -        l = self._get_tokens(klass) +    def tok(self, klass): +        """ +            Fetch first token that is an instance of klass +        """ +        l = self.toks(klass)          if l:              return l[0]      @property      def raw(self): -        return bool(self._get_token(Raw)) +        return bool(self.tok(Raw))      @property      def actions(self): -        return self._get_tokens(_Action) +        return self.toks(_Action)      @property      def body(self): -        return self._get_token(Body) +        return self.tok(Body)      @property      def headers(self): -        return self._get_tokens(_Header) +        return self.toks(_Header)      def length(self, settings):          """ @@ -883,8 +927,8 @@ class _Message(object):              vals.append(self.body.value.get_generator(settings))          return vals -    def freeze(self, settings, **kwargs): -        r = self.resolve(settings, **kwargs) +    def freeze(self, settings): +        r = self.resolve(settings)          return self.__class__([i.freeze(settings) for i in r.tokens])      def __repr__(self): @@ -909,16 +953,25 @@ class Response(_Message):      logattrs = ["code", "reason", "version", "body"]      @property +    def ws(self): +        return self.tok(WS) + +    @property      def code(self): -        return self._get_token(Code) +        return self.tok(Code)      @property      def reason(self): -        return self._get_token(Reason) +        return self.tok(Reason)      def preamble(self, settings):          l = [self.version, " "] -        l.extend(self.code.values(settings)) +        if self.code: +            l.extend(self.code.values(settings)) +            code = int(self.code.code) +        elif self.ws: +            l.extend(Code(101).values(settings)) +            code = 101          l.append(" ")          if self.reason:              l.extend(self.reason.values(settings)) @@ -926,7 +979,7 @@ class Response(_Message):              l.append(                  LiteralGenerator(                      http_status.RESPONSES.get( -                        int(self.code.code), +                        code,                          "Unknown code"                      )                  ) @@ -935,6 +988,22 @@ class Response(_Message):      def resolve(self, settings):          tokens = self.tokens[:] +        if self.ws: +            if not settings.websocket_key: +                raise RenderError( +                    "No websocket key - have we seen a client handshake?" +                ) +            if not self.code: +                tokens.insert( +                    1, +                    Code(101) +                ) +            hdrs = websockets.server_handshake_headers(settings.websocket_key) +            for i in hdrs.lst: +                if not utils.get_header(i[0], self.headers): +                    tokens.append( +                        Header(ValueLiteral(i[0]), ValueLiteral(i[1])) +                    )          if not self.raw:              if not utils.get_header("Content-Length", self.headers):                  if not self.body: @@ -958,7 +1027,12 @@ class Response(_Message):          atom = pp.MatchFirst(parts)          resp = pp.And(              [ -                Code.expr(), +                pp.MatchFirst( +                    [ +                        WS.expr() + pp.Optional(Sep + Code.expr()), +                        Code.expr(), +                    ] +                ),                  pp.ZeroOrMore(Sep + atom)              ]          ) @@ -983,16 +1057,20 @@ class Request(_Message):      logattrs = ["method", "path", "body"]      @property +    def ws(self): +        return self.tok(WS) + +    @property      def method(self): -        return self._get_token(Method) +        return self.tok(Method)      @property      def path(self): -        return self._get_token(Path) +        return self.tok(Path)      @property      def pathodspec(self): -        return self._get_token(PathodSpec) +        return self.tok(PathodSpec)      def preamble(self, settings):          v = self.method.values(settings) @@ -1004,10 +1082,14 @@ class Request(_Message):          v.append(self.version)          return v -    def resolve(self, settings, **kwargs): +    def resolve(self, settings):          tokens = self.tokens[:] -        if self.method.string().lower() == "ws": -            tokens[0] = Method("get") +        if self.ws: +            if not self.method: +                tokens.insert( +                    1, +                    Method("get") +                )              for i in websockets.client_handshake_headers().lst:                  if not utils.get_header(i[0], self.headers):                      tokens.append( @@ -1023,13 +1105,12 @@ class Request(_Message):                              ValueLiteral(str(length)),                          )                      ) -            request_host = kwargs.get("request_host") -            if request_host: +            if settings.request_host:                  if not utils.get_header("Host", self.headers):                      tokens.append(                          Header(                              ValueLiteral("Host"), -                            ValueLiteral(request_host) +                            ValueLiteral(settings.request_host)                          )                      )          intermediate = self.__class__(tokens) @@ -1043,7 +1124,12 @@ class Request(_Message):          atom = pp.MatchFirst(parts)          resp = pp.And(              [ -                Method.expr(), +                pp.MatchFirst( +                    [ +                        WS.expr() + pp.Optional(Sep + Method.expr()), +                        Method.expr(), +                    ] +                ),                  Sep,                  Path.expr(),                  pp.ZeroOrMore(Sep + atom) diff --git a/libpathod/pathoc.py b/libpathod/pathoc.py index 0d8ec8f9..cf9be5b9 100644 --- a/libpathod/pathoc.py +++ b/libpathod/pathoc.py @@ -108,9 +108,10 @@ class Pathoc(tcp.TCPClient):              ignorecodes: Sequence of return codes to ignore          """          tcp.TCPClient.__init__(self, address) -        self.settings = dict( +        self.settings = language.Settings(              staticdir = os.getcwd(),              unconstrained_file_access = True, +            request_host = self.address.host          )          self.ssl, self.sni = ssl, sni          self.clientcert = clientcert @@ -201,15 +202,14 @@ class Pathoc(tcp.TCPClient):          if self.showresp:              self.rfile.start_log()          try: -            req = language.serve( -                r, -                self.wfile, -                self.settings, -                request_host = self.address.host -            ) +            req = language.serve(r, self.wfile, self.settings)              self.wfile.flush()              resp = list( -                http.read_response(self.rfile, r.method.string(), None) +                http.read_response( +                    self.rfile, +                    req["method"], +                    None +                )              )              resp.append(self.sslinfo)              resp = Response(*resp) @@ -290,7 +290,7 @@ def main(args): # pragma: nocover              )              if args.explain or args.memo:                  playlist = [ -                    i.freeze(p.settings, request_host=p.address.host) for i in playlist +                    i.freeze(p.settings) for i in playlist                  ]              if args.memo:                  newlist = [] diff --git a/libpathod/pathod.py b/libpathod/pathod.py index 1c23baae..0c626777 100644 --- a/libpathod/pathod.py +++ b/libpathod/pathod.py @@ -66,10 +66,10 @@ class PathodHandler(tcp.BaseHandler):          self.sni = connection.get_servername()      def serve_crafted(self, crafted): -        c = self.server.check_policy(crafted, self.server.request_settings) +        c = self.server.check_policy(crafted, self.server.settings)          if c:              err = language.make_error_response(c) -            language.serve(err, self.wfile, self.server.request_settings) +            language.serve(err, self.wfile, self.server.settings)              log = dict(                  type="error",                  msg=c @@ -77,12 +77,12 @@ class PathodHandler(tcp.BaseHandler):              return False, log          if self.server.explain and not isinstance(crafted, language.PathodErrorResponse): -            crafted = crafted.freeze(self.server.request_settings) +            crafted = crafted.freeze(self.server.settings)              self.info(">> Spec: %s" % crafted.spec())          response_log = language.serve(              crafted,              self.wfile, -            self.server.request_settings +            self.server.settings          )          if response_log["disconnect"]:              return False, response_log @@ -199,7 +199,7 @@ class PathodHandler(tcp.BaseHandler):              return again, retlog          elif self.server.noweb:              crafted = language.make_error_response("Access Denied") -            language.serve(crafted, self.wfile, self.server.request_settings) +            language.serve(crafted, self.wfile, self.server.settings)              return False, dict(                  type="error",                  msg="Access denied: web interface disabled" @@ -323,6 +323,10 @@ class Pathod(tcp.TCPServer):          self.logid = 0          self.anchors = anchors +        self.settings = language.Settings( +            staticdir = self.staticdir +        ) +      def check_policy(self, req, settings):          """              A policy check that verifies the request size is withing limits. @@ -337,12 +341,6 @@ class Pathod(tcp.TCPServer):              return "Pauses have been disabled."          return False -    @property -    def request_settings(self): -        return dict( -            staticdir=self.staticdir -        ) -      def handle_client_connection(self, request, client_address):          h = PathodHandler(request, client_address, self)          try: diff --git a/libpathod/templates/docs_lang.html b/libpathod/templates/docs_lang.html index 4ed7f151..e67b13c5 100644 --- a/libpathod/templates/docs_lang.html +++ b/libpathod/templates/docs_lang.html @@ -11,6 +11,7 @@  <ul class="nav nav-tabs">    <li class="active"><a href="#specifying_responses" data-toggle="tab">Responses</a></li>    <li><a href="#specifying_requests" data-toggle="tab">Requests</a></li> +  <li><a href="#websockets" data-toggle="tab">Websockets</a></li>  </ul>  <div class="tab-content"> @@ -199,6 +200,43 @@          </table>      </div> +    <div class="tab-pane" id="websockets"> + +        <p>Requests and responses can be decorated with the <b>ws</b> prefix to +        create a websockets client or server handshake. Since the websocket +        specifier implies a request method (GET) and a response code (102), +        these can optionally be omitted. All other request and response +        features can be applied, and websocket-specific headers can be +        over-ridden explicitly.</p> + +        <h2>Request</h2> + +        <pre class="example">ws:[method:]path:[colon-separated list of features]</pre></p> + +        <p>This will generate a wsocket client handshake with a GET method:</p> + +        <pre class="example">ws:/</pre></p> + +        <p>This will do the same, but using the (invalid) PUT method:</p> + +        <pre class="example">ws:put:/</pre></p> + + +        <h2>Response</h2> + +        <pre class="example">ws[:code:][colon-separated list of features]</pre></p> + +        <p>This will generate a simple protocol acceptance with a 101 response +        code:</p> + +        <pre class="example">ws</pre></p> + +        <p>This will do the same, but using the (invalid) 202 code:</p> + +        <pre class="example">ws:202</pre></p> + +    </div> +  </div> diff --git a/libpathod/templates/request_previewform.html b/libpathod/templates/request_previewform.html index 607bfefd..d3083735 100644 --- a/libpathod/templates/request_previewform.html +++ b/libpathod/templates/request_previewform.html @@ -1,5 +1,5 @@  <form style="margin-bottom: 0" class="form-inline" method="GET" action="/request_preview"> -    <input  +    <input          style="width: 18em"          id="spec"          name="spec" @@ -46,6 +46,10 @@                  <td><a href="/request_preview?spec=get:/:b@100,ascii:ir,@1">get:/:b@100,ascii:ir,@1</a></td>                  <td>100 ASCII bytes as the body, and randomly inject a random byte</td>              </tr> +            <tr> +                <td><a href="/request_preview?spec=ws:/">ws:/</a></td> +                <td>Initiate a websocket handshake.</td> +            </tr>          </tbody>      </table>  </div> diff --git a/libpathod/templates/response_previewform.html b/libpathod/templates/response_previewform.html index fbc3de5a..28551015 100644 --- a/libpathod/templates/response_previewform.html +++ b/libpathod/templates/response_previewform.html @@ -1,5 +1,5 @@  <form style="margin-bottom: 0" class="form-inline" method="GET" action="/response_preview"> -    <input  +    <input          style="width: 18em"          id="spec"          name="spec" @@ -68,6 +68,12 @@                  </td>                  <td>100 ASCII bytes as the body, randomly generated 100k header name, with the value 'foo'.</td>              </tr> +            <tr> +                <td> +                    <a href="/response_preview?spec=ws">ws</a> +                </td> +                <td>A websocket connection acceptance handshake.</td> +            </tr>          </tbody>      </table>  </div> diff --git a/test/test_language.py b/test/test_language.py index 4dd3d8ac..28e26e10 100644 --- a/test/test_language.py +++ b/test/test_language.py @@ -101,7 +101,7 @@ class TestValueGenerate:      def test_freeze(self):          v = language.ValueGenerate(100, "b", "ascii") -        f = v.freeze({}) +        f = v.freeze(language.Settings())          assert len(f.val) == 100 @@ -121,16 +121,26 @@ class TestValueFile:              with open(p, "wb") as f:                  f.write("x" * 10000) -            assert v.get_generator(dict(staticdir=t)) +            assert v.get_generator(language.Settings(staticdir=t))              v = language.Value.parseString("<path2")[0]              tutils.raises( -                language.FileAccessDenied, v.get_generator, dict(staticdir=t) +                language.FileAccessDenied, +                v.get_generator, +                language.Settings(staticdir=t) +            ) +            tutils.raises( +                "access disabled", +                v.get_generator, +                language.Settings()              ) -            tutils.raises("access disabled", v.get_generator, dict())              v = language.Value.parseString("</outside")[0] -            tutils.raises("outside", v.get_generator, dict(staticdir=t)) +            tutils.raises( +                "outside", +                v.get_generator, +                language.Settings(staticdir=t) +            )      def test_spec(self):          v = language.Value.parseString("<'one two'")[0] @@ -556,7 +566,12 @@ class TestRequest:      def test_render(self):          s = cStringIO.StringIO()          r = parse_request("GET:'/foo'") -        assert language.serve(r, s, {}, request_host = "foo.com") +        assert language.serve( +            r, +            s, +            language.Settings(request_host = "foo.com") +        ) +      def test_multiline(self):          l = """ @@ -593,17 +608,28 @@ class TestRequest:          rt("get:/foo:da")      def test_freeze(self): -        r = parse_request("GET:/:b@100").freeze({}) +        r = parse_request("GET:/:b@100").freeze(language.Settings())          assert len(r.spec()) > 100      def test_path_generator(self): -        r = parse_request("GET:@100").freeze({}) +        r = parse_request("GET:@100").freeze(language.Settings())          assert len(r.spec()) > 100      def test_websocket(self): -        r = parse_request('ws:"/foo"') -        res = r.resolve({}) -        assert utils.get_header("upgrade", res.headers) +        r = parse_request('ws:/path/') +        res = r.resolve(language.Settings()) +        assert res.method.string().lower() == "get" +        assert res.tok(language.Path).value.val == "/path/" +        assert res.tok(language.Method).value.val.lower() == "get" +        assert utils.get_header("Upgrade", res.headers).value.val == "websocket" + +        r = parse_request('ws:put:/path/') +        res = r.resolve(language.Settings()) +        assert r.method.string().lower() == "put" +        assert res.tok(language.Path).value.val == "/path/" +        assert res.tok(language.Method).value.val.lower() == "put" +        assert utils.get_header("Upgrade", res.headers).value.val == "websocket" +  class TestWriteValues: @@ -725,13 +751,13 @@ class TestResponse:          r = language.parse_response("400:b'foo':r")          language.serve(r, s, {})          v = s.getvalue() -        assert not "Content-Length" in v +        assert "Content-Length" not in v      def test_length(self):          def testlen(x):              s = cStringIO.StringIO() -            language.serve(x, s, {}) -            assert x.length({}) == len(s.getvalue()) +            language.serve(x, s, language.Settings()) +            assert x.length(language.Settings()) == len(s.getvalue())          testlen(language.parse_response("400:m'msg':r"))          testlen(language.parse_response("400:m'msg':h'foo'='bar':r"))          testlen(language.parse_response("400:m'msg':h'foo'='bar':b@100b:r")) @@ -797,6 +823,12 @@ class TestResponse:          rt("400")          rt("400:da") +    def test_websockets(self): +        r = language.parse_response("ws") +        tutils.raises("no websocket key", r.resolve, language.Settings()) +        res = r.resolve(language.Settings(websocket_key="foo")) +        assert res.code.string() == "101" +  def test_read_file():      tutils.raises(language.FileAccessDenied, language.read_file, {}, "=/foo") diff --git a/test/test_pathoc.py b/test/test_pathoc.py index 52075231..e14450b9 100644 --- a/test/test_pathoc.py +++ b/test/test_pathoc.py @@ -74,7 +74,7 @@ class _TestDaemon:          for i in requests:              r = language.parse_requests(i)[0]              if explain: -                r = r.freeze({}) +                r = r.freeze(language.Settings())              try:                  c.request(r)              except (http.HttpError, tcp.NetLibError), v: diff --git a/test/test_pathod.py b/test/test_pathod.py index c32f6e84..00634e27 100644 --- a/test/test_pathod.py +++ b/test/test_pathod.py @@ -13,7 +13,7 @@ class TestPathod(object):          p.clear_log()          assert len(p.get_log()) == 0 -        for i in range(p.LOGBUF + 1): +        for _ in range(p.LOGBUF + 1):              p.add_log(dict(s="foo"))          assert len(p.get_log()) <= p.LOGBUF | 
