diff options
-rw-r--r-- | libs/ezsat/ezsat.cc | 42 |
1 files changed, 28 insertions, 14 deletions
diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc index 47fdb8efe..8c666ca1f 100644 --- a/libs/ezsat/ezsat.cc +++ b/libs/ezsat/ezsat.cc @@ -1371,20 +1371,34 @@ int ezSAT::onehot(const std::vector<int> &vec, bool max_only) if (max_only == false) formula.push_back(expression(OpOr, vec)); - // create binary vector - int num_bits = clog2(vec.size()); - std::vector<int> bits; - for (int k = 0; k < num_bits; k++) - bits.push_back(literal()); - - // add at-most-one clauses using binary encoding - for (size_t i = 0; i < vec.size(); i++) - for (int k = 0; k < num_bits; k++) { - std::vector<int> clause; - clause.push_back(NOT(vec[i])); - clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k])); - formula.push_back(expression(OpOr, clause)); - } + if (vec.size() < 8) + { + // fall-back to simple O(n^2) solution for small cases + for (size_t i = 0; i < vec.size(); i++) + for (size_t j = i+1; j < vec.size(); j++) { + std::vector<int> clause; + clause.push_back(NOT(vec[i])); + clause.push_back(NOT(vec[j])); + formula.push_back(expression(OpOr, clause)); + } + } + else + { + // create binary vector + int num_bits = clog2(vec.size()); + std::vector<int> bits; + for (int k = 0; k < num_bits; k++) + bits.push_back(literal()); + + // add at-most-one clauses using binary encoding + for (size_t i = 0; i < vec.size(); i++) + for (int k = 0; k < num_bits; k++) { + std::vector<int> clause; + clause.push_back(NOT(vec[i])); + clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k])); + formula.push_back(expression(OpOr, clause)); + } + } return expression(OpAnd, formula); } |