aboutsummaryrefslogtreecommitdiffstats
path: root/mitmproxy/types.py
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@corte.si>2017-12-19 09:20:29 +1300
committerAldo Cortesi <aldo@corte.si>2017-12-19 10:19:08 +1300
commitcda14830d349f4c1c60af2d1ec563e4894b836c3 (patch)
tree8c8ecb39d712ef07c4144f3da40555c12107feb9 /mitmproxy/types.py
parent38b37ba7f51c72bb94e1a5d2c9c7cf9836cb5a06 (diff)
downloadmitmproxy-cda14830d349f4c1c60af2d1ec563e4894b836c3.tar.gz
mitmproxy-cda14830d349f4c1c60af2d1ec563e4894b836c3.tar.bz2
mitmproxy-cda14830d349f4c1c60af2d1ec563e4894b836c3.zip
types: add validation functions
Diffstat (limited to 'mitmproxy/types.py')
-rw-r--r--mitmproxy/types.py93
1 files changed, 87 insertions, 6 deletions
diff --git a/mitmproxy/types.py b/mitmproxy/types.py
index b6b414ba..713a0ae5 100644
--- a/mitmproxy/types.py
+++ b/mitmproxy/types.py
@@ -56,13 +56,29 @@ class _BaseType:
def completion(
self, manager: _CommandBase, t: typing.Any, s: str
- ) -> typing.Sequence[str]: # pragma: no cover
- pass
+ ) -> typing.Sequence[str]:
+ """
+ Returns a list of completion strings for a given prefix. The strings
+ returned don't necessarily need to be suffixes of the prefix, since
+ completers will do prefix filtering themselves..
+ """
+ raise NotImplementedError
def parse(
- self, manager: _CommandBase, t: typing.Any, s: str
- ) -> typing.Any: # pragma: no cover
- pass
+ self, manager: _CommandBase, typ: typing.Any, s: str
+ ) -> typing.Any:
+ """
+ Parse a string, given the specific type instance (to allow rich type annotations like Choice) and a string.
+
+ Raises exceptions.TypeError if the value is invalid.
+ """
+ raise NotImplementedError
+
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ """
+ Check if data is valid for this type.
+ """
+ raise NotImplementedError
class _BoolType(_BaseType):
@@ -82,6 +98,9 @@ class _BoolType(_BaseType):
"Booleans are 'true' or 'false', got %s" % s
)
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return val in [True, False]
+
class _StrType(_BaseType):
typ = str
@@ -93,6 +112,9 @@ class _StrType(_BaseType):
def parse(self, manager: _CommandBase, t: type, s: str) -> str:
return s
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return isinstance(val, str)
+
class _IntType(_BaseType):
typ = int
@@ -107,6 +129,9 @@ class _IntType(_BaseType):
except ValueError as e:
raise exceptions.TypeError from e
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return isinstance(val, int)
+
class _PathType(_BaseType):
typ = Path
@@ -137,6 +162,9 @@ class _PathType(_BaseType):
def parse(self, manager: _CommandBase, t: type, s: str) -> str:
return s
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return isinstance(val, str)
+
class _CmdType(_BaseType):
typ = Cmd
@@ -148,6 +176,9 @@ class _CmdType(_BaseType):
def parse(self, manager: _CommandBase, t: type, s: str) -> str:
return s
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return val in manager.commands
+
class _ArgType(_BaseType):
typ = Arg
@@ -159,6 +190,9 @@ class _ArgType(_BaseType):
def parse(self, manager: _CommandBase, t: type, s: str) -> str:
return s
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return isinstance(val, str)
+
class _StrSeqType(_BaseType):
typ = typing.Sequence[str]
@@ -170,6 +204,9 @@ class _StrSeqType(_BaseType):
def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[str]:
return [x.strip() for x in s.split(",")]
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return isinstance(val, str)
+
class _CutSpecType(_BaseType):
typ = CutSpec
@@ -224,15 +261,29 @@ class _CutSpecType(_BaseType):
parts = s.split(",") # type: typing.Any
return parts
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ if not isinstance(val, str):
+ return False
+ parts = [x.strip() for x in val.split(",")]
+ for p in parts:
+ for pref in self.valid_prefixes:
+ if p.startswith(pref):
+ break
+ else:
+ return False
+ return True
+
class _BaseFlowType(_BaseType):
- valid_prefixes = [
+ viewmarkers = [
"@all",
"@focus",
"@shown",
"@hidden",
"@marked",
"@unmarked",
+ ]
+ valid_prefixes = viewmarkers + [
"~q",
"~s",
"~a",
@@ -264,6 +315,9 @@ class _FlowType(_BaseFlowType):
)
return flows[0]
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ return isinstance(val, flow.Flow)
+
class _FlowsType(_BaseFlowType):
typ = typing.Sequence[flow.Flow]
@@ -272,6 +326,15 @@ class _FlowsType(_BaseFlowType):
def parse(self, manager: _CommandBase, t: type, s: str) -> typing.Sequence[flow.Flow]:
return manager.call_args("view.resolve", [s])
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ try:
+ for v in val:
+ if not isinstance(v, flow.Flow):
+ return False
+ except TypeError:
+ return False
+ return True
+
class _DataType(_BaseType):
typ = Data
@@ -287,6 +350,17 @@ class _DataType(_BaseType):
) -> typing.Any: # pragma: no cover
raise exceptions.TypeError("data cannot be passed as argument")
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ # FIXME: validate that all rows have equal length, and all columns have equal types
+ try:
+ for row in val:
+ for cell in row:
+ if not (isinstance(cell, str) or isinstance(cell, bytes)):
+ return False
+ except TypeError:
+ return False
+ return True
+
class _ChoiceType(_BaseType):
typ = Choice
@@ -301,6 +375,13 @@ class _ChoiceType(_BaseType):
raise exceptions.TypeError("Invalid choice.")
return s
+ def is_valid(self, manager: _CommandBase, typ: typing.Any, val: typing.Any) -> bool:
+ try:
+ opts = manager.call(typ.options_command)
+ except exceptions.CommandError:
+ return False
+ return val in opts
+
class TypeManager:
def __init__(self, *types):