diff options
-rw-r--r-- | libs/ezsat/Makefile | 3 | ||||
-rw-r--r-- | libs/ezsat/ezminisat.cc | 5 | ||||
-rw-r--r-- | libs/ezsat/ezsat.cc | 46 | ||||
-rw-r--r-- | libs/ezsat/ezsat.h | 19 | ||||
-rw-r--r-- | libs/ezsat/testbench.cc | 6 |
5 files changed, 71 insertions, 8 deletions
diff --git a/libs/ezsat/Makefile b/libs/ezsat/Makefile index 1dcb5d15b..b1f864160 100644 --- a/libs/ezsat/Makefile +++ b/libs/ezsat/Makefile @@ -18,7 +18,8 @@ test: all ./testbench ./demo_bit ./demo_vec - ./demo_cmp + # ./demo_cmp + # ./puzzle3d clean: rm -f demo_bit demo_vec demo_cmp testbench puzzle3d *.o *.d diff --git a/libs/ezsat/ezminisat.cc b/libs/ezsat/ezminisat.cc index 6a6c075f5..3f43f3ece 100644 --- a/libs/ezsat/ezminisat.cc +++ b/libs/ezsat/ezminisat.cc @@ -63,7 +63,8 @@ void ezMiniSAT::clear() #if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL void ezMiniSAT::freeze(int id) { - cnfFrozenVars.insert(bind(id)); + if (!mode_non_incremental()) + cnfFrozenVars.insert(bind(id)); } bool ezMiniSAT::eliminated(int idx) @@ -89,6 +90,8 @@ void ezMiniSAT::alarmHandler(int) bool ezMiniSAT::solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions) { + preSolverCallback(); + solverTimoutStatus = false; if (0) { diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc index 6da363fc1..4c0b624be 100644 --- a/libs/ezsat/ezsat.cc +++ b/libs/ezsat/ezsat.cc @@ -30,6 +30,11 @@ const int ezSAT::FALSE = 2; ezSAT::ezSAT() { + flag_keep_cnf = false; + flag_non_incremental = false; + + non_incremental_solve_used_up = false; + cnfConsumed = false; cnfVariableCount = 0; cnfClausesCount = 0; @@ -588,19 +593,40 @@ int ezSAT::bind(int id, bool auto_freeze) void ezSAT::consumeCnf() { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnfClauses.clear(); } void ezSAT::consumeCnf(std::vector<std::vector<int>> &cnf) { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnf.swap(cnfClauses); cnfClauses.clear(); } +void ezSAT::getFullCnf(std::vector<std::vector<int>> &full_cnf) const +{ + assert(full_cnf.empty()); + full_cnf.insert(full_cnf.end(), cnfClausesBackup.begin(), cnfClausesBackup.end()); + full_cnf.insert(full_cnf.end(), cnfClauses.begin(), cnfClauses.end()); +} + +void ezSAT::preSolverCallback() +{ + assert(!non_incremental_solve_used_up); + if (mode_non_incremental()) + non_incremental_solve_used_up = true; +} + bool ezSAT::solver(const std::vector<int>&, std::vector<bool>&, const std::vector<int>&) { + preSolverCallback(); fprintf(stderr, "************************************************************************\n"); fprintf(stderr, "ERROR: You are trying to use the solve() method of the ezSAT base class!\n"); fprintf(stderr, "Use a dervied class like ezMiniSAT instead.\n"); @@ -1081,16 +1107,26 @@ void ezSAT::printDIMACS(FILE *f, bool verbose) const if (cnfExpressionVariables[i] != 0) fprintf(f, "c %*d: %s\n", digits, cnfExpressionVariables[i], to_string(-i-1).c_str()); + if (mode_keep_cnf()) { + fprintf(f, "c\n"); + fprintf(f, "c %d clauses from backup, %d from current buffer\n", + int(cnfClausesBackup.size()), int(cnfClauses.size())); + } + fprintf(f, "c\n"); } - fprintf(f, "p cnf %d %d\n", cnfVariableCount, int(cnfClauses.size())); + std::vector<std::vector<int>> all_clauses; + getFullCnf(all_clauses); + assert(cnfClausesCount == int(all_clauses.size())); + + fprintf(f, "p cnf %d %d\n", cnfVariableCount, cnfClausesCount); int maxClauseLen = 0; - for (auto &clause : cnfClauses) + for (auto &clause : all_clauses) maxClauseLen = std::max(int(clause.size()), maxClauseLen); if (!verbose) maxClauseLen = std::min(maxClauseLen, 3); - for (auto &clause : cnfClauses) { + for (auto &clause : all_clauses) { for (auto idx : clause) fprintf(f, " %*d", digits, idx); if (maxClauseLen >= int(clause.size())) diff --git a/libs/ezsat/ezsat.h b/libs/ezsat/ezsat.h index 852405566..83e1b23c5 100644 --- a/libs/ezsat/ezsat.h +++ b/libs/ezsat/ezsat.h @@ -48,6 +48,11 @@ public: static const int FALSE; private: + bool flag_keep_cnf; + bool flag_non_incremental; + + bool non_incremental_solve_used_up; + std::map<std::string, int> literalsCache; std::vector<std::string> literals; @@ -57,7 +62,7 @@ private: bool cnfConsumed; int cnfVariableCount, cnfClausesCount; std::vector<int> cnfLiteralVariables, cnfExpressionVariables; - std::vector<std::vector<int>> cnfClauses; + std::vector<std::vector<int>> cnfClauses, cnfClausesBackup; void add_clause(const std::vector<int> &args); void add_clause(const std::vector<int> &args, bool argsPolarity, int a = 0, int b = 0, int c = 0); @@ -67,6 +72,9 @@ private: int bind_cnf_and(const std::vector<int> &args); int bind_cnf_or(const std::vector<int> &args); +protected: + void preSolverCallback(); + public: int solverTimeout; bool solverTimoutStatus; @@ -74,6 +82,12 @@ public: ezSAT(); virtual ~ezSAT(); + void keep_cnf() { flag_keep_cnf = true; } + void non_incremental() { flag_non_incremental = true; } + + bool mode_keep_cnf() const { return flag_keep_cnf; } + bool mode_non_incremental() const { return flag_non_incremental; } + // manage expressions int value(bool val); @@ -155,6 +169,9 @@ public: void consumeCnf(); void consumeCnf(std::vector<std::vector<int>> &cnf); + // use this function to get the full CNF in keep_cnf mode + void getFullCnf(std::vector<std::vector<int>> &full_cnf) const; + std::string cnfLiteralInfo(int idx) const; // simple helpers for build expressions easily diff --git a/libs/ezsat/testbench.cc b/libs/ezsat/testbench.cc index 8332ad919..d20258c37 100644 --- a/libs/ezsat/testbench.cc +++ b/libs/ezsat/testbench.cc @@ -64,6 +64,7 @@ void test_simple() printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT sat; + sat.non_incremental(); sat.assume(sat.OR("A", "B")); sat.assume(sat.NOT(sat.AND("A", "B"))); test(sat); @@ -121,6 +122,8 @@ void test_xorshift32() printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT sat; + sat.keep_cnf(); + xorshift128 rng; std::vector<int> bits = sat.vec_var("i", 32); @@ -137,6 +140,9 @@ void test_xorshift32() test_xorshift32_try(sat, rng()); test_xorshift32_try(sat, rng()); test_xorshift32_try(sat, rng()); + + sat.printDIMACS(stdout, true); + printf("\n"); } // ------------------------------------------------------------------------------------------------------------ |