diff options
| -rw-r--r-- | libmproxy/proxy.py | 46 | ||||
| -rw-r--r-- | test/test_server.py | 59 | ||||
| -rw-r--r-- | test/tservers.py | 6 | 
3 files changed, 76 insertions, 35 deletions
| diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py index a6a72d55..458ea2b5 100644 --- a/libmproxy/proxy.py +++ b/libmproxy/proxy.py @@ -326,11 +326,11 @@ class ProxyHandler(tcp.BaseHandler):          if not self.ssl_established and (port in self.config.transparent_proxy["sslports"]):              scheme = "https"              dummycert = self.find_cert(client_conn, host, port, host) +            sni = HandleSNI( +                self, client_conn, host, port, +                dummycert, self.config.certfile or self.config.cacert +            )              try: -                sni = HandleSNI( -                    self, client_conn, host, port, -                    dummycert, self.config.certfile or self.config.cacert -                )                  self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni)              except tcp.NetLibError, v:                  raise ProxyError(400, str(v)) @@ -356,31 +356,29 @@ class ProxyHandler(tcp.BaseHandler):          line = self.get_line(self.rfile)          if line == "":              return None -        if http.parse_init_connect(line): -            r = http.parse_init_connect(line) -            if not r: -                raise ProxyError(400, "Bad HTTP request line: %s"%repr(line)) -            host, port, httpversion = r - -            headers = self.read_headers(authenticate=True) -            self.wfile.write( -                        'HTTP/1.1 200 Connection established\r\n' + -                        ('Proxy-agent: %s\r\n'%self.server_version) + -                        '\r\n' -                        ) -            self.wfile.flush() -            dummycert = self.find_cert(client_conn, host, port, host) -            try: +        if not self.proxy_connect_state: +            connparts = http.parse_init_connect(line) +            if connparts: +                host, port, httpversion = connparts +                headers = self.read_headers(authenticate=True) +                self.wfile.write( +                            'HTTP/1.1 200 Connection established\r\n' + +                            ('Proxy-agent: %s\r\n'%self.server_version) + +                            '\r\n' +                            ) +                self.wfile.flush() +                dummycert = self.find_cert(client_conn, host, port, host)                  sni = HandleSNI(                      self, client_conn, host, port,                      dummycert, self.config.certfile or self.config.cacert                  ) -                self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) -            except tcp.NetLibError, v: -                raise ProxyError(400, str(v)) -            self.proxy_connect_state = (host, port, httpversion) -            line = self.rfile.readline(line) +                try: +                    self.convert_to_ssl(dummycert, self.config.certfile or self.config.cacert, handle_sni=sni) +                except tcp.NetLibError, v: +                    raise ProxyError(400, str(v)) +                self.proxy_connect_state = (host, port, httpversion) +                line = self.rfile.readline(line)          if self.proxy_connect_state:              r = http.parse_init_http(line) diff --git a/test/test_server.py b/test/test_server.py index 86a75452..3a1b019f 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -45,6 +45,12 @@ class CommonMixin:          assert "host" in l.request.headers          assert l.response.code == 304 +    def test_invalid_http(self): +        t = tcp.TCPClient("127.0.0.1", self.proxy.port) +        t.connect() +        t.wfile.write("invalid\r\n\r\n") +        t.wfile.flush() +        assert "Bad Request" in t.rfile.readline()  class TestHTTP(tservers.HTTPProxTest, CommonMixin): @@ -54,13 +60,6 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin):          assert ret.status_code == 500          assert "ValueError" in ret.content -    def test_invalid_http(self): -        t = tcp.TCPClient("127.0.0.1", self.proxy.port) -        t.connect() -        t.wfile.write("invalid\n\n") -        t.wfile.flush() -        assert "Bad Request" in t.rfile.readline() -      def test_invalid_connect(self):          t = tcp.TCPClient("127.0.0.1", self.proxy.port)          t.connect() @@ -125,6 +124,25 @@ class TestHTTP(tservers.HTTPProxTest, CommonMixin):          ret = p.request("get:'http://localhost:0'")          assert ret.status_code == 502 +    def test_blank_leading_line(self): +        p = self.pathoc() +        req = "get:'%s/p/201':i0,'\r\n'" +        assert p.request(req%self.server.urlbase).status_code == 201 + +    def test_invalid_headers(self): +        p = self.pathoc() +        req = p.request("get:'http://foo':h':foo'='bar'") +        print req + + +class TestHTTPConnectSSLError(tservers.HTTPProxTest): +    certfile = True +    def test_go(self): +        p = self.pathoc() +        req = "connect:'localhost:%s'"%self.proxy.port +        assert p.request(req).status_code == 200 +        assert p.request(req).status_code == 400 +  class TestHTTPS(tservers.HTTPProxTest, CommonMixin):      ssl = True @@ -140,6 +158,11 @@ class TestHTTPS(tservers.HTTPProxTest, CommonMixin):          l = self.server.last_log()          assert self.server.last_log()["request"]["sni"] == "testserver.com" +    def test_error_post_connect(self): +        p = self.pathoc() +        assert p.request("get:/:i0,'invalid\r\n\r\n'").status_code == 400 + +  class TestHTTPSNoUpstream(tservers.HTTPProxTest, CommonMixin):      ssl = True @@ -163,12 +186,10 @@ class TestReverse(tservers.ReverseProxTest, CommonMixin):  class TestTransparent(tservers.TransparentProxTest, CommonMixin): -    transparent = True      ssl = False  class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin): -    transparent = True      ssl = True      def test_sni(self):          f = self.pathod("304", sni="testserver.com") @@ -176,6 +197,10 @@ class TestTransparentSSL(tservers.TransparentProxTest, CommonMixin):          l = self.server.last_log()          assert self.server.last_log()["request"]["sni"] == "testserver.com" +    def test_sslerr(self): +        p = pathoc.Pathoc("localhost", self.proxy.port) +        p.connect() +        assert p.request("get:/").status_code == 400  class TestProxy(tservers.HTTPProxTest): @@ -267,3 +292,19 @@ class TestKillResponse(tservers.HTTPProxTest):          # The server should have seen a request          assert self.server.last_log() + +class EResolver(tservers.TResolver): +    def original_addr(self, sock): +        return None + + +class TestTransparentResolveError(tservers.TransparentProxTest): +    resolver = EResolver +    def test_resolve_error(self): +        assert self.pathod("304").status_code == 502 + + + + + + diff --git a/test/tservers.py b/test/tservers.py index 4efed7e2..7672f34a 100644 --- a/test/tservers.py +++ b/test/tservers.py @@ -131,7 +131,7 @@ class ProxTestBase:  class HTTPProxTest(ProxTestBase):      def pathoc_raw(self):          return libpathod.pathoc.Pathoc("127.0.0.1", self.proxy.port) -     +      def pathoc(self, sni=None):          """              Returns a connected Pathoc instance. @@ -148,6 +148,7 @@ class HTTPProxTest(ProxTestBase):              Constructs a pathod GET request, with the appropriate base and proxy.          """          p = self.pathoc(sni=sni) +        spec = spec.encode("string_escape")          if self.ssl:              q = "get:'/p/%s'"%spec          else: @@ -165,6 +166,7 @@ class TResolver:  class TransparentProxTest(ProxTestBase):      ssl = None +    resolver = TResolver      @classmethod      def get_proxy_config(cls):          d = ProxTestBase.get_proxy_config() @@ -173,7 +175,7 @@ class TransparentProxTest(ProxTestBase):          else:              ports = []          d["transparent_proxy"] = dict( -            resolver = TResolver(cls.server.port), +            resolver = cls.resolver(cls.server.port),              sslports = ports          )          return d | 
