diff options
| author | Aldo Cortesi <aldo@nullcube.com> | 2012-07-21 16:10:54 +1200 | 
|---|---|---|
| committer | Aldo Cortesi <aldo@nullcube.com> | 2012-07-21 16:10:54 +1200 | 
| commit | 2387d2e8ed7d94e42b1ac02a4ea73f54e4c63ab8 (patch) | |
| tree | b6516481e5c125b59def209627453de13abb2889 | |
| parent | ba53d2e4caa34df883a2cd6322d607426c97201b (diff) | |
| download | mitmproxy-2387d2e8ed7d94e42b1ac02a4ea73f54e4c63ab8.tar.gz mitmproxy-2387d2e8ed7d94e42b1ac02a4ea73f54e4c63ab8.tar.bz2 mitmproxy-2387d2e8ed7d94e42b1ac02a4ea73f54e4c63ab8.zip | |
Timeout for TCP clients.
| -rw-r--r-- | netlib/tcp.py | 36 | ||||
| -rw-r--r-- | test/test_tcp.py | 33 | 
2 files changed, 60 insertions, 9 deletions
| diff --git a/netlib/tcp.py b/netlib/tcp.py index 3aee4c74..8771e789 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -1,4 +1,4 @@ -import select, socket, threading, traceback, sys +import select, socket, threading, traceback, sys, time  from OpenSSL import SSL  import certutils @@ -35,8 +35,8 @@ OP_TLS_ROLLBACK_BUG = SSL.OP_TLS_ROLLBACK_BUG  class NetLibError(Exception): pass -  class NetLibDisconnect(Exception): pass +class NetLibTimeout(Exception): pass  class FileLike: @@ -47,15 +47,25 @@ class FileLike:          return getattr(self.o, attr)      def flush(self): -        pass +        if hasattr(self.o, "flush"): +            self.o.flush()      def read(self, length):          result = '' +        start = time.time()          while length > 0:              try:                  data = self.o.read(length)              except (SSL.ZeroReturnError, SSL.SysCallError):                  break +            except SSL.WantReadError: +                if (time.time() - start) < self.o.gettimeout(): +                    time.sleep(0.1) +                    continue +                else: +                    raise NetLibTimeout +            except socket.timeout: +                raise NetLibTimeout              if not data:                  break              result += data @@ -65,7 +75,11 @@ class FileLike:      def write(self, v):          if v:              try: -                return self.o.sendall(v) +                if hasattr(self.o, "sendall"): +                    return self.o.sendall(v) +                else: +                    r = self.o.write(v) +                    return r              except SSL.Error:                  raise NetLibDisconnect() @@ -119,12 +133,18 @@ class TCPClient:              addr = socket.gethostbyname(self.host)              connection = socket.socket(socket.AF_INET, socket.SOCK_STREAM)              connection.connect((addr, self.port)) -            self.rfile = connection.makefile('rb', self.rbufsize) -            self.wfile = connection.makefile('wb', self.wbufsize) +            self.rfile = FileLike(connection.makefile('rb', self.rbufsize)) +            self.wfile = FileLike(connection.makefile('wb', self.wbufsize))          except socket.error, err:              raise NetLibError('Error connecting to "%s": %s' % (self.host, err))          self.connection = connection +    def settimeout(self, n): +        self.connection.settimeout(n) + +    def gettimeout(self): +        self.connection.gettimeout() +      def close(self):          """              Does a hard close of the socket, i.e. a shutdown, followed by a close. @@ -148,8 +168,8 @@ class BaseHandler:      wbufsize = -1      def __init__(self, connection, client_address, server):          self.connection = connection -        self.rfile = self.connection.makefile('rb', self.rbufsize) -        self.wfile = self.connection.makefile('wb', self.wbufsize) +        self.rfile = FileLike(self.connection.makefile('rb', self.rbufsize)) +        self.wfile = FileLike(self.connection.makefile('wb', self.wbufsize))          self.client_address = client_address          self.server = server diff --git a/test/test_tcp.py b/test/test_tcp.py index cb27c63b..d6235b01 100644 --- a/test/test_tcp.py +++ b/test/test_tcp.py @@ -1,4 +1,4 @@ -import cStringIO, threading, Queue +import cStringIO, threading, Queue, time  from netlib import tcp, certutils  import tutils @@ -57,6 +57,12 @@ class DisconnectHandler(tcp.BaseHandler):          self.close() +class HangHandler(tcp.BaseHandler): +    def handle(self): +        while 1: +            time.sleep(1) + +  class TServer(tcp.TCPServer):      def __init__(self, addr, ssl, q, handler, v3_only=False):          tcp.TCPServer.__init__(self, addr) @@ -188,6 +194,31 @@ class TestDisconnect(ServerTestBase):          c.close() +class TestTimeOut(ServerTestBase): +    @classmethod +    def makeserver(cls): +        return TServer(("127.0.0.1", 0), False, cls.q, HangHandler) + +    def test_timeout_client(self): +        c = tcp.TCPClient("127.0.0.1", self.port) +        c.connect() +        c.settimeout(0.1) +        tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + + +class TestSSLTimeOut(ServerTestBase): +    @classmethod +    def makeserver(cls): +        return TServer(("127.0.0.1", 0), True, cls.q, HangHandler) + +    def test_timeout_client(self): +        c = tcp.TCPClient("127.0.0.1", self.port) +        c.connect() +        c.convert_to_ssl() +        c.settimeout(0.1) +        tutils.raises(tcp.NetLibTimeout, c.rfile.read, 10) + +  class TestTCPClient:      def test_conerr(self):          c = tcp.TCPClient("127.0.0.1", 0) | 
