aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAldo Cortesi <aldo@nullcube.com>2018-03-31 17:06:09 +1300
committerAldo Cortesi <aldo@corte.si>2018-04-01 09:46:32 +1200
commit976b2018a3fc320272ac4f588250977fc08cf9b5 (patch)
tree84543da70376ede6efcaf2e661bf953c7d492ec7
parenta2d45193546962e2e14d1959e1bf008c83b9f3cf (diff)
downloadmitmproxy-976b2018a3fc320272ac4f588250977fc08cf9b5.tar.gz
mitmproxy-976b2018a3fc320272ac4f588250977fc08cf9b5.tar.bz2
mitmproxy-976b2018a3fc320272ac4f588250977fc08cf9b5.zip
asyncio: clean up event loop acquisition
We now acquire the event loop through asyncio.get_event_loop, avoiding having to pass the loop explicity in a bunch of places. This function does not return the currently running loop from within coroutines in versions of Python prior to 3.6.
-rw-r--r--mitmproxy/master.py28
-rw-r--r--test/mitmproxy/test_controller.py89
2 files changed, 20 insertions, 97 deletions
diff --git a/mitmproxy/master.py b/mitmproxy/master.py
index 0fcf312e..31849a88 100644
--- a/mitmproxy/master.py
+++ b/mitmproxy/master.py
@@ -2,7 +2,6 @@ import threading
import contextlib
import asyncio
import signal
-import time
from mitmproxy import addonmanager
from mitmproxy import options
@@ -37,11 +36,10 @@ class Master:
The master handles mitmproxy's main event loop.
"""
def __init__(self, opts):
- self.loop = asyncio.new_event_loop()
- asyncio.set_event_loop(self.loop)
+ loop = asyncio.get_event_loop()
for signame in ('SIGINT', 'SIGTERM'):
- self.loop.add_signal_handler(getattr(signal, signame), self.shutdown)
- self.event_queue = asyncio.Queue(loop=self.loop)
+ loop.add_signal_handler(getattr(signal, signame), self.shutdown)
+ self.event_queue = asyncio.Queue()
self.options = opts or options.Options() # type: options.Options
self.commands = command.CommandManager(self)
@@ -57,9 +55,7 @@ class Master:
@server.setter
def server(self, server):
- server.set_channel(
- controller.Channel(self.loop, self.event_queue)
- )
+ server.set_channel(controller.Channel(asyncio.get_event_loop(), self.event_queue))
self._server = server
@contextlib.contextmanager
@@ -111,18 +107,16 @@ class Master:
self.addons.trigger("running")
while True:
if self.should_exit.is_set():
- self.loop.stop()
+ asyncio.get_event_loop().stop()
return
self.addons.trigger("tick")
- await asyncio.sleep(0.1, loop=self.loop)
+ await asyncio.sleep(0.1)
- def run(self, inject=None):
+ def run(self):
self.start()
- asyncio.ensure_future(self.main(), loop=self.loop)
- asyncio.ensure_future(self.tick(), loop=self.loop)
- if inject:
- asyncio.ensure_future(inject(), loop=self.loop)
- self.loop.run_forever()
+ asyncio.ensure_future(self.main())
+ asyncio.ensure_future(self.tick())
+ asyncio.get_event_loop().run_forever()
self.shutdown()
self.addons.trigger("done")
@@ -214,7 +208,7 @@ class Master:
rt = http_replay.RequestReplayThread(
self.options,
f,
- self.loop,
+ asyncio.get_event_loop(),
self.event_queue,
self.should_exit
)
diff --git a/test/mitmproxy/test_controller.py b/test/mitmproxy/test_controller.py
index e27f6baf..f7c64ed9 100644
--- a/test/mitmproxy/test_controller.py
+++ b/test/mitmproxy/test_controller.py
@@ -1,102 +1,31 @@
import asyncio
-from threading import Thread, Event
-from unittest.mock import Mock
import queue
import pytest
-import sys
from mitmproxy.exceptions import Kill, ControlException
from mitmproxy import controller
-from mitmproxy import master
-from mitmproxy import proxy
from mitmproxy.test import taddons
-class TMsg:
- pass
+@pytest.mark.asyncio
+async def test_master():
+ class TMsg:
+ pass
-
-def test_master():
class tAddon:
def log(self, _):
ctx.master.should_exit.set()
- with taddons.context() as ctx:
- ctx.master.addons.add(tAddon())
+ with taddons.context(tAddon()) as ctx:
assert not ctx.master.should_exit.is_set()
async def test():
msg = TMsg()
msg.reply = controller.DummyReply()
- await ctx.master.event_queue.put(("log", msg))
-
- ctx.master.run(inject=test)
-
-
-# class TestMaster:
-# # def test_simple(self):
-# # class tAddon:
-# # def log(self, _):
-# # ctx.master.should_exit.set()
-
-# # with taddons.context() as ctx:
-# # ctx.master.addons.add(tAddon())
-# # assert not ctx.master.should_exit.is_set()
-# # msg = TMsg()
-# # msg.reply = controller.DummyReply()
-# # ctx.master.event_queue.put(("log", msg))
-# # ctx.master.run()
-# # assert ctx.master.should_exit.is_set()
-
-# # def test_server_simple(self):
-# # m = master.Master(None)
-# # m.server = proxy.DummyServer()
-# # m.start()
-# # m.shutdown()
-# # m.start()
-# # m.shutdown()
-# pass
-
-
-# class TestServerThread:
-# def test_simple(self):
-# m = Mock()
-# t = master.ServerThread(m)
-# t.run()
-# assert m.serve_forever.called
-
-
-# class TestChannel:
-# def test_tell(self):
-# q = queue.Queue()
-# channel = controller.Channel(q, Event())
-# m = Mock(name="test_tell")
-# channel.tell("test", m)
-# assert q.get() == ("test", m)
-# assert m.reply
-
-# def test_ask_simple(self):
-# q = queue.Queue()
-
-# def reply():
-# m, obj = q.get()
-# assert m == "test"
-# obj.reply.send(42)
-# obj.reply.take()
-# obj.reply.commit()
-
-# Thread(target=reply).start()
-
-# channel = controller.Channel(q, Event())
-# assert channel.ask("test", Mock(name="test_ask_simple")) == 42
-
-# def test_ask_shutdown(self):
-# q = queue.Queue()
-# done = Event()
-# done.set()
-# channel = controller.Channel(q, done)
-# with pytest.raises(Kill):
-# channel.ask("test", Mock(name="test_ask_shutdown"))
+ await ctx.master.channel.tell("log", msg)
+
+ asyncio.ensure_future(test())
+ assert not ctx.master.should_exit.is_set()
class TestReply: