diff options
Diffstat (limited to 'netlib')
| -rw-r--r-- | netlib/socks.py | 32 | ||||
| -rw-r--r-- | netlib/tcp.py | 48 | ||||
| -rw-r--r-- | netlib/websockets.py | 16 | 
3 files changed, 51 insertions, 45 deletions
| diff --git a/netlib/socks.py b/netlib/socks.py index 497b8eef..6f9f57bd 100644 --- a/netlib/socks.py +++ b/netlib/socks.py @@ -52,20 +52,6 @@ METHOD = utils.BiDi(  ) -def _read(f, n): -    try: -        d = f.read(n) -        if len(d) == n: -            return d -        else: -            raise SocksError( -                REP.GENERAL_SOCKS_SERVER_FAILURE, -                "Incomplete Read" -            ) -    except socket.error as e: -        raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE, str(e)) - -  class ClientGreeting(object):      __slots__ = ("ver", "methods") @@ -75,9 +61,9 @@ class ClientGreeting(object):      @classmethod      def from_file(cls, f): -        ver, nmethods = struct.unpack("!BB", _read(f, 2)) +        ver, nmethods = struct.unpack("!BB", f.safe_read(2))          methods = array.array("B") -        methods.fromstring(_read(f, nmethods)) +        methods.fromstring(f.safe_read(nmethods))          return cls(ver, methods)      def to_file(self, f): @@ -94,7 +80,7 @@ class ServerGreeting(object):      @classmethod      def from_file(cls, f): -        ver, method = struct.unpack("!BB", _read(f, 2)) +        ver, method = struct.unpack("!BB", f.safe_read(2))          return cls(ver, method)      def to_file(self, f): @@ -112,27 +98,27 @@ class Message(object):      @classmethod      def from_file(cls, f): -        ver, msg, rsv, atyp = struct.unpack("!BBBB", _read(f, 4)) +        ver, msg, rsv, atyp = struct.unpack("!BBBB", f.safe_read(4))          if rsv != 0x00:              raise SocksError(REP.GENERAL_SOCKS_SERVER_FAILURE,                               "Socks Request: Invalid reserved byte: %s" % rsv)          if atyp == ATYP.IPV4_ADDRESS:              # We use tnoa here as ntop is not commonly available on Windows. -            host = socket.inet_ntoa(_read(f, 4)) +            host = socket.inet_ntoa(f.safe_read(4))              use_ipv6 = False          elif atyp == ATYP.IPV6_ADDRESS: -            host = socket.inet_ntop(socket.AF_INET6, _read(f, 16)) +            host = socket.inet_ntop(socket.AF_INET6, f.safe_read(16))              use_ipv6 = True          elif atyp == ATYP.DOMAINNAME: -            length, = struct.unpack("!B", _read(f, 1)) -            host = _read(f, length) +            length, = struct.unpack("!B", f.safe_read(1)) +            host = f.safe_read(length)              use_ipv6 = False          else:              raise SocksError(REP.ADDRESS_TYPE_NOT_SUPPORTED,                               "Socks Request: Unknown ATYP: %s" % atyp) -        port, = struct.unpack("!H", _read(f, 2)) +        port, = struct.unpack("!H", f.safe_read(2))          addr = tcp.Address((host, port), use_ipv6=use_ipv6)          return cls(ver, msg, atyp, addr) diff --git a/netlib/tcp.py b/netlib/tcp.py index 84008e2c..dbe114a1 100644 --- a/netlib/tcp.py +++ b/netlib/tcp.py @@ -24,6 +24,7 @@ OP_NO_SSLv3 = SSL.OP_NO_SSLv3  class NetLibError(Exception): pass  class NetLibDisconnect(NetLibError): pass +class NetLibIncomplete(NetLibError): pass  class NetLibTimeout(NetLibError): pass  class NetLibSSLError(NetLibError): pass @@ -195,10 +196,23 @@ class Reader(_FileLike):                      break          return result +    def safe_read(self, length): +        """ +            Like .read, but is guaranteed to either return length bytes, or +            raise an exception. +        """ +        result = self.read(length) +        if length != -1 and len(result) != length: +            raise NetLibIncomplete( +                "Expected %s bytes, got %s"%(length, len(result)) +            ) +        return result +  class Address(object):      """ -    This class wraps an IPv4/IPv6 tuple to provide named attributes and ipv6 information. +        This class wraps an IPv4/IPv6 tuple to provide named attributes and +        ipv6 information.      """      def __init__(self, address, use_ipv6=False):          self.address = tuple(address) @@ -247,22 +261,28 @@ def close_socket(sock):      """      try:          # We already indicate that we close our end. -        sock.shutdown(socket.SHUT_WR)  # may raise "Transport endpoint is not connected" on Linux +         # may raise "Transport endpoint is not connected" on Linux +        sock.shutdown(socket.SHUT_WR) -        # Section 4.2.2.13 of RFC 1122 tells us that a close() with any -        # pending readable data could lead to an immediate RST being sent (which is the case on Windows). +        # Section 4.2.2.13 of RFC 1122 tells us that a close() with any pending +        # readable data could lead to an immediate RST being sent (which is the +        # case on Windows).          # http://ia600609.us.archive.org/22/items/TheUltimateSo_lingerPageOrWhyIsMyTcpNotReliable/the-ultimate-so_linger-page-or-why-is-my-tcp-not-reliable.html          # -        # This in turn results in the following issue: If we send an error page to the client and then close the socket, -        # the RST may be received by the client before the error page and the users sees a connection error rather than -        # the error page. Thus, we try to empty the read buffer on Windows first. -        # (see https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988) +        # This in turn results in the following issue: If we send an error page +        # to the client and then close the socket, the RST may be received by +        # the client before the error page and the users sees a connection +        # error rather than the error page. Thus, we try to empty the read +        # buffer on Windows first. (see +        # https://github.com/mitmproxy/mitmproxy/issues/527#issuecomment-93782988)          # +          if os.name == "nt":  # pragma: no cover -            # We cannot rely on the shutdown()-followed-by-read()-eof technique proposed by the page above: -            # Some remote machines just don't send a TCP FIN, which would leave us in the unfortunate situation that -            # recv() would block infinitely. -            # As a workaround, we set a timeout here even if we are in blocking mode. +            # We cannot rely on the shutdown()-followed-by-read()-eof technique +            # proposed by the page above: Some remote machines just don't send +            # a TCP FIN, which would leave us in the unfortunate situation that +            # recv() would block infinitely. As a workaround, we set a timeout +            # here even if we are in blocking mode.              sock.settimeout(sock.gettimeout() or 20)              # limit at a megabyte so that we don't read infinitely @@ -292,10 +312,10 @@ class _Connection(object):      def finish(self):          self.finished = True -          # If we have an SSL connection, wfile.close == connection.close          # (We call _FileLike.set_descriptor(conn)) -        # Closing the socket is not our task, therefore we don't call close then. +        # Closing the socket is not our task, therefore we don't call close +        # then.          if type(self.connection) != SSL.Connection:              if not getattr(self.wfile, "closed", False):                  try: diff --git a/netlib/websockets.py b/netlib/websockets.py index 0ad0e294..6d08e101 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -5,7 +5,7 @@ import os  import struct  import io -from . import utils, odict +from . import utils, odict, tcp  # Colleciton of utility functions that implement small portions of the RFC6455  # WebSockets Protocol Useful for building WebSocket clients and servers. @@ -217,8 +217,8 @@ class FrameHeader:          """            read a websockets frame header          """ -        first_byte = utils.bytes_to_int(fp.read(1)) -        second_byte = utils.bytes_to_int(fp.read(1)) +        first_byte = utils.bytes_to_int(fp.safe_read(1)) +        second_byte = utils.bytes_to_int(fp.safe_read(1))          fin = utils.getbit(first_byte, 7)          rsv1 = utils.getbit(first_byte, 6) @@ -235,13 +235,13 @@ class FrameHeader:          if length_code <= 125:              payload_length = length_code          elif length_code == 126: -            payload_length = utils.bytes_to_int(fp.read(2)) +            payload_length = utils.bytes_to_int(fp.safe_read(2))          elif length_code == 127: -            payload_length = utils.bytes_to_int(fp.read(8)) +            payload_length = utils.bytes_to_int(fp.safe_read(8))          # masking key only present if mask bit set          if mask_bit == 1: -            masking_key = fp.read(4) +            masking_key = fp.safe_read(4)          else:              masking_key = None @@ -319,7 +319,7 @@ class Frame(object):            Construct a websocket frame from an in-memory bytestring            to construct a frame from a stream of bytes, use from_file() directly          """ -        return cls.from_file(io.BytesIO(bytestring)) +        return cls.from_file(tcp.Reader(io.BytesIO(bytestring)))      def human_readable(self):          hdr = self.header.human_readable() @@ -351,7 +351,7 @@ class Frame(object):            stream or a disk or an in memory stream reader          """          header = FrameHeader.from_file(fp) -        payload = fp.read(header.payload_length) +        payload = fp.safe_read(header.payload_length)          if header.mask == 1 and header.masking_key:              payload = Masker(header.masking_key)(payload) | 
