diff options
| -rw-r--r-- | libmproxy/protocol/http.py | 65 | ||||
| -rw-r--r-- | test/test_server.py | 9 | 
2 files changed, 28 insertions, 46 deletions
| diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 711cb06c..31dd39f5 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -293,7 +293,8 @@ class HTTPRequest(HTTPMessage):              raise http.HttpError(400, "Invalid headers")          if include_body: -            content = http.read_http_body(rfile, headers, body_size_limit, True) +            content = http.read_http_body(rfile, headers, body_size_limit, +                                          method, None, True)              timestamp_end = utils.timestamp()          return HTTPRequest(form_in, method, scheme, host, port, path, httpversion, headers, @@ -305,7 +306,7 @@ class HTTPRequest(HTTPMessage):          if form == "relative":              path = self.path if self.method != "OPTIONS" else "*"              request_line = '%s %s HTTP/%s.%s' % \ -                (self.method, path, self.httpversion[0], self.httpversion[1]) +                           (self.method, path, self.httpversion[0], self.httpversion[1])          elif form == "authority":              request_line = '%s %s:%s HTTP/%s.%s' % (self.method, self.host, self.port,                                                      self.httpversion[0], self.httpversion[1]) @@ -634,9 +635,9 @@ class HTTPResponse(HTTPMessage):      def _assemble_headers(self, preserve_transfer_encoding=False):          headers = self.headers.copy() -        utils.del_all(headers,['Proxy-Connection']) +        utils.del_all(headers, ['Proxy-Connection'])          if not preserve_transfer_encoding: -            utils.del_all(headers,['Transfer-Encoding']) +            utils.del_all(headers, ['Transfer-Encoding'])          if self.content:              headers["Content-Length"] = [str(len(self.content))] @@ -646,7 +647,8 @@ class HTTPResponse(HTTPMessage):          return str(headers)      def _assemble_head(self, preserve_transfer_encoding=False): -        return '%s\r\n%s\r\n' % (self._assemble_first_line(), self._assemble_headers(preserve_transfer_encoding=preserve_transfer_encoding)) +        return '%s\r\n%s\r\n' % ( +            self._assemble_first_line(), self._assemble_headers(preserve_transfer_encoding=preserve_transfer_encoding))      def _assemble(self):          """ @@ -862,7 +864,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):              pass          self.c.close = True -    def get_response_from_server(self, request, stream=False): +    def get_response_from_server(self, request, include_body=True):          self.c.establish_server_connection()          request_raw = request._assemble() @@ -870,7 +872,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):              try:                  self.c.server_conn.send(request_raw)                  res = HTTPResponse.from_stream(self.c.server_conn.rfile, request.method, -                                                body_size_limit=self.c.config.body_size_limit, include_body=(not stream)) +                                               body_size_limit=self.c.config.body_size_limit, include_body=include_body)                  return res              except (tcp.NetLibDisconnect, http.HttpErrorConnClosed), v:                  self.c.log("error in server communication: %s" % str(v), level="debug") @@ -915,7 +917,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):              else:                  # read initially in "stream" mode, so we can get the headers separately -                flow.response = self.get_response_from_server(flow.request, stream=True) +                flow.response = self.get_response_from_server(flow.request, include_body=False)                  flow.response.stream = False                  # call the appropriate script hook - this is an opportunity for an inline script to set flow.stream = True @@ -923,7 +925,9 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):                  # now get the rest of the request body, if body still needs to be read but not streaming this response                  if not flow.response.stream and flow.response.content is None: -                    flow.response.content = http.read_http_body(self.c.server_conn.rfile, flow.response.headers, self.c.config.body_size_limit, False) +                    flow.response.content = http.read_http_body(self.c.server_conn.rfile, flow.response.headers, +                                                                self.c.config.body_size_limit, +                                                                flow.request.method, flow.response.code, False)              flow.server_conn = self.c.server_conn  # no further manipulation of self.c.server_conn beyond this point              # we can safely set it as the final attribute value here. @@ -933,8 +937,6 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):              if response_reply is None or response_reply == KILL:                  return False -            disconnected_while_streaming = False -              if flow.response.content is not None:                  # if not streaming or there is no body to be read, we'll already have the body, just send it                  self.c.client_conn.send(flow.response._assemble()) @@ -946,38 +948,19 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):                  h = flow.response._assemble_head(preserve_transfer_encoding=True)                  self.c.client_conn.send(h) -                # if chunked then we send back each chunk -                if http.has_chunked_encoding(flow.response.headers): -                    while 1: -                        content = http.read_next_chunk(self.c.server_conn.rfile, flow.response.headers, False) -                        if not http.write_chunk(self.c.client_conn.wfile, content): -                            break -                        self.c.client_conn.wfile.flush() +                for chunk in http.read_http_body_chunked(self.c.server_conn.rfile, +                                                         flow.response.headers, +                                                         self.c.config.body_size_limit, "GET", 200, False, 4096): +                    for part in chunk: +                        self.c.client_conn.wfile.write(part)                      self.c.client_conn.wfile.flush() -                else: # not chunked, we send back 4k at a time -                    clen = http.expected_http_body_size(flow.response.headers, False) -                    clen = clen if clen >= 0 else (64 * 1024 * 1024 * 1024) # arbitrary max of 64G if no length set -                    rcount = 0 -                    blocksize = 4096 -                    while 1: -                        bytes_to_read = min(blocksize, clen - rcount) -                        if bytes_to_read == 0: -                            break -                        content = self.c.server_conn.rfile.read(bytes_to_read) -                        if content == "": # check for EOF -                            disconnected_while_streaming = True -                            break -                        rcount += len(content) -                        self.c.client_conn.wfile.write(content) -                        self.c.client_conn.wfile.flush() -                        if rcount >= clen: # check for having read up to clen -                            break -              flow.timestamp_end = utils.timestamp() -            if (disconnected_while_streaming or http.connection_close(flow.request.httpversion, flow.request.headers) or -                    http.connection_close(flow.response.httpversion, flow.response.headers)): +            if (http.connection_close(flow.request.httpversion, flow.request.headers) or +                    http.connection_close(flow.response.httpversion, flow.response.headers) or +                        http.expected_http_body_size(flow.response.headers, False, flow.request.method, +                                                     flow.response.code) == -1):                  return False              if flow.request.form_in == "authority": @@ -1009,7 +992,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):              if flow.request and not flow.response:                  self.c.channel.ask("error", flow.error)          else: -            pass  #  FIXME: Do we want to persist errors without flows? +            pass  # FIXME: Do we want to persist errors without flows?          try:              self.send_error(code, message, headers) @@ -1109,7 +1092,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):              return True          raise http.HttpError(400, "Invalid HTTP request form (expected: %s, got: %s)" % -                                  (self.expected_form_in, request.form_in)) +                             (self.expected_form_in, request.form_in))      def authenticate(self, request):          if self.c.config.authenticator: diff --git a/test/test_server.py b/test/test_server.py index 71f00d96..d7318849 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -427,16 +427,15 @@ class TestStreamRequest(tservers.HTTPProxTest):          fconn = connection.makefile()          spec = '200:h"Transfer-Encoding"="chunked":r:b"4\\r\\nthis\\r\\n7\\r\\nisatest\\r\\n0\\r\\n\\r\\n"'          connection.send("GET %s/p/%s HTTP/1.1\r\n"%(self.server.urlbase, spec)) -        connection.send("\r\n"); +        connection.send("\r\n") -        httpversion, code, msg, headers, content = http.read_response(fconn, "GET", 100000, include_body=False) +        httpversion, code, msg, headers, content = http.read_response(fconn, "GET", None, include_body=False)          assert headers["Transfer-Encoding"][0] == 'chunked'          assert code == 200 -        assert http.read_next_chunk(fconn, headers, False) == "this" -        assert http.read_next_chunk(fconn, headers, False) == "isatest" -        assert http.read_next_chunk(fconn, headers, False) == None +        chunks = list(content for _, content, _ in http.read_http_body_chunked(fconn, headers, None, "GET", 200, False)) +        assert chunks == ["this", "isatest", ""]          connection.close() | 
