tree_engine.cpp 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. #include "opencv2/ml.hpp"
  2. #include "opencv2/core.hpp"
  3. #include "opencv2/core/utility.hpp"
  4. #include <stdio.h>
  5. #include <string>
  6. #include <map>
  7. using namespace cv;
  8. using namespace cv::ml;
  9. static void help(char** argv)
  10. {
  11. printf(
  12. "\nThis sample demonstrates how to use different decision trees and forests including boosting and random trees.\n"
  13. "Usage:\n\t%s [-r=<response_column>] [-ts=type_spec] <csv filename>\n"
  14. "where -r=<response_column> specified the 0-based index of the response (0 by default)\n"
  15. "-ts= specifies the var type spec in the form ord[n1,n2-n3,n4-n5,...]cat[m1-m2,m3,m4-m5,...]\n"
  16. "<csv filename> is the name of training data file in comma-separated value format\n\n", argv[0]);
  17. }
  18. static void train_and_print_errs(Ptr<StatModel> model, const Ptr<TrainData>& data)
  19. {
  20. bool ok = model->train(data);
  21. if( !ok )
  22. {
  23. printf("Training failed\n");
  24. }
  25. else
  26. {
  27. printf( "train error: %f\n", model->calcError(data, false, noArray()) );
  28. printf( "test error: %f\n\n", model->calcError(data, true, noArray()) );
  29. }
  30. }
  31. int main(int argc, char** argv)
  32. {
  33. cv::CommandLineParser parser(argc, argv, "{ help h | | }{r | 0 | }{ts | | }{@input | | }");
  34. if (parser.has("help"))
  35. {
  36. help(argv);
  37. return 0;
  38. }
  39. std::string filename = parser.get<std::string>("@input");
  40. int response_idx;
  41. std::string typespec;
  42. response_idx = parser.get<int>("r");
  43. typespec = parser.get<std::string>("ts");
  44. if( filename.empty() || !parser.check() )
  45. {
  46. parser.printErrors();
  47. help(argv);
  48. return 0;
  49. }
  50. printf("\nReading in %s...\n\n",filename.c_str());
  51. const double train_test_split_ratio = 0.5;
  52. Ptr<TrainData> data = TrainData::loadFromCSV(filename, 0, response_idx, response_idx+1, typespec);
  53. if( data.empty() )
  54. {
  55. printf("ERROR: File %s can not be read\n", filename.c_str());
  56. return 0;
  57. }
  58. data->setTrainTestSplitRatio(train_test_split_ratio);
  59. std::cout << "Test/Train: " << data->getNTestSamples() << "/" << data->getNTrainSamples();
  60. printf("======DTREE=====\n");
  61. Ptr<DTrees> dtree = DTrees::create();
  62. dtree->setMaxDepth(10);
  63. dtree->setMinSampleCount(2);
  64. dtree->setRegressionAccuracy(0);
  65. dtree->setUseSurrogates(false);
  66. dtree->setMaxCategories(16);
  67. dtree->setCVFolds(0);
  68. dtree->setUse1SERule(false);
  69. dtree->setTruncatePrunedTree(false);
  70. dtree->setPriors(Mat());
  71. train_and_print_errs(dtree, data);
  72. if( (int)data->getClassLabels().total() <= 2 ) // regression or 2-class classification problem
  73. {
  74. printf("======BOOST=====\n");
  75. Ptr<Boost> boost = Boost::create();
  76. boost->setBoostType(Boost::GENTLE);
  77. boost->setWeakCount(100);
  78. boost->setWeightTrimRate(0.95);
  79. boost->setMaxDepth(2);
  80. boost->setUseSurrogates(false);
  81. boost->setPriors(Mat());
  82. train_and_print_errs(boost, data);
  83. }
  84. printf("======RTREES=====\n");
  85. Ptr<RTrees> rtrees = RTrees::create();
  86. rtrees->setMaxDepth(10);
  87. rtrees->setMinSampleCount(2);
  88. rtrees->setRegressionAccuracy(0);
  89. rtrees->setUseSurrogates(false);
  90. rtrees->setMaxCategories(16);
  91. rtrees->setPriors(Mat());
  92. rtrees->setCalculateVarImportance(true);
  93. rtrees->setActiveVarCount(0);
  94. rtrees->setTermCriteria(TermCriteria(TermCriteria::MAX_ITER, 100, 0));
  95. train_and_print_errs(rtrees, data);
  96. cv::Mat ref_labels = data->getClassLabels();
  97. cv::Mat test_data = data->getTestSampleIdx();
  98. cv::Mat predict_labels;
  99. rtrees->predict(data->getSamples(), predict_labels);
  100. cv::Mat variable_importance = rtrees->getVarImportance();
  101. std::cout << "Estimated variable importance" << std::endl;
  102. for (int i = 0; i < variable_importance.rows; i++) {
  103. std::cout << "Variable " << i << ": " << variable_importance.at<float>(i, 0) << std::endl;
  104. }
  105. return 0;
  106. }