diff options
Diffstat (limited to 'backends')
-rw-r--r-- | backends/smt2/smtbmc.py | 13 | ||||
-rw-r--r-- | backends/smt2/smtio.py | 69 |
2 files changed, 59 insertions, 23 deletions
diff --git a/backends/smt2/smtbmc.py b/backends/smt2/smtbmc.py index 3d96b07a0..f74908f87 100644 --- a/backends/smt2/smtbmc.py +++ b/backends/smt2/smtbmc.py @@ -98,7 +98,7 @@ smt.setup("QF_AUFBV") with open(args[0], "r") as f: for line in f: smt.write(line) - smt.getinfo(line) + smt.info(line) if topmod is None: topmod = smt.topmod @@ -106,18 +106,19 @@ if topmod is None: assert topmod is not None assert topmod in smt.modinfo + def write_vcd_model(steps): print("%s Writing model to VCD file." % smt.timestamp()) vcd = mkvcd(open(vcdfile, "w")) - for netname in sorted(smt.modinfo[topmod].wsize.keys()): - width = len(smt.get_net_bin(topmod, netname, "s0")) - vcd.add_net(netname, width) + for netpath in sorted(smt.hiernets(topmod)): + width = len(smt.get_net_bin(topmod, netpath, "s0")) + vcd.add_net([topmod] + netpath, width) for i in range(steps): vcd.set_time(i) - for netname in smt.modinfo[topmod].wsize.keys(): - vcd.set_net(netname, smt.get_net_bin(topmod, netname, "s%d" % i)) + for netpath in sorted(smt.hiernets(topmod)): + vcd.set_net([topmod] + netpath, smt.get_net_bin(topmod, netpath, "s%d" % i)) vcd.set_time(steps) diff --git a/backends/smt2/smtio.py b/backends/smt2/smtio.py index 14ad75e3e..5f93a2fed 100644 --- a/backends/smt2/smtio.py +++ b/backends/smt2/smtio.py @@ -97,7 +97,7 @@ class smtio: self.p.stdin.write(bytes(stmt + "\n", "ascii")) self.p.stdin.flush() - def getinfo(self, stmt): + def info(self, stmt): if not stmt.startswith("; yosys-smt2-"): return @@ -129,6 +129,17 @@ class smtio: self.modinfo[self.curmod].wires.add(fields[2]) self.modinfo[self.curmod].wsize[fields[2]] = int(fields[3]) + def hiernets(self, top): + def hiernets_worker(nets, mod, cursor): + for netname in sorted(self.modinfo[mod].wsize.keys()): + nets.append(cursor + [netname]) + for cellname, celltype in sorted(self.modinfo[mod].cells.items()): + hiernets_worker(nets, celltype, cursor + [cellname]) + + nets = list() + hiernets_worker(nets, top, []) + return nets + def read(self): stmt = [] count_brackets = 0 @@ -282,19 +293,32 @@ class smtio: self.write("(get-value (%s))" % (expr)) return self.parse(self.read())[0][1] - def get_net(self, mod_name, net_name, state_name): - return self.get("(|%s_n %s| %s)" % (mod_name, net_name, state_name)) + def get_net(self, mod_name, net_path, state_name): + def mkexpr(mod, base, path): + if len(path) == 1: + assert mod in self.modinfo + assert path[0] in self.modinfo[mod].wsize + return "(|%s_n %s| %s)" % (mod, path[0], base) + + assert mod in self.modinfo + assert path[0] in self.modinfo[mod].cells + + nextmod = self.modinfo[mod].cells[path[0]] + nextbase = "(|%s_h %s| %s)" % (mod, path[0], base) + return mkexpr(nextmod, nextbase, path[1:]) + + return self.get(mkexpr(mod_name, state_name, net_path)) - def get_net_bool(self, mod_name, net_name, state_name): - v = self.get_net(mod_name, net_name, state_name) + def get_net_bool(self, mod_name, net_path, state_name): + v = self.get_net(mod_name, net_path, state_name) assert v in ["true", "false"] return 1 if v == "true" else 0 - def get_net_hex(self, mod_name, net_name, state_name): - return self.bv2hex(self.get_net(mod_name, net_name, state_name)) + def get_net_hex(self, mod_name, net_path, state_name): + return self.bv2hex(self.get_net(mod_name, net_path, state_name)) - def get_net_bin(self, mod_name, net_name, state_name): - return self.bv2bin(self.get_net(mod_name, net_name, state_name)) + def get_net_bin(self, mod_name, net_path, state_name): + return self.bv2bin(self.get_net(mod_name, net_path, state_name)) def wait(self): self.p.wait() @@ -344,24 +368,35 @@ class mkvcd: self.t = -1 self.nets = dict() - def add_net(self, name, width): + def add_net(self, path, width): + path = tuple(path) assert self.t == -1 key = "n%d" % len(self.nets) - self.nets[name] = (key, width) + self.nets[path] = (key, width) - def set_net(self, name, bits): - assert name in self.nets + def set_net(self, path, bits): + path = tuple(path) assert self.t >= 0 - print("b%s %s" % (bits, self.nets[name][0]), file=self.f) + assert path in self.nets + print("b%s %s" % (bits, self.nets[path][0]), file=self.f) def set_time(self, t): assert t >= self.t if t != self.t: if self.t == -1: print("$var event 1 ! smt_clock $end", file=self.f) - for name in sorted(self.nets): - key, width = self.nets[name] - print("$var wire %d %s %s $end" % (width, key, name), file=self.f) + scope = [] + for path in sorted(self.nets): + while len(scope)+1 > len(path) or (len(scope) > 0 and scope[-1] != path[len(scope)-1]): + print("$upscope $end", file=self.f) + scope = scope[:-1] + while len(scope)+1 < len(path): + print("$scope module %s $end" % path[len(scope)], file=self.f) + scope.append(path[len(scope)-1]) + key, width = self.nets[path] + print("$var wire %d %s %s $end" % (width, key, path[-1]), file=self.f) + for i in range(len(scope)): + print("$upscope $end", file=self.f) print("$enddefinitions $end", file=self.f) self.t = t assert self.t >= 0 |