DiscreteBayesNetExample.cpp 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. /* ----------------------------------------------------------------------------
  2. * GTSAM Copyright 2010, Georgia Tech Research Corporation,
  3. * Atlanta, Georgia 30332-0415
  4. * All Rights Reserved
  5. * Authors: Frank Dellaert, et al. (see THANKS for the full author list)
  6. * See LICENSE for the license information
  7. * -------------------------------------------------------------------------- */
  8. /**
  9. * @file DiscreteBayesNetExample.cpp
  10. * @brief Discrete Bayes Net example with famous Asia Bayes Network
  11. * @author Frank Dellaert
  12. * @date JULY 10, 2020
  13. */
  14. #include <gtsam/discrete/DiscreteFactorGraph.h>
  15. #include <gtsam/discrete/DiscreteMarginals.h>
  16. #include <gtsam/inference/BayesNet.h>
  17. #include <iomanip>
  18. using namespace std;
  19. using namespace gtsam;
  20. int main(int argc, char **argv) {
  21. DiscreteBayesNet asia;
  22. DiscreteKey Asia(0, 2), Smoking(4, 2), Tuberculosis(3, 2), LungCancer(6, 2),
  23. Bronchitis(7, 2), Either(5, 2), XRay(2, 2), Dyspnea(1, 2);
  24. asia.add(Asia % "99/1");
  25. asia.add(Smoking % "50/50");
  26. asia.add(Tuberculosis | Asia = "99/1 95/5");
  27. asia.add(LungCancer | Smoking = "99/1 90/10");
  28. asia.add(Bronchitis | Smoking = "70/30 40/60");
  29. asia.add((Either | Tuberculosis, LungCancer) = "F T T T");
  30. asia.add(XRay | Either = "95/5 2/98");
  31. asia.add((Dyspnea | Either, Bronchitis) = "9/1 2/8 3/7 1/9");
  32. // print
  33. vector<string> pretty = {"Asia", "Dyspnea", "XRay", "Tuberculosis",
  34. "Smoking", "Either", "LungCancer", "Bronchitis"};
  35. auto formatter = [pretty](Key key) { return pretty[key]; };
  36. asia.print("Asia", formatter);
  37. // Convert to factor graph
  38. DiscreteFactorGraph fg(asia);
  39. // Create solver and eliminate
  40. Ordering ordering;
  41. ordering += Key(0), Key(1), Key(2), Key(3), Key(4), Key(5), Key(6), Key(7);
  42. DiscreteBayesNet::shared_ptr chordal = fg.eliminateSequential(ordering);
  43. // solve
  44. DiscreteFactor::sharedValues mpe = chordal->optimize();
  45. GTSAM_PRINT(*mpe);
  46. // We can also build a Bayes tree (directed junction tree).
  47. // The elimination order above will do fine:
  48. auto bayesTree = fg.eliminateMultifrontal(ordering);
  49. bayesTree->print("bayesTree", formatter);
  50. // add evidence, we were in Asia and we have dyspnea
  51. fg.add(Asia, "0 1");
  52. fg.add(Dyspnea, "0 1");
  53. // solve again, now with evidence
  54. DiscreteBayesNet::shared_ptr chordal2 = fg.eliminateSequential(ordering);
  55. DiscreteFactor::sharedValues mpe2 = chordal2->optimize();
  56. GTSAM_PRINT(*mpe2);
  57. // We can also sample from it
  58. cout << "\n10 samples:" << endl;
  59. for (size_t i = 0; i < 10; i++) {
  60. DiscreteFactor::sharedValues sample = chordal2->sample();
  61. GTSAM_PRINT(*sample);
  62. }
  63. return 0;
  64. }