diff options
| author | Aldo Cortesi <aldo@nullcube.com> | 2012-06-19 10:42:25 +1200 | 
|---|---|---|
| committer | Aldo Cortesi <aldo@nullcube.com> | 2012-06-19 10:42:25 +1200 | 
| commit | c7e9051cbbee1e76abb24518268d30a24df3a16a (patch) | |
| tree | 3c65e2d95309027d0f6d4ce1f4cfef443a95e294 | |
| parent | b558997fd9db8406b2a24a1831d06e283dbf35a6 (diff) | |
| download | mitmproxy-c7e9051cbbee1e76abb24518268d30a24df3a16a.tar.gz mitmproxy-c7e9051cbbee1e76abb24518268d30a24df3a16a.tar.bz2 mitmproxy-c7e9051cbbee1e76abb24518268d30a24df3a16a.zip | |
Import wsgi.
| -rw-r--r-- | netlib/wsgi.py | 125 | ||||
| -rw-r--r-- | test/test_wsgi.py | 98 | 
2 files changed, 223 insertions, 0 deletions
| diff --git a/netlib/wsgi.py b/netlib/wsgi.py new file mode 100644 index 00000000..0608245c --- /dev/null +++ b/netlib/wsgi.py @@ -0,0 +1,125 @@ +import cStringIO, urllib, time, sys, traceback +import odict + +def date_time_string(): +    """Return the current date and time formatted for a message header.""" +    WEEKS = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'] +    MONTHS = [None, +                 'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', +                 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec'] +    now = time.time() +    year, month, day, hh, mm, ss, wd, y, z = time.gmtime(now) +    s = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % ( +            WEEKS[wd], +            day, MONTHS[month], year, +            hh, mm, ss) +    return s + + +class WSGIAdaptor: +    def __init__(self, app, domain, port, sversion): +        self.app, self.domain, self.port, self.sversion = app, domain, port, sversion + +    def make_environ(self, request, errsoc): +        if '?' in request.path: +            path_info, query = request.path.split('?', 1) +        else: +            path_info = request.path +            query = '' +        environ = { +            'wsgi.version':         (1, 0), +            'wsgi.url_scheme':      request.scheme, +            'wsgi.input':           cStringIO.StringIO(request.content), +            'wsgi.errors':          errsoc, +            'wsgi.multithread':     True, +            'wsgi.multiprocess':    False, +            'wsgi.run_once':        False, +            'SERVER_SOFTWARE':      self.sversion, +            'REQUEST_METHOD':       request.method, +            'SCRIPT_NAME':          '', +            'PATH_INFO':            urllib.unquote(path_info), +            'QUERY_STRING':         query, +            'CONTENT_TYPE':         request.headers.get('Content-Type', [''])[0], +            'CONTENT_LENGTH':       request.headers.get('Content-Length', [''])[0], +            'SERVER_NAME':          self.domain, +            'SERVER_PORT':          self.port, +            # FIXME: We need to pick up the protocol read from the request. +            'SERVER_PROTOCOL':      "HTTP/1.1", +        } +        if request.client_conn.address: +            environ["REMOTE_ADDR"], environ["REMOTE_PORT"] = request.client_conn.address + +        for key, value in request.headers.items(): +            key = 'HTTP_' + key.upper().replace('-', '_') +            if key not in ('HTTP_CONTENT_TYPE', 'HTTP_CONTENT_LENGTH'): +                environ[key] = value +        return environ + +    def error_page(self, soc, headers_sent, s): +        """ +            Make a best-effort attempt to write an error page. If headers are +            already sent, we just bung the error into the page. +        """ +        c = """ +            <html> +                <h1>Internal Server Error</h1> +                <pre>%s"</pre> +            </html> +        """%s +        if not headers_sent: +            soc.write("HTTP/1.1 500 Internal Server Error\r\n") +            soc.write("Content-Type: text/html\r\n") +            soc.write("Content-Length: %s\r\n"%len(c)) +            soc.write("\r\n") +        soc.write(c) + +    def serve(self, request, soc): +        state = dict( +            response_started = False, +            headers_sent = False, +            status = None, +            headers = None +        ) +        def write(data): +            if not state["headers_sent"]: +                soc.write("HTTP/1.1 %s\r\n"%state["status"]) +                h = state["headers"] +                if 'server' not in h: +                    h["Server"] = [version.NAMEVERSION] +                if 'date' not in h: +                    h["Date"] = [date_time_string()] +                soc.write(str(h)) +                soc.write("\r\n") +                state["headers_sent"] = True +            soc.write(data) +            soc.flush() + +        def start_response(status, headers, exc_info=None): +            if exc_info: +                try: +                    if state["headers_sent"]: +                        raise exc_info[0], exc_info[1], exc_info[2] +                finally: +                    exc_info = None +            elif state["status"]: +                raise AssertionError('Response already started') +            state["status"] = status +            state["headers"] = odict.ODictCaseless(headers) +            return write + +        errs = cStringIO.StringIO() +        try: +            dataiter = self.app(self.make_environ(request, errs), start_response) +            for i in dataiter: +                write(i) +            if not state["headers_sent"]: +                write("") +        except Exception, v: +            try: +                s = traceback.format_exc() +                self.error_page(soc, state["headers_sent"], s) +            except Exception, v:    # pragma: no cover +                pass                # pragma: no cover +        return errs.getvalue() + + diff --git a/test/test_wsgi.py b/test/test_wsgi.py new file mode 100644 index 00000000..c55ab1d8 --- /dev/null +++ b/test/test_wsgi.py @@ -0,0 +1,98 @@ +import cStringIO, sys +import libpry +from netlib import wsgi +import tutils + + +class TestApp: +    def __init__(self): +        self.called = False + +    def __call__(self, environ, start_response): +        self.called = True +        status = '200 OK' +        response_headers = [('Content-type', 'text/plain')] +        start_response(status, response_headers) +        return ['Hello', ' world!\n'] + + +class uWSGIAdaptor(libpry.AutoTree): +    def test_make_environ(self): +        w = wsgi.WSGIAdaptor(None, "foo", 80) +        tr = tutils.treq() +        assert w.make_environ(tr, None) + +        tr.path = "/foo?bar=voing" +        r = w.make_environ(tr, None) +        assert r["QUERY_STRING"] == "bar=voing" + +    def test_serve(self): +        ta = TestApp() +        w = wsgi.WSGIAdaptor(ta, "foo", 80) +        r = tutils.treq() +        r.host = "foo" +        r.port = 80 + +        wfile = cStringIO.StringIO() +        err = w.serve(r, wfile) +        assert ta.called +        assert not err + +        val = wfile.getvalue() +        assert "Hello world" in val +        assert "Server:" in val + +    def _serve(self, app): +        w = wsgi.WSGIAdaptor(app, "foo", 80) +        r = tutils.treq() +        r.host = "foo" +        r.port = 80 +        wfile = cStringIO.StringIO() +        err = w.serve(r, wfile) +        return wfile.getvalue() + +    def test_serve_empty_body(self): +        def app(environ, start_response): +            status = '200 OK' +            response_headers = [('Foo', 'bar')] +            start_response(status, response_headers) +            return [] +        assert self._serve(app) + +    def test_serve_double_start(self): +        def app(environ, start_response): +            try: +                raise ValueError("foo") +            except: +                ei = sys.exc_info() +            status = '200 OK' +            response_headers = [('Content-type', 'text/plain')] +            start_response(status, response_headers) +            start_response(status, response_headers) +        assert "Internal Server Error" in self._serve(app) + +    def test_serve_single_err(self): +        def app(environ, start_response): +            try: +                raise ValueError("foo") +            except: +                ei = sys.exc_info() +            status = '200 OK' +            response_headers = [('Content-type', 'text/plain')] +            start_response(status, response_headers, ei) +        assert "Internal Server Error" in self._serve(app) + +    def test_serve_double_err(self): +        def app(environ, start_response): +            try: +                raise ValueError("foo") +            except: +                ei = sys.exc_info() +            status = '200 OK' +            response_headers = [('Content-type', 'text/plain')] +            start_response(status, response_headers) +            yield "aaa" +            start_response(status, response_headers, ei) +            yield "bbb" +        assert "Internal Server Error" in self._serve(app) + | 
