diff options
-rw-r--r-- | libmproxy/protocol/__init__.py | 8 | ||||
-rw-r--r-- | libmproxy/protocol/http.py | 49 | ||||
-rw-r--r-- | libmproxy/stateobject.py | 12 | ||||
-rw-r--r-- | test/test_flow.py | 8 | ||||
-rw-r--r-- | test/test_protocol_http.py | 86 | ||||
-rw-r--r-- | test/test_script.py | 3 | ||||
-rw-r--r-- | test/tutils.py | 12 |
7 files changed, 147 insertions, 31 deletions
diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py index 4c72ad48..580d693c 100644 --- a/libmproxy/protocol/__init__.py +++ b/libmproxy/protocol/__init__.py @@ -15,19 +15,19 @@ class ProtocolHandler(object): self.c = c """@type: libmproxy.proxy.ConnectionHandler""" - def handle_messages(self): # pragma: nocover + def handle_messages(self): """ This method gets called if a client connection has been made. Depending on the proxy settings, a server connection might already exist as well. """ - raise NotImplementedError + raise NotImplementedError # pragma: nocover - def handle_error(self, error): # pragma: nocover + def handle_error(self, error): """ This method gets called should there be an uncaught exception during the connection. This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode. """ - raise error + raise error # pragma: nocover class TemporaryServerChangeMixin(object): diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py index 21624513..8dcc21d7 100644 --- a/libmproxy/protocol/http.py +++ b/libmproxy/protocol/http.py @@ -56,6 +56,7 @@ class HTTPMessage(stateobject.SimpleStateObject): def __init__(self, httpversion, headers, content, timestamp_start=None, timestamp_end=None): self.httpversion = httpversion self.headers = headers + """@type: ODictCaseless""" self.content = content self.timestamp_start = timestamp_start self.timestamp_end = timestamp_end @@ -142,31 +143,31 @@ class HTTPMessage(stateobject.SimpleStateObject): """ Parse an HTTP message from a file stream """ - raise NotImplementedError + raise NotImplementedError # pragma: nocover def _assemble_first_line(self): """ Returns the assembled request/response line """ - raise NotImplementedError + raise NotImplementedError # pragma: nocover def _assemble_headers(self): """ Returns the assembled headers """ - raise NotImplementedError + raise NotImplementedError # pragma: nocover def _assemble_head(self): """ Returns the assembled request/response line plus headers """ - raise NotImplementedError + raise NotImplementedError # pragma: nocover def _assemble(self): """ Returns the assembled request/response """ - raise NotImplementedError + raise NotImplementedError # pragma: nocover class HTTPRequest(HTTPMessage): @@ -253,9 +254,15 @@ class HTTPRequest(HTTPMessage): httpversion, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \ = None, None, None, None, None, None, None, None, None, None - rfile.reset_timestamps() + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() + request_line = get_line(rfile) - timestamp_start = rfile.first_byte_timestamp + + if hasattr(rfile, "first_byte_timestamp"): + timestamp_start = rfile.first_byte_timestamp + else: + timestamp_start = utils.timestamp() request_line_parts = http.parse_init(request_line) if not request_line_parts: @@ -597,14 +604,21 @@ class HTTPResponse(HTTPMessage): Parse an HTTP response from a file stream """ if not include_content: - raise NotImplementedError + raise NotImplementedError # pragma: nocover + + if hasattr(rfile, "reset_timestamps"): + rfile.reset_timestamps() - rfile.reset_timestamps() httpversion, code, msg, headers, content = http.read_response( rfile, request_method, body_size_limit) - timestamp_start = rfile.first_byte_timestamp + + if hasattr(rfile, "first_byte_timestamp"): + timestamp_start = rfile.first_byte_timestamp + else: + timestamp_start = utils.timestamp() + timestamp_end = utils.timestamp() return HTTPResponse(httpversion, code, msg, headers, content, timestamp_start, timestamp_end) @@ -661,7 +675,7 @@ class HTTPResponse(HTTPMessage): # This can happen when the expires tag is invalid. # reddit.com sends a an expires tag like this: "Thu, 31 Dec # 2037 23:59:59 GMT", which is valid RFC 1123, but not - # strictly correct according tot he cookie spec. Browsers + # strictly correct according to the cookie spec. Browsers # appear to parse this tolerantly - maybe we should too. # For now, we just ignore this. del i["expires"] @@ -824,7 +838,7 @@ class HttpAuthenticationError(Exception): def __init__(self, auth_headers=None): self.auth_headers = auth_headers - def __str__(self): # pragma: nocover + def __str__(self): return "HttpAuthenticationError" @@ -907,9 +921,12 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin): def handle_error(self, error, flow=None): code, message, headers = None, None, None if isinstance(error, HttpAuthenticationError): - code, message, headers = 407, "Proxy Authentication Required", error.auth_headers + code = 407 + message = "Proxy Authentication Required" + headers = error.auth_headers elif isinstance(error, (http.HttpError, ProxyError)): - code, message = error.code, error.msg + code = error.code + message = error.msg elif isinstance(error, tcp.NetLibError): code = 502 message = error.message or error.__class__ @@ -917,7 +934,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin): if code: err = "%s: %s" % (code, message) else: - err = message + err = error.__class__ self.c.log("error: %s" % err) @@ -1010,7 +1027,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin): self.ssl_upgrade() # raises ConnectionTypeChange exception if self.c.mode == "regular": - if request.form_in == "authority": + if request.form_in == "authority": # forward mode pass elif request.form_in == "absolute": if request.scheme != "http": diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py index 7d91519d..a752999d 100644 --- a/libmproxy/stateobject.py +++ b/libmproxy/stateobject.py @@ -1,13 +1,13 @@ class StateObject(object): - def _get_state(self): # pragma: nocover - raise NotImplementedError + def _get_state(self): + raise NotImplementedError # pragma: nocover - def _load_state(self, state): # pragma: nocover - raise NotImplementedError + def _load_state(self, state): + raise NotImplementedError # pragma: nocover @classmethod - def _from_state(cls, state): # pragma: nocover - raise NotImplementedError + def _from_state(cls, state): + raise NotImplementedError # pragma: nocover # Usually, this function roughly equals to the following code: # f = cls() # f._load_state(state) diff --git a/test/test_flow.py b/test/test_flow.py index 006dfe51..65e153ea 100644 --- a/test/test_flow.py +++ b/test/test_flow.py @@ -1047,7 +1047,7 @@ class TestResponse: r = tutils.tresp() r.headers["content-encoding"] = ["identity"] r.content = "falafel" - r.decode() + assert r.decode() assert not r.headers["content-encoding"] assert r.content == "falafel" @@ -1064,10 +1064,14 @@ class TestResponse: r.encode("gzip") assert r.headers["content-encoding"] == ["gzip"] assert r.content != "falafel" - r.decode() + assert r.decode() assert not r.headers["content-encoding"] assert r.content == "falafel" + r.headers["content-encoding"] = ["gzip"] + assert not r.decode() + assert r.content == "falafel" + def test_header_size(self): r = tutils.tresp() result = len(r._assemble_headers()) diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py new file mode 100644 index 00000000..9e043049 --- /dev/null +++ b/test/test_protocol_http.py @@ -0,0 +1,86 @@ +from libmproxy import proxy # FIXME: Remove +from libmproxy.protocol.http import * +from cStringIO import StringIO +import tutils + + +def test_HttpAuthenticationError(): + x = HttpAuthenticationError({"foo": "bar"}) + assert str(x) + assert "foo" in x.auth_headers + + +def test_stripped_chunked_encoding_no_content(): + """ + https://github.com/mitmproxy/mitmproxy/issues/186 + """ + r = tutils.tresp(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in r._assemble_headers() + + r = tutils.treq(content="") + r.headers["Transfer-Encoding"] = ["chunked"] + assert "Content-Length" in r._assemble_headers() + + +class TestHTTPRequest: + def test_asterisk_form(self): + s = StringIO("OPTIONS * HTTP/1.1") + f = tutils.tflow_noreq() + f.request = HTTPRequest.from_stream(s) + assert f.request.form_in == "asterisk" + x = f.request._assemble() + assert f.request._assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n" + + def test_origin_form(self): + s = StringIO("GET /foo\xff HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + + def test_authority_form(self): + s = StringIO("CONNECT oops-no-port.com HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + s = StringIO("CONNECT address:22 HTTP/1.1") + r = HTTPRequest.from_stream(s) + assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n" + + + def test_absolute_form(self): + s = StringIO("GET oops-no-protocol.com HTTP/1.1") + tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s) + s = StringIO("GET http://address:22/ HTTP/1.1") + r = HTTPRequest.from_stream(s) + assert r._assemble() == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\n\r\n" + + def test_assemble_unknown_form(self): + r = tutils.treq() + tutils.raises("Invalid request form", r._assemble, "antiauthority") + + + def test_set_url(self): + r = tutils.treq_absolute() + r.set_url("https://otheraddress:42/ORLY") + assert r.scheme == "https" + assert r.host == "otheraddress" + assert r.port == 42 + assert r.path == "/ORLY" + + +class TestHTTPResponse: + def test_read_from_stringio(self): + _s = "HTTP/1.1 200 OK\r\n" \ + "Content-Length: 7\r\n" \ + "\r\n"\ + "content\r\n" \ + "HTTP/1.1 204 OK\r\n" \ + "\r\n" + s = StringIO(_s) + r = HTTPResponse.from_stream(s, "GET") + assert r.code == 200 + assert r.content == "content" + assert HTTPResponse.from_stream(s, "GET").code == 204 + + s = StringIO(_s) + r = HTTPResponse.from_stream(s, "HEAD") + assert r.code == 200 + assert r.content == "" + tutils.raises("Invalid server response: 'content", HTTPResponse.from_stream, s, "GET")
\ No newline at end of file diff --git a/test/test_script.py b/test/test_script.py index 2e48081b..2999a910 100644 --- a/test/test_script.py +++ b/test/test_script.py @@ -75,9 +75,6 @@ class TestScript: # Two instantiations assert m.call_count == 2 assert (time.time() - t_start) < 0.09 - time.sleep(0.3 - (time.time() - t_start)) - # Plus two invocations - assert m.call_count == 4 def test_concurrent2(self): s = flow.State() diff --git a/test/tutils.py b/test/tutils.py index ad2960d9..75fb7c0b 100644 --- a/test/tutils.py +++ b/test/tutils.py @@ -39,6 +39,14 @@ def tserver_conn(): return c +def treq_absolute(conn=None, content="content"): + r = treq(conn, content) + r.form_in = r.form_out = "absolute" + r.host = "address" + r.port = 22 + r.scheme = "http" + return r + def treq(conn=None, content="content"): if not conn: conn = tclient_conn() @@ -78,6 +86,10 @@ def terr(req=None): f.error.reply = controller.DummyReply() return f.error +def tflow_noreq(): + f = tflow() + f.request = None + return f def tflow(req=None): if not req: |