HMMExample.cpp 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. /* ----------------------------------------------------------------------------
  2. * GTSAM Copyright 2010-2020, Georgia Tech Research Corporation,
  3. * Atlanta, Georgia 30332-0415
  4. * All Rights Reserved
  5. * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
  6. * See LICENSE for the license information
  7. * -------------------------------------------------------------------------- */
  8. /**
  9. * @file DiscreteBayesNetExample.cpp
  10. * @brief Hidden Markov Model example, discrete.
  11. * @author Frank Dellaert
  12. * @date July 12, 2020
  13. */
  14. #include <gtsam/discrete/DiscreteFactorGraph.h>
  15. #include <gtsam/discrete/DiscreteMarginals.h>
  16. #include <gtsam/inference/BayesNet.h>
  17. #include <iomanip>
  18. #include <sstream>
  19. using namespace std;
  20. using namespace gtsam;
  21. int main(int argc, char **argv) {
  22. const int nrNodes = 4;
  23. const size_t nrStates = 3;
  24. // Define variables as well as ordering
  25. Ordering ordering;
  26. vector<DiscreteKey> keys;
  27. for (int k = 0; k < nrNodes; k++) {
  28. DiscreteKey key_i(k, nrStates);
  29. keys.push_back(key_i);
  30. ordering.emplace_back(k);
  31. }
  32. // Create HMM as a DiscreteBayesNet
  33. DiscreteBayesNet hmm;
  34. // Define backbone
  35. const string transition = "8/1/1 1/8/1 1/1/8";
  36. for (int k = 1; k < nrNodes; k++) {
  37. hmm.add(keys[k] | keys[k - 1] = transition);
  38. }
  39. // Add some measurements, not needed for all time steps!
  40. hmm.add(keys[0] % "7/2/1");
  41. hmm.add(keys[1] % "1/9/0");
  42. hmm.add(keys.back() % "5/4/1");
  43. // print
  44. hmm.print("HMM");
  45. // Convert to factor graph
  46. DiscreteFactorGraph factorGraph(hmm);
  47. // Create solver and eliminate
  48. // This will create a DAG ordered with arrow of time reversed
  49. DiscreteBayesNet::shared_ptr chordal =
  50. factorGraph.eliminateSequential(ordering);
  51. chordal->print("Eliminated");
  52. // solve
  53. DiscreteFactor::sharedValues mpe = chordal->optimize();
  54. GTSAM_PRINT(*mpe);
  55. // We can also sample from it
  56. cout << "\n10 samples:" << endl;
  57. for (size_t k = 0; k < 10; k++) {
  58. DiscreteFactor::sharedValues sample = chordal->sample();
  59. GTSAM_PRINT(*sample);
  60. }
  61. // Or compute the marginals. This re-eliminates the FG into a Bayes tree
  62. cout << "\nComputing Node Marginals .." << endl;
  63. DiscreteMarginals marginals(factorGraph);
  64. for (int k = 0; k < nrNodes; k++) {
  65. Vector margProbs = marginals.marginalProbabilities(keys[k]);
  66. stringstream ss;
  67. ss << "marginal " << k;
  68. print(margProbs, ss.str());
  69. }
  70. // TODO(frank): put in the glue to have DiscreteMarginals produce *arbitrary*
  71. // joints efficiently, by the Bayes tree shortcut magic. All the code is there
  72. // but it's not yet connected.
  73. return 0;
  74. }