diff options
| -rw-r--r-- | netlib/http.py | 132 | ||||
| -rw-r--r-- | test/test_http.py | 57 | 
2 files changed, 126 insertions, 63 deletions
diff --git a/netlib/http.py b/netlib/http.py index 413c73a1..774bac6c 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -1,4 +1,5 @@  import string, urlparse, binascii +import sys  import odict, utils @@ -88,14 +89,14 @@ def read_headers(fp):              # We're being liberal in what we accept, here.              if i > 0:                  name = line[:i] -                value = line[i+1:].strip() +                value = line[i + 1:].strip()                  ret.append([name, value])              else:                  return None      return odict.ODictCaseless(ret) -def read_chunked(fp, headers, limit, is_request): +def read_chunked(fp, limit, is_request):      """          Read a chunked HTTP body. @@ -103,10 +104,9 @@ def read_chunked(fp, headers, limit, is_request):      """      # FIXME: Should check if chunked is the final encoding in the headers      # http://tools.ietf.org/html/draft-ietf-httpbis-p1-messaging-16#section-3.3 3.3 2. -    content = ""      total = 0      code = 400 if is_request else 502 -    while 1: +    while True:          line = fp.readline(128)          if line == "":              raise HttpErrorConnClosed(code, "Connection closed prematurely") @@ -114,27 +114,19 @@ def read_chunked(fp, headers, limit, is_request):              try:                  length = int(line, 16)              except ValueError: -                # FIXME: Not strictly correct - this could be from the server, in which -                # case we should send a 502. -                raise HttpError(code, "Invalid chunked encoding length: %s"%line) -            if not length: -                break +                raise HttpError(code, "Invalid chunked encoding length: %s" % line)              total += length              if limit is not None and total > limit: -                msg = "HTTP Body too large."\ -                      " Limit is %s, chunked content length was at least %s"%(limit, total) +                msg = "HTTP Body too large." \ +                      " Limit is %s, chunked content length was at least %s" % (limit, total)                  raise HttpError(code, msg) -            content += fp.read(length) -            line = fp.readline(5) -            if line != '\r\n': +            chunk = fp.read(length) +            suffix = fp.readline(5) +            if suffix != '\r\n':                  raise HttpError(code, "Malformed chunked body") -    while 1: -        line = fp.readline() -        if line == "": -            raise HttpErrorConnClosed(code, "Connection closed prematurely") -        if line == '\r\n' or line == '\n': -            break -    return content +            yield line, chunk, '\r\n' +            if length == 0: +                return  def get_header_tokens(headers, key): @@ -264,6 +256,7 @@ def parse_init_http(line):  def connection_close(httpversion, headers):      """          Checks the message to see if the client connection should be closed according to RFC 2616 Section 8.1 +        Note that a connection should be closed as well if the response has been read until end of the stream.      """      # At first, check if we have an explicit Connection header.      if "connection" in headers: @@ -280,7 +273,7 @@ def connection_close(httpversion, headers):  def parse_response_line(line):      parts = line.strip().split(" ", 2) -    if len(parts) == 2: # handle missing message gracefully +    if len(parts) == 2:  # handle missing message gracefully          parts.append("")      if len(parts) != 3:          return None @@ -308,26 +301,27 @@ def read_response(rfile, request_method, body_size_limit, include_body=True):          raise HttpErrorConnClosed(502, "Server disconnect.")      parts = parse_response_line(line)      if not parts: -        raise HttpError(502, "Invalid server response: %s"%repr(line)) +        raise HttpError(502, "Invalid server response: %s" % repr(line))      proto, code, msg = parts      httpversion = parse_http_protocol(proto)      if httpversion is None: -        raise HttpError(502, "Invalid HTTP version in line: %s"%repr(proto)) +        raise HttpError(502, "Invalid HTTP version in line: %s" % repr(proto))      headers = read_headers(rfile)      if headers is None:          raise HttpError(502, "Invalid headers.") -    # Parse response body according to http://tools.ietf.org/html/rfc7230#section-3.3 -    if request_method in ["HEAD", "CONNECT"] or (code in [204, 304]) or 100 <= code <= 199: -        content = "" -    elif include_body: -        content = read_http_body(rfile, headers, body_size_limit, False) +    if include_body: +        content = read_http_body(rfile, headers, body_size_limit, request_method, code, False)      else:          content = None  # if include_body==False then a None content means the body should be read separately      return httpversion, code, msg, headers, content -def read_http_body(rfile, headers, limit, is_request): +def read_http_body(*args, **kwargs): +    return "".join(content for _, content, _ in read_http_body_chunked(*args, **kwargs)) + + +def read_http_body_chunked(rfile, headers, limit, request_method, response_code, is_request, max_chunk_size=None):      """          Read an HTTP message body: @@ -336,23 +330,69 @@ def read_http_body(rfile, headers, limit, is_request):              limit: Size limit.              is_request: True if the body to read belongs to a request, False otherwise      """ -    if has_chunked_encoding(headers): -        content = read_chunked(rfile, headers, limit, is_request) -    elif "content-length" in headers: -        try: -            l = int(headers["content-length"][0]) -            if l < 0: -                raise ValueError() -        except ValueError: -            raise HttpError(400 if is_request else 502, "Invalid content-length header: %s"%headers["content-length"]) -        if limit is not None and l > limit: -            raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l)) -        content = rfile.read(l) -    elif is_request: -        content = "" +    if max_chunk_size is None: +        max_chunk_size = limit or sys.maxint + +    expected_size = expected_http_body_size(headers, is_request, request_method, response_code) + +    if expected_size is None: +        if has_chunked_encoding(headers): +            # Python 3: yield from +            for x in read_chunked(rfile, limit, is_request): +                yield x +        else:  # pragma: nocover +            raise HttpError(400 if is_request else 502, "Content-Length unknown but no chunked encoding") +    elif expected_size >= 0: +        if limit is not None and expected_size > limit: +            raise HttpError(400 if is_request else 509, +                            "HTTP Body too large. Limit is %s, content-length was %s" % (limit, expected_size)) +        bytes_left = expected_size +        while bytes_left: +            chunk_size = min(bytes_left, max_chunk_size) +            yield "", rfile.read(chunk_size), "" +            bytes_left -= chunk_size      else: -        content = rfile.read(limit if limit else -1) +        bytes_left = limit or -1 +        while bytes_left: +            chunk_size = min(bytes_left, max_chunk_size) +            content = rfile.read(chunk_size) +            if not content: +                return +            yield "", content, "" +            bytes_left -= chunk_size          not_done = rfile.read(1)          if not_done:              raise HttpError(400 if is_request else 509, "HTTP Body too large. Limit is %s," % limit) -    return content
\ No newline at end of file + + +def expected_http_body_size(headers, is_request, request_method, response_code): +    """ +        Returns the expected body length: +         - a positive integer, if the size is known in advance +         - None, if the size in unknown in advance (chunked encoding) +         - -1, if all data should be read until end of stream. +    """ + +    # Determine response size according to http://tools.ietf.org/html/rfc7230#section-3.3 +    if request_method: +        request_method = request_method.upper() + +    if (not is_request and ( +                            request_method == "HEAD" or +                        (request_method == "CONNECT" and response_code == 200) or +                        response_code in [204, 304] or +                        100 <= response_code <= 199)): +        return 0 +    if has_chunked_encoding(headers): +        return None +    if "content-length" in headers: +        try: +            size = int(headers["content-length"][0]) +            if size < 0: +                raise ValueError() +            return size +        except ValueError: +            raise HttpError(400 if is_request else 502, "Invalid content-length header: %s" % headers["content-length"]) +    if is_request: +        return 0 +    return -1 diff --git a/test/test_http.py b/test/test_http.py index df351dc7..497e80e2 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -16,26 +16,30 @@ def test_has_chunked_encoding():  def test_read_chunked(): + +    h = odict.ODictCaseless() +    h["transfer-encoding"] = ["chunked"]      s = cStringIO.StringIO("1\r\na\r\n0\r\n") -    tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) + +    tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True)      s = cStringIO.StringIO("1\r\na\r\n0\r\n\r\n") -    assert http.read_chunked(s, None, None, True) == "a" +    assert http.read_http_body(s, h, None, "GET", None, True) == "a"      s = cStringIO.StringIO("\r\n\r\n1\r\na\r\n0\r\n\r\n") -    assert http.read_chunked(s, None, None, True) == "a" +    assert http.read_http_body(s, h, None, "GET", None, True) == "a"      s = cStringIO.StringIO("\r\n") -    tutils.raises("closed prematurely", http.read_chunked, s, None, None, True) +    tutils.raises("closed prematurely", http.read_http_body, s, h, None, "GET", None, True)      s = cStringIO.StringIO("1\r\nfoo") -    tutils.raises("malformed chunked body", http.read_chunked, s, None, None, True) +    tutils.raises("malformed chunked body", http.read_http_body, s, h, None, "GET", None, True)      s = cStringIO.StringIO("foo\r\nfoo") -    tutils.raises(http.HttpError, http.read_chunked, s, None, None, True) +    tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", None, True)      s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") -    tutils.raises("too large", http.read_chunked, s, None, 2, True) +    tutils.raises("too large", http.read_http_body, s, h, 2, "GET", None, True)  def test_connection_close(): @@ -63,54 +67,73 @@ def test_get_header_tokens():  def test_read_http_body_request():      h = odict.ODictCaseless()      r = cStringIO.StringIO("testing") -    assert http.read_http_body(r, h, None, True) == "" +    assert http.read_http_body(r, h, None, "GET", None, True) == ""  def test_read_http_body_response():      h = odict.ODictCaseless()      s = cStringIO.StringIO("testing") -    assert http.read_http_body(s, h, None, False) == "testing" +    assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"  def test_read_http_body():      # test default case      h = odict.ODictCaseless()      h["content-length"] = [7]      s = cStringIO.StringIO("testing") -    assert http.read_http_body(s, h, None, False) == "testing" +    assert http.read_http_body(s, h, None, "GET", 200, False) == "testing"      # test content length: invalid header      h["content-length"] = ["foo"]      s = cStringIO.StringIO("testing") -    tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) +    tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False)      # test content length: invalid header #2      h["content-length"] = [-1]      s = cStringIO.StringIO("testing") -    tutils.raises(http.HttpError, http.read_http_body, s, h, None, False) +    tutils.raises(http.HttpError, http.read_http_body, s, h, None, "GET", 200, False)      # test content length: content length > actual content      h["content-length"] = [5]      s = cStringIO.StringIO("testing") -    tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) +    tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False)      # test content length: content length < actual content      s = cStringIO.StringIO("testing") -    assert len(http.read_http_body(s, h, None, False)) == 5 +    assert len(http.read_http_body(s, h, None, "GET", 200, False)) == 5      # test no content length: limit > actual content      h = odict.ODictCaseless()      s = cStringIO.StringIO("testing") -    assert len(http.read_http_body(s, h, 100, False)) == 7 +    assert len(http.read_http_body(s, h, 100, "GET", 200, False)) == 7      # test no content length: limit < actual content      s = cStringIO.StringIO("testing") -    tutils.raises(http.HttpError, http.read_http_body, s, h, 4, False) +    tutils.raises(http.HttpError, http.read_http_body, s, h, 4, "GET", 200, False)      # test chunked      h = odict.ODictCaseless()      h["transfer-encoding"] = ["chunked"]      s = cStringIO.StringIO("5\r\naaaaa\r\n0\r\n\r\n") -    assert http.read_http_body(s, h, 100, False) == "aaaaa" +    assert http.read_http_body(s, h, 100, "GET", 200, False) == "aaaaa" +def test_expected_http_body_size(): +    # gibber in the content-length field +    h = odict.ODictCaseless() +    h["content-length"] = ["foo"] +    tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) +    # negative number in the content-length field +    h = odict.ODictCaseless() +    h["content-length"] = ["-7"] +    tutils.raises(http.HttpError, http.expected_http_body_size, h, False, "GET", 200) +    # explicit length +    h = odict.ODictCaseless() +    h["content-length"] = ["5"] +    assert http.expected_http_body_size(h, False, "GET", 200) == 5 +    # no length +    h = odict.ODictCaseless() +    assert http.expected_http_body_size(h, False, "GET", 200) == -1 +    # no length request +    h = odict.ODictCaseless() +    assert http.expected_http_body_size(h, True, "GET", None) == 0  def test_parse_http_protocol():      assert http.parse_http_protocol("HTTP/1.1") == (1, 1)  | 
