aboutsummaryrefslogtreecommitdiffstats
path: root/pathod/language/message.py
blob: 566bce60fbe70958ccd8d55cbad95624a5f5382a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import abc
import typing  # noqa

import pyparsing as pp

from mitmproxy.utils import strutils
from . import actions, exceptions, base

LOG_TRUNCATE = 1024


class Message:
    __metaclass__ = abc.ABCMeta
    logattrs: typing.List[str] = []

    def __init__(self, tokens):
        track = set([])
        for i in tokens:
            if i.unique_name:
                if i.unique_name in track:
                    raise exceptions.ParseException(
                        "Message has multiple %s clauses, "
                        "but should only have one." % i.unique_name,
                        0, 0
                    )
                else:
                    track.add(i.unique_name)
        self.tokens = tokens

    def strike_token(self, name):
        toks = [i for i in self.tokens if i.unique_name != name]
        return self.__class__(toks)

    def toks(self, klass):
        """
            Fetch all tokens that are instances of klass
        """
        return [i for i in self.tokens if isinstance(i, klass)]

    def tok(self, klass):
        """
            Fetch first token that is an instance of klass
        """
        l = self.toks(klass)
        if l:
            return l[0]

    def length(self, settings):
        """
            Calculate the length of the base message without any applied
            actions.
        """
        return sum(len(x) for x in self.values(settings))

    def preview_safe(self):
        """
            Return a copy of this message that is safe for previews.
        """
        tokens = [i for i in self.tokens if not isinstance(i, actions.PauseAt)]
        return self.__class__(tokens)

    def maximum_length(self, settings):
        """
            Calculate the maximum length of the base message with all applied
            actions.
        """
        l = self.length(settings)
        for i in self.actions:
            if isinstance(i, actions.InjectAt):
                l += len(i.value.get_generator(settings))
        return l

    @classmethod
    def expr(cls):  # pragma: no cover
        pass

    def log(self, settings):
        """
            A dictionary that should be logged if this message is served.
        """
        ret = {}
        for i in self.logattrs:
            v = getattr(self, i)
            # Careful not to log any VALUE specs without sanitizing them first.
            # We truncate at 1k.
            if hasattr(v, "values"):
                v = [x[:LOG_TRUNCATE] for x in v.values(settings)]
                v = strutils.bytes_to_escaped_str(b"".join(v))
            elif hasattr(v, "__len__"):
                v = v[:LOG_TRUNCATE]
                v = strutils.bytes_to_escaped_str(v)
            ret[i] = v
        ret["spec"] = self.spec()
        return ret

    def freeze(self, settings):
        r = self.resolve(settings)
        return self.__class__([i.freeze(settings) for i in r.tokens])

    def __repr__(self):
        return self.spec()


class NestedMessage(base.Token):
    """
        A nested message, as an escaped string with a preamble.
    """
    preamble = ""
    nest_type: typing.Optional[typing.Type[Message]] = None

    def __init__(self, value):
        super().__init__()
        self.value = value
        try:
            self.parsed = self.nest_type(
                self.nest_type.expr().parseString(
                    value.val.decode(),
                    parseAll=True
                )
            )
        except pp.ParseException as v:
            raise exceptions.ParseException(v.msg, v.line, v.col)

    @classmethod
    def expr(cls):
        e = pp.Literal(cls.preamble).suppress()
        e = e + base.TokValueLiteral.expr()
        return e.setParseAction(lambda x: cls(*x))

    def values(self, settings):
        return [
            self.value.get_generator(settings),
        ]

    def spec(self):
        return "%s%s" % (self.preamble, self.value.spec())

    def freeze(self, settings):
        f = self.parsed.freeze(settings).spec()
        return self.__class__(
            base.TokValueLiteral(
                strutils.bytes_to_escaped_str(f.encode(), escape_single_quotes=True)
            )
        )