aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--libs/ezsat/ezsat.cc42
1 files changed, 28 insertions, 14 deletions
diff --git a/libs/ezsat/ezsat.cc b/libs/ezsat/ezsat.cc
index 47fdb8efe..8c666ca1f 100644
--- a/libs/ezsat/ezsat.cc
+++ b/libs/ezsat/ezsat.cc
@@ -1371,20 +1371,34 @@ int ezSAT::onehot(const std::vector<int> &vec, bool max_only)
if (max_only == false)
formula.push_back(expression(OpOr, vec));
- // create binary vector
- int num_bits = clog2(vec.size());
- std::vector<int> bits;
- for (int k = 0; k < num_bits; k++)
- bits.push_back(literal());
-
- // add at-most-one clauses using binary encoding
- for (size_t i = 0; i < vec.size(); i++)
- for (int k = 0; k < num_bits; k++) {
- std::vector<int> clause;
- clause.push_back(NOT(vec[i]));
- clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k]));
- formula.push_back(expression(OpOr, clause));
- }
+ if (vec.size() < 8)
+ {
+ // fall-back to simple O(n^2) solution for small cases
+ for (size_t i = 0; i < vec.size(); i++)
+ for (size_t j = i+1; j < vec.size(); j++) {
+ std::vector<int> clause;
+ clause.push_back(NOT(vec[i]));
+ clause.push_back(NOT(vec[j]));
+ formula.push_back(expression(OpOr, clause));
+ }
+ }
+ else
+ {
+ // create binary vector
+ int num_bits = clog2(vec.size());
+ std::vector<int> bits;
+ for (int k = 0; k < num_bits; k++)
+ bits.push_back(literal());
+
+ // add at-most-one clauses using binary encoding
+ for (size_t i = 0; i < vec.size(); i++)
+ for (int k = 0; k < num_bits; k++) {
+ std::vector<int> clause;
+ clause.push_back(NOT(vec[i]));
+ clause.push_back((i & (1 << k)) != 0 ? bits[k] : NOT(bits[k]));
+ formula.push_back(expression(OpOr, clause));
+ }
+ }
return expression(OpAnd, formula);
}