diff options
| author | Aldo Cortesi <aldo@nullcube.com> | 2015-04-24 15:09:21 +1200 | 
|---|---|---|
| committer | Aldo Cortesi <aldo@nullcube.com> | 2015-04-24 15:09:21 +1200 | 
| commit | f22bc0b4c74776bcc312fed1f4ceede83f869a6e (patch) | |
| tree | 7d8b947d9940bb0faa68fe21d924642f6c3d1667 | |
| parent | 3519871f340cb0466fc6935d6e8e3b7822d36c52 (diff) | |
| download | mitmproxy-f22bc0b4c74776bcc312fed1f4ceede83f869a6e.tar.gz mitmproxy-f22bc0b4c74776bcc312fed1f4ceede83f869a6e.tar.bz2 mitmproxy-f22bc0b4c74776bcc312fed1f4ceede83f869a6e.zip  | |
websocket: interface refactoring
- Separate out FrameHeader. We need to deal with this separately in many circumstances.
- Simpler equality scheme.
- Bits are now specified by truthiness - we don't care about the integer value.
This means lots of validation is not needed any more.
| -rw-r--r-- | netlib/utils.py | 16 | ||||
| -rw-r--r-- | netlib/websockets.py | 303 | ||||
| -rw-r--r-- | test/test_websockets.py | 49 | 
3 files changed, 201 insertions, 167 deletions
diff --git a/netlib/utils.py b/netlib/utils.py index 66bbdb5e..44bed43a 100644 --- a/netlib/utils.py +++ b/netlib/utils.py @@ -49,3 +49,19 @@ def hexdump(s):              (o, x, cleanBin(part, True))          )      return parts + + +def setbit(byte, offset, value): +    """ +        Set a bit in a byte to 1 if value is truthy, 0 if not. +    """ +    if value: +        return byte | (1 << offset) +    else: +        return byte & ~(1 << offset) + + +def getbit(byte, offset): +    mask = 1 << offset +    if byte & mask: +        return True diff --git a/netlib/websockets.py b/netlib/websockets.py index 7c127563..016e75c2 100644 --- a/netlib/websockets.py +++ b/netlib/websockets.py @@ -1,5 +1,4 @@  from __future__ import absolute_import -  import base64  import hashlib  import os @@ -83,23 +82,6 @@ def server_handshake_headers(key):      ) -def get_payload_length_pair(payload_bytestring): -    """ -     A websockets frame contains an initial length_code, and an optional -     extended length code to represent the actual length if length code is -     larger than 125 -    """ -    actual_length = len(payload_bytestring) - -    if actual_length <= 125: -        length_code = actual_length -    elif actual_length >= 126 and actual_length <= 65535: -        length_code = 126 -    else: -        length_code = 127 -    return (length_code, actual_length) - -  def make_length_code(len):      """       A websockets frame contains an initial length_code, and an optional @@ -132,40 +114,113 @@ def create_server_nonce(client_nonce):      ) -def frame_header_bytes( -    opcode = 0, -    payload_length = 0, -    fin = 0, -    rsv1 = 0, -    rsv2 = 0, -    rsv3 = 0, -    mask = 0, -    masking_key = None, -    length_code = None -): -    first_byte = (fin << 7) | (rsv1 << 6) |\ -                 (rsv2 << 4) | (rsv3 << 4) | opcode - -    if length_code is None: -        length_code = make_length_code(payload_length) - -    second_byte = (mask << 7) | length_code - -    b = chr(first_byte) + chr(second_byte) - -    if payload_length < 126: -        pass -    elif payload_length < MAX_16_BIT_INT: -        # '!H' pack as 16 bit unsigned short -        # add 2 byte extended payload length -        b += struct.pack('!H', payload_length) -    elif payload_length < MAX_64_BIT_INT: -        # '!Q' = pack as 64 bit unsigned long long -        # add 8 bytes extended payload length -        b += struct.pack('!Q', payload_length) -    if masking_key is not None: -        b += masking_key -    return b +DEFAULT = object() +class FrameHeader: +    def __init__( +        self, +        opcode = OPCODE.TEXT, +        payload_length = 0, +        fin = False, +        rsv1 = False, +        rsv2 = False, +        rsv3 = False, +        masking_key = None, +        mask = DEFAULT, +        length_code = DEFAULT +    ): +        self.opcode = opcode +        self.payload_length = payload_length +        self.fin = fin +        self.rsv1 = rsv1 +        self.rsv2 = rsv2 +        self.rsv3 = rsv3 +        self.mask = mask +        self.masking_key = masking_key +        self.length_code = length_code + +    def to_bytes(self): +        first_byte = utils.setbit(0, 7, self.fin) +        first_byte = utils.setbit(first_byte, 6, self.rsv1) +        first_byte = utils.setbit(first_byte, 5, self.rsv2) +        first_byte = utils.setbit(first_byte, 4, self.rsv3) +        first_byte = first_byte | self.opcode + +        if self.length_code is DEFAULT: +            length_code = make_length_code(self.payload_length) +        else: +            length_code = self.length_code + +        if self.mask is DEFAULT: +            mask = bool(self.masking_key) +        else: +            mask = self.mask + +        second_byte = (mask << 7) | length_code + +        b = chr(first_byte) + chr(second_byte) + +        if self.payload_length < 126: +            pass +        elif self.payload_length < MAX_16_BIT_INT: +            # '!H' pack as 16 bit unsigned short +            # add 2 byte extended payload length +            b += struct.pack('!H', self.payload_length) +        elif self.payload_length < MAX_64_BIT_INT: +            # '!Q' = pack as 64 bit unsigned long long +            # add 8 bytes extended payload length +            b += struct.pack('!Q', self.payload_length) +        if self.masking_key is not None: +            b += self.masking_key +        return b + +    @classmethod +    def from_file(klass, fp): +        """ +          read a websockets frame header +        """ +        first_byte = utils.bytes_to_int(fp.read(1)) +        second_byte = utils.bytes_to_int(fp.read(1)) + +        fin = utils.getbit(first_byte, 7) +        rsv1 = utils.getbit(first_byte, 6) +        rsv2 = utils.getbit(first_byte, 5) +        rsv3 = utils.getbit(first_byte, 4) +        # grab right most 4 bits by and-ing with 00001111 +        opcode = first_byte & 15 +        # grab left most bit +        mask_bit = second_byte >> 7 +        # grab the next 7 bits +        length_code = second_byte & 127 + +        # payload_lengthy > 125 indicates you need to read more bytes +        # to get the actual payload length +        if length_code <= 125: +            payload_length = length_code +        elif length_code == 126: +            payload_length = utils.bytes_to_int(fp.read(2)) +        elif length_code == 127: +            payload_length = utils.bytes_to_int(fp.read(8)) + +        # masking key only present if mask bit set +        if mask_bit == 1: +            masking_key = fp.read(4) +        else: +            masking_key = None + +        return klass( +            fin = fin, +            rsv1 = rsv1, +            rsv2 = rsv2, +            rsv3 = rsv3, +            opcode = opcode, +            mask = mask_bit, +            length_code = length_code, +            payload_length = payload_length, +            masking_key = masking_key, +        ) + +    def __eq__(self, other): +        return self.to_bytes() == other.to_bytes()  class Frame(object): @@ -194,27 +249,10 @@ class Frame(object):           |                     Payload Data continued ...                |           +---------------------------------------------------------------+      """ -    def __init__( -        self, -        fin,                          # decmial integer 1 or 0 -        opcode,                       # decmial integer 1 - 4 -        payload = "",                 # bytestring -        masking_key = None,           # 32 bit byte string -        mask_bit = 0,                 # decimal integer 1 or 0 -        payload_length_code = None,   # decimal integer 1 - 127 -        rsv1 = 0,                     # decimal integer 1 or 0 -        rsv2 = 0,                     # decimal integer 1 or 0 -        rsv3 = 0,                     # decimal integer 1 or 0 -    ): -        self.fin = fin -        self.rsv1 = rsv1 -        self.rsv2 = rsv2 -        self.rsv3 = rsv3 -        self.opcode = opcode -        self.mask_bit = mask_bit -        self.payload_length_code = payload_length_code -        self.masking_key = masking_key +    def __init__(self, payload = "", **kwargs):          self.payload = payload +        kwargs["payload_length"] = kwargs.get("payload_length", len(payload)) +        self.header = FrameHeader(**kwargs)      @classmethod      def default(cls, message, from_client = False): @@ -230,10 +268,10 @@ class Frame(object):              masking_key = None          return cls( +            message,              fin = 1, # final frame              opcode = OPCODE.TEXT, # text -            mask_bit = mask_bit, -            payload = message, +            mask = mask_bit,              masking_key = masking_key,          ) @@ -243,30 +281,30 @@ class Frame(object):              Frame has not been corrupted.          """          constraints = [ -            0 <= self.fin <= 1, -            0 <= self.rsv1 <= 1, -            0 <= self.rsv2 <= 1, -            0 <= self.rsv3 <= 1, -            1 <= self.opcode <= 4, -            0 <= self.mask_bit <= 1, +            0 <= self.header.fin <= 1, +            0 <= self.header.rsv1 <= 1, +            0 <= self.header.rsv2 <= 1, +            0 <= self.header.rsv3 <= 1, +            1 <= self.header.opcode <= 4, +            0 <= self.header.mask <= 1,              #1 <= self.payload_length_code <= 127, -            1 <= len(self.masking_key) <= 4 if self.mask_bit else True, -            self.masking_key is not None if self.mask_bit else True +            1 <= len(self.header.masking_key) <= 4 if self.header.mask else True, +            self.header.masking_key is not None if self.header.mask else True          ]          if not all(constraints):              return False          return True -    def human_readable(self): # pragma: nocover +    def human_readable(self):          return "\n".join([ -            ("fin                   - " + str(self.fin)), -            ("rsv1                  - " + str(self.rsv1)), -            ("rsv2                  - " + str(self.rsv2)), -            ("rsv3                  - " + str(self.rsv3)), -            ("opcode                - " + str(self.opcode)), -            ("mask_bit              - " + str(self.mask_bit)), -            ("payload_length_code   - " + str(self.payload_length_code)), -            ("masking_key           - " + repr(str(self.masking_key))), +            ("fin                   - " + str(self.header.fin)), +            ("rsv1                  - " + str(self.header.rsv1)), +            ("rsv2                  - " + str(self.header.rsv2)), +            ("rsv3                  - " + str(self.header.rsv3)), +            ("opcode                - " + str(self.header.opcode)), +            ("mask                  - " + str(self.header.mask)), +            ("length_code           - " + str(self.header.length_code)), +            ("masking_key           - " + repr(str(self.header.masking_key))),              ("payload               - " + repr(str(self.payload))),          ]) @@ -284,18 +322,9 @@ class Frame(object):              If you haven't checked is_valid_frame() then there's no guarentees              that the serialized bytes will be correct. see safe_to_bytes()          """ -        b = frame_header_bytes( -            opcode = self.opcode, -            fin = self.fin, -            rsv1 = self.rsv1, -            rsv2 = self.rsv2, -            rsv3 = self.rsv3, -            mask = self.mask_bit, -            masking_key = self.masking_key, -            payload_length = len(self.payload) if self.payload else 0 -        ) -        if self.masking_key: -            b += apply_mask(self.payload, self.masking_key) +        b = self.header.to_bytes() +        if self.header.masking_key: +            b += apply_mask(self.payload, self.header.masking_key)          else:              b += self.payload          return b @@ -312,66 +341,20 @@ class Frame(object):            fp is a "file like" object that could be backed by a network            stream or a disk or an in memory stream reader          """ -        first_byte = utils.bytes_to_int(fp.read(1)) -        second_byte = utils.bytes_to_int(fp.read(1)) - -        # grab the left most bit -        fin = first_byte >> 7 -        # grab right most 4 bits by and-ing with 00001111 -        opcode = first_byte & 15 -        # grab left most bit -        mask_bit = second_byte >> 7 -        # grab the next 7 bits -        payload_length = second_byte & 127 - -        # payload_lengthy > 125 indicates you need to read more bytes -        # to get the actual payload length -        if payload_length <= 125: -            actual_payload_length = payload_length - -        elif payload_length == 126: -            actual_payload_length = utils.bytes_to_int(fp.read(2)) - -        elif payload_length == 127: -            actual_payload_length = utils.bytes_to_int(fp.read(8)) - -        # masking key only present if mask bit set -        if mask_bit == 1: -            masking_key = fp.read(4) -        else: -            masking_key = None - -        payload = fp.read(actual_payload_length) +        header = FrameHeader.from_file(fp) +        payload = fp.read(header.payload_length) -        if mask_bit == 1 and masking_key: -            payload = apply_mask(payload, masking_key) +        if header.mask == 1 and header.masking_key: +            payload = apply_mask(payload, header.masking_key)          return cls( -            fin = fin, -            opcode = opcode, -            mask_bit = mask_bit, -            payload_length_code = payload_length, -            payload = payload, -            masking_key = masking_key, +            payload, +            fin = header.fin, +            opcode = header.opcode, +            mask = header.mask, +            payload_length = header.payload_length, +            masking_key = header.masking_key,          )      def __eq__(self, other): -        if self.payload_length_code is None: -            myplc = make_length_code(len(self.payload)) -        else: -            myplc = self.payload_length_code -        if other.payload_length_code is None: -            otherplc = make_length_code(len(other.payload)) -        else: -            otherplc = other.payload_length_code -        return ( -            self.fin == other.fin and -            self.rsv1 == other.rsv1 and -            self.rsv2 == other.rsv2 and -            self.rsv3 == other.rsv3 and -            self.opcode == other.opcode and -            self.mask_bit == other.mask_bit and -            self.masking_key == other.masking_key and -            self.payload == other.payload, -            myplc == otherplc -        ) +        return self.to_bytes() == other.to_bytes() diff --git a/test/test_websockets.py b/test/test_websockets.py index bf8ec5cd..06876e0b 100644 --- a/test/test_websockets.py +++ b/test/test_websockets.py @@ -1,10 +1,9 @@ -from netlib import tcp, test, websockets, http +import cStringIO  import os -from nose.tools import raises +from nose.tools import raises -def test_frame_header_bytes(): -    assert websockets.frame_header_bytes() +from netlib import tcp, test, websockets, http  class WebSocketsEchoHandler(tcp.BaseHandler): @@ -119,12 +118,12 @@ class TestWebSockets(test.ServerTestBase):          assert frame.is_valid()          frame = f() -        frame.fin = 2 +        frame.header.fin = 2          assert not frame.is_valid()          frame = f() -        frame.mask_bit = 1 -        frame.masking_key = "foobbarboo" +        frame.header.mask_bit = 1 +        frame.header.masking_key = "foobbarboo"          assert not frame.is_valid()      def test_serialization_bijection(self): @@ -181,3 +180,39 @@ class TestBadHandshake(test.ServerTestBase):          client = WebSocketsClient(("127.0.0.1", self.port))          client.connect()          client.send_message("hello") + + +class TestFrameHeader: +    def test_roundtrip(self): +        def round(*args, **kwargs): +            f = websockets.FrameHeader(*args, **kwargs) +            bytes = f.to_bytes() +            f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) +            assert f == f2 +        round() +        round(fin=1) +        round(rsv1=1) +        round(rsv2=1) +        round(rsv3=1) +        round(payload_length=1) +        round(payload_length=100) +        round(payload_length=1000) +        round(payload_length=10000) +        round(opcode=websockets.OPCODE.PING) +        round(masking_key="test") + +    def test_funky(self): +        f = websockets.FrameHeader(masking_key="test", mask=False) +        bytes = f.to_bytes() +        f2 = websockets.FrameHeader.from_file(cStringIO.StringIO(bytes)) +        assert not f2.mask + + +class TestFrame: +    def test_roundtrip(self): +        def round(*args, **kwargs): +            f = websockets.Frame(*args, **kwargs) +            bytes = f.to_bytes() +            f2 = websockets.Frame.from_file(cStringIO.StringIO(bytes)) +            assert f == f2 +        round("test")  | 
