aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--libmproxy/flow.py157
-rw-r--r--libmproxy/netlib.py182
-rw-r--r--libmproxy/protocol.py220
-rw-r--r--libmproxy/proxy.py29
-rw-r--r--libmproxy/utils.py5
-rw-r--r--setup.py2
-rw-r--r--test/test_netlib.py93
-rw-r--r--test/test_protocol.py163
9 files changed, 23 insertions, 831 deletions
diff --git a/.gitignore b/.gitignore
index 78b1cdb5..b88b179b 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,6 +8,5 @@ MANIFEST
*.swo
mitmproxyc
mitmdumpc
-mitmplaybackc
-mitmrecordc
+netlib
.coverage
diff --git a/libmproxy/flow.py b/libmproxy/flow.py
index a737057e..f9a9a75d 100644
--- a/libmproxy/flow.py
+++ b/libmproxy/flow.py
@@ -21,11 +21,15 @@ import hashlib, Cookie, cookielib, copy, re, urlparse
import time
import tnetstring, filt, script, utils, encoding, proxy
from email.utils import parsedate_tz, formatdate, mktime_tz
-import controller, version, certutils, protocol
+from netlib import odict, protocol
+import controller, version, certutils
HDR_FORM_URLENCODED = "application/x-www-form-urlencoded"
CONTENT_MISSING = 0
+ODict = odict.ODict
+ODictCaseless = odict.ODictCaseless
+
class ReplaceHooks:
def __init__(self):
@@ -117,157 +121,6 @@ class ScriptContext:
self._master.replay_request(f)
-class ODict:
- """
- A dictionary-like object for managing ordered (key, value) data.
- """
- def __init__(self, lst=None):
- self.lst = lst or []
-
- def _kconv(self, s):
- return s
-
- def __eq__(self, other):
- return self.lst == other.lst
-
- def __getitem__(self, k):
- """
- Returns a list of values matching key.
- """
- ret = []
- k = self._kconv(k)
- for i in self.lst:
- if self._kconv(i[0]) == k:
- ret.append(i[1])
- return ret
-
- def _filter_lst(self, k, lst):
- k = self._kconv(k)
- new = []
- for i in lst:
- if self._kconv(i[0]) != k:
- new.append(i)
- return new
-
- def __len__(self):
- """
- Total number of (key, value) pairs.
- """
- return len(self.lst)
-
- def __setitem__(self, k, valuelist):
- """
- Sets the values for key k. If there are existing values for this
- key, they are cleared.
- """
- if isinstance(valuelist, basestring):
- raise ValueError("ODict valuelist should be lists.")
- new = self._filter_lst(k, self.lst)
- for i in valuelist:
- new.append([k, i])
- self.lst = new
-
- def __delitem__(self, k):
- """
- Delete all items matching k.
- """
- self.lst = self._filter_lst(k, self.lst)
-
- def __contains__(self, k):
- for i in self.lst:
- if self._kconv(i[0]) == self._kconv(k):
- return True
- return False
-
- def add(self, key, value):
- self.lst.append([key, str(value)])
-
- def get(self, k, d=None):
- if k in self:
- return self[k]
- else:
- return d
-
- def items(self):
- return self.lst[:]
-
- def _get_state(self):
- return [tuple(i) for i in self.lst]
-
- @classmethod
- def _from_state(klass, state):
- return klass([list(i) for i in state])
-
- def copy(self):
- """
- Returns a copy of this object.
- """
- lst = copy.deepcopy(self.lst)
- return self.__class__(lst)
-
- def __repr__(self):
- elements = []
- for itm in self.lst:
- elements.append(itm[0] + ": " + itm[1])
- elements.append("")
- return "\r\n".join(elements)
-
- def in_any(self, key, value, caseless=False):
- """
- Do any of the values matching key contain value?
-
- If caseless is true, value comparison is case-insensitive.
- """
- if caseless:
- value = value.lower()
- for i in self[key]:
- if caseless:
- i = i.lower()
- if value in i:
- return True
- return False
-
- def match_re(self, expr):
- """
- Match the regular expression against each (key, value) pair. For
- each pair a string of the following format is matched against:
-
- "key: value"
- """
- for k, v in self.lst:
- s = "%s: %s"%(k, v)
- if re.search(expr, s):
- return True
- return False
-
- def replace(self, pattern, repl, *args, **kwargs):
- """
- Replaces a regular expression pattern with repl in both keys and
- values. Encoded content will be decoded before replacement, and
- re-encoded afterwards.
-
- Returns the number of replacements made.
- """
- nlst, count = [], 0
- for i in self.lst:
- k, c = utils.safe_subn(pattern, repl, i[0], *args, **kwargs)
- count += c
- v, c = utils.safe_subn(pattern, repl, i[1], *args, **kwargs)
- count += c
- nlst.append([k, v])
- self.lst = nlst
- return count
-
-
-class ODictCaseless(ODict):
- """
- A variant of ODict with "caseless" keys. This version _preserves_ key
- case, but does not consider case when setting or getting items.
- """
- def _kconv(self, s):
- return s.lower()
-
-
class decoded(object):
"""
diff --git a/libmproxy/netlib.py b/libmproxy/netlib.py
deleted file mode 100644
index 08ccba09..00000000
--- a/libmproxy/netlib.py
+++ /dev/null
@@ -1,182 +0,0 @@
-import select, socket, threading, traceback, sys
-from OpenSSL import SSL
-
-
-class NetLibError(Exception): pass
-
-
-class FileLike:
- def __init__(self, o):
- self.o = o
-
- def __getattr__(self, attr):
- return getattr(self.o, attr)
-
- def flush(self):
- pass
-
- def read(self, length):
- result = ''
- while len(result) < length:
- try:
- data = self.o.read(length)
- except SSL.ZeroReturnError:
- break
- if not data:
- break
- result += data
- return result
-
- def write(self, v):
- self.o.sendall(v)
-
- def readline(self, size = None):
- result = ''
- bytes_read = 0
- while True:
- if size is not None and bytes_read >= size:
- break
- ch = self.read(1)
- bytes_read += 1
- if not ch:
- break
- else:
- result += ch
- if ch == '\n':
- break
- return result
-
-
-class TCPClient:
- def __init__(self, ssl, host, port, clientcert):
- self.ssl, self.host, self.port, self.clientcert = ssl, host, port, clientcert
- self.connection, self.rfile, self.wfile = None, None, None
- self.cert = None
- self.connect()
-
- def connect(self):
- try:
- addr = socket.gethostbyname(self.host)
- server = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- if self.ssl:
- context = SSL.Context(SSL.SSLv23_METHOD)
- if self.clientcert:
- context.use_certificate_file(self.clientcert)
- server = SSL.Connection(context, server)
- server.connect((addr, self.port))
- if self.ssl:
- self.cert = server.get_peer_certificate()
- self.rfile, self.wfile = FileLike(server), FileLike(server)
- else:
- self.rfile, self.wfile = server.makefile('rb'), server.makefile('wb')
- except socket.error, err:
- raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
- self.connection = server
-
-
-class BaseHandler:
- rbufsize = -1
- wbufsize = 0
- def __init__(self, connection, client_address, server):
- self.connection = connection
- self.rfile = self.connection.makefile('rb', self.rbufsize)
- self.wfile = self.connection.makefile('wb', self.wbufsize)
-
- self.client_address = client_address
- self.server = server
- self.handle()
- self.finish()
-
- def convert_to_ssl(self, cert, key):
- ctx = SSL.Context(SSL.SSLv23_METHOD)
- ctx.use_privatekey_file(key)
- ctx.use_certificate_file(cert)
- self.connection = SSL.Connection(ctx, self.connection)
- self.connection.set_accept_state()
- self.rfile = FileLike(self.connection)
- self.wfile = FileLike(self.connection)
-
- def finish(self):
- try:
- if not getattr(self.wfile, "closed", False):
- self.wfile.flush()
- self.connection.close()
- self.wfile.close()
- self.rfile.close()
- except IOError: # pragma: no cover
- pass
-
- def handle(self): # pragma: no cover
- raise NotImplementedError
-
-
-class TCPServer:
- request_queue_size = 20
- def __init__(self, server_address):
- self.server_address = server_address
- self.__is_shut_down = threading.Event()
- self.__shutdown_request = False
- self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.socket.bind(self.server_address)
- self.server_address = self.socket.getsockname()
- self.socket.listen(self.request_queue_size)
- self.port = self.socket.getsockname()[1]
-
- def request_thread(self, request, client_address):
- try:
- self.handle_connection(request, client_address)
- request.close()
- except:
- self.handle_error(request, client_address)
- request.close()
-
- def serve_forever(self, poll_interval=0.5):
- self.__is_shut_down.clear()
- try:
- while not self.__shutdown_request:
- r, w, e = select.select([self.socket], [], [], poll_interval)
- if self.socket in r:
- try:
- request, client_address = self.socket.accept()
- except socket.error:
- return
- try:
- t = threading.Thread(
- target = self.request_thread,
- args = (request, client_address)
- )
- t.setDaemon(1)
- t.start()
- except:
- self.handle_error(request, client_address)
- request.close()
- finally:
- self.__shutdown_request = False
- self.__is_shut_down.set()
-
- def shutdown(self):
- self.__shutdown_request = True
- self.__is_shut_down.wait()
- self.handle_shutdown()
-
- def handle_error(self, request, client_address, fp=sys.stderr):
- """
- Called when handle_connection raises an exception.
- """
- print >> fp, '-'*40
- print >> fp, "Error processing of request from %s:%s"%client_address
- print >> fp, traceback.format_exc()
- print >> fp, '-'*40
-
- def handle_connection(self, request, client_address): # pragma: no cover
- """
- Called after client connection.
- """
- raise NotImplementedError
-
- def handle_shutdown(self):
- """
- Called after server shutdown.
- """
- pass
diff --git a/libmproxy/protocol.py b/libmproxy/protocol.py
deleted file mode 100644
index 547bff9e..00000000
--- a/libmproxy/protocol.py
+++ /dev/null
@@ -1,220 +0,0 @@
-import string, urlparse
-
-class ProtocolError(Exception):
- def __init__(self, code, msg):
- self.code, self.msg = code, msg
-
- def __str__(self):
- return "ProtocolError(%s, %s)"%(self.code, self.msg)
-
-
-def parse_url(url):
- """
- Returns a (scheme, host, port, path) tuple, or None on error.
- """
- scheme, netloc, path, params, query, fragment = urlparse.urlparse(url)
- if not scheme:
- return None
- if ':' in netloc:
- host, port = string.rsplit(netloc, ':', maxsplit=1)
- try:
- port = int(port)
- except ValueError:
- return None
- else:
- host = netloc
- if scheme == "https":
- port = 443
- else:
- port = 80
- path = urlparse.urlunparse(('', '', path, params, query, fragment))
- if not path.startswith("/"):
- path = "/" + path
- return scheme, host, port, path
-
-
-def read_headers(fp):
- """
- Read a set of headers from a file pointer. Stop once a blank line
- is reached. Return a ODictCaseless object.
- """
- ret = []
- name = ''
- while 1:
- line = fp.readline()
- if not line or line == '\r\n' or line == '\n':
- break
- if line[0] in ' \t':
- # continued header
- ret[-1][1] = ret[-1][1] + '\r\n ' + line.strip()
- else:
- i = line.find(':')
- # We're being liberal in what we accept, here.
- if i > 0:
- name = line[:i]
- value = line[i+1:].strip()
- ret.append([name, value])
- return ret
-
-
-def read_chunked(fp, limit):
- content = ""
- total = 0
- while 1:
- line = fp.readline(128)
- if line == "":
- raise IOError("Connection closed")
- if line == '\r\n' or line == '\n':
- continue
- try:
- length = int(line,16)
- except ValueError:
- # FIXME: Not strictly correct - this could be from the server, in which
- # case we should send a 502.
- raise ProtocolError(400, "Invalid chunked encoding length: %s"%line)
- if not length:
- break
- total += length
- if limit is not None and total > limit:
- msg = "HTTP Body too large."\
- " Limit is %s, chunked content length was at least %s"%(limit, total)
- raise ProtocolError(509, msg)
- content += fp.read(length)
- line = fp.readline(5)
- if line != '\r\n':
- raise IOError("Malformed chunked body")
- while 1:
- line = fp.readline()
- if line == "":
- raise IOError("Connection closed")
- if line == '\r\n' or line == '\n':
- break
- return content
-
-
-def has_chunked_encoding(headers):
- for i in headers["transfer-encoding"]:
- for j in i.split(","):
- if j.lower() == "chunked":
- return True
- return False
-
-
-def read_http_body(rfile, headers, all, limit):
- if has_chunked_encoding(headers):
- content = read_chunked(rfile, limit)
- elif "content-length" in headers:
- try:
- l = int(headers["content-length"][0])
- except ValueError:
- # FIXME: Not strictly correct - this could be from the server, in which
- # case we should send a 502.
- raise ProtocolError(400, "Invalid content-length header: %s"%headers["content-length"])
- if limit is not None and l > limit:
- raise ProtocolError(509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l))
- content = rfile.read(l)
- elif all:
- content = rfile.read(limit if limit else None)
- else:
- content = ""
- return content
-
-
-def parse_http_protocol(s):
- if not s.startswith("HTTP/"):
- return None
- major, minor = s.split('/')[1].split('.')
- major = int(major)
- minor = int(minor)
- return major, minor
-
-
-def parse_init_connect(line):
- try:
- method, url, protocol = string.split(line)
- except ValueError:
- return None
- if method != 'CONNECT':
- return None
- try:
- host, port = url.split(":")
- except ValueError:
- return None
- port = int(port)
- httpversion = parse_http_protocol(protocol)
- if not httpversion:
- return None
- return host, port, httpversion
-
-
-def parse_init_proxy(line):
- try:
- method, url, protocol = string.split(line)
- except ValueError:
- return None
- parts = parse_url(url)
- if not parts:
- return None
- scheme, host, port, path = parts
- httpversion = parse_http_protocol(protocol)
- if not httpversion:
- return None
- return method, scheme, host, port, path, httpversion
-
-
-def parse_init_http(line):
- """
- Returns (method, url, httpversion)
- """
- try:
- method, url, protocol = string.split(line)
- except ValueError:
- return None
- if not (url.startswith("/") or url == "*"):
- return None
- httpversion = parse_http_protocol(protocol)
- if not httpversion:
- return None
- return method, url, httpversion
-
-
-def request_connection_close(httpversion, headers):
- """
- Checks the request to see if the client connection should be closed.
- """
- if "connection" in headers:
- for value in ",".join(headers['connection']).split(","):
- value = value.strip()
- if value == "close":
- return True
- elif value == "keep-alive":
- return False
- # HTTP 1.1 connections are assumed to be persistent
- if httpversion == (1, 1):
- return False
- return True
-
-
-def response_connection_close(httpversion, headers):
- """
- Checks the response to see if the client connection should be closed.
- """
- if request_connection_close(httpversion, headers):
- return True
- elif not has_chunked_encoding(headers) and "content-length" in headers:
- return True
- return False
-
-
-def read_http_body_request(rfile, wfile, headers, httpversion, limit):
- if "expect" in headers:
- # FIXME: Should be forwarded upstream
- expect = ",".join(headers['expect'])
- if expect == "100-continue" and httpversion >= (1, 1):
- wfile.write('HTTP/1.1 100 Continue\r\n')
- wfile.write('Proxy-agent: %s\r\n'%version.NAMEVERSION)
- wfile.write('\r\n')
- del headers['expect']
- return read_http_body(rfile, headers, False, limit)
-
-
diff --git a/libmproxy/proxy.py b/libmproxy/proxy.py
index 58ab7a58..04734fcb 100644
--- a/libmproxy/proxy.py
+++ b/libmproxy/proxy.py
@@ -15,8 +15,9 @@
import sys, os, string, socket, time
import shutil, tempfile, threading
import optparse, SocketServer
-import utils, flow, certutils, version, wsgi, netlib, protocol
from OpenSSL import SSL
+from netlib import odict, tcp, protocol
+import utils, flow, certutils, version, wsgi
class ProxyError(Exception):
@@ -56,18 +57,18 @@ class RequestReplayThread(threading.Thread):
except (ProxyError, protocol.ProtocolError), v:
err = flow.Error(self.flow.request, v.msg)
err._send(self.masterq)
- except netlib.NetLibError, v:
+ except tcp.NetLibError, v:
raise ProxyError(502, v)
-class ServerConnection(netlib.TCPClient):
+class ServerConnection(tcp.TCPClient):
def __init__(self, config, scheme, host, port):
clientcert = None
if config.clientcerts:
path = os.path.join(config.clientcerts, self.host) + ".pem"
if os.path.exists(clientcert):
clientcert = path
- netlib.TCPClient.__init__(
+ tcp.TCPClient.__init__(
self,
True if scheme == "https" else False,
host,
@@ -107,7 +108,7 @@ class ServerConnection(netlib.TCPClient):
code = int(code)
except ValueError:
raise ProxyError(502, "Invalid server response: %s."%line)
- headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
+ headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
if code >= 100 and code <= 199:
return self.read_response()
if request.method == "HEAD" or code == 204 or code == 304:
@@ -125,13 +126,13 @@ class ServerConnection(netlib.TCPClient):
pass
-class ProxyHandler(netlib.BaseHandler):
+class ProxyHandler(tcp.BaseHandler):
def __init__(self, config, connection, client_address, server, q):
self.mqueue = q
self.config = config
self.server_conn = None
self.proxy_connect_state = None
- netlib.BaseHandler.__init__(self, connection, client_address, server)
+ tcp.BaseHandler.__init__(self, connection, client_address, server)
def handle(self):
cc = flow.ClientConnect(self.client_address)
@@ -150,7 +151,7 @@ class ProxyHandler(netlib.BaseHandler):
if not self.server_conn:
try:
self.server_conn = ServerConnection(self.config, scheme, host, port)
- except netlib.NetLibError, v:
+ except tcp.NetLibError, v:
raise ProxyError(502, v)
def handle_request(self, cc):
@@ -243,7 +244,7 @@ class ProxyHandler(netlib.BaseHandler):
else:
scheme = "http"
method, path, httpversion = protocol.parse_init_http(line)
- headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
+ headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
content = protocol.read_http_body_request(
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
)
@@ -251,7 +252,7 @@ class ProxyHandler(netlib.BaseHandler):
elif self.config.reverse_proxy:
scheme, host, port = self.config.reverse_proxy
method, path, httpversion = protocol.parse_init_http(line)
- headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
+ headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
content = protocol.read_http_body_request(
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
)
@@ -278,14 +279,14 @@ class ProxyHandler(netlib.BaseHandler):
if self.proxy_connect_state:
host, port, httpversion = self.proxy_connect_state
method, path, httpversion = protocol.parse_init_http(line)
- headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
+ headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
content = protocol.read_http_body_request(
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
)
return flow.Request(client_conn, httpversion, host, port, "https", method, path, headers, content)
else:
method, scheme, host, port, path, httpversion = protocol.parse_init_proxy(line)
- headers = flow.ODictCaseless(protocol.read_headers(self.rfile))
+ headers = odict.ODictCaseless(protocol.read_headers(self.rfile))
content = protocol.read_http_body_request(
self.rfile, self.wfile, headers, httpversion, self.config.body_size_limit
)
@@ -317,7 +318,7 @@ class ProxyHandler(netlib.BaseHandler):
class ProxyServerError(Exception): pass
-class ProxyServer(netlib.TCPServer):
+class ProxyServer(tcp.TCPServer):
allow_reuse_address = True
bound = True
def __init__(self, config, port, address=''):
@@ -326,7 +327,7 @@ class ProxyServer(netlib.TCPServer):
"""
self.config, self.port, self.address = config, port, address
try:
- netlib.TCPServer.__init__(self, (address, port))
+ tcp.TCPServer.__init__(self, (address, port))
except socket.error, v:
raise ProxyServerError('Error starting proxy server: ' + v.strerror)
self.masterq = None
diff --git a/libmproxy/utils.py b/libmproxy/utils.py
index 989bb695..35c7a878 100644
--- a/libmproxy/utils.py
+++ b/libmproxy/utils.py
@@ -15,7 +15,7 @@
import os, datetime, urlparse, string, urllib, re
import time, functools, cgi
import json
-import protocol
+from netlib import protocol
def timestamp():
"""
@@ -294,6 +294,3 @@ def safe_subn(pattern, repl, target, *args, **kwargs):
need a better solution that is aware of the actual content ecoding.
"""
return re.subn(str(pattern), str(repl), target, *args, **kwargs)
-
-
-
diff --git a/setup.py b/setup.py
index 4070eb1b..88a39f38 100644
--- a/setup.py
+++ b/setup.py
@@ -92,5 +92,5 @@ setup(
"Topic :: Internet :: Proxy Servers",
"Topic :: Software Development :: Testing"
],
- install_requires=['urwid>=1.0', 'pyasn1>0.1.2', 'pyopenssl>=0.12', "PIL", "lxml"],
+ install_requires=["netlib", "urwid>=1.0", "pyasn1>0.1.2", "pyopenssl>=0.12", "PIL", "lxml"],
)
diff --git a/test/test_netlib.py b/test/test_netlib.py
deleted file mode 100644
index 19902d17..00000000
--- a/test/test_netlib.py
+++ /dev/null
@@ -1,93 +0,0 @@
-import cStringIO, threading, Queue
-from libmproxy import netlib
-import tutils
-
-class ServerThread(threading.Thread):
- def __init__(self, server):
- self.server = server
- threading.Thread.__init__(self)
-
- def run(self):
- self.server.serve_forever()
-
- def shutdown(self):
- self.server.shutdown()
-
-
-class ServerTestBase:
- @classmethod
- def setupAll(cls):
- cls.server = ServerThread(cls.makeserver())
- cls.server.start()
-
- @classmethod
- def teardownAll(cls):
- cls.server.shutdown()
-
-
-class THandler(netlib.BaseHandler):
- def handle(self):
- v = self.rfile.readline()
- if v.startswith("echo"):
- self.wfile.write(v)
- elif v.startswith("error"):
- raise ValueError("Testing an error.")
- self.wfile.flush()
-
-
-class TServer(netlib.TCPServer):
- def __init__(self, addr, q):
- netlib.TCPServer.__init__(self, addr)
- self.q = q
-
- def handle_connection(self, request, client_address):
- THandler(request, client_address, self)
-
- def handle_error(self, request, client_address):
- s = cStringIO.StringIO()
- netlib.TCPServer.handle_error(self, request, client_address, s)
- self.q.put(s.getvalue())
-
-
-class TestServer(ServerTestBase):
- @classmethod
- def makeserver(cls):
- cls.q = Queue.Queue()
- s = TServer(("127.0.0.1", 0), cls.q)
- cls.port = s.port
- return s
-
- def test_echo(self):
- testval = "echo!\n"
- c = netlib.TCPClient(False, "127.0.0.1", self.port, None)
- c.wfile.write(testval)
- c.wfile.flush()
- assert c.rfile.readline() == testval
-
- def test_error(self):
- testval = "error!\n"
- c = netlib.TCPClient(False, "127.0.0.1", self.port, None)
- c.wfile.write(testval)
- c.wfile.flush()
- assert "Testing an error" in self.q.get()
-
-
-class TestTCPClient:
- def test_conerr(self):
- tutils.raises(netlib.NetLibError, netlib.TCPClient, False, "127.0.0.1", 0, None)
-
-
-class TestFileLike:
- def test_wrap(self):
- s = cStringIO.StringIO("foobar\nfoobar")
- s = netlib.FileLike(s)
- s.flush()
- assert s.readline() == "foobar\n"
- assert s.readline() == "foobar"
- # Test __getattr__
- assert s.isatty
-
- def test_limit(self):
- s = cStringIO.StringIO("foobar\nfoobar")
- s = netlib.FileLike(s)
- assert s.readline(3) == "foo"
diff --git a/test/test_protocol.py b/test/test_protocol.py
deleted file mode 100644
index 81b5fefb..00000000
--- a/test/test_protocol.py
+++ /dev/null
@@ -1,163 +0,0 @@
-import cStringIO, textwrap
-from libmproxy import protocol, flow
-import tutils
-
-def test_has_chunked_encoding():
- h = flow.ODictCaseless()
- assert not protocol.has_chunked_encoding(h)
- h["transfer-encoding"] = ["chunked"]
- assert protocol.has_chunked_encoding(h)
-
-
-def test_read_chunked():
- s = cStringIO.StringIO("1\r\na\r\n0\r\n")
- tutils.raises(IOError, protocol.read_chunked, s, None)
-
- s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n")
- assert protocol.read_chunked(s, None) == "a"
-
- s = cStringIO.StringIO("\r\n")
- tutils.raises(IOError, protocol.read_chunked, s, None)
-
- s = cStringIO.StringIO("1\r\nfoo")
- tutils.raises(IOError, protocol.read_chunked, s, None)
-
- s = cStringIO.StringIO("foo\r\nfoo")
- tutils.raises(protocol.ProtocolError, protocol.read_chunked, s, None)
-
-
-def test_request_connection_close():
- h = flow.ODictCaseless()
- assert protocol.request_connection_close((1, 0), h)
- assert not protocol.request_connection_close((1, 1), h)
-
- h["connection"] = ["keep-alive"]
- assert not protocol.request_connection_close((1, 1), h)
-
-
-def test_read_http_body():
- h = flow.ODict()
- s = cStringIO.StringIO("testing")
- assert protocol.read_http_body(s, h, False, None) == ""
-
- h["content-length"] = ["foo"]
- s = cStringIO.StringIO("testing")
- tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, None)
-
- h["content-length"] = [5]
- s = cStringIO.StringIO("testing")
- assert len(protocol.read_http_body(s, h, False, None)) == 5
- s = cStringIO.StringIO("testing")
- tutils.raises(protocol.ProtocolError, protocol.read_http_body, s, h, False, 4)
-
- h = flow.ODict()
- s = cStringIO.StringIO("testing")
- assert len(protocol.read_http_body(s, h, True, 4)) == 4
- s = cStringIO.StringIO("testing")
- assert len(protocol.read_http_body(s, h, True, 100)) == 7
-
-def test_parse_http_protocol():
- assert protocol.parse_http_protocol("HTTP/1.1") == (1, 1)
- assert protocol.parse_http_protocol("HTTP/0.0") == (0, 0)
- assert not protocol.parse_http_protocol("foo/0.0")
-
-
-def test_parse_init_connect():
- assert protocol.parse_init_connect("CONNECT host.com:443 HTTP/1.0")
- assert not protocol.parse_init_connect("bogus")
- assert not protocol.parse_init_connect("GET host.com:443 HTTP/1.0")
- assert not protocol.parse_init_connect("CONNECT host.com443 HTTP/1.0")
- assert not protocol.parse_init_connect("CONNECT host.com:443 foo/1.0")
-
-
-def test_prase_init_proxy():
- u = "GET http://foo.com:8888/test HTTP/1.1"
- m, s, h, po, pa, httpversion = protocol.parse_init_proxy(u)
- assert m == "GET"
- assert s == "http"
- assert h == "foo.com"
- assert po == 8888
- assert pa == "/test"
- assert httpversion == (1, 1)
-
- assert not protocol.parse_init_proxy("invalid")
- assert not protocol.parse_init_proxy("GET invalid HTTP/1.1")
- assert not protocol.parse_init_proxy("GET http://foo.com:8888/test foo/1.1")
-
-
-def test_parse_init_http():
- u = "GET /test HTTP/1.1"
- m, u, httpversion= protocol.parse_init_http(u)
- assert m == "GET"
- assert u == "/test"
- assert httpversion == (1, 1)
-
- assert not protocol.parse_init_http("invalid")
- assert not protocol.parse_init_http("GET invalid HTTP/1.1")
- assert not protocol.parse_init_http("GET /test foo/1.1")
-
-
-class TestReadHeaders:
- def test_read_simple(self):
- data = """
- Header: one
- Header2: two
- \r\n
- """
- data = textwrap.dedent(data)
- data = data.strip()
- s = cStringIO.StringIO(data)
- h = protocol.read_headers(s)
- assert h == [["Header", "one"], ["Header2", "two"]]
-
- def test_read_multi(self):
- data = """
- Header: one
- Header: two
- \r\n
- """
- data = textwrap.dedent(data)
- data = data.strip()
- s = cStringIO.StringIO(data)
- h = protocol.read_headers(s)
- assert h == [["Header", "one"], ["Header", "two"]]
-
- def test_read_continued(self):
- data = """
- Header: one
- \ttwo
- Header2: three
- \r\n
- """
- data = textwrap.dedent(data)
- data = data.strip()
- s = cStringIO.StringIO(data)
- h = protocol.read_headers(s)
- assert h == [["Header", "one\r\n two"], ["Header2", "three"]]
-
-
-def test_parse_url():
- assert not protocol.parse_url("")
-
- u = "http://foo.com:8888/test"
- s, h, po, pa = protocol.parse_url(u)
- assert s == "http"
- assert h == "foo.com"
- assert po == 8888
- assert pa == "/test"
-
- s, h, po, pa = protocol.parse_url("http://foo/bar")
- assert s == "http"
- assert h == "foo"
- assert po == 80
- assert pa == "/bar"
-
- s, h, po, pa = protocol.parse_url("http://foo")
- assert pa == "/"
-
- s, h, po, pa = protocol.parse_url("https://foo")
- assert po == 443
-
- assert not protocol.parse_url("https://foo:bar")
- assert not protocol.parse_url("https://foo:")
-