aboutsummaryrefslogtreecommitdiffstats
path: root/tools/python/xen/xend/server/channel.py
diff options
context:
space:
mode:
Diffstat (limited to 'tools/python/xen/xend/server/channel.py')
-rwxr-xr-xtools/python/xen/xend/server/channel.py378
1 files changed, 378 insertions, 0 deletions
diff --git a/tools/python/xen/xend/server/channel.py b/tools/python/xen/xend/server/channel.py
new file mode 100755
index 0000000000..be98a37fd5
--- /dev/null
+++ b/tools/python/xen/xend/server/channel.py
@@ -0,0 +1,378 @@
+import xen.ext.xc; xc = xen.ext.xc.new()
+from xen.ext import xu
+from messages import msgTypeName
+
+VIRQ_MISDIRECT = 0 # Catch-all interrupt for unbound VIRQs.
+VIRQ_TIMER = 1 # Timebase update, and/or requested timeout.
+VIRQ_DEBUG = 2 # Request guest to dump debug info.
+VIRQ_CONSOLE = 3 # (DOM0) bytes received on emergency console.
+VIRQ_DOM_EXC = 4 # (DOM0) Exceptional event for some domain.
+
+def eventChannel(dom1, dom2):
+ return xc.evtchn_bind_interdomain(dom1=dom1, dom2=dom2)
+
+class ChannelFactory:
+ """Factory for creating channels.
+ Maintains a table of channels.
+ """
+
+ """ Channels indexed by index. """
+ channels = {}
+
+ def __init__(self):
+ """Constructor - do not use. Use the channelFactory function."""
+ self.notifier = xu.notifier()
+
+ def addChannel(self, channel):
+ """Add a channel.
+ """
+ idx = channel.idx
+ self.channels[idx] = channel
+ self.notifier.bind(idx)
+ # Try to wake it up
+ #self.notifier.unmask(idx)
+ #channel.notify()
+
+ def getChannel(self, idx):
+ """Get the channel with the given index (if any).
+ """
+ return self.channels.get(idx)
+
+ def delChannel(self, idx):
+ """Remove the channel with the given index (if any).
+ """
+ if idx in self.channels:
+ del self.channels[idx]
+ self.notifier.unbind(idx)
+
+ def domChannel(self, dom):
+ """Get the channel for the given domain.
+ Construct if necessary.
+ """
+ dom = int(dom)
+ for chan in self.channels.values():
+ if not isinstance(chan, Channel): continue
+ if chan.dom == dom:
+ return chan
+ chan = Channel(self, dom)
+ self.addChannel(chan)
+ return chan
+
+ def virqChannel(self, virq):
+ """Get the channel for the given virq.
+ Construct if necessary.
+ """
+ for chan in self.channels.values():
+ if not isinstance(chan, VirqChannel): continue
+ if chan.virq == virq:
+ return chan
+ chan = VirqChannel(self, virq)
+ self.addChannel(chan)
+ return chan
+
+ def channelClosed(self, channel):
+ """The given channel has been closed - remove it.
+ """
+ self.delChannel(channel.idx)
+
+ def createPort(self, dom):
+ """Create a port for a channel to the given domain.
+ """
+ return xu.port(dom)
+
+def channelFactory():
+ """Singleton constructor for the channel factory.
+ Use this instead of the class constructor.
+ """
+ global inst
+ try:
+ inst
+ except:
+ inst = ChannelFactory()
+ return inst
+
+class BaseChannel:
+ """Abstract superclass for channels.
+
+ The subclass constructor must set idx to the port to use.
+ """
+
+ def __init__(self, factory):
+ self.factory = factory
+ self.idx = -1
+ self.closed = 0
+
+ def getIndex(self):
+ """Get the channel index.
+ """
+ return self.idx
+
+ def notificationReceived(self, type):
+ """Called when a notification is received.
+ Closes the channel on error, otherwise calls
+ handleNotification(type), which should be defined
+ in a subclass.
+ """
+ #print 'notificationReceived> type=', type, self
+ if self.closed: return
+ if type == self.factory.notifier.EXCEPTION:
+ print 'notificationReceived> EXCEPTION'
+ info = xc.evtchn_status(self.idx)
+ if info['status'] == 'unbound':
+ print 'notificationReceived> EXCEPTION closing...'
+ self.close()
+ return
+ self.handleNotification(type)
+
+ def close(self):
+ """Close the channel. Calls channelClosed() on the factory.
+ Override in subclass.
+ """
+ self.factory.channelClosed(self)
+
+ def handleNotification(self, type):
+ """Handle notification.
+ Define in subclass.
+ """
+ pass
+
+
+class VirqChannel(BaseChannel):
+ """A channel for handling a virq.
+ """
+
+ def __init__(self, factory, virq):
+ """Create a channel for the given virq using the given factory.
+
+ Do not call directly, use virqChannel on the factory.
+ """
+ BaseChannel.__init__(self, factory)
+ self.virq = virq
+ # Notification port (int).
+ self.port = xc.evtchn_bind_virq(virq)
+ self.idx = self.port
+ # Clients to call when a virq arrives.
+ self.clients = []
+
+ def __repr__(self):
+ return ('<VirqChannel virq=%d port=%d>'
+ % (self.virq, self.port))
+
+ def getVirq(self):
+ """Get the channel's virq.
+ """
+ return self.virq
+
+ def close(self):
+ """Close the channel. Calls lostChannel(self) on all its clients and
+ channelClosed() on the factory.
+ """
+ for c in self.clients:
+ c.lostChannel(self)
+ del self.clients
+ BaseChannel.close(self)
+
+ def registerClient(self, client):
+ """Register a client. The client will be called with
+ client.virqReceived(virq) when a virq is received.
+ The client will be called with client.lostChannel(self) if the
+ channel is closed.
+ """
+ self.clients.append(client)
+
+ def handleNotification(self, type):
+ for c in self.clients:
+ c.virqReceived(self.virq)
+
+ def notify(self):
+ xc.evtchn_send(self.port)
+
+
+class Channel(BaseChannel):
+ """A control channel to a domain. Messages for the domain device controllers
+ are multiplexed over the channel (console, block devs, net devs).
+ """
+
+ def __init__(self, factory, dom):
+ """Create a channel to the given domain using the given factory.
+
+ Do not call directly, use domChannel on the factory.
+ """
+ BaseChannel.__init__(self, factory)
+ # Domain.
+ self.dom = int(dom)
+ # Domain port (object).
+ self.port = self.factory.createPort(dom)
+ # Channel port (int).
+ self.idx = self.port.local_port
+ # Registered devices.
+ self.devs = []
+ # Devices indexed by the message types they handle.
+ self.devs_by_type = {}
+ # Output queue.
+ self.queue = []
+ self.closed = 0
+
+ def getLocalPort(self):
+ """Get the local port.
+ """
+ return self.port.local_port
+
+ def getRemotePort(self):
+ """Get the remote port.
+ """
+ return self.port.remote_port
+
+ def close(self):
+ """Close the channel. Calls lostChannel() on all its devices and
+ channelClosed() on the factory.
+ """
+ self.closed = 1
+ for d in self.devs:
+ d.lostChannel()
+ self.factory.channelClosed(self)
+ self.devs = []
+ self.devs_by_type = {}
+
+ def registerDevice(self, types, dev):
+ """Register a device controller.
+
+ @param types message types the controller handles
+ @param dev device controller
+ """
+ if self.closed: return
+ self.devs.append(dev)
+ for ty in types:
+ self.devs_by_type[ty] = dev
+
+ def deregisterDevice(self, dev):
+ """Remove the registration for a device controller.
+
+ @param dev device controller
+ """
+ if dev in self.devs:
+ self.devs.remove(dev)
+ types = [ ty for (ty, d) in self.devs_by_type.items() if d == dev ]
+ for ty in types:
+ del self.devs_by_type[ty]
+
+ def getDevice(self, type):
+ """Get the device controller handling a message type.
+
+ @param type message type
+ @returns controller or None
+ """
+ return self.devs_by_type.get(type)
+
+ def getMessageType(self, msg):
+ """Get a 2-tuple of the message type and subtype.
+ """
+ hdr = msg.get_header()
+ return (hdr['type'], hdr.get('subtype'))
+
+ def __repr__(self):
+ return ('<Channel dom=%d ports=%d:%d>'
+ % (self.dom,
+ self.port.local_port,
+ self.port.remote_port))
+
+ def handleNotification(self, type):
+ work = 0
+ work += self.handleRequests()
+ work += self.handleResponses()
+ work += self.handleWrites()
+ if work:
+ self.notify()
+
+ def notify(self):
+ self.port.notify()
+
+ def handleRequests(self):
+ work = 0
+ while 1:
+ msg = self.readRequest()
+ if not msg: break
+ self.requestReceived(msg)
+ work += 1
+ return work
+
+ def requestReceived(self, msg):
+ (ty, subty) = self.getMessageType(msg)
+ #todo: Must respond before writing any more messages.
+ #todo: Should automate this (respond on write)
+ self.port.write_response(msg)
+ dev = self.getDevice(ty)
+ if dev:
+ dev.requestReceived(msg, ty, subty)
+ else:
+ print ("requestReceived> No device: Message type %s %d:%d"
+ % (msgTypeName(ty, subty), ty, subty)), self
+
+ def handleResponses(self):
+ work = 0
+ while 1:
+ msg = self.readResponse()
+ if not msg: break
+ self.responseReceived(msg)
+ work += 1
+ return work
+
+ def responseReceived(self, msg):
+ (ty, subty) = self.getMessageType(msg)
+ dev = self.getDevice(ty)
+ if dev:
+ dev.responseReceived(msg, ty, subty)
+ else:
+ print ("responseReceived> No device: Message type %d:%d"
+ % (msgTypeName(ty, subty), ty, subty)), self
+
+ def handleWrites(self):
+ work = 0
+ # Pull data from producers.
+ for dev in self.devs:
+ work += dev.produceRequests()
+ # Flush the queue.
+ while self.queue and self.port.space_to_write_request():
+ msg = self.queue.pop(0)
+ self.port.write_request(msg)
+ work += 1
+ return work
+
+ def writeRequest(self, msg, notify=1):
+ if self.closed:
+ val = -1
+ elif self.writeReady():
+ self.port.write_request(msg)
+ if notify: self.notify()
+ val = 1
+ else:
+ self.queue.append(msg)
+ val = 0
+ return val
+
+ def writeResponse(self, msg):
+ if self.closed: return -1
+ self.port.write_response(msg)
+ return 1
+
+ def writeReady(self):
+ if self.closed or self.queue: return 0
+ return self.port.space_to_write_request()
+
+ def readRequest(self):
+ if self.closed:
+ return None
+ if self.port.request_to_read():
+ val = self.port.read_request()
+ else:
+ val = None
+ return val
+
+ def readResponse(self):
+ if self.closed:
+ return None
+ if self.port.response_to_read():
+ val = self.port.read_response()
+ else:
+ val = None
+ return val