aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libmproxy/protocol/__init__.py8
-rw-r--r--libmproxy/protocol/http.py49
-rw-r--r--libmproxy/stateobject.py12
-rw-r--r--test/test_flow.py8
-rw-r--r--test/test_protocol_http.py86
-rw-r--r--test/test_script.py3
-rw-r--r--test/tutils.py12
7 files changed, 147 insertions, 31 deletions
diff --git a/libmproxy/protocol/__init__.py b/libmproxy/protocol/__init__.py
index 4c72ad48..580d693c 100644
--- a/libmproxy/protocol/__init__.py
+++ b/libmproxy/protocol/__init__.py
@@ -15,19 +15,19 @@ class ProtocolHandler(object):
self.c = c
"""@type: libmproxy.proxy.ConnectionHandler"""
- def handle_messages(self): # pragma: nocover
+ def handle_messages(self):
"""
This method gets called if a client connection has been made. Depending on the proxy settings,
a server connection might already exist as well.
"""
- raise NotImplementedError
+ raise NotImplementedError # pragma: nocover
- def handle_error(self, error): # pragma: nocover
+ def handle_error(self, error):
"""
This method gets called should there be an uncaught exception during the connection.
This might happen outside of handle_messages, e.g. if the initial SSL handshake fails in transparent mode.
"""
- raise error
+ raise error # pragma: nocover
class TemporaryServerChangeMixin(object):
diff --git a/libmproxy/protocol/http.py b/libmproxy/protocol/http.py
index 21624513..8dcc21d7 100644
--- a/libmproxy/protocol/http.py
+++ b/libmproxy/protocol/http.py
@@ -56,6 +56,7 @@ class HTTPMessage(stateobject.SimpleStateObject):
def __init__(self, httpversion, headers, content, timestamp_start=None, timestamp_end=None):
self.httpversion = httpversion
self.headers = headers
+ """@type: ODictCaseless"""
self.content = content
self.timestamp_start = timestamp_start
self.timestamp_end = timestamp_end
@@ -142,31 +143,31 @@ class HTTPMessage(stateobject.SimpleStateObject):
"""
Parse an HTTP message from a file stream
"""
- raise NotImplementedError
+ raise NotImplementedError # pragma: nocover
def _assemble_first_line(self):
"""
Returns the assembled request/response line
"""
- raise NotImplementedError
+ raise NotImplementedError # pragma: nocover
def _assemble_headers(self):
"""
Returns the assembled headers
"""
- raise NotImplementedError
+ raise NotImplementedError # pragma: nocover
def _assemble_head(self):
"""
Returns the assembled request/response line plus headers
"""
- raise NotImplementedError
+ raise NotImplementedError # pragma: nocover
def _assemble(self):
"""
Returns the assembled request/response
"""
- raise NotImplementedError
+ raise NotImplementedError # pragma: nocover
class HTTPRequest(HTTPMessage):
@@ -253,9 +254,15 @@ class HTTPRequest(HTTPMessage):
httpversion, host, port, scheme, method, path, headers, content, timestamp_start, timestamp_end \
= None, None, None, None, None, None, None, None, None, None
- rfile.reset_timestamps()
+ if hasattr(rfile, "reset_timestamps"):
+ rfile.reset_timestamps()
+
request_line = get_line(rfile)
- timestamp_start = rfile.first_byte_timestamp
+
+ if hasattr(rfile, "first_byte_timestamp"):
+ timestamp_start = rfile.first_byte_timestamp
+ else:
+ timestamp_start = utils.timestamp()
request_line_parts = http.parse_init(request_line)
if not request_line_parts:
@@ -597,14 +604,21 @@ class HTTPResponse(HTTPMessage):
Parse an HTTP response from a file stream
"""
if not include_content:
- raise NotImplementedError
+ raise NotImplementedError # pragma: nocover
+
+ if hasattr(rfile, "reset_timestamps"):
+ rfile.reset_timestamps()
- rfile.reset_timestamps()
httpversion, code, msg, headers, content = http.read_response(
rfile,
request_method,
body_size_limit)
- timestamp_start = rfile.first_byte_timestamp
+
+ if hasattr(rfile, "first_byte_timestamp"):
+ timestamp_start = rfile.first_byte_timestamp
+ else:
+ timestamp_start = utils.timestamp()
+
timestamp_end = utils.timestamp()
return HTTPResponse(httpversion, code, msg, headers, content, timestamp_start, timestamp_end)
@@ -661,7 +675,7 @@ class HTTPResponse(HTTPMessage):
# This can happen when the expires tag is invalid.
# reddit.com sends a an expires tag like this: "Thu, 31 Dec
# 2037 23:59:59 GMT", which is valid RFC 1123, but not
- # strictly correct according tot he cookie spec. Browsers
+ # strictly correct according to the cookie spec. Browsers
# appear to parse this tolerantly - maybe we should too.
# For now, we just ignore this.
del i["expires"]
@@ -824,7 +838,7 @@ class HttpAuthenticationError(Exception):
def __init__(self, auth_headers=None):
self.auth_headers = auth_headers
- def __str__(self): # pragma: nocover
+ def __str__(self):
return "HttpAuthenticationError"
@@ -907,9 +921,12 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
def handle_error(self, error, flow=None):
code, message, headers = None, None, None
if isinstance(error, HttpAuthenticationError):
- code, message, headers = 407, "Proxy Authentication Required", error.auth_headers
+ code = 407
+ message = "Proxy Authentication Required"
+ headers = error.auth_headers
elif isinstance(error, (http.HttpError, ProxyError)):
- code, message = error.code, error.msg
+ code = error.code
+ message = error.msg
elif isinstance(error, tcp.NetLibError):
code = 502
message = error.message or error.__class__
@@ -917,7 +934,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
if code:
err = "%s: %s" % (code, message)
else:
- err = message
+ err = error.__class__
self.c.log("error: %s" % err)
@@ -1010,7 +1027,7 @@ class HTTPHandler(ProtocolHandler, TemporaryServerChangeMixin):
self.ssl_upgrade() # raises ConnectionTypeChange exception
if self.c.mode == "regular":
- if request.form_in == "authority":
+ if request.form_in == "authority": # forward mode
pass
elif request.form_in == "absolute":
if request.scheme != "http":
diff --git a/libmproxy/stateobject.py b/libmproxy/stateobject.py
index 7d91519d..a752999d 100644
--- a/libmproxy/stateobject.py
+++ b/libmproxy/stateobject.py
@@ -1,13 +1,13 @@
class StateObject(object):
- def _get_state(self): # pragma: nocover
- raise NotImplementedError
+ def _get_state(self):
+ raise NotImplementedError # pragma: nocover
- def _load_state(self, state): # pragma: nocover
- raise NotImplementedError
+ def _load_state(self, state):
+ raise NotImplementedError # pragma: nocover
@classmethod
- def _from_state(cls, state): # pragma: nocover
- raise NotImplementedError
+ def _from_state(cls, state):
+ raise NotImplementedError # pragma: nocover
# Usually, this function roughly equals to the following code:
# f = cls()
# f._load_state(state)
diff --git a/test/test_flow.py b/test/test_flow.py
index 006dfe51..65e153ea 100644
--- a/test/test_flow.py
+++ b/test/test_flow.py
@@ -1047,7 +1047,7 @@ class TestResponse:
r = tutils.tresp()
r.headers["content-encoding"] = ["identity"]
r.content = "falafel"
- r.decode()
+ assert r.decode()
assert not r.headers["content-encoding"]
assert r.content == "falafel"
@@ -1064,10 +1064,14 @@ class TestResponse:
r.encode("gzip")
assert r.headers["content-encoding"] == ["gzip"]
assert r.content != "falafel"
- r.decode()
+ assert r.decode()
assert not r.headers["content-encoding"]
assert r.content == "falafel"
+ r.headers["content-encoding"] = ["gzip"]
+ assert not r.decode()
+ assert r.content == "falafel"
+
def test_header_size(self):
r = tutils.tresp()
result = len(r._assemble_headers())
diff --git a/test/test_protocol_http.py b/test/test_protocol_http.py
new file mode 100644
index 00000000..9e043049
--- /dev/null
+++ b/test/test_protocol_http.py
@@ -0,0 +1,86 @@
+from libmproxy import proxy # FIXME: Remove
+from libmproxy.protocol.http import *
+from cStringIO import StringIO
+import tutils
+
+
+def test_HttpAuthenticationError():
+ x = HttpAuthenticationError({"foo": "bar"})
+ assert str(x)
+ assert "foo" in x.auth_headers
+
+
+def test_stripped_chunked_encoding_no_content():
+ """
+ https://github.com/mitmproxy/mitmproxy/issues/186
+ """
+ r = tutils.tresp(content="")
+ r.headers["Transfer-Encoding"] = ["chunked"]
+ assert "Content-Length" in r._assemble_headers()
+
+ r = tutils.treq(content="")
+ r.headers["Transfer-Encoding"] = ["chunked"]
+ assert "Content-Length" in r._assemble_headers()
+
+
+class TestHTTPRequest:
+ def test_asterisk_form(self):
+ s = StringIO("OPTIONS * HTTP/1.1")
+ f = tutils.tflow_noreq()
+ f.request = HTTPRequest.from_stream(s)
+ assert f.request.form_in == "asterisk"
+ x = f.request._assemble()
+ assert f.request._assemble() == "OPTIONS * HTTP/1.1\r\nHost: address:22\r\n\r\n"
+
+ def test_origin_form(self):
+ s = StringIO("GET /foo\xff HTTP/1.1")
+ tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s)
+
+ def test_authority_form(self):
+ s = StringIO("CONNECT oops-no-port.com HTTP/1.1")
+ tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s)
+ s = StringIO("CONNECT address:22 HTTP/1.1")
+ r = HTTPRequest.from_stream(s)
+ assert r._assemble() == "CONNECT address:22 HTTP/1.1\r\nHost: address:22\r\n\r\n"
+
+
+ def test_absolute_form(self):
+ s = StringIO("GET oops-no-protocol.com HTTP/1.1")
+ tutils.raises("Bad HTTP request line", HTTPRequest.from_stream, s)
+ s = StringIO("GET http://address:22/ HTTP/1.1")
+ r = HTTPRequest.from_stream(s)
+ assert r._assemble() == "GET http://address:22/ HTTP/1.1\r\nHost: address:22\r\n\r\n"
+
+ def test_assemble_unknown_form(self):
+ r = tutils.treq()
+ tutils.raises("Invalid request form", r._assemble, "antiauthority")
+
+
+ def test_set_url(self):
+ r = tutils.treq_absolute()
+ r.set_url("https://otheraddress:42/ORLY")
+ assert r.scheme == "https"
+ assert r.host == "otheraddress"
+ assert r.port == 42
+ assert r.path == "/ORLY"
+
+
+class TestHTTPResponse:
+ def test_read_from_stringio(self):
+ _s = "HTTP/1.1 200 OK\r\n" \
+ "Content-Length: 7\r\n" \
+ "\r\n"\
+ "content\r\n" \
+ "HTTP/1.1 204 OK\r\n" \
+ "\r\n"
+ s = StringIO(_s)
+ r = HTTPResponse.from_stream(s, "GET")
+ assert r.code == 200
+ assert r.content == "content"
+ assert HTTPResponse.from_stream(s, "GET").code == 204
+
+ s = StringIO(_s)
+ r = HTTPResponse.from_stream(s, "HEAD")
+ assert r.code == 200
+ assert r.content == ""
+ tutils.raises("Invalid server response: 'content", HTTPResponse.from_stream, s, "GET") \ No newline at end of file
diff --git a/test/test_script.py b/test/test_script.py
index 2e48081b..2999a910 100644
--- a/test/test_script.py
+++ b/test/test_script.py
@@ -75,9 +75,6 @@ class TestScript:
# Two instantiations
assert m.call_count == 2
assert (time.time() - t_start) < 0.09
- time.sleep(0.3 - (time.time() - t_start))
- # Plus two invocations
- assert m.call_count == 4
def test_concurrent2(self):
s = flow.State()
diff --git a/test/tutils.py b/test/tutils.py
index ad2960d9..75fb7c0b 100644
--- a/test/tutils.py
+++ b/test/tutils.py
@@ -39,6 +39,14 @@ def tserver_conn():
return c
+def treq_absolute(conn=None, content="content"):
+ r = treq(conn, content)
+ r.form_in = r.form_out = "absolute"
+ r.host = "address"
+ r.port = 22
+ r.scheme = "http"
+ return r
+
def treq(conn=None, content="content"):
if not conn:
conn = tclient_conn()
@@ -78,6 +86,10 @@ def terr(req=None):
f.error.reply = controller.DummyReply()
return f.error
+def tflow_noreq():
+ f = tflow()
+ f.request = None
+ return f
def tflow(req=None):
if not req: