diff options
Diffstat (limited to 'libs/ezsat/ezsat.cc')
-rw-r--r-- | libs/ezsat/ezsat.cc | 46 |
1 files changed, 41 insertions, 5 deletions
diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc index 6da363fc1..4c0b624be 100644 --- a/libs/ezsat/ezsat.cc +++ b/libs/ezsat/ezsat.cc @@ -30,6 +30,11 @@ const int ezSAT::FALSE = 2; ezSAT::ezSAT() { + flag_keep_cnf = false; + flag_non_incremental = false; + + non_incremental_solve_used_up = false; + cnfConsumed = false; cnfVariableCount = 0; cnfClausesCount = 0; @@ -588,19 +593,40 @@ int ezSAT::bind(int id, bool auto_freeze) void ezSAT::consumeCnf() { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnfClauses.clear(); } void ezSAT::consumeCnf(std::vector<std::vector<int>> &cnf) { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnf.swap(cnfClauses); cnfClauses.clear(); } +void ezSAT::getFullCnf(std::vector<std::vector<int>> &full_cnf) const +{ + assert(full_cnf.empty()); + full_cnf.insert(full_cnf.end(), cnfClausesBackup.begin(), cnfClausesBackup.end()); + full_cnf.insert(full_cnf.end(), cnfClauses.begin(), cnfClauses.end()); +} + +void ezSAT::preSolverCallback() +{ + assert(!non_incremental_solve_used_up); + if (mode_non_incremental()) + non_incremental_solve_used_up = true; +} + bool ezSAT::solver(const std::vector<int>&, std::vector<bool>&, const std::vector<int>&) { + preSolverCallback(); fprintf(stderr, "************************************************************************\n"); fprintf(stderr, "ERROR: You are trying to use the solve() method of the ezSAT base class!\n"); fprintf(stderr, "Use a dervied class like ezMiniSAT instead.\n"); @@ -1081,16 +1107,26 @@ void ezSAT::printDIMACS(FILE *f, bool verbose) const if (cnfExpressionVariables[i] != 0) fprintf(f, "c %*d: %s\n", digits, cnfExpressionVariables[i], to_string(-i-1).c_str()); + if (mode_keep_cnf()) { + fprintf(f, "c\n"); + fprintf(f, "c %d clauses from backup, %d from current buffer\n", + int(cnfClausesBackup.size()), int(cnfClauses.size())); + } + fprintf(f, "c\n"); } - fprintf(f, "p cnf %d %d\n", cnfVariableCount, int(cnfClauses.size())); + std::vector<std::vector<int>> all_clauses; + getFullCnf(all_clauses); + assert(cnfClausesCount == int(all_clauses.size())); + + fprintf(f, "p cnf %d %d\n", cnfVariableCount, cnfClausesCount); int maxClauseLen = 0; - for (auto &clause : cnfClauses) + for (auto &clause : all_clauses) maxClauseLen = std::max(int(clause.size()), maxClauseLen); if (!verbose) maxClauseLen = std::min(maxClauseLen, 3); - for (auto &clause : cnfClauses) { + for (auto &clause : all_clauses) { for (auto idx : clause) fprintf(f, " %*d", digits, idx); if (maxClauseLen >= int(clause.size())) |