diff options
Diffstat (limited to 'libs/ezsat')
| -rw-r--r-- | libs/ezsat/Makefile | 8 | ||||
| -rw-r--r-- | libs/ezsat/ezminisat.cc | 88 | ||||
| -rw-r--r-- | libs/ezsat/ezminisat.h | 20 | ||||
| -rw-r--r-- | libs/ezsat/ezsat.cc | 322 | ||||
| -rw-r--r-- | libs/ezsat/ezsat.h | 38 | ||||
| -rw-r--r-- | libs/ezsat/puzzle3d.cc | 4 | ||||
| -rw-r--r-- | libs/ezsat/testbench.cc | 131 |
7 files changed, 365 insertions, 246 deletions
diff --git a/libs/ezsat/Makefile b/libs/ezsat/Makefile index da2355a9b..b1f864160 100644 --- a/libs/ezsat/Makefile +++ b/libs/ezsat/Makefile @@ -3,7 +3,8 @@ CC = clang CXX = clang CXXFLAGS = -MD -Wall -Wextra -ggdb CXXFLAGS += -std=c++11 -O0 -LDLIBS = -lminisat -lstdc++ +LDLIBS = ../minisat/Options.cc ../minisat/SimpSolver.cc ../minisat/Solver.cc ../minisat/System.cc -lm -lstdc++ + all: demo_vec demo_bit demo_cmp testbench puzzle3d @@ -17,10 +18,11 @@ test: all ./testbench ./demo_bit ./demo_vec - ./demo_cmp + # ./demo_cmp + # ./puzzle3d clean: - rm -f demo_bit demo_vec testbench puzzle3d *.o *.d + rm -f demo_bit demo_vec demo_cmp testbench puzzle3d *.o *.d .PHONY: all test clean diff --git a/libs/ezsat/ezminisat.cc b/libs/ezsat/ezminisat.cc index d545834cf..dc4e5d283 100644 --- a/libs/ezsat/ezminisat.cc +++ b/libs/ezsat/ezminisat.cc @@ -25,15 +25,20 @@ #include <limits.h> #include <stdint.h> -#include <signal.h> +#include <csignal> #include <cinttypes> +#include <unistd.h> -#include <minisat/core/Solver.h> +#include "../minisat/Solver.h" +#include "../minisat/SimpSolver.h" ezMiniSAT::ezMiniSAT() : minisatSolver(NULL) { minisatSolver = NULL; foundContradiction = false; + + freeze(TRUE); + freeze(FALSE); } ezMiniSAT::~ezMiniSAT() @@ -50,9 +55,28 @@ void ezMiniSAT::clear() } foundContradiction = false; minisatVars.clear(); +#if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL + cnfFrozenVars.clear(); +#endif ezSAT::clear(); } +#if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL +void ezMiniSAT::freeze(int id) +{ + if (!mode_non_incremental()) + cnfFrozenVars.insert(bind(id)); +} + +bool ezMiniSAT::eliminated(int idx) +{ + idx = idx < 0 ? -idx : idx; + if (minisatSolver != NULL && idx > 0 && idx <= int(minisatVars.size())) + return minisatSolver->isEliminated(minisatVars.at(idx-1)); + return false; +} +#endif + ezMiniSAT *ezMiniSAT::alarmHandlerThis = NULL; clock_t ezMiniSAT::alarmHandlerTimeout = 0; @@ -67,6 +91,8 @@ void ezMiniSAT::alarmHandler(int) bool ezMiniSAT::solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions) { + preSolverCallback(); + solverTimoutStatus = false; if (0) { @@ -90,22 +116,42 @@ contradiction: for (auto id : modelExpressions) modelIdx.push_back(bind(id)); - if (minisatSolver == NULL) - minisatSolver = new Minisat::Solver; + if (minisatSolver == NULL) { + minisatSolver = new Solver; + minisatSolver->verbosity = EZMINISAT_VERBOSITY; + } +#if EZMINISAT_INCREMENTAL std::vector<std::vector<int>> cnf; consumeCnf(cnf); +#else + const std::vector<std::vector<int>> &cnf = this->cnf(); +#endif while (int(minisatVars.size()) < numCnfVariables()) minisatVars.push_back(minisatSolver->newVar()); +#if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL + for (auto idx : cnfFrozenVars) + minisatSolver->setFrozen(minisatVars.at(idx > 0 ? idx-1 : -idx-1), true); + cnfFrozenVars.clear(); +#endif + for (auto &clause : cnf) { Minisat::vec<Minisat::Lit> ps; - for (auto idx : clause) + for (auto idx : clause) { if (idx > 0) ps.push(Minisat::mkLit(minisatVars.at(idx-1))); else ps.push(Minisat::mkLit(minisatVars.at(-idx-1), true)); +#if EZMINISAT_SIMPSOLVER + if (minisatSolver->isEliminated(minisatVars.at(idx > 0 ? idx-1 : -idx-1))) { + fprintf(stderr, "Assert in %s:%d failed! Missing call to ezsat->freeze(): %s (lit=%d)\n", + __FILE__, __LINE__, cnfLiteralInfo(idx).c_str(), idx); + abort(); + } +#endif + } if (!minisatSolver->addClause(ps)) goto contradiction; } @@ -115,20 +161,31 @@ contradiction: Minisat::vec<Minisat::Lit> assumps; - for (auto idx : extraClauses) + for (auto idx : extraClauses) { if (idx > 0) assumps.push(Minisat::mkLit(minisatVars.at(idx-1))); else assumps.push(Minisat::mkLit(minisatVars.at(-idx-1), true)); +#if EZMINISAT_SIMPSOLVER + if (minisatSolver->isEliminated(minisatVars.at(idx > 0 ? idx-1 : -idx-1))) { + fprintf(stderr, "Assert in %s:%d failed! Missing call to ezsat->freeze(): %s\n", __FILE__, __LINE__, cnfLiteralInfo(idx).c_str()); + abort(); + } +#endif + } - sighandler_t old_alarm_sighandler = NULL; + struct sigaction sig_action; + struct sigaction old_sig_action; int old_alarm_timeout = 0; if (solverTimeout > 0) { + sig_action.sa_handler = alarmHandler; + sigemptyset(&sig_action.sa_mask); + sig_action.sa_flags = SA_RESTART; alarmHandlerThis = this; alarmHandlerTimeout = clock() + solverTimeout*CLOCKS_PER_SEC; old_alarm_timeout = alarm(0); - old_alarm_sighandler = signal(SIGALRM, alarmHandler); + sigaction(SIGALRM, &sig_action, &old_sig_action); alarm(1); } @@ -138,12 +195,18 @@ contradiction: if (alarmHandlerTimeout == 0) solverTimoutStatus = true; alarm(0); - signal(SIGALRM, old_alarm_sighandler); + sigaction(SIGALRM, &old_sig_action, NULL); alarm(old_alarm_timeout); } - if (!foundSolution) + if (!foundSolution) { +#if !EZMINISAT_INCREMENTAL + delete minisatSolver; + minisatSolver = NULL; + minisatVars.clear(); +#endif return false; + } modelValues.clear(); modelValues.resize(modelIdx.size()); @@ -161,6 +224,11 @@ contradiction: modelValues[i] = (value == Minisat::lbool(refvalue)); } +#if !EZMINISAT_INCREMENTAL + delete minisatSolver; + minisatSolver = NULL; + minisatVars.clear(); +#endif return true; } diff --git a/libs/ezsat/ezminisat.h b/libs/ezsat/ezminisat.h index 2919aa2e3..ac9c071c3 100644 --- a/libs/ezsat/ezminisat.h +++ b/libs/ezsat/ezminisat.h @@ -20,6 +20,10 @@ #ifndef EZMINISAT_H #define EZMINISAT_H +#define EZMINISAT_SIMPSOLVER 1 +#define EZMINISAT_VERBOSITY 0 +#define EZMINISAT_INCREMENTAL 1 + #include "ezsat.h" #include <time.h> @@ -28,15 +32,25 @@ // don't force ezSAT users to use minisat headers.. namespace Minisat { class Solver; + class SimpSolver; } class ezMiniSAT : public ezSAT { private: - Minisat::Solver *minisatSolver; +#if EZMINISAT_SIMPSOLVER + typedef Minisat::SimpSolver Solver; +#else + typedef Minisat::Solver Solver; +#endif + Solver *minisatSolver; std::vector<int> minisatVars; bool foundContradiction; +#if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL + std::set<int> cnfFrozenVars; +#endif + static ezMiniSAT *alarmHandlerThis; static clock_t alarmHandlerTimeout; static void alarmHandler(int); @@ -45,6 +59,10 @@ public: ezMiniSAT(); virtual ~ezMiniSAT(); virtual void clear(); +#if EZMINISAT_SIMPSOLVER && EZMINISAT_INCREMENTAL + virtual void freeze(int id); + virtual bool eliminated(int idx); +#endif virtual bool solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions); }; diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc index dccc00555..13ed112ed 100644 --- a/libs/ezsat/ezsat.cc +++ b/libs/ezsat/ezsat.cc @@ -19,21 +19,21 @@ #include "ezsat.h" +#include <cmath> #include <algorithm> +#include <cassert> #include <stdlib.h> -#include <assert.h> const int ezSAT::TRUE = 1; const int ezSAT::FALSE = 2; ezSAT::ezSAT() { - literal("TRUE"); - literal("FALSE"); + flag_keep_cnf = false; + flag_non_incremental = false; - assert(literal("TRUE") == TRUE); - assert(literal("FALSE") == FALSE); + non_incremental_solve_used_up = false; cnfConsumed = false; cnfVariableCount = 0; @@ -41,6 +41,12 @@ ezSAT::ezSAT() solverTimeout = 0; solverTimoutStatus = false; + + literal("TRUE"); + literal("FALSE"); + + assert(literal("TRUE") == TRUE); + assert(literal("FALSE") == FALSE); } ezSAT::~ezSAT() @@ -67,6 +73,20 @@ int ezSAT::literal(const std::string &name) return literalsCache.at(name); } +int ezSAT::frozen_literal() +{ + int id = literal(); + freeze(id); + return id; +} + +int ezSAT::frozen_literal(const std::string &name) +{ + int id = literal(name); + freeze(id); + return id; +} + int ezSAT::expression(OpId op, int a, int b, int c, int d, int e, int f) { std::vector<int> args(6); @@ -342,13 +362,19 @@ void ezSAT::clear() cnfLiteralVariables.clear(); cnfExpressionVariables.clear(); cnfClauses.clear(); - cnfAssumptions.clear(); } -void ezSAT::assume(int id) +void ezSAT::freeze(int) +{ +} + +bool ezSAT::eliminated(int) { - cnfAssumptions.insert(id); + return false; +} +void ezSAT::assume(int id) +{ if (id < 0) { assert(0 < -id && -id <= int(expressions.size())); @@ -462,11 +488,32 @@ int ezSAT::bound(int id) const return 0; } -int ezSAT::bind(int id) +std::string ezSAT::cnfLiteralInfo(int idx) const +{ + for (size_t i = 0; i < cnfLiteralVariables.size(); i++) { + if (cnfLiteralVariables[i] == idx) + return to_string(i+1); + if (cnfLiteralVariables[i] == -idx) + return "NOT " + to_string(i+1); + } + for (size_t i = 0; i < cnfExpressionVariables.size(); i++) { + if (cnfExpressionVariables[i] == idx) + return to_string(-i-1); + if (cnfExpressionVariables[i] == -idx) + return "NOT " + to_string(-i-1); + } + return "<unnamed>"; +} + +int ezSAT::bind(int id, bool auto_freeze) { if (id >= 0) { assert(0 < id && id <= int(literals.size())); cnfLiteralVariables.resize(literals.size()); + if (eliminated(cnfLiteralVariables[id-1])) { + fprintf(stderr, "ezSAT: Missing freeze on literal `%s'.\n", to_string(id).c_str()); + abort(); + } if (cnfLiteralVariables[id-1] == 0) { cnfLiteralVariables[id-1] = ++cnfVariableCount; if (id == TRUE) @@ -480,6 +527,17 @@ int ezSAT::bind(int id) assert(0 < -id && -id <= int(expressions.size())); cnfExpressionVariables.resize(expressions.size()); + if (eliminated(cnfExpressionVariables[-id-1])) + { + cnfExpressionVariables[-id-1] = 0; + + // this will recursively call bind(id). within the recursion + // the cnf is pre-set to 0. an idx is allocated there, then it + // is frozen, then it returns here with the new idx already set. + if (auto_freeze) + freeze(id); + } + if (cnfExpressionVariables[-id-1] == 0) { OpId op; @@ -497,7 +555,7 @@ int ezSAT::bind(int id) newArgs.push_back(OR(AND(args[i], NOT(args[i+1])), AND(NOT(args[i]), args[i+1]))); args.swap(newArgs); } - idx = bind(args.at(0)); + idx = bind(args.at(0), false); goto assign_idx; } @@ -505,17 +563,17 @@ int ezSAT::bind(int id) std::vector<int> invArgs; for (auto arg : args) invArgs.push_back(NOT(arg)); - idx = bind(OR(expression(OpAnd, args), expression(OpAnd, invArgs))); + idx = bind(OR(expression(OpAnd, args), expression(OpAnd, invArgs)), false); goto assign_idx; } if (op == OpITE) { - idx = bind(OR(AND(args[0], args[1]), AND(NOT(args[0]), args[2]))); + idx = bind(OR(AND(args[0], args[1]), AND(NOT(args[0]), args[2])), false); goto assign_idx; } for (int i = 0; i < int(args.size()); i++) - args[i] = bind(args[i]); + args[i] = bind(args[i], false); switch (op) { @@ -535,120 +593,45 @@ int ezSAT::bind(int id) void ezSAT::consumeCnf() { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnfClauses.clear(); } void ezSAT::consumeCnf(std::vector<std::vector<int>> &cnf) { - cnfConsumed = true; + if (mode_keep_cnf()) + cnfClausesBackup.insert(cnfClausesBackup.end(), cnfClauses.begin(), cnfClauses.end()); + else + cnfConsumed = true; cnf.swap(cnfClauses); cnfClauses.clear(); } -static bool test_bit(uint32_t bitmask, int idx) +void ezSAT::getFullCnf(std::vector<std::vector<int>> &full_cnf) const { - if (idx > 0) - return (bitmask & (1 << (+idx-1))) != 0; - else - return (bitmask & (1 << (-idx-1))) == 0; + assert(full_cnf.empty()); + full_cnf.insert(full_cnf.end(), cnfClausesBackup.begin(), cnfClausesBackup.end()); + full_cnf.insert(full_cnf.end(), cnfClauses.begin(), cnfClauses.end()); } -bool ezSAT::solver(const std::vector<int> &modelExpressions, std::vector<bool> &modelValues, const std::vector<int> &assumptions) +void ezSAT::preSolverCallback() { - std::vector<int> extraClauses, modelIdx; - std::vector<int> values(numLiterals()); - - for (auto id : assumptions) - extraClauses.push_back(bind(id)); - for (auto id : modelExpressions) - modelIdx.push_back(bind(id)); - - if (cnfVariableCount > 20) { - fprintf(stderr, "*************************************************************************************\n"); - fprintf(stderr, "ERROR: You are trying to use the builtin solver of ezSAT with more than 20 variables!\n"); - fprintf(stderr, "The builtin solver is a dumb brute force solver and only ment for testing and demo\n"); - fprintf(stderr, "purposes. Use a real SAT solve like MiniSAT (e.g. using the ezMiniSAT class) instead.\n"); - fprintf(stderr, "*************************************************************************************\n"); - abort(); - } - - for (uint32_t bitmask = 0; bitmask < (1 << numCnfVariables()); bitmask++) - { - // printf("%07o:", int(bitmask)); - // for (int i = 2; i < numLiterals(); i++) - // if (bound(i+1)) - // printf(" %s=%d", to_string(i+1).c_str(), test_bit(bitmask, bound(i+1))); - // printf(" |"); - // for (int idx = 1; idx <= numCnfVariables(); idx++) - // printf(" %3d", test_bit(bitmask, idx) ? idx : -idx); - // printf("\n"); - - for (auto idx : extraClauses) - if (!test_bit(bitmask, idx)) - goto next; - - for (auto &clause : cnfClauses) { - for (auto idx : clause) - if (test_bit(bitmask, idx)) - goto next_clause; - // printf("failed clause:"); - // for (auto idx2 : clause) - // printf(" %3d", idx2); - // printf("\n"); - goto next; - next_clause:; - // printf("passed clause:"); - // for (auto idx2 : clause) - // printf(" %3d", idx2); - // printf("\n"); - } - - modelValues.resize(modelIdx.size()); - for (int i = 0; i < int(modelIdx.size()); i++) - modelValues[i] = test_bit(bitmask, modelIdx[i]); - - // validate result using eval() - - values[0] = TRUE, values[1] = FALSE; - for (int i = 2; i < numLiterals(); i++) { - int idx = bound(i+1); - values[i] = idx != 0 ? (test_bit(bitmask, idx) ? TRUE : FALSE) : 0; - } - - for (auto id : cnfAssumptions) { - int result = eval(id, values); - if (result != TRUE) { - printInternalState(stderr); - fprintf(stderr, "Variables:"); - for (int i = 0; i < numLiterals(); i++) - fprintf(stderr, " %s=%s", lookup_literal(i+1).c_str(), values[i] == TRUE ? "TRUE" : values[i] == FALSE ? "FALSE" : "UNDEF"); - fprintf(stderr, "\nValidation of solver results failed: got `%d' (%s) for assumption '%d': %s\n", - result, result == FALSE ? "FALSE" : "UNDEF", id, to_string(id).c_str()); - abort(); - } - // printf("OK: %d -> %d\n", id, result); - } - - for (auto id : assumptions) { - int result = eval(id, values); - if (result != TRUE) { - printInternalState(stderr); - fprintf(stderr, "Variables:"); - for (int i = 0; i < numLiterals(); i++) - fprintf(stderr, " %s=%s", lookup_literal(i+1).c_str(), values[i] == TRUE ? "TRUE" : values[i] == FALSE ? "FALSE" : "UNDEF"); - fprintf(stderr, "\nValidation of solver results failed: got `%d' (%s) for assumption '%d': %s\n", - result, result == FALSE ? "FALSE" : "UNDEF", id, to_string(id).c_str()); - abort(); - } - // printf("OK: %d -> %d\n", id, result); - } - - return true; - next:; - } + assert(!non_incremental_solve_used_up); + if (mode_non_incremental()) + non_incremental_solve_used_up = true; +} - return false; +bool ezSAT::solver(const std::vector<int>&, std::vector<bool>&, const std::vector<int>&) +{ + preSolverCallback(); + fprintf(stderr, "************************************************************************\n"); + fprintf(stderr, "ERROR: You are trying to use the solve() method of the ezSAT base class!\n"); + fprintf(stderr, "Use a dervied class like ezMiniSAT instead.\n"); + fprintf(stderr, "************************************************************************\n"); + abort(); } std::vector<int> ezSAT::vec_const(const std::vector<bool> &bits) @@ -994,6 +977,96 @@ std::vector<int> ezSAT::vec_srl(const std::vector<int> &vec1, int shift) return vec; } +std::vector<int> ezSAT::vec_shift(const std::vector<int> &vec1, int shift, int extend_left, int extend_right) +{ + std::vector<int> vec; + for (int i = 0; i < int(vec1.size()); i++) { + int j = i+shift; + if (j < 0) + vec.push_back(extend_right); + else if (j >= int(vec1.size())) + vec.push_back(extend_left); + else + vec.push_back(vec1[j]); + } + return vec; +} + +static int my_clog2(int x) +{ + int result = 0; + for (x--; x > 0; result++) + x >>= 1; + return result; +} + +std::vector<int> ezSAT::vec_shift_right(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right) +{ + int vec2_bits = std::min(my_clog2(vec1.size()) + (vec2_signed ? 1 : 0), int(vec2.size())); + + std::vector<int> overflow_bits(vec2.begin() + vec2_bits, vec2.end()); + int overflow_left = FALSE, overflow_right = FALSE; + + if (vec2_signed) { + int overflow = FALSE; + for (auto bit : overflow_bits) + overflow = OR(overflow, XOR(bit, vec2[vec2_bits-1])); + overflow_left = AND(overflow, NOT(vec2.back())); + overflow_right = AND(overflow, vec2.back()); + } else + overflow_left = vec_reduce_or(overflow_bits); + + std::vector<int> buffer = vec1; + + if (vec2_signed) + while (buffer.size() < vec1.size() + (1 << vec2_bits)) + buffer.push_back(extend_left); + + std::vector<int> overflow_pattern_left(buffer.size(), extend_left); + std::vector<int> overflow_pattern_right(buffer.size(), extend_right); + + buffer = vec_ite(overflow_left, overflow_pattern_left, buffer); + + if (vec2_signed) + buffer = vec_ite(overflow_right, overflow_pattern_left, buffer); + + for (int i = vec2_bits-1; i >= 0; i--) { + std::vector<int> shifted_buffer; + if (vec2_signed && i == vec2_bits-1) + shifted_buffer = vec_shift(buffer, -(1 << i), extend_left, extend_right); + else + shifted_buffer = vec_shift(buffer, 1 << i, extend_left, extend_right); + buffer = vec_ite(vec2[i], shifted_buffer, buffer); + } + + buffer.resize(vec1.size()); + return buffer; +} + +std::vector<int> ezSAT::vec_shift_left(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right) +{ + // vec2_signed is not implemented in vec_shift_left() yet + assert(vec2_signed == false); + + int vec2_bits = std::min(my_clog2(vec1.size()), int(vec2.size())); + + std::vector<int> overflow_bits(vec2.begin() + vec2_bits, vec2.end()); + int overflow = vec_reduce_or(overflow_bits); + + std::vector<int> buffer = vec1; + std::vector<int> overflow_pattern_right(buffer.size(), extend_right); + buffer = vec_ite(overflow, overflow_pattern_right, buffer); + + for (int i = 0; i < vec2_bits; i++) { + std::vector<int> shifted_buffer; + shifted_buffer = vec_shift(buffer, -(1 << i), extend_left, extend_right); + buffer = vec_ite(vec2[i], shifted_buffer, buffer); + } + + buffer.resize(vec1.size()); + return buffer; +} + void ezSAT::vec_append(std::vector<int> &vec, const std::vector<int> &vec1) const { for (auto bit : vec1) @@ -1124,17 +1197,32 @@ void ezSAT::printDIMACS(FILE *f, bool verbose) const if (cnfExpressionVariables[i] != 0) fprintf(f, "c %*d: %s\n", digits, cnfExpressionVariables[i], to_string(-i-1).c_str()); + if (mode_keep_cnf()) { + fprintf(f, "c\n"); + fprintf(f, "c %d clauses from backup, %d from current buffer\n", + int(cnfClausesBackup.size()), int(cnfClauses.size())); + } + fprintf(f, "c\n"); } - fprintf(f, "p cnf %d %d\n", cnfVariableCount, int(cnfClauses.size())); + std::vector<std::vector<int>> all_clauses; + getFullCnf(all_clauses); + assert(cnfClausesCount == int(all_clauses.size())); + + fprintf(f, "p cnf %d %d\n", cnfVariableCount, cnfClausesCount); int maxClauseLen = 0; - for (auto &clause : cnfClauses) + for (auto &clause : all_clauses) maxClauseLen = std::max(int(clause.size()), maxClauseLen); - for (auto &clause : cnfClauses) { + if (!verbose) + maxClauseLen = std::min(maxClauseLen, 3); + for (auto &clause : all_clauses) { for (auto idx : clause) fprintf(f, " %*d", digits, idx); - fprintf(f, " %*d\n", (digits + 1)*int(maxClauseLen - clause.size()) + digits, 0); + if (maxClauseLen >= int(clause.size())) + fprintf(f, " %*d\n", (digits + 1)*int(maxClauseLen - clause.size()) + digits, 0); + else + fprintf(f, " %*d\n", digits, 0); } } diff --git a/libs/ezsat/ezsat.h b/libs/ezsat/ezsat.h index 547edb93b..c5ef6b0a2 100644 --- a/libs/ezsat/ezsat.h +++ b/libs/ezsat/ezsat.h @@ -48,6 +48,11 @@ public: static const int FALSE; private: + bool flag_keep_cnf; + bool flag_non_incremental; + + bool non_incremental_solve_used_up; + std::map<std::string, int> literalsCache; std::vector<std::string> literals; @@ -57,8 +62,7 @@ private: bool cnfConsumed; int cnfVariableCount, cnfClausesCount; std::vector<int> cnfLiteralVariables, cnfExpressionVariables; - std::vector<std::vector<int>> cnfClauses; - std::set<int> cnfAssumptions; + std::vector<std::vector<int>> cnfClauses, cnfClausesBackup; void add_clause(const std::vector<int> &args); void add_clause(const std::vector<int> &args, bool argsPolarity, int a = 0, int b = 0, int c = 0); @@ -68,6 +72,9 @@ private: int bind_cnf_and(const std::vector<int> &args); int bind_cnf_or(const std::vector<int> &args); +protected: + void preSolverCallback(); + public: int solverTimeout; bool solverTimoutStatus; @@ -75,11 +82,19 @@ public: ezSAT(); virtual ~ezSAT(); + void keep_cnf() { flag_keep_cnf = true; } + void non_incremental() { flag_non_incremental = true; } + + bool mode_keep_cnf() const { return flag_keep_cnf; } + bool mode_non_incremental() const { return flag_non_incremental; } + // manage expressions int value(bool val); int literal(); int literal(const std::string &name); + int frozen_literal(); + int frozen_literal(const std::string &name); int expression(OpId op, int a = 0, int b = 0, int c = 0, int d = 0, int e = 0, int f = 0); int expression(OpId op, const std::vector<int> &args); @@ -141,10 +156,10 @@ public: // manage CNF (usually only accessed by SAT solvers) virtual void clear(); + virtual void freeze(int id); + virtual bool eliminated(int idx); void assume(int id); - int bind(int id); - - const std::set<int> &assumed() const { return cnfAssumptions; } + int bind(int id, bool auto_freeze = true); int bound(int id) const; int numCnfVariables() const { return cnfVariableCount; } @@ -154,6 +169,11 @@ public: void consumeCnf(); void consumeCnf(std::vector<std::vector<int>> &cnf); + // use this function to get the full CNF in keep_cnf mode + void getFullCnf(std::vector<std::vector<int>> &full_cnf) const; + + std::string cnfLiteralInfo(int idx) const; + // simple helpers for build expressions easily struct _V { @@ -165,7 +185,7 @@ public: int get(ezSAT *that) { if (name.empty()) return id; - return that->literal(name); + return that->frozen_literal(name); } }; @@ -217,7 +237,7 @@ public: std::vector<int> vec_iff(const std::vector<int> &vec1, const std::vector<int> &vec2); std::vector<int> vec_ite(const std::vector<int> &vec1, const std::vector<int> &vec2, const std::vector<int> &vec3); - std::vector<int> vec_ite(int sel, const std::vector<int> &vec2, const std::vector<int> &vec3); + std::vector<int> vec_ite(int sel, const std::vector<int> &vec1, const std::vector<int> &vec2); std::vector<int> vec_count(const std::vector<int> &vec, int numBits, bool clip = true); std::vector<int> vec_add(const std::vector<int> &vec1, const std::vector<int> &vec2); @@ -245,6 +265,10 @@ public: std::vector<int> vec_shr(const std::vector<int> &vec1, int shift, bool signExtend = false) { return vec_shl(vec1, -shift, signExtend); } std::vector<int> vec_srr(const std::vector<int> &vec1, int shift) { return vec_srl(vec1, -shift); } + std::vector<int> vec_shift(const std::vector<int> &vec1, int shift, int extend_left, int extend_right); + std::vector<int> vec_shift_right(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right); + std::vector<int> vec_shift_left(const std::vector<int> &vec1, const std::vector<int> &vec2, bool vec2_signed, int extend_left, int extend_right); + void vec_append(std::vector<int> &vec, const std::vector<int> &vec1) const; void vec_append_signed(std::vector<int> &vec, const std::vector<int> &vec1, int64_t value); void vec_append_unsigned(std::vector<int> &vec, const std::vector<int> &vec1, uint64_t value); diff --git a/libs/ezsat/puzzle3d.cc b/libs/ezsat/puzzle3d.cc index 56d293260..aee0044b4 100644 --- a/libs/ezsat/puzzle3d.cc +++ b/libs/ezsat/puzzle3d.cc @@ -260,8 +260,10 @@ int main() std::vector<int> modelExpressions; std::vector<bool> modelValues; - for (auto &it : blockinfo) + for (auto &it : blockinfo) { + ez.freeze(it.first); modelExpressions.push_back(it.first); + } int solution_counter = 0; while (1) diff --git a/libs/ezsat/testbench.cc b/libs/ezsat/testbench.cc index cc0fe5734..d20258c37 100644 --- a/libs/ezsat/testbench.cc +++ b/libs/ezsat/testbench.cc @@ -38,11 +38,6 @@ struct xorshift128 { bool test(ezSAT &sat, int assumption = 0) { - for (auto id : sat.assumed()) - printf("%s\n", sat.to_string(id).c_str()); - if (assumption) - printf("%s\n", sat.to_string(assumption).c_str()); - std::vector<int> modelExpressions; std::vector<bool> modelValues; @@ -68,7 +63,8 @@ void test_simple() { printf("==== %s ====\n\n", __PRETTY_FUNCTION__); - ezSAT sat; + ezMiniSAT sat; + sat.non_incremental(); sat.assume(sat.OR("A", "B")); sat.assume(sat.NOT(sat.AND("A", "B"))); test(sat); @@ -76,89 +72,6 @@ void test_simple() // ------------------------------------------------------------------------------------------------------------ -void test_basic_operators(ezSAT &sat, xorshift128 &rng, int iter, bool buildTrees, bool buildClusters, std::vector<bool> &log) -{ - int vars[6] = { - sat.VAR("A"), sat.VAR("B"), sat.VAR("C"), - sat.NOT("A"), sat.NOT("B"), sat.NOT("C") - }; - for (int i = 0; i < iter; i++) { - int assumption = 0, op = rng() % 6, to = rng() % 6; - int a = vars[rng() % 6], b = vars[rng() % 6], c = vars[rng() % 6]; - // printf("--> %d %d:%s %d:%s %d:%s\n", op, a, sat.to_string(a).c_str(), b, sat.to_string(b).c_str(), c, sat.to_string(c).c_str()); - switch (op) - { - case 0: - assumption = sat.NOT(a); - break; - case 1: - assumption = sat.AND(a, b); - break; - case 2: - assumption = sat.OR(a, b); - break; - case 3: - assumption = sat.XOR(a, b); - break; - case 4: - assumption = sat.IFF(a, b); - break; - case 5: - assumption = sat.ITE(a, b, c); - break; - } - // printf(" --> %d:%s\n", to, sat.to_string(assumption).c_str()); - if (buildTrees) - vars[to] = assumption; - if (!buildClusters) - sat.clear(); - sat.assume(assumption); - if (sat.numCnfVariables() < 15) { - printf("%d:\n", int(log.size())); - log.push_back(test(sat)); - } else { - // printf("** skipping large problem **\n"); - } - } -} - -void test_basic_operators(ezSAT &sat, std::vector<bool> &log) -{ - printf("-- %s --\n\n", __PRETTY_FUNCTION__); - - xorshift128 rng; - test_basic_operators(sat, rng, 1000, false, false, log); - for (int i = 0; i < 100; i++) - test_basic_operators(sat, rng, 10, true, false, log); - for (int i = 0; i < 100; i++) - test_basic_operators(sat, rng, 10, false, true, log); -} - -void test_basic_operators() -{ - printf("==== %s ====\n\n", __PRETTY_FUNCTION__); - - ezSAT sat; - ezMiniSAT miniSat; - std::vector<bool> logSat, logMiniSat; - - test_basic_operators(sat, logSat); - test_basic_operators(miniSat, logMiniSat); - - if (logSat != logMiniSat) { - printf("Differences between logSat and logMiniSat:"); - for (int i = 0; i < int(std::max(logSat.size(), logMiniSat.size())); i++) - if (i >= int(logSat.size()) || i >= int(logMiniSat.size()) || logSat[i] != logMiniSat[i]) - printf(" %d", i); - printf("\n"); - abort(); - } else { - printf("Completed %d tests with identical results with ezSAT and ezMiniSAT.\n\n", int(logSat.size())); - } -} - -// ------------------------------------------------------------------------------------------------------------ - void test_xorshift32_try(ezSAT &sat, uint32_t input_pattern) { uint32_t output_pattern = input_pattern; @@ -209,6 +122,8 @@ void test_xorshift32() printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT sat; + sat.keep_cnf(); + xorshift128 rng; std::vector<int> bits = sat.vec_var("i", 32); @@ -225,6 +140,9 @@ void test_xorshift32() test_xorshift32_try(sat, rng()); test_xorshift32_try(sat, rng()); test_xorshift32_try(sat, rng()); + + sat.printDIMACS(stdout, true); + printf("\n"); } // ------------------------------------------------------------------------------------------------------------ @@ -243,7 +161,7 @@ void check(const char *expr1_str, bool expr1, const char *expr2_str, bool expr2) void test_signed(int8_t a, int8_t b, int8_t c) { - ezSAT sat; + ezMiniSAT sat; std::vector<int> av = sat.vec_const_signed(a, 8); std::vector<int> bv = sat.vec_const_signed(b, 8); @@ -262,7 +180,7 @@ void test_signed(int8_t a, int8_t b, int8_t c) void test_unsigned(uint8_t a, uint8_t b, uint8_t c) { - ezSAT sat; + ezMiniSAT sat; if (b < c) b ^= c, c ^= b, b ^= c; @@ -284,7 +202,7 @@ void test_unsigned(uint8_t a, uint8_t b, uint8_t c) void test_count(uint32_t x) { - ezSAT sat; + ezMiniSAT sat; int count = 0; for (int i = 0; i < 32; i++) @@ -338,10 +256,10 @@ void test_onehot() printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT ez; - int a = ez.literal("a"); - int b = ez.literal("b"); - int c = ez.literal("c"); - int d = ez.literal("d"); + int a = ez.frozen_literal("a"); + int b = ez.frozen_literal("b"); + int c = ez.frozen_literal("c"); + int d = ez.frozen_literal("d"); std::vector<int> abcd; abcd.push_back(a); @@ -392,10 +310,10 @@ void test_manyhot() printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT ez; - int a = ez.literal("a"); - int b = ez.literal("b"); - int c = ez.literal("c"); - int d = ez.literal("d"); + int a = ez.frozen_literal("a"); + int b = ez.frozen_literal("b"); + int c = ez.frozen_literal("c"); + int d = ez.frozen_literal("d"); std::vector<int> abcd; abcd.push_back(a); @@ -446,13 +364,13 @@ void test_ordered() printf("==== %s ====\n\n", __PRETTY_FUNCTION__); ezMiniSAT ez; - int a = ez.literal("a"); - int b = ez.literal("b"); - int c = ez.literal("c"); + int a = ez.frozen_literal("a"); + int b = ez.frozen_literal("b"); + int c = ez.frozen_literal("c"); - int x = ez.literal("x"); - int y = ez.literal("y"); - int z = ez.literal("z"); + int x = ez.frozen_literal("x"); + int y = ez.frozen_literal("y"); + int z = ez.frozen_literal("z"); std::vector<int> abc; abc.push_back(a); @@ -512,7 +430,6 @@ void test_ordered() int main() { test_simple(); - test_basic_operators(); test_xorshift32(); test_arith(); test_onehot(); |
