aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy/proxy/protocol/http_replay.py
blob: 7efb0782a33760abbd6891df07364504ec199b9f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import traceback

from mitmproxy import log
from mitmproxy import controller
from mitmproxy import exceptions
from mitmproxy import http
from mitmproxy import flow
from mitmproxy import connections
from netlib.http import http1
from mitmproxy.types import basethread


# TODO: Doesn't really belong into mitmproxy.proxy.protocol...


class RequestReplayThread(basethread.BaseThread):
    name = "RequestReplayThread"

    def __init__(self, config, f, event_queue, should_exit):
        """
            event_queue can be a queue or None, if no scripthooks should be
            processed.
        """
        self.config, self.f = config, f
        f.live = True
        if event_queue:
            self.channel = controller.Channel(event_queue, should_exit)
        else:
            self.channel = None
        super().__init__(
            "RequestReplay (%s)" % f.request.url
        )

    def run(self):
        r = self.f.request
        first_line_format_backup = r.first_line_format
        server = None
        try:
            self.f.response = None

            # If we have a channel, run script hooks.
            if self.channel:
                request_reply = self.channel.ask("request", self.f)
                if isinstance(request_reply, http.HTTPResponse):
                    self.f.response = request_reply

            if not self.f.response:
                # In all modes, we directly connect to the server displayed
                if self.config.options.mode == "upstream":
                    server_address = self.config.upstream_server.address
                    server = connections.ServerConnection(server_address, (self.config.options.listen_host, 0))
                    server.connect()
                    if r.scheme == "https":
                        connect_request = http.make_connect_request((r.data.host, r.port))
                        server.wfile.write(http1.assemble_request(connect_request))
                        server.wfile.flush()
                        resp = http1.read_response(
                            server.rfile,
                            connect_request,
                            body_size_limit=self.config.options.body_size_limit
                        )
                        if resp.status_code != 200:
                            raise exceptions.ReplayException("Upstream server refuses CONNECT request")
                        server.establish_ssl(
                            self.config.clientcerts,
                            sni=self.f.server_conn.sni
                        )
                        r.first_line_format = "relative"
                    else:
                        r.first_line_format = "absolute"
                else:
                    server_address = (r.host, r.port)
                    server = connections.ServerConnection(
                        server_address,
                        (self.config.options.listen_host, 0)
                    )
                    server.connect()
                    if r.scheme == "https":
                        server.establish_ssl(
                            self.config.clientcerts,
                            sni=self.f.server_conn.sni
                        )
                    r.first_line_format = "relative"

                server.wfile.write(http1.assemble_request(r))
                server.wfile.flush()
                self.f.server_conn = server
                self.f.response = http.HTTPResponse.wrap(
                    http1.read_response(
                        server.rfile,
                        r,
                        body_size_limit=self.config.options.body_size_limit
                    )
                )
            if self.channel:
                response_reply = self.channel.ask("response", self.f)
                if response_reply == exceptions.Kill:
                    raise exceptions.Kill()
        except (exceptions.ReplayException, exceptions.NetlibException) as e:
            self.f.error = flow.Error(str(e))
            if self.channel:
                self.channel.ask("error", self.f)
        except exceptions.Kill:
            # Kill should only be raised if there's a channel in the
            # first place.
            self.channel.tell(
                "log",
                log.LogEntry("Connection killed", "info")
            )
        except Exception:
            self.channel.tell(
                "log",
                log.LogEntry(traceback.format_exc(), "error")
            )
        finally:
            r.first_line_format = first_line_format_backup
            self.f.live = False
            if server:
                server.finish()