boost.h 3.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. #ifndef _OPENCV_BOOST_H_
  2. #define _OPENCV_BOOST_H_
  3. #include "traincascade_features.h"
  4. #include "old_ml.hpp"
  5. struct CvCascadeBoostParams : CvBoostParams
  6. {
  7. float minHitRate;
  8. float maxFalseAlarm;
  9. CvCascadeBoostParams();
  10. CvCascadeBoostParams( int _boostType, float _minHitRate, float _maxFalseAlarm,
  11. double _weightTrimRate, int _maxDepth, int _maxWeakCount );
  12. virtual ~CvCascadeBoostParams() {}
  13. void write( cv::FileStorage &fs ) const;
  14. bool read( const cv::FileNode &node );
  15. virtual void printDefaults() const;
  16. virtual void printAttrs() const;
  17. virtual bool scanAttr( const std::string prmName, const std::string val);
  18. };
  19. struct CvCascadeBoostTrainData : CvDTreeTrainData
  20. {
  21. CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
  22. const CvDTreeParams& _params );
  23. CvCascadeBoostTrainData( const CvFeatureEvaluator* _featureEvaluator,
  24. int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
  25. const CvDTreeParams& _params = CvDTreeParams() );
  26. virtual void setData( const CvFeatureEvaluator* _featureEvaluator,
  27. int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
  28. const CvDTreeParams& _params=CvDTreeParams() );
  29. void precalculate();
  30. virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
  31. virtual const int* get_class_labels( CvDTreeNode* n, int* labelsBuf );
  32. virtual const int* get_cv_labels( CvDTreeNode* n, int* labelsBuf);
  33. virtual const int* get_sample_indices( CvDTreeNode* n, int* indicesBuf );
  34. virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ordValuesBuf, int* sortedIndicesBuf,
  35. const float** ordValues, const int** sortedIndices, int* sampleIndicesBuf );
  36. virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* catValuesBuf );
  37. virtual float getVarValue( int vi, int si );
  38. virtual void free_train_data();
  39. const CvFeatureEvaluator* featureEvaluator;
  40. cv::Mat valCache; // precalculated feature values (CV_32FC1)
  41. CvMat _resp; // for casting
  42. int numPrecalcVal, numPrecalcIdx;
  43. };
  44. class CvCascadeBoostTree : public CvBoostTree
  45. {
  46. public:
  47. virtual CvDTreeNode* predict( int sampleIdx ) const;
  48. void write( cv::FileStorage &fs, const cv::Mat& featureMap );
  49. void read( const cv::FileNode &node, CvBoost* _ensemble, CvDTreeTrainData* _data );
  50. void markFeaturesInMap( cv::Mat& featureMap );
  51. protected:
  52. virtual void split_node_data( CvDTreeNode* n );
  53. };
  54. class CvCascadeBoost : public CvBoost
  55. {
  56. public:
  57. virtual bool train( const CvFeatureEvaluator* _featureEvaluator,
  58. int _numSamples, int _precalcValBufSize, int _precalcIdxBufSize,
  59. const CvCascadeBoostParams& _params=CvCascadeBoostParams() );
  60. virtual float predict( int sampleIdx, bool returnSum = false ) const;
  61. float getThreshold() const { return threshold; }
  62. void write( cv::FileStorage &fs, const cv::Mat& featureMap ) const;
  63. bool read( const cv::FileNode &node, const CvFeatureEvaluator* _featureEvaluator,
  64. const CvCascadeBoostParams& _params );
  65. void markUsedFeaturesInMap( cv::Mat& featureMap );
  66. protected:
  67. virtual bool set_params( const CvBoostParams& _params );
  68. virtual void update_weights( CvBoostTree* tree );
  69. virtual bool isErrDesired();
  70. float threshold;
  71. float minHitRate, maxFalseAlarm;
  72. };
  73. #endif