aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libs/ezsat/Makefile3
-rw-r--r--libs/ezsat/ezminisat.cc5
-rw-r--r--libs/ezsat/ezsat.cc46
-rw-r--r--libs/ezsat/ezsat.h19
-rw-r--r--libs/ezsat/testbench.cc6
5 files changed, 71 insertions, 8 deletions
diff --git a/libs/ezsat/Makefile b/libs/ezsat/Makefile
index 1dcb5d15b..b1f864160 100644
--- a/libs/ezsat/Makefile
+++ b/libs/ezsat/Makefile
@@ -18,7 +18,8 @@ test: all
./testbench
./demo_bit
./demo_vec
- ./demo_cmp
+ # ./demo_cmp
+ # ./puzzle3d
clean:
rm -f demo_bit demo_vec demo_cmp testbench puzzle3d *.o *.d
diff --git a/libs/ezsat/ezminisat.cc b/libs/ezsat/ezminisat.cc
index 6a6c075f5..3f43f3ece 100644
--- a/libs/ezsat/ezminisat.cc
+++ b/libs/ezsat/ezminisat.cc
@@ -63,7 +63,8 @@ void ezMiniSAT::clear()
#if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL
void ezMiniSAT::freeze(int id)
{
- cnfFrozenVars.insert(bind(id));
+ if (!mode_non_incremental())
+ cnfFrozenVars.insert(bind(id));
}
bool ezMiniSAT::eliminated(int idx)
@@ -89,6 +90,8 @@ void ezMiniSAT::alarmHandler(int)
bool ezMiniSAT::solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions)
{
+ preSolverCallback();
+
solverTimoutStatus = false;
if (0) {
diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc
index 6da363fc1..4c0b624be 100644
--- a/libs/ezsat/ezsat.cc
+++ b/libs/ezsat/ezsat.cc
@@ -30,6 +30,11 @@ const int ezSAT::FALSE = 2;
ezSAT::ezSAT()
{
+ flag_keep_cnf = false;
+ flag_non_incremental = false;
+
+ non_incremental_solve_used_up = false;
+
cnfConsumed = false;
cnfVariableCount = 0;
cnfClausesCount = 0;
@@ -588,19 +593,40 @@ int ezSAT::bind(int id, bool auto_freeze)
void ezSAT::consumeCnf()
{
- cnfConsumed = true;
+ if (mode_keep_cnf())
+ cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end());
+ else
+ cnfConsumed = true;
cnfClauses.clear();
}
void ezSAT::consumeCnf(std::vector<std::vector<int>> &cnf)
{
- cnfConsumed = true;
+ if (mode_keep_cnf())
+ cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end());
+ else
+ cnfConsumed = true;
cnf.swap(cnfClauses);
cnfClauses.clear();
}
+void ezSAT::getFullCnf(std::vector<std::vector<int>> &full_cnf) const
+{
+ assert(full_cnf.empty());
+ full_cnf.insert(full_cnf.end(), cnfClausesBackup.begin(), cnfClausesBackup.end());
+ full_cnf.insert(full_cnf.end(), cnfClauses.begin(), cnfClauses.end());
+}
+
+void ezSAT::preSolverCallback()
+{
+ assert(!non_incremental_solve_used_up);
+ if (mode_non_incremental())
+ non_incremental_solve_used_up = true;
+}
+
bool ezSAT::solver(const std::vector<int>&, std::vector<bool>&, const std::vector<int>&)
{
+ preSolverCallback();
fprintf(stderr, "************************************************************************\n");
fprintf(stderr, "ERROR: You are trying to use the solve() method of the ezSAT base class!\n");
fprintf(stderr, "Use a dervied class like ezMiniSAT instead.\n");
@@ -1081,16 +1107,26 @@ void ezSAT::printDIMACS(FILE *f, bool verbose) const
if (cnfExpressionVariables[i] != 0)
fprintf(f, "c %*d: %s\n", digits, cnfExpressionVariables[i], to_string(-i-1).c_str());
+ if (mode_keep_cnf()) {
+ fprintf(f, "c\n");
+ fprintf(f, "c %d clauses from backup, %d from current buffer\n",
+ int(cnfClausesBackup.size()), int(cnfClauses.size()));
+ }
+
fprintf(f, "c\n");
}
- fprintf(f, "p cnf %d %d\n", cnfVariableCount, int(cnfClauses.size()));
+ std::vector<std::vector<int>> all_clauses;
+ getFullCnf(all_clauses);
+ assert(cnfClausesCount == int(all_clauses.size()));
+
+ fprintf(f, "p cnf %d %d\n", cnfVariableCount, cnfClausesCount);
int maxClauseLen = 0;
- for (auto &clause : cnfClauses)
+ for (auto &clause : all_clauses)
maxClauseLen = std::max(int(clause.size()), maxClauseLen);
if (!verbose)
maxClauseLen = std::min(maxClauseLen, 3);
- for (auto &clause : cnfClauses) {
+ for (auto &clause : all_clauses) {
for (auto idx : clause)
fprintf(f, " %*d", digits, idx);
if (maxClauseLen >= int(clause.size()))
diff --git a/libs/ezsat/ezsat.h b/libs/ezsat/ezsat.h
index 852405566..83e1b23c5 100644
--- a/libs/ezsat/ezsat.h
+++ b/libs/ezsat/ezsat.h
@@ -48,6 +48,11 @@ public:
static const int FALSE;
private:
+ bool flag_keep_cnf;
+ bool flag_non_incremental;
+
+ bool non_incremental_solve_used_up;
+
std::map<std::string, int> literalsCache;
std::vector<std::string> literals;
@@ -57,7 +62,7 @@ private:
bool cnfConsumed;
int cnfVariableCount, cnfClausesCount;
std::vector<int> cnfLiteralVariables, cnfExpressionVariables;
- std::vector<std::vector<int>> cnfClauses;
+ std::vector<std::vector<int>> cnfClauses, cnfClausesBackup;
void add_clause(const std::vector<int> &args);
void add_clause(const std::vector<int> &args, bool argsPolarity, int a = 0, int b = 0, int c = 0);
@@ -67,6 +72,9 @@ private:
int bind_cnf_and(const std::vector<int> &args);
int bind_cnf_or(const std::vector<int> &args);
+protected:
+ void preSolverCallback();
+
public:
int solverTimeout;
bool solverTimoutStatus;
@@ -74,6 +82,12 @@ public:
ezSAT();
virtual ~ezSAT();
+ void keep_cnf() { flag_keep_cnf = true; }
+ void non_incremental() { flag_non_incremental = true; }
+
+ bool mode_keep_cnf() const { return flag_keep_cnf; }
+ bool mode_non_incremental() const { return flag_non_incremental; }
+
// manage expressions
int value(bool val);
@@ -155,6 +169,9 @@ public:
void consumeCnf();
void consumeCnf(std::vector<std::vector<int>> &cnf);
+ // use this function to get the full CNF in keep_cnf mode
+ void getFullCnf(std::vector<std::vector<int>> &full_cnf) const;
+
std::string cnfLiteralInfo(int idx) const;
// simple helpers for build expressions easily
diff --git a/libs/ezsat/testbench.cc b/libs/ezsat/testbench.cc
index 8332ad919..d20258c37 100644
--- a/libs/ezsat/testbench.cc
+++ b/libs/ezsat/testbench.cc
@@ -64,6 +64,7 @@ void test_simple()
printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
ezMiniSAT sat;
+ sat.non_incremental();
sat.assume(sat.OR("A", "B"));
sat.assume(sat.NOT(sat.AND("A", "B")));
test(sat);
@@ -121,6 +122,8 @@ void test_xorshift32()
printf("==== %s ====\n\n", __PRETTY_FUNCTION__);
ezMiniSAT sat;
+ sat.keep_cnf();
+
xorshift128 rng;
std::vector<int> bits = sat.vec_var("i", 32);
@@ -137,6 +140,9 @@ void test_xorshift32()
test_xorshift32_try(sat, rng());
test_xorshift32_try(sat, rng());
test_xorshift32_try(sat, rng());
+
+ sat.printDIMACS(stdout, true);
+ printf("\n");
}
// ------------------------------------------------------------------------------------------------------------