aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--netlib/tcp.py69
-rw-r--r--test/test_tcp.py34
2 files changed, 84 insertions, 19 deletions
diff --git a/netlib/tcp.py b/netlib/tcp.py
index 0fed7380..e1318435 100644
--- a/netlib/tcp.py
+++ b/netlib/tcp.py
@@ -39,10 +39,11 @@ class NetLibDisconnect(Exception): pass
class NetLibTimeout(Exception): pass
-class FileLike:
+class _FileLike:
BLOCKSIZE = 1024 * 32
def __init__(self, o):
self.o = o
+ self._log = None
def set_descriptor(self, o):
self.o = o
@@ -50,6 +51,37 @@ class FileLike:
def __getattr__(self, attr):
return getattr(self.o, attr)
+ def start_log(self):
+ """
+ Starts or resets the log.
+
+ This will store all bytes read or written.
+ """
+ self._log = []
+
+ def stop_log(self):
+ """
+ Stops the log.
+ """
+ self._log = None
+
+ def is_logging(self):
+ return self._log is not None
+
+ def get_log(self):
+ """
+ Returns the log as a string.
+ """
+ if not self.is_logging():
+ raise ValueError("Not logging!")
+ return "".join(self._log)
+
+ def add_log(self, v):
+ if self.is_logging():
+ self._log.append(v)
+
+
+class Writer(_FileLike):
def flush(self):
try:
if hasattr(self.o, "flush"):
@@ -57,6 +89,21 @@ class FileLike:
except socket.error, v:
raise NetLibDisconnect(str(v))
+ def write(self, v):
+ if v:
+ try:
+ if hasattr(self.o, "sendall"):
+ self.add_log(v)
+ return self.o.sendall(v)
+ else:
+ r = self.o.write(v)
+ self.add_log(v[:r])
+ return r
+ except (SSL.Error, socket.error):
+ raise NetLibDisconnect()
+
+
+class Reader(_FileLike):
def read(self, length):
"""
If length is None, we read until connection closes.
@@ -85,19 +132,9 @@ class FileLike:
result += data
if length != -1:
length -= len(data)
+ self.add_log(result)
return result
- def write(self, v):
- if v:
- try:
- if hasattr(self.o, "sendall"):
- return self.o.sendall(v)
- else:
- r = self.o.write(v)
- return r
- except (SSL.Error, socket.error):
- raise NetLibDisconnect()
-
def readline(self, size = None):
result = ''
bytes_read = 0
@@ -151,8 +188,8 @@ class TCPClient:
addr = socket.gethostbyname(self.host)
connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
connection.connect((addr, self.port))
- self.rfile = FileLike(connection.makefile('rb', self.rbufsize))
- self.wfile = FileLike(connection.makefile('wb', self.wbufsize))
+ self.rfile = Reader(connection.makefile('rb', self.rbufsize))
+ self.wfile = Writer(connection.makefile('wb', self.wbufsize))
except socket.error, err:
raise NetLibError('Error connecting to "%s": %s' % (self.host, err))
self.connection = connection
@@ -186,8 +223,8 @@ class BaseHandler:
wbufsize = -1
def __init__(self, connection, client_address, server):
self.connection = connection
- self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize))
- self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize))
+ self.rfile = Reader(self.connection.makefile('rb', self.rbufsize))
+ self.wfile = Writer(self.connection.makefile('wb', self.wbufsize))
self.client_address = client_address
self.server = server
diff --git a/test/test_tcp.py b/test/test_tcp.py
index 67c56a37..9d581939 100644
--- a/test/test_tcp.py
+++ b/test/test_tcp.py
@@ -228,8 +228,8 @@ class TestTCPClient:
class TestFileLike:
def test_wrap(self):
s = cStringIO.StringIO("foobar\nfoobar")
- s = tcp.FileLike(s)
s.flush()
+ s = tcp.Reader(s)
assert s.readline() == "foobar\n"
assert s.readline() == "foobar"
# Test __getattr__
@@ -237,11 +237,39 @@ class TestFileLike:
def test_limit(self):
s = cStringIO.StringIO("foobar\nfoobar")
- s = tcp.FileLike(s)
+ s = tcp.Reader(s)
assert s.readline(3) == "foo"
def test_limitless(self):
s = cStringIO.StringIO("f"*(50*1024))
- s = tcp.FileLike(s)
+ s = tcp.Reader(s)
ret = s.read(-1)
assert len(ret) == 50 * 1024
+
+ def test_readlog(self):
+ s = cStringIO.StringIO("foobar\nfoobar")
+ s = tcp.Reader(s)
+ assert not s.is_logging()
+ s.start_log()
+ assert s.is_logging()
+ s.readline()
+ assert s.get_log() == "foobar\n"
+ s.read(1)
+ assert s.get_log() == "foobar\nf"
+ s.start_log()
+ assert s.get_log() == ""
+ s.read(1)
+ assert s.get_log() == "o"
+ s.stop_log()
+ tutils.raises(ValueError, s.get_log)
+
+ def test_writelog(self):
+ s = cStringIO.StringIO()
+ s = tcp.Writer(s)
+ s.start_log()
+ assert s.is_logging()
+ s.write("x")
+ assert s.get_log() == "x"
+ s.write("x")
+ assert s.get_log() == "xx"
+