diff options
| -rw-r--r-- | netlib/http.py | 40 | ||||
| -rw-r--r-- | test/test_http.py | 42 | 
2 files changed, 78 insertions, 4 deletions
| diff --git a/netlib/http.py b/netlib/http.py index 1f5f8901..f0982b6d 100644 --- a/netlib/http.py +++ b/netlib/http.py @@ -128,7 +128,7 @@ def read_http_body(code, rfile, headers, all, limit):              raise HttpError(code, "HTTP Body too large. Limit is %s, content-length was %s"%(limit, l))          content = rfile.read(l)      elif all: -        content = rfile.read(limit if limit else None) +        content = rfile.read(limit if limit else -1)      else:          content = ""      return content @@ -141,7 +141,10 @@ def parse_http_protocol(s):      """      if not s.startswith("HTTP/"):          return None -    major, minor = s.split('/')[1].split('.') +    _, version = s.split('/') +    if "." not in version: +        return None +    major, minor = version.split('.')      major = int(major)      minor = int(minor)      return major, minor @@ -237,8 +240,37 @@ def read_http_body_request(rfile, wfile, headers, httpversion, limit):      return read_http_body(400, rfile, headers, False, limit) -def read_http_body_response(rfile, headers, False, limit): +def read_http_body_response(rfile, headers, all, limit):      """          Read the HTTP body from a server response.      """ -    return read_http_body(500, rfile, headers, False, limit) +    return read_http_body(500, rfile, headers, all, limit) + + +def read_response(rfile, method, body_size_limit): +    line = rfile.readline() +    if line == "\r\n" or line == "\n": # Possible leftover from previous message +        line = rfile.readline() +    if not line: +        raise HttpError(502, "Blank server response.") +    parts = line.strip().split(" ", 2) +    if len(parts) == 2: # handle missing message gracefully +        parts.append("") +    if not len(parts) == 3: +        raise HttpError(502, "Invalid server response: %s."%line) +    proto, code, msg = parts +    httpversion = parse_http_protocol(proto) +    if httpversion is None: +        raise HttpError(502, "Invalid HTTP version: %s."%httpversion) +    try: +        code = int(code) +    except ValueError: +        raise HttpError(502, "Invalid server response: %s."%line) +    headers = read_headers(rfile) +    if code >= 100 and code <= 199: +        return read_response(rfile, method, body_size_limit) +    if method == "HEAD" or code == 204 or code == 304: +        content = "" +    else: +        content = read_http_body_response(rfile, headers, True, body_size_limit) +    return httpversion, code, msg, headers, content diff --git a/test/test_http.py b/test/test_http.py index 3546fec6..b7ee6697 100644 --- a/test/test_http.py +++ b/test/test_http.py @@ -107,6 +107,7 @@ def test_parse_http_protocol():      assert http.parse_http_protocol("HTTP/1.1") == (1, 1)      assert http.parse_http_protocol("HTTP/0.0") == (0, 0)      assert not http.parse_http_protocol("foo/0.0") +    assert not http.parse_http_protocol("HTTP/x")  def test_parse_init_connect(): @@ -183,6 +184,47 @@ class TestReadHeaders:          assert h.lst == [["Header", "one\r\n two"], ["Header2", "three"]] +def test_read_response(): +    def tst(data, method, limit): +        data = textwrap.dedent(data) +        r = cStringIO.StringIO(data) +        return  http.read_response(r, method, limit) + +    tutils.raises("blank server response", tst, "", "GET", None) +    tutils.raises("invalid server response", tst, "foo", "GET", None) +    data = """ +        HTTP/1.1 200 OK +    """ +    assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') +    data = """ +        HTTP/1.1 200 +    """ +    assert tst(data, "GET", None) == ((1, 1), 200, '', odict.ODictCaseless(), '') +    data = """ +        HTTP/x 200 OK +    """ +    tutils.raises("invalid http version", tst, data, "GET", None) +    data = """ +        HTTP/1.1 xx OK +    """ +    tutils.raises("invalid server response", tst, data, "GET", None) + +    data = """ +        HTTP/1.1 100 CONTINUE + +        HTTP/1.1 200 OK +    """ +    assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + +    data = """ +        HTTP/1.1 200 OK + +        foo +    """ +    assert tst(data, "GET", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), 'foo\n') +    assert tst(data, "HEAD", None) == ((1, 1), 200, 'OK', odict.ODictCaseless(), '') + +  def test_parse_url():      assert not http.parse_url("") | 
