diff options
author | Maximilian Hils <git@maximilianhils.com> | 2015-08-29 14:28:11 +0200 |
---|---|---|
committer | Maximilian Hils <git@maximilianhils.com> | 2015-08-29 14:28:11 +0200 |
commit | 63844df34367bf7147c2d43a9e4061515f6430c9 (patch) | |
tree | 72d668c4af9bbfc95d1195f75bf61bcb4219dc9c /libmproxy | |
parent | 2dfba2105b4b5ad094ee364124c0552d2e4a4947 (diff) | |
download | mitmproxy-63844df34367bf7147c2d43a9e4061515f6430c9.tar.gz mitmproxy-63844df34367bf7147c2d43a9e4061515f6430c9.tar.bz2 mitmproxy-63844df34367bf7147c2d43a9e4061515f6430c9.zip |
fix streaming
Diffstat (limited to 'libmproxy')
-rw-r--r-- | libmproxy/protocol2/http.py | 192 |
1 files changed, 122 insertions, 70 deletions
diff --git a/libmproxy/protocol2/http.py b/libmproxy/protocol2/http.py index 792cf266..0fde9fb1 100644 --- a/libmproxy/protocol2/http.py +++ b/libmproxy/protocol2/http.py @@ -25,32 +25,101 @@ from netlib.http.http2 import HTTP2Protocol # TODO: The HTTP2 layer is missing multiplexing, which requires a major rewrite. -class Http1Layer(Layer): +class _HttpLayer(Layer): + supports_streaming = False + + def read_request(self): + raise NotImplementedError() + + def send_request(self, request): + raise NotImplementedError() + + def read_response(self, request_method): + raise NotImplementedError() + + def send_response(self, response): + raise NotImplementedError() + +class _StreamingHttpLayer(_HttpLayer): + supports_streaming = True + + def read_response_headers(self): + raise NotImplementedError + + def read_response_body(self, headers, request_method, response_code, max_chunk_size=None): + raise NotImplementedError() + yield "this is a generator" + + def send_response_headers(self, response): + raise NotImplementedError + + def send_response_body(self, response, chunks): + raise NotImplementedError() + + +class Http1Layer(_StreamingHttpLayer): + def __init__(self, ctx, mode): super(Http1Layer, self).__init__(ctx) self.mode = mode self.client_protocol = HTTP1Protocol(self.client_conn) self.server_protocol = HTTP1Protocol(self.server_conn) - def read_from_client(self): + def read_request(self): return HTTPRequest.from_protocol( self.client_protocol, body_size_limit=self.config.body_size_limit ) - def read_from_server(self, request_method): + def send_request(self, request): + self.server_conn.send(self.server_protocol.assemble(request)) + + def read_response(self, request_method): return HTTPResponse.from_protocol( self.server_protocol, - request_method, + request_method=request_method, body_size_limit=self.config.body_size_limit, - include_body=False, + include_body=True ) - def send_to_client(self, message): - self.client_conn.send(self.client_protocol.assemble(message)) + def send_response(self, response): + self.client_conn.send(self.client_protocol.assemble(response)) - def send_to_server(self, message): - self.server_conn.send(self.server_protocol.assemble(message)) + def read_response_headers(self): + return HTTPResponse.from_protocol( + self.server_protocol, + request_method=None, # does not matter if we don't read the body. + body_size_limit=self.config.body_size_limit, + include_body=False + ) + + def read_response_body(self, headers, request_method, response_code, max_chunk_size=None): + return self.server_protocol.read_http_body_chunked( + headers, + self.config.body_size_limit, + request_method, + response_code, + False, + max_chunk_size + ) + + def send_response_headers(self, response): + h = self.client_protocol._assemble_response_first_line(response) + self.client_conn.wfile.write(h+"\r\n") + h = self.client_protocol._assemble_response_headers( + response, + preserve_transfer_encoding=True + ) + self.client_conn.send(h+"\r\n") + + def send_response_body(self, response, chunks): + if self.client_protocol.has_chunked_encoding(response.headers): + chunks = ( + "%d\r\n%s\r\n" % (len(chunk), chunk) + for chunk in chunks + ) + for chunk in chunks: + self.client_conn.send(chunk) def connect(self): self.ctx.connect() @@ -69,14 +138,14 @@ class Http1Layer(Layer): layer() -class Http2Layer(Layer): +class Http2Layer(_HttpLayer): def __init__(self, ctx, mode): super(Http2Layer, self).__init__(ctx) self.mode = mode self.client_protocol = HTTP2Protocol(self.client_conn, is_server=True, unhandled_frame_cb=self.handle_unexpected_frame) self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) - def read_from_client(self): + def read_request(self): request = HTTPRequest.from_protocol( self.client_protocol, body_size_limit=self.config.body_size_limit @@ -84,23 +153,23 @@ class Http2Layer(Layer): self._stream_id = request.stream_id return request - def read_from_server(self, request_method): + def send_request(self, message): + # TODO: implement flow control and WINDOW_UPDATE frames + self.server_conn.send(self.server_protocol.assemble(message)) + + def read_response(self, request_method): return HTTPResponse.from_protocol( self.server_protocol, - request_method, + request_method=request_method, body_size_limit=self.config.body_size_limit, include_body=True, stream_id=self._stream_id ) - def send_to_client(self, message): + def send_response(self, message): # TODO: implement flow control and WINDOW_UPDATE frames self.client_conn.send(self.client_protocol.assemble(message)) - def send_to_server(self, message): - # TODO: implement flow control and WINDOW_UPDATE frames - self.server_conn.send(self.server_protocol.assemble(message)) - def connect(self): self.ctx.connect() self.server_protocol = HTTP2Protocol(self.server_conn, is_server=False, unhandled_frame_cb=self.handle_unexpected_frame) @@ -122,7 +191,7 @@ class Http2Layer(Layer): layer() def handle_unexpected_frame(self, frm): - print(frm.human_readable()) + self.log("Unexpected HTTP2 Frame: %s" % frm.human_readable(), "info") def make_error_response(status_code, message, headers=None): @@ -204,13 +273,13 @@ class UpstreamConnectLayer(Layer): def connect(self): if not self.server_conn: self.ctx.connect() - self.send_to_server(self.connect_request) + self.send_request(self.connect_request) else: pass # swallow the message def reconnect(self): self.ctx.reconnect() - self.send_to_server(self.connect_request) + self.send_request(self.connect_request) def set_server(self, address, server_tls=None, sni=None, depth=1): if depth == 1: @@ -240,7 +309,7 @@ class HttpLayer(Layer): flow = HTTPFlow(self.client_conn, self.server_conn, live=self) try: - request = self.read_from_client() + request = self.read_request() except tcp.NetLibError: # don't throw an error for disconnects that happen # before/between requests. @@ -280,7 +349,7 @@ class HttpLayer(Layer): except (HttpErrorConnClosed, NetLibError, HttpError, ProtocolException) as e: try: - self.send_to_client(make_error_response( + self.send_response(make_error_response( getattr(e, "code", 502), repr(e) )) @@ -295,7 +364,7 @@ class HttpLayer(Layer): def handle_regular_mode_connect(self, request): self.set_server((request.host, request.port)) - self.send_to_client(make_connect_response(request.httpversion)) + self.send_response(make_connect_response(request.httpversion)) layer = self.ctx.next_layer(self) layer() @@ -334,44 +403,33 @@ class HttpLayer(Layer): return close_connection def send_response_to_client(self, flow): - if not flow.response.stream: + if not (self.supports_streaming and flow.response.stream): # no streaming: # we already received the full response from the server and can # send it to the client straight away. - self.send_to_client(flow.response) + self.send_response(flow.response) else: # streaming: - # First send the headers and then transfer the response - # incrementally: - h = self.client_protocol._assemble_response_first_line(flow.response) - self.send_to_client(h + "\r\n") - h = self.client_protocol._assemble_response_headers(flow.response, preserve_transfer_encoding=True) - self.send_to_client(h + "\r\n") - - chunks = self.client_protocol.read_http_body_chunked( - flow.response.headers, - self.config.body_size_limit, - flow.request.method, - flow.response.code, - False, - 4096 + # First send the headers and then transfer the response incrementally + self.send_response_headers(flow.response) + chunks = self.read_response_body( + flow.response.headers, + flow.request.method, + flow.response.code, + max_chunk_size=4096 ) - if callable(flow.response.stream): chunks = flow.response.stream(chunks) - - for chunk in chunks: - for part in chunk: - # TODO: That's going to fail. - self.send_to_client(part) - self.client_conn.wfile.flush() - + self.send_response_body(flow.response, chunks) flow.response.timestamp_end = utils.timestamp() def get_response_from_server(self, flow): def get_response(): - self.send_to_server(flow.request) - flow.response = self.read_from_server(flow.request.method) + self.send_request(flow.request) + if self.supports_streaming: + flow.response = self.read_response_headers() + else: + flow.response = self.read_response() try: get_response() @@ -400,18 +458,15 @@ class HttpLayer(Layer): if flow is None or flow == KILL: raise Kill() - if isinstance(self.ctx, Http2Layer): - pass # streaming is not implemented for http2 yet. - elif flow.response.stream: - flow.response.content = CONTENT_MISSING - else: - flow.response.content = self.server_protocol.read_http_body( - flow.response.headers, - self.config.body_size_limit, - flow.request.method, - flow.response.code, - False - ) + if self.supports_streaming: + if flow.response.stream: + flow.response.content = CONTENT_MISSING + else: + flow.response.content = "".join(self.read_response_body( + flow.response.headers, + flow.request.method, + flow.response.code + )) flow.response.timestamp_end = utils.timestamp() # no further manipulation of self.server_conn beyond this point @@ -480,14 +535,14 @@ class HttpLayer(Layer): if self.server_conn.tls_established: self.reconnect() - self.send_to_server(make_connect_request(address)) + self.send_request(make_connect_request(address)) tls_layer = TlsLayer(self, False, True) tls_layer._establish_tls_with_server() """ def validate_request(self, request): if request.form_in == "absolute" and request.scheme != "http": - self.send_to_client(make_error_response(400, "Invalid request scheme: %s" % request.scheme)) + self.send_response(make_error_response(400, "Invalid request scheme: %s" % request.scheme)) raise HttpException("Invalid request scheme: %s" % request.scheme) expected_request_forms = { @@ -501,7 +556,7 @@ class HttpLayer(Layer): err_message = "Invalid HTTP request form (expected: %s, got: %s)" % ( " or ".join(allowed_request_forms), request.form_in ) - self.send_to_client(make_error_response(400, err_message)) + self.send_response(make_error_response(400, err_message)) raise HttpException(err_message) if self.mode == "regular": @@ -512,7 +567,7 @@ class HttpLayer(Layer): if self.config.authenticator.authenticate(request.headers): self.config.authenticator.clean(request.headers) else: - self.send_to_client(make_error_response( + self.send_response(make_error_response( 407, "Proxy Authentication Required", odict.ODictCaseless([[k,v] for k, v in self.config.authenticator.auth_challenge_headers().items()]) @@ -552,10 +607,7 @@ class RequestReplayThread(threading.Thread): if not self.flow.response: # In all modes, we directly connect to the server displayed if self.config.mode == "upstream": - # FIXME - server_address = self.config.mode.get_upstream_server( - self.flow.client_conn - )[2:] + server_address = self.config.upstream_server.address server = ServerConnection(server_address) server.connect() protocol = HTTP1Protocol(server) |