UGM_chain.cpp 2.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. /* ----------------------------------------------------------------------------
  2. * GTSAM Copyright 2010, Georgia Tech Research Corporation,
  3. * Atlanta, Georgia 30332-0415
  4. * All Rights Reserved
  5. * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
  6. * See LICENSE for the license information
  7. * -------------------------------------------------------------------------- */
  8. /**
  9. * @file UGM_chain.cpp
  10. * @brief UGM (undirected graphical model) examples: chain
  11. * @author Frank Dellaert
  12. * @author Abhijit Kundu
  13. *
  14. * See http://www.di.ens.fr/~mschmidt/Software/UGM/chain.html
  15. * for more explanation. This code demos the same example using GTSAM.
  16. */
  17. #include <gtsam/base/timing.h>
  18. #include <gtsam/discrete/DiscreteFactorGraph.h>
  19. #include <gtsam/discrete/DiscreteMarginals.h>
  20. #include <iomanip>
  21. using namespace std;
  22. using namespace gtsam;
  23. int main(int argc, char** argv) {
  24. // Set Number of Nodes in the Graph
  25. const int nrNodes = 60;
  26. // Each node takes 1 of 7 possible states denoted by 0-6 in following order:
  27. // ["VideoGames" "Industry" "GradSchool" "VideoGames(with PhD)"
  28. // "Industry(with PhD)" "Academia" "Deceased"]
  29. const size_t nrStates = 7;
  30. // define variables
  31. vector<DiscreteKey> nodes;
  32. for (int i = 0; i < nrNodes; i++) {
  33. DiscreteKey dk(i, nrStates);
  34. nodes.push_back(dk);
  35. }
  36. // create graph
  37. DiscreteFactorGraph graph;
  38. // add node potentials
  39. graph.add(nodes[0], ".3 .6 .1 0 0 0 0");
  40. for (int i = 1; i < nrNodes; i++) graph.add(nodes[i], "1 1 1 1 1 1 1");
  41. const std::string edgePotential =
  42. ".08 .9 .01 0 0 0 .01 "
  43. ".03 .95 .01 0 0 0 .01 "
  44. ".06 .06 .75 .05 .05 .02 .01 "
  45. "0 0 0 .3 .6 .09 .01 "
  46. "0 0 0 .02 .95 .02 .01 "
  47. "0 0 0 .01 .01 .97 .01 "
  48. "0 0 0 0 0 0 1";
  49. // add edge potentials
  50. for (int i = 0; i < nrNodes - 1; i++)
  51. graph.add(nodes[i] & nodes[i + 1], edgePotential);
  52. cout << "Created Factor Graph with " << nrNodes << " variable nodes and "
  53. << graph.size() << " factors (Unary+Edge).";
  54. // "Decoding", i.e., configuration with largest value
  55. // We use sequential variable elimination
  56. DiscreteBayesNet::shared_ptr chordal = graph.eliminateSequential();
  57. DiscreteFactor::sharedValues optimalDecoding = chordal->optimize();
  58. optimalDecoding->print("\nMost Probable Explanation (optimalDecoding)\n");
  59. // "Inference" Computing marginals for each node
  60. // Here we'll make use of DiscreteMarginals class, which makes use of
  61. // bayes-tree based shortcut evaluation of marginals
  62. DiscreteMarginals marginals(graph);
  63. cout << "\nComputing Node Marginals ..(BayesTree based)" << endl;
  64. gttic_(Multifrontal);
  65. for (vector<DiscreteKey>::iterator it = nodes.begin(); it != nodes.end();
  66. ++it) {
  67. // Compute the marginal
  68. Vector margProbs = marginals.marginalProbabilities(*it);
  69. // Print the marginals
  70. cout << "Node#" << setw(4) << it->first << " : ";
  71. print(margProbs);
  72. }
  73. gttoc_(Multifrontal);
  74. tictoc_print_();
  75. return 0;
  76. }