diff options
Diffstat (limited to 'passes/pmgen/pmgen.py')
| -rw-r--r-- | passes/pmgen/pmgen.py | 263 | 
1 files changed, 176 insertions, 87 deletions
| diff --git a/passes/pmgen/pmgen.py b/passes/pmgen/pmgen.py index d9747b065..81052afce 100644 --- a/passes/pmgen/pmgen.py +++ b/passes/pmgen/pmgen.py @@ -3,15 +3,42 @@  import re  import sys  import pprint +import getopt  pp = pprint.PrettyPrinter(indent=4) -pmgfile = sys.argv[1] -assert pmgfile.endswith(".pmg") -prefix = pmgfile[0:-4] -prefix = prefix.split('/')[-1] -outfile = sys.argv[2] - +prefix = None +pmgfiles = list() +outfile = None +debug = False +genhdr = False + +opts, args = getopt.getopt(sys.argv[1:], "p:o:dg") + +for o, a in opts: +    if o == "-p": +        prefix = a +    elif o == "-o": +        outfile = a +    elif o == "-d": +        debug = True +    elif o == "-g": +        genhdr = True + +if outfile is None: +    outfile = "/dev/stdout" + +for a in args: +    assert a.endswith(".pmg") +    if prefix is None and len(args) == 1: +        prefix = a[0:-4] +        prefix = prefix.split('/')[-1] +    pmgfiles.append(a) + +assert prefix is not None + +current_pattern = None +patterns = dict()  state_types = dict()  udata_types = dict()  blocks = list() @@ -77,7 +104,8 @@ def rewrite_cpp(s):      return "".join(t) -with open(pmgfile, "r") as f: +def process_pmgfile(f): +    global current_pattern      while True:          line = f.readline()          if line == "": break @@ -87,14 +115,31 @@ with open(pmgfile, "r") as f:          if len(cmd) == 0 or cmd[0].startswith("//"): continue          cmd = cmd[0] +        if cmd == "pattern": +            if current_pattern is not None: +                block = dict() +                block["type"] = "final" +                block["pattern"] = current_pattern +                blocks.append(block) +            line = line.split() +            assert len(line) == 2 +            assert line[1] not in patterns +            current_pattern = line[1] +            patterns[current_pattern] = len(blocks) +            state_types[current_pattern] = dict() +            udata_types[current_pattern] = dict() +            continue + +        assert current_pattern is not None +          if cmd == "state":              m = re.match(r"^state\s+<(.*?)>\s+(([A-Za-z_][A-Za-z_0-9]*\s+)*[A-Za-z_][A-Za-z_0-9]*)\s*$", line)              assert m              type_str = m.group(1)              states_str = m.group(2)              for s in re.split(r"\s+", states_str): -                assert s not in state_types -                state_types[s] = type_str +                assert s not in state_types[current_pattern] +                state_types[current_pattern][s] = type_str              continue          if cmd == "udata": @@ -103,19 +148,20 @@ with open(pmgfile, "r") as f:              type_str = m.group(1)              udatas_str = m.group(2)              for s in re.split(r"\s+", udatas_str): -                assert s not in udata_types -                udata_types[s] = type_str +                assert s not in udata_types[current_pattern] +                udata_types[current_pattern][s] = type_str              continue          if cmd == "match":              block = dict()              block["type"] = "match" +            block["pattern"] = current_pattern              line = line.split()              assert len(line) == 2 -            assert line[1] not in state_types +            assert line[1] not in state_types[current_pattern]              block["cell"] = line[1] -            state_types[line[1]] = "Cell*"; +            state_types[current_pattern][line[1]] = "Cell*";              block["if"] = list()              block["select"] = list() @@ -158,15 +204,18 @@ with open(pmgfile, "r") as f:                  assert False              blocks.append(block) +            continue          if cmd == "code":              block = dict()              block["type"] = "code" +            block["pattern"] = current_pattern +              block["code"] = list()              block["states"] = set()              for s in line.split()[1:]: -                assert s in state_types +                assert s in state_types[current_pattern]                  block["states"].add(s)              while True: @@ -179,18 +228,37 @@ with open(pmgfile, "r") as f:                  block["code"].append(rewrite_cpp(l.rstrip()))              blocks.append(block) +            continue -with open(outfile, "w") as f: -    print("// Generated by pmgen.py from {}.pgm".format(prefix), file=f) -    print("", file=f) +        assert False -    print("#include \"kernel/yosys.h\"", file=f) -    print("#include \"kernel/sigtools.h\"", file=f) -    print("", file=f) +for fn in pmgfiles: +    with open(fn, "r") as f: +        process_pmgfile(f) + +if current_pattern is not None: +    block = dict() +    block["type"] = "final" +    block["pattern"] = current_pattern +    blocks.append(block) + +current_pattern = None + +if debug: +    pp.pprint(blocks) -    print("YOSYS_NAMESPACE_BEGIN", file=f) +with open(outfile, "w") as f: +    for fn in pmgfiles: +        print("// Generated by pmgen.py from {}".format(fn), file=f)      print("", file=f) +    if genhdr: +        print("#include \"kernel/yosys.h\"", file=f) +        print("#include \"kernel/sigtools.h\"", file=f) +        print("", file=f) +        print("YOSYS_NAMESPACE_BEGIN", file=f) +        print("", file=f) +      print("struct {}_pm {{".format(prefix), file=f)      print("  Module *module;", file=f)      print("  SigMap sigmap;", file=f) @@ -212,17 +280,19 @@ with open(outfile, "w") as f:      print("  int rollback;", file=f)      print("", file=f) -    print("  struct state_t {", file=f) -    for s, t in sorted(state_types.items()): -        print("    {} {};".format(t, s), file=f) -    print("  } st;", file=f) -    print("", file=f) +    for current_pattern in sorted(patterns.keys()): +        print("  struct state_{}_t {{".format(current_pattern), file=f) +        for s, t in sorted(state_types[current_pattern].items()): +            print("    {} {};".format(t, s), file=f) +        print("  }} st_{};".format(current_pattern), file=f) +        print("", file=f) -    print("  struct udata_t {", file=f) -    for s, t in sorted(udata_types.items()): -        print("    {} {};".format(t, s), file=f) -    print("  } ud;", file=f) -    print("", file=f) +        print("  struct udata_{}_t {{".format(current_pattern), file=f) +        for s, t in sorted(udata_types[current_pattern].items()): +            print("    {} {};".format(t, s), file=f) +        print("  }} ud_{};".format(current_pattern), file=f) +        print("", file=f) +    current_pattern = None      for v, n in sorted(ids.items()):          if n[0] == "\\": @@ -258,20 +328,24 @@ with open(outfile, "w") as f:      print("  }", file=f)      print("", file=f) -    print("  void check_blacklist() {", file=f) -    print("    if (!blacklist_dirty)", file=f) -    print("      return;", file=f) -    print("    blacklist_dirty = false;", file=f) -    for index in range(len(blocks)): -        block = blocks[index] -        if block["type"] == "match": -            print("    if (st.{} != nullptr && blacklist_cells.count(st.{})) {{".format(block["cell"], block["cell"]), file=f) -            print("      rollback = {};".format(index+1), file=f) -            print("      return;", file=f) -            print("    }", file=f) -    print("    rollback = 0;", file=f) -    print("  }", file=f) -    print("", file=f) +    for current_pattern in sorted(patterns.keys()): +        print("  void check_blacklist_{}() {{".format(current_pattern), file=f) +        print("    if (!blacklist_dirty)", file=f) +        print("      return;", file=f) +        print("    blacklist_dirty = false;", file=f) +        for index in range(len(blocks)): +            block = blocks[index] +            if block["pattern"] != current_pattern: +                continue +            if block["type"] == "match": +                print("    if (st_{}.{} != nullptr && blacklist_cells.count(st_{}.{})) {{".format(current_pattern, block["cell"], current_pattern, block["cell"]), file=f) +                print("      rollback = {};".format(index+1), file=f) +                print("      return;", file=f) +                print("    }", file=f) +        print("    rollback = 0;", file=f) +        print("  }", file=f) +        print("", file=f) +    current_pattern = None      print("  SigSpec port(Cell *cell, IdString portname) {", file=f)      print("    return sigmap(cell->getPort(portname));", file=f) @@ -294,11 +368,13 @@ with open(outfile, "w") as f:      print("  {}_pm(Module *module, const vector<Cell*> &cells) :".format(prefix), file=f)      print("      module(module), sigmap(module) {", file=f) -    for s, t in sorted(udata_types.items()): -        if t.endswith("*"): -            print("    ud.{} = nullptr;".format(s), file=f) -        else: -            print("    ud.{} = {}();".format(s, t), file=f) +    for current_pattern in sorted(patterns.keys()): +        for s, t in sorted(udata_types[current_pattern].items()): +            if t.endswith("*"): +                print("    ud_{}.{} = nullptr;".format(current_pattern,s), file=f) +            else: +                print("    ud_{}.{} = {}();".format(current_pattern, s, t), file=f) +    current_pattern = None      print("    for (auto cell : module->cells()) {", file=f)      print("      for (auto &conn : cell->connections())", file=f)      print("        add_siguser(conn.second, cell);", file=f) @@ -328,34 +404,52 @@ with open(outfile, "w") as f:      print("  }", file=f)      print("", file=f) -    print("  void run(std::function<void()> on_accept_f) {", file=f) -    print("    on_accept = on_accept_f;", file=f) -    print("    rollback = 0;", file=f) -    print("    blacklist_dirty = false;", file=f) -    for s, t in sorted(state_types.items()): -        if t.endswith("*"): -            print("    st.{} = nullptr;".format(s), file=f) -        else: -            print("    st.{} = {}();".format(s, t), file=f) -    print("    block_0();", file=f) -    print("  }", file=f) -    print("", file=f) - -    print("  void run(std::function<void({}_pm&)> on_accept_f) {{".format(prefix), file=f) -    print("    run([&](){on_accept_f(*this);});", file=f) -    print("  }", file=f) -    print("", file=f) +    for current_pattern in sorted(patterns.keys()): +        print("  void run_{}(std::function<void()> on_accept_f) {{".format(current_pattern), file=f) +        print("    on_accept = on_accept_f;", file=f) +        print("    rollback = 0;", file=f) +        print("    blacklist_dirty = false;", file=f) +        for s, t in sorted(state_types[current_pattern].items()): +            if t.endswith("*"): +                print("    st_{}.{} = nullptr;".format(current_pattern, s), file=f) +            else: +                print("    st_{}.{} = {}();".format(current_pattern, s, t), file=f) +        print("    block_{}();".format(patterns[current_pattern]), file=f) +        print("  }", file=f) +        print("", file=f) +        print("  void run_{}(std::function<void({}_pm&)> on_accept_f) {{".format(current_pattern, prefix), file=f) +        print("    run_{}([&](){{on_accept_f(*this);}});".format(current_pattern), file=f) +        print("  }", file=f) +        print("", file=f) +        print("  void run_{}(std::function<void(state_{}_t&)> on_accept_f) {{".format(current_pattern, current_pattern), file=f) +        print("    run_{}([&](){{on_accept_f(st_{});}});".format(current_pattern, current_pattern), file=f) +        print("  }", file=f) +        print("", file=f) +        print("  void run_{}() {{".format(current_pattern), file=f) +        print("    run_{}([](){{}});".format(current_pattern, current_pattern), file=f) +        print("  }", file=f) +        print("", file=f) +    current_pattern = None      for index in range(len(blocks)):          block = blocks[index]          print("  void block_{}() {{".format(index), file=f) +        current_pattern = block["pattern"] + +        if block["type"] == "final": +            print("    on_accept();", file=f) +            print("    check_blacklist_{}();".format(current_pattern), file=f) +            print("  }", file=f) +            if index+1 != len(blocks): +                print("", file=f) +            continue          const_st = set()          nonconst_st = set()          restore_st = set() -        for i in range(index): +        for i in range(patterns[current_pattern], index):              if blocks[i]["type"] == "code":                  for s in blocks[i]["states"]:                      const_st.add(s) @@ -378,27 +472,27 @@ with open(outfile, "w") as f:              assert False          for s in sorted(const_st): -            t = state_types[s] +            t = state_types[current_pattern][s]              if t.endswith("*"): -                print("    {} const &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f) +                print("    {} const &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)              else: -                print("    const {} &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f) +                print("    const {} &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)          for s in sorted(nonconst_st): -            t = state_types[s] -            print("    {} &{} YS_ATTRIBUTE(unused) = st.{};".format(t, s, s), file=f) +            t = state_types[current_pattern][s] +            print("    {} &{} YS_ATTRIBUTE(unused) = st_{}.{};".format(t, s, current_pattern, s), file=f)          if len(restore_st):              print("", file=f)              for s in sorted(restore_st): -                t = state_types[s] +                t = state_types[current_pattern][s]                  print("    {} backup_{} = {};".format(t, s, s), file=f)          if block["type"] == "code":              print("", file=f)              print("    do {", file=f) -            print("#define reject do { check_blacklist(); goto rollback_label; } while(0)", file=f) -            print("#define accept do { on_accept(); check_blacklist(); if (rollback) goto rollback_label; } while(0)", file=f) +            print("#define reject do {{ check_blacklist_{}(); goto rollback_label; }} while(0)".format(current_pattern), file=f) +            print("#define accept do {{ on_accept(); check_blacklist_{}(); if (rollback) goto rollback_label; }} while(0)".format(current_pattern), file=f)              print("#define branch do {{ block_{}(); if (rollback) goto rollback_label; }} while(0)".format(index+1), file=f)              for line in block["code"]: @@ -417,11 +511,11 @@ with open(outfile, "w") as f:              if len(restore_st) or len(nonconst_st):                  print("", file=f)                  for s in sorted(restore_st): -                    t = state_types[s] +                    t = state_types[current_pattern][s]                      print("    {} = backup_{};".format(s, s), file=f)                  for s in sorted(nonconst_st):                      if s not in restore_st: -                        t = state_types[s] +                        t = state_types[current_pattern][s]                          if t.endswith("*"):                              print("    {} = nullptr;".format(s), file=f)                          else: @@ -470,17 +564,12 @@ with open(outfile, "w") as f:          else:              assert False - +        current_pattern = None          print("  }", file=f)          print("", file=f) -    print("  void block_{}() {{".format(len(blocks)), file=f) -    print("    on_accept();", file=f) -    print("    check_blacklist();", file=f) -    print("  }", file=f)      print("};", file=f) -    print("", file=f) -    print("YOSYS_NAMESPACE_END", file=f) - -# pp.pprint(blocks) +    if genhdr: +        print("", file=f) +        print("YOSYS_NAMESPACE_END", file=f) | 
