UGM_small.cpp 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. /* ----------------------------------------------------------------------------
  2. * GTSAM Copyright 2010, Georgia Tech Research Corporation,
  3. * Atlanta, Georgia 30332-0415
  4. * All Rights Reserved
  5. * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
  6. * See LICENSE for the license information
  7. * -------------------------------------------------------------------------- */
  8. /**
  9. * @file UGM_small.cpp
  10. * @brief UGM (undirected graphical model) examples: small
  11. * @author Frank Dellaert
  12. *
  13. * See http://www.di.ens.fr/~mschmidt/Software/UGM/small.html
  14. */
  15. #include <gtsam/base/Vector.h>
  16. #include <gtsam/discrete/DiscreteFactorGraph.h>
  17. #include <gtsam/discrete/DiscreteMarginals.h>
  18. using namespace std;
  19. using namespace gtsam;
  20. int main(int argc, char** argv) {
  21. // We will assume 2-state variables, where, to conform to the "small" example
  22. // we have 0 == "right answer" and 1 == "wrong answer"
  23. size_t nrStates = 2;
  24. // define variables
  25. DiscreteKey Cathy(1, nrStates), Heather(2, nrStates), Mark(3, nrStates),
  26. Allison(4, nrStates);
  27. // create graph
  28. DiscreteFactorGraph graph;
  29. // add node potentials
  30. graph.add(Cathy, "1 3");
  31. graph.add(Heather, "9 1");
  32. graph.add(Mark, "1 3");
  33. graph.add(Allison, "9 1");
  34. // add edge potentials
  35. graph.add(Cathy & Heather, "2 1 1 2");
  36. graph.add(Heather & Mark, "2 1 1 2");
  37. graph.add(Mark & Allison, "2 1 1 2");
  38. // Print the UGM distribution
  39. cout << "\nUGM distribution:" << endl;
  40. vector<DiscreteFactor::Values> allPosbValues = cartesianProduct(
  41. Cathy & Heather & Mark & Allison);
  42. for (size_t i = 0; i < allPosbValues.size(); ++i) {
  43. DiscreteFactor::Values values = allPosbValues[i];
  44. double prodPot = graph(values);
  45. cout << values[Cathy.first] << " " << values[Heather.first] << " "
  46. << values[Mark.first] << " " << values[Allison.first] << " :\t"
  47. << prodPot << "\t" << prodPot / 3790 << endl;
  48. }
  49. // "Decoding", i.e., configuration with largest value (MPE)
  50. // We use sequential variable elimination
  51. DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
  52. DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
  53. optimalDecoding->print("\noptimalDecoding");
  54. // "Inference" Computing marginals
  55. cout << "\nComputing Node Marginals .." << endl;
  56. DiscreteMarginals marginals(graph);
  57. Vector margProbs = marginals.marginalProbabilities(Cathy);
  58. print(margProbs, "Cathy's Node Marginal:");
  59. margProbs = marginals.marginalProbabilities(Heather);
  60. print(margProbs, "Heather's Node Marginal");
  61. margProbs = marginals.marginalProbabilities(Mark);
  62. print(margProbs, "Mark's Node Marginal");
  63. margProbs = marginals.marginalProbabilities(Allison);
  64. print(margProbs, "Allison's Node Marginal");
  65. return 0;
  66. }