AllDiff.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. /*
  2. * AllDiff.cpp
  3. * @brief General "all-different" constraint
  4. * @date Feb 6, 2012
  5. * @author Frank Dellaert
  6. */
  7. #include <gtsam/base/Testable.h>
  8. #include <gtsam_unstable/discrete/AllDiff.h>
  9. #include <gtsam_unstable/discrete/Domain.h>
  10. #include <boost/make_shared.hpp>
  11. namespace gtsam {
  12. /* ************************************************************************* */
  13. AllDiff::AllDiff(const DiscreteKeys& dkeys) : Constraint(dkeys.indices()) {
  14. for (const DiscreteKey& dkey : dkeys) cardinalities_.insert(dkey);
  15. }
  16. /* ************************************************************************* */
  17. void AllDiff::print(const std::string& s, const KeyFormatter& formatter) const {
  18. std::cout << s << "AllDiff on ";
  19. for (Key dkey : keys_) std::cout << formatter(dkey) << " ";
  20. std::cout << std::endl;
  21. }
  22. /* ************************************************************************* */
  23. double AllDiff::operator()(const Values& values) const {
  24. std::set<size_t> taken; // record values taken by keys
  25. for (Key dkey : keys_) {
  26. size_t value = values.at(dkey); // get the value for that key
  27. if (taken.count(value)) return 0.0; // check if value alreday taken
  28. taken.insert(value); // if not, record it as taken and keep checking
  29. }
  30. return 1.0;
  31. }
  32. /* ************************************************************************* */
  33. DecisionTreeFactor AllDiff::toDecisionTreeFactor() const {
  34. // We will do this by converting the allDif into many BinaryAllDiff
  35. // constraints
  36. DecisionTreeFactor converted;
  37. size_t nrKeys = keys_.size();
  38. for (size_t i1 = 0; i1 < nrKeys; i1++)
  39. for (size_t i2 = i1 + 1; i2 < nrKeys; i2++) {
  40. BinaryAllDiff binary12(discreteKey(i1), discreteKey(i2));
  41. converted = converted * binary12.toDecisionTreeFactor();
  42. }
  43. return converted;
  44. }
  45. /* ************************************************************************* */
  46. DecisionTreeFactor AllDiff::operator*(const DecisionTreeFactor& f) const {
  47. // TODO: can we do this more efficiently?
  48. return toDecisionTreeFactor() * f;
  49. }
  50. /* ************************************************************************* */
  51. bool AllDiff::ensureArcConsistency(size_t j,
  52. std::vector<Domain>& domains) const {
  53. // Though strictly not part of allDiff, we check for
  54. // a value in domains[j] that does not occur in any other connected domain.
  55. // If found, we make this a singleton...
  56. // TODO: make a new constraint where this really is true
  57. Domain& Dj = domains[j];
  58. if (Dj.checkAllDiff(keys_, domains)) return true;
  59. // Check all other domains for singletons and erase corresponding values
  60. // This is the same as arc-consistency on the equivalent binary constraints
  61. bool changed = false;
  62. for (Key k : keys_)
  63. if (k != j) {
  64. const Domain& Dk = domains[k];
  65. if (Dk.isSingleton()) { // check if singleton
  66. size_t value = Dk.firstValue();
  67. if (Dj.contains(value)) {
  68. Dj.erase(value); // erase value if true
  69. changed = true;
  70. }
  71. }
  72. }
  73. return changed;
  74. }
  75. /* ************************************************************************* */
  76. Constraint::shared_ptr AllDiff::partiallyApply(const Values& values) const {
  77. DiscreteKeys newKeys;
  78. // loop over keys and add them only if they do not appear in values
  79. for (Key k : keys_)
  80. if (values.find(k) == values.end()) {
  81. newKeys.push_back(DiscreteKey(k, cardinalities_.at(k)));
  82. }
  83. return boost::make_shared<AllDiff>(newKeys);
  84. }
  85. /* ************************************************************************* */
  86. Constraint::shared_ptr AllDiff::partiallyApply(
  87. const std::vector<Domain>& domains) const {
  88. DiscreteFactor::Values known;
  89. for (Key k : keys_) {
  90. const Domain& Dk = domains[k];
  91. if (Dk.isSingleton()) known[k] = Dk.firstValue();
  92. }
  93. return partiallyApply(known);
  94. }
  95. /* ************************************************************************* */
  96. } // namespace gtsam