old_ml.hpp 73 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041
  1. /*M///////////////////////////////////////////////////////////////////////////////////////
  2. //
  3. // IMPORTANT: READ BEFORE DOWNLOADING, COPYING, INSTALLING OR USING.
  4. //
  5. // By downloading, copying, installing or using the software you agree to this license.
  6. // If you do not agree to this license, do not download, install,
  7. // copy or use the software.
  8. //
  9. //
  10. // Intel License Agreement
  11. //
  12. // Copyright (C) 2000, Intel Corporation, all rights reserved.
  13. // Third party copyrights are property of their respective owners.
  14. //
  15. // Redistribution and use in source and binary forms, with or without modification,
  16. // are permitted provided that the following conditions are met:
  17. //
  18. // * Redistribution's of source code must retain the above copyright notice,
  19. // this list of conditions and the following disclaimer.
  20. //
  21. // * Redistribution's in binary form must reproduce the above copyright notice,
  22. // this list of conditions and the following disclaimer in the documentation
  23. // and/or other materials provided with the distribution.
  24. //
  25. // * The name of Intel Corporation may not be used to endorse or promote products
  26. // derived from this software without specific prior written permission.
  27. //
  28. // This software is provided by the copyright holders and contributors "as is" and
  29. // any express or implied warranties, including, but not limited to, the implied
  30. // warranties of merchantability and fitness for a particular purpose are disclaimed.
  31. // In no event shall the Intel Corporation or contributors be liable for any direct,
  32. // indirect, incidental, special, exemplary, or consequential damages
  33. // (including, but not limited to, procurement of substitute goods or services;
  34. // loss of use, data, or profits; or business interruption) however caused
  35. // and on any theory of liability, whether in contract, strict liability,
  36. // or tort (including negligence or otherwise) arising in any way out of
  37. // the use of this software, even if advised of the possibility of such damage.
  38. //
  39. //M*/
  40. #ifndef OPENCV_OLD_ML_HPP
  41. #define OPENCV_OLD_ML_HPP
  42. #ifdef __cplusplus
  43. # include "opencv2/core.hpp"
  44. #endif
  45. #include "opencv2/core/core_c.h"
  46. #include <limits.h>
  47. #ifdef __cplusplus
  48. #include <map>
  49. #include <iostream>
  50. // Apple defines a check() macro somewhere in the debug headers
  51. // that interferes with a method definition in this header
  52. #undef check
  53. /****************************************************************************************\
  54. * Main struct definitions *
  55. \****************************************************************************************/
  56. /* log(2*PI) */
  57. #define CV_LOG2PI (1.8378770664093454835606594728112)
  58. /* columns of <trainData> matrix are training samples */
  59. #define CV_COL_SAMPLE 0
  60. /* rows of <trainData> matrix are training samples */
  61. #define CV_ROW_SAMPLE 1
  62. #define CV_IS_ROW_SAMPLE(flags) ((flags) & CV_ROW_SAMPLE)
  63. struct CvVectors
  64. {
  65. int type;
  66. int dims, count;
  67. CvVectors* next;
  68. union
  69. {
  70. uchar** ptr;
  71. float** fl;
  72. double** db;
  73. } data;
  74. };
  75. #if 0
  76. /* A structure, representing the lattice range of statmodel parameters.
  77. It is used for optimizing statmodel parameters by cross-validation method.
  78. The lattice is logarithmic, so <step> must be greater than 1. */
  79. typedef struct CvParamLattice
  80. {
  81. double min_val;
  82. double max_val;
  83. double step;
  84. }
  85. CvParamLattice;
  86. CV_INLINE CvParamLattice cvParamLattice( double min_val, double max_val,
  87. double log_step )
  88. {
  89. CvParamLattice pl;
  90. pl.min_val = MIN( min_val, max_val );
  91. pl.max_val = MAX( min_val, max_val );
  92. pl.step = MAX( log_step, 1. );
  93. return pl;
  94. }
  95. CV_INLINE CvParamLattice cvDefaultParamLattice( void )
  96. {
  97. CvParamLattice pl = {0,0,0};
  98. return pl;
  99. }
  100. #endif
  101. /* Variable type */
  102. #define CV_VAR_NUMERICAL 0
  103. #define CV_VAR_ORDERED 0
  104. #define CV_VAR_CATEGORICAL 1
  105. #define CV_TYPE_NAME_ML_SVM "opencv-ml-svm"
  106. #define CV_TYPE_NAME_ML_KNN "opencv-ml-knn"
  107. #define CV_TYPE_NAME_ML_NBAYES "opencv-ml-bayesian"
  108. #define CV_TYPE_NAME_ML_BOOSTING "opencv-ml-boost-tree"
  109. #define CV_TYPE_NAME_ML_TREE "opencv-ml-tree"
  110. #define CV_TYPE_NAME_ML_ANN_MLP "opencv-ml-ann-mlp"
  111. #define CV_TYPE_NAME_ML_CNN "opencv-ml-cnn"
  112. #define CV_TYPE_NAME_ML_RTREES "opencv-ml-random-trees"
  113. #define CV_TYPE_NAME_ML_ERTREES "opencv-ml-extremely-randomized-trees"
  114. #define CV_TYPE_NAME_ML_GBT "opencv-ml-gradient-boosting-trees"
  115. #define CV_TRAIN_ERROR 0
  116. #define CV_TEST_ERROR 1
  117. class CvStatModel
  118. {
  119. public:
  120. CvStatModel();
  121. virtual ~CvStatModel();
  122. virtual void clear();
  123. CV_WRAP virtual void save( const char* filename, const char* name=0 ) const;
  124. CV_WRAP virtual void load( const char* filename, const char* name=0 );
  125. virtual void write( cv::FileStorage& storage, const char* name ) const;
  126. virtual void read( const cv::FileNode& node );
  127. protected:
  128. const char* default_model_name;
  129. };
  130. /****************************************************************************************\
  131. * Normal Bayes Classifier *
  132. \****************************************************************************************/
  133. /* The structure, representing the grid range of statmodel parameters.
  134. It is used for optimizing statmodel accuracy by varying model parameters,
  135. the accuracy estimate being computed by cross-validation.
  136. The grid is logarithmic, so <step> must be greater than 1. */
  137. class CvMLData;
  138. struct CvParamGrid
  139. {
  140. // SVM params type
  141. enum { SVM_C=0, SVM_GAMMA=1, SVM_P=2, SVM_NU=3, SVM_COEF=4, SVM_DEGREE=5 };
  142. CvParamGrid()
  143. {
  144. min_val = max_val = step = 0;
  145. }
  146. CvParamGrid( double min_val, double max_val, double log_step );
  147. //CvParamGrid( int param_id );
  148. bool check() const;
  149. CV_PROP_RW double min_val;
  150. CV_PROP_RW double max_val;
  151. CV_PROP_RW double step;
  152. };
  153. inline CvParamGrid::CvParamGrid( double _min_val, double _max_val, double _log_step )
  154. {
  155. min_val = _min_val;
  156. max_val = _max_val;
  157. step = _log_step;
  158. }
  159. class CvNormalBayesClassifier : public CvStatModel
  160. {
  161. public:
  162. CV_WRAP CvNormalBayesClassifier();
  163. virtual ~CvNormalBayesClassifier();
  164. CvNormalBayesClassifier( const CvMat* trainData, const CvMat* responses,
  165. const CvMat* varIdx=0, const CvMat* sampleIdx=0 );
  166. virtual bool train( const CvMat* trainData, const CvMat* responses,
  167. const CvMat* varIdx = 0, const CvMat* sampleIdx=0, bool update=false );
  168. virtual float predict( const CvMat* samples, CV_OUT CvMat* results=0, CV_OUT CvMat* results_prob=0 ) const;
  169. CV_WRAP virtual void clear();
  170. CV_WRAP CvNormalBayesClassifier( const cv::Mat& trainData, const cv::Mat& responses,
  171. const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat() );
  172. CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
  173. const cv::Mat& varIdx = cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
  174. bool update=false );
  175. CV_WRAP virtual float predict( const cv::Mat& samples, CV_OUT cv::Mat* results=0, CV_OUT cv::Mat* results_prob=0 ) const;
  176. virtual void write( cv::FileStorage& storage, const char* name ) const;
  177. virtual void read( const cv::FileNode& node );
  178. protected:
  179. int var_count, var_all;
  180. CvMat* var_idx;
  181. CvMat* cls_labels;
  182. CvMat** count;
  183. CvMat** sum;
  184. CvMat** productsum;
  185. CvMat** avg;
  186. CvMat** inv_eigen_values;
  187. CvMat** cov_rotate_mats;
  188. CvMat* c;
  189. };
  190. /****************************************************************************************\
  191. * K-Nearest Neighbour Classifier *
  192. \****************************************************************************************/
  193. // k Nearest Neighbors
  194. class CvKNearest : public CvStatModel
  195. {
  196. public:
  197. CV_WRAP CvKNearest();
  198. virtual ~CvKNearest();
  199. CvKNearest( const CvMat* trainData, const CvMat* responses,
  200. const CvMat* sampleIdx=0, bool isRegression=false, int max_k=32 );
  201. virtual bool train( const CvMat* trainData, const CvMat* responses,
  202. const CvMat* sampleIdx=0, bool is_regression=false,
  203. int maxK=32, bool updateBase=false );
  204. virtual float find_nearest( const CvMat* samples, int k, CV_OUT CvMat* results=0,
  205. const float** neighbors=0, CV_OUT CvMat* neighborResponses=0, CV_OUT CvMat* dist=0 ) const;
  206. CV_WRAP CvKNearest( const cv::Mat& trainData, const cv::Mat& responses,
  207. const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false, int max_k=32 );
  208. CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
  209. const cv::Mat& sampleIdx=cv::Mat(), bool isRegression=false,
  210. int maxK=32, bool updateBase=false );
  211. virtual float find_nearest( const cv::Mat& samples, int k, cv::Mat* results=0,
  212. const float** neighbors=0, cv::Mat* neighborResponses=0,
  213. cv::Mat* dist=0 ) const;
  214. CV_WRAP virtual float find_nearest( const cv::Mat& samples, int k, CV_OUT cv::Mat& results,
  215. CV_OUT cv::Mat& neighborResponses, CV_OUT cv::Mat& dists) const;
  216. virtual void clear();
  217. int get_max_k() const;
  218. int get_var_count() const;
  219. int get_sample_count() const;
  220. bool is_regression() const;
  221. virtual float write_results( int k, int k1, int start, int end,
  222. const float* neighbor_responses, const float* dist, CvMat* _results,
  223. CvMat* _neighbor_responses, CvMat* _dist, Cv32suf* sort_buf ) const;
  224. virtual void find_neighbors_direct( const CvMat* _samples, int k, int start, int end,
  225. float* neighbor_responses, const float** neighbors, float* dist ) const;
  226. protected:
  227. int max_k, var_count;
  228. int total;
  229. bool regression;
  230. CvVectors* samples;
  231. };
  232. /****************************************************************************************\
  233. * Support Vector Machines *
  234. \****************************************************************************************/
  235. // SVM training parameters
  236. struct CvSVMParams
  237. {
  238. CvSVMParams();
  239. CvSVMParams( int svm_type, int kernel_type,
  240. double degree, double gamma, double coef0,
  241. double Cvalue, double nu, double p,
  242. CvMat* class_weights, CvTermCriteria term_crit );
  243. CV_PROP_RW int svm_type;
  244. CV_PROP_RW int kernel_type;
  245. CV_PROP_RW double degree; // for poly
  246. CV_PROP_RW double gamma; // for poly/rbf/sigmoid/chi2
  247. CV_PROP_RW double coef0; // for poly/sigmoid
  248. CV_PROP_RW double C; // for CV_SVM_C_SVC, CV_SVM_EPS_SVR and CV_SVM_NU_SVR
  249. CV_PROP_RW double nu; // for CV_SVM_NU_SVC, CV_SVM_ONE_CLASS, and CV_SVM_NU_SVR
  250. CV_PROP_RW double p; // for CV_SVM_EPS_SVR
  251. CvMat* class_weights; // for CV_SVM_C_SVC
  252. CV_PROP_RW CvTermCriteria term_crit; // termination criteria
  253. };
  254. struct CvSVMKernel
  255. {
  256. typedef void (CvSVMKernel::*Calc)( int vec_count, int vec_size, const float** vecs,
  257. const float* another, float* results );
  258. CvSVMKernel();
  259. CvSVMKernel( const CvSVMParams* params, Calc _calc_func );
  260. virtual bool create( const CvSVMParams* params, Calc _calc_func );
  261. virtual ~CvSVMKernel();
  262. virtual void clear();
  263. virtual void calc( int vcount, int n, const float** vecs, const float* another, float* results );
  264. const CvSVMParams* params;
  265. Calc calc_func;
  266. virtual void calc_non_rbf_base( int vec_count, int vec_size, const float** vecs,
  267. const float* another, float* results,
  268. double alpha, double beta );
  269. virtual void calc_intersec( int vcount, int var_count, const float** vecs,
  270. const float* another, float* results );
  271. virtual void calc_chi2( int vec_count, int vec_size, const float** vecs,
  272. const float* another, float* results );
  273. virtual void calc_linear( int vec_count, int vec_size, const float** vecs,
  274. const float* another, float* results );
  275. virtual void calc_rbf( int vec_count, int vec_size, const float** vecs,
  276. const float* another, float* results );
  277. virtual void calc_poly( int vec_count, int vec_size, const float** vecs,
  278. const float* another, float* results );
  279. virtual void calc_sigmoid( int vec_count, int vec_size, const float** vecs,
  280. const float* another, float* results );
  281. };
  282. struct CvSVMKernelRow
  283. {
  284. CvSVMKernelRow* prev;
  285. CvSVMKernelRow* next;
  286. float* data;
  287. };
  288. struct CvSVMSolutionInfo
  289. {
  290. double obj;
  291. double rho;
  292. double upper_bound_p;
  293. double upper_bound_n;
  294. double r; // for Solver_NU
  295. };
  296. class CvSVMSolver
  297. {
  298. public:
  299. typedef bool (CvSVMSolver::*SelectWorkingSet)( int& i, int& j );
  300. typedef float* (CvSVMSolver::*GetRow)( int i, float* row, float* dst, bool existed );
  301. typedef void (CvSVMSolver::*CalcRho)( double& rho, double& r );
  302. CvSVMSolver();
  303. CvSVMSolver( int count, int var_count, const float** samples, schar* y,
  304. int alpha_count, double* alpha, double Cp, double Cn,
  305. CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
  306. SelectWorkingSet select_working_set, CalcRho calc_rho );
  307. virtual bool create( int count, int var_count, const float** samples, schar* y,
  308. int alpha_count, double* alpha, double Cp, double Cn,
  309. CvMemStorage* storage, CvSVMKernel* kernel, GetRow get_row,
  310. SelectWorkingSet select_working_set, CalcRho calc_rho );
  311. virtual ~CvSVMSolver();
  312. virtual void clear();
  313. virtual bool solve_generic( CvSVMSolutionInfo& si );
  314. virtual bool solve_c_svc( int count, int var_count, const float** samples, schar* y,
  315. double Cp, double Cn, CvMemStorage* storage,
  316. CvSVMKernel* kernel, double* alpha, CvSVMSolutionInfo& si );
  317. virtual bool solve_nu_svc( int count, int var_count, const float** samples, schar* y,
  318. CvMemStorage* storage, CvSVMKernel* kernel,
  319. double* alpha, CvSVMSolutionInfo& si );
  320. virtual bool solve_one_class( int count, int var_count, const float** samples,
  321. CvMemStorage* storage, CvSVMKernel* kernel,
  322. double* alpha, CvSVMSolutionInfo& si );
  323. virtual bool solve_eps_svr( int count, int var_count, const float** samples, const float* y,
  324. CvMemStorage* storage, CvSVMKernel* kernel,
  325. double* alpha, CvSVMSolutionInfo& si );
  326. virtual bool solve_nu_svr( int count, int var_count, const float** samples, const float* y,
  327. CvMemStorage* storage, CvSVMKernel* kernel,
  328. double* alpha, CvSVMSolutionInfo& si );
  329. virtual float* get_row_base( int i, bool* _existed );
  330. virtual float* get_row( int i, float* dst );
  331. int sample_count;
  332. int var_count;
  333. int cache_size;
  334. int cache_line_size;
  335. const float** samples;
  336. const CvSVMParams* params;
  337. CvMemStorage* storage;
  338. CvSVMKernelRow lru_list;
  339. CvSVMKernelRow* rows;
  340. int alpha_count;
  341. double* G;
  342. double* alpha;
  343. // -1 - lower bound, 0 - free, 1 - upper bound
  344. schar* alpha_status;
  345. schar* y;
  346. double* b;
  347. float* buf[2];
  348. double eps;
  349. int max_iter;
  350. double C[2]; // C[0] == Cn, C[1] == Cp
  351. CvSVMKernel* kernel;
  352. SelectWorkingSet select_working_set_func;
  353. CalcRho calc_rho_func;
  354. GetRow get_row_func;
  355. virtual bool select_working_set( int& i, int& j );
  356. virtual bool select_working_set_nu_svm( int& i, int& j );
  357. virtual void calc_rho( double& rho, double& r );
  358. virtual void calc_rho_nu_svm( double& rho, double& r );
  359. virtual float* get_row_svc( int i, float* row, float* dst, bool existed );
  360. virtual float* get_row_one_class( int i, float* row, float* dst, bool existed );
  361. virtual float* get_row_svr( int i, float* row, float* dst, bool existed );
  362. };
  363. struct CvSVMDecisionFunc
  364. {
  365. double rho;
  366. int sv_count;
  367. double* alpha;
  368. int* sv_index;
  369. };
  370. // SVM model
  371. class CvSVM : public CvStatModel
  372. {
  373. public:
  374. // SVM type
  375. enum { C_SVC=100, NU_SVC=101, ONE_CLASS=102, EPS_SVR=103, NU_SVR=104 };
  376. // SVM kernel type
  377. enum { LINEAR=0, POLY=1, RBF=2, SIGMOID=3, CHI2=4, INTER=5 };
  378. // SVM params type
  379. enum { C=0, GAMMA=1, P=2, NU=3, COEF=4, DEGREE=5 };
  380. CV_WRAP CvSVM();
  381. virtual ~CvSVM();
  382. CvSVM( const CvMat* trainData, const CvMat* responses,
  383. const CvMat* varIdx=0, const CvMat* sampleIdx=0,
  384. CvSVMParams params=CvSVMParams() );
  385. virtual bool train( const CvMat* trainData, const CvMat* responses,
  386. const CvMat* varIdx=0, const CvMat* sampleIdx=0,
  387. CvSVMParams params=CvSVMParams() );
  388. virtual bool train_auto( const CvMat* trainData, const CvMat* responses,
  389. const CvMat* varIdx, const CvMat* sampleIdx, CvSVMParams params,
  390. int kfold = 10,
  391. CvParamGrid Cgrid = get_default_grid(CvSVM::C),
  392. CvParamGrid gammaGrid = get_default_grid(CvSVM::GAMMA),
  393. CvParamGrid pGrid = get_default_grid(CvSVM::P),
  394. CvParamGrid nuGrid = get_default_grid(CvSVM::NU),
  395. CvParamGrid coeffGrid = get_default_grid(CvSVM::COEF),
  396. CvParamGrid degreeGrid = get_default_grid(CvSVM::DEGREE),
  397. bool balanced=false );
  398. virtual float predict( const CvMat* sample, bool returnDFVal=false ) const;
  399. virtual float predict( const CvMat* samples, CV_OUT CvMat* results, bool returnDFVal=false ) const;
  400. CV_WRAP CvSVM( const cv::Mat& trainData, const cv::Mat& responses,
  401. const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
  402. CvSVMParams params=CvSVMParams() );
  403. CV_WRAP virtual bool train( const cv::Mat& trainData, const cv::Mat& responses,
  404. const cv::Mat& varIdx=cv::Mat(), const cv::Mat& sampleIdx=cv::Mat(),
  405. CvSVMParams params=CvSVMParams() );
  406. CV_WRAP virtual bool train_auto( const cv::Mat& trainData, const cv::Mat& responses,
  407. const cv::Mat& varIdx, const cv::Mat& sampleIdx, CvSVMParams params,
  408. int k_fold = 10,
  409. CvParamGrid Cgrid = CvSVM::get_default_grid(CvSVM::C),
  410. CvParamGrid gammaGrid = CvSVM::get_default_grid(CvSVM::GAMMA),
  411. CvParamGrid pGrid = CvSVM::get_default_grid(CvSVM::P),
  412. CvParamGrid nuGrid = CvSVM::get_default_grid(CvSVM::NU),
  413. CvParamGrid coeffGrid = CvSVM::get_default_grid(CvSVM::COEF),
  414. CvParamGrid degreeGrid = CvSVM::get_default_grid(CvSVM::DEGREE),
  415. bool balanced=false);
  416. CV_WRAP virtual float predict( const cv::Mat& sample, bool returnDFVal=false ) const;
  417. CV_WRAP_AS(predict_all) virtual void predict( cv::InputArray samples, cv::OutputArray results ) const;
  418. CV_WRAP virtual int get_support_vector_count() const;
  419. virtual const float* get_support_vector(int i) const;
  420. virtual CvSVMParams get_params() const { return params; }
  421. CV_WRAP virtual void clear();
  422. virtual const CvSVMDecisionFunc* get_decision_function() const { return decision_func; }
  423. static CvParamGrid get_default_grid( int param_id );
  424. virtual void write( cv::FileStorage& storage, const char* name ) const;
  425. virtual void read( const cv::FileNode& node );
  426. CV_WRAP int get_var_count() const { return var_idx ? var_idx->cols : var_all; }
  427. protected:
  428. virtual bool set_params( const CvSVMParams& params );
  429. virtual bool train1( int sample_count, int var_count, const float** samples,
  430. const void* responses, double Cp, double Cn,
  431. CvMemStorage* _storage, double* alpha, double& rho );
  432. virtual bool do_train( int svm_type, int sample_count, int var_count, const float** samples,
  433. const CvMat* responses, CvMemStorage* _storage, double* alpha );
  434. virtual void create_kernel();
  435. virtual void create_solver();
  436. virtual float predict( const float* row_sample, int row_len, bool returnDFVal=false ) const;
  437. virtual void write_params( cv::FileStorage& fs ) const;
  438. virtual void read_params( const cv::FileNode& node );
  439. void optimize_linear_svm();
  440. CvSVMParams params;
  441. CvMat* class_labels;
  442. int var_all;
  443. float** sv;
  444. int sv_total;
  445. CvMat* var_idx;
  446. CvMat* class_weights;
  447. CvSVMDecisionFunc* decision_func;
  448. CvMemStorage* storage;
  449. CvSVMSolver* solver;
  450. CvSVMKernel* kernel;
  451. private:
  452. CvSVM(const CvSVM&);
  453. CvSVM& operator = (const CvSVM&);
  454. };
  455. /****************************************************************************************\
  456. * Decision Tree *
  457. \****************************************************************************************/\
  458. struct CvPair16u32s
  459. {
  460. unsigned short* u;
  461. int* i;
  462. };
  463. #define CV_DTREE_CAT_DIR(idx,subset) \
  464. (2*((subset[(idx)>>5]&(1 << ((idx) & 31)))==0)-1)
  465. struct CvDTreeSplit
  466. {
  467. int var_idx;
  468. int condensed_idx;
  469. int inversed;
  470. float quality;
  471. CvDTreeSplit* next;
  472. union
  473. {
  474. int subset[2];
  475. struct
  476. {
  477. float c;
  478. int split_point;
  479. }
  480. ord;
  481. };
  482. };
  483. struct CvDTreeNode
  484. {
  485. int class_idx;
  486. int Tn;
  487. double value;
  488. CvDTreeNode* parent;
  489. CvDTreeNode* left;
  490. CvDTreeNode* right;
  491. CvDTreeSplit* split;
  492. int sample_count;
  493. int depth;
  494. int* num_valid;
  495. int offset;
  496. int buf_idx;
  497. double maxlr;
  498. // global pruning data
  499. int complexity;
  500. double alpha;
  501. double node_risk, tree_risk, tree_error;
  502. // cross-validation pruning data
  503. int* cv_Tn;
  504. double* cv_node_risk;
  505. double* cv_node_error;
  506. int get_num_valid(int vi) { return num_valid ? num_valid[vi] : sample_count; }
  507. void set_num_valid(int vi, int n) { if( num_valid ) num_valid[vi] = n; }
  508. };
  509. struct CvDTreeParams
  510. {
  511. CV_PROP_RW int max_categories;
  512. CV_PROP_RW int max_depth;
  513. CV_PROP_RW int min_sample_count;
  514. CV_PROP_RW int cv_folds;
  515. CV_PROP_RW bool use_surrogates;
  516. CV_PROP_RW bool use_1se_rule;
  517. CV_PROP_RW bool truncate_pruned_tree;
  518. CV_PROP_RW float regression_accuracy;
  519. const float* priors;
  520. CvDTreeParams();
  521. CvDTreeParams( int max_depth, int min_sample_count,
  522. float regression_accuracy, bool use_surrogates,
  523. int max_categories, int cv_folds,
  524. bool use_1se_rule, bool truncate_pruned_tree,
  525. const float* priors );
  526. };
  527. struct CvDTreeTrainData
  528. {
  529. CvDTreeTrainData();
  530. CvDTreeTrainData( const CvMat* trainData, int tflag,
  531. const CvMat* responses, const CvMat* varIdx=0,
  532. const CvMat* sampleIdx=0, const CvMat* varType=0,
  533. const CvMat* missingDataMask=0,
  534. const CvDTreeParams& params=CvDTreeParams(),
  535. bool _shared=false, bool _add_labels=false );
  536. virtual ~CvDTreeTrainData();
  537. virtual void set_data( const CvMat* trainData, int tflag,
  538. const CvMat* responses, const CvMat* varIdx=0,
  539. const CvMat* sampleIdx=0, const CvMat* varType=0,
  540. const CvMat* missingDataMask=0,
  541. const CvDTreeParams& params=CvDTreeParams(),
  542. bool _shared=false, bool _add_labels=false,
  543. bool _update_data=false );
  544. virtual void do_responses_copy();
  545. virtual void get_vectors( const CvMat* _subsample_idx,
  546. float* values, uchar* missing, float* responses, bool get_class_idx=false );
  547. virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
  548. virtual void write_params( cv::FileStorage& fs ) const;
  549. virtual void read_params( const cv::FileNode& node );
  550. // release all the data
  551. virtual void clear();
  552. int get_num_classes() const;
  553. int get_var_type(int vi) const;
  554. int get_work_var_count() const {return work_var_count;}
  555. virtual const float* get_ord_responses( CvDTreeNode* n, float* values_buf, int* sample_indices_buf );
  556. virtual const int* get_class_labels( CvDTreeNode* n, int* labels_buf );
  557. virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
  558. virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
  559. virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
  560. virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* sorted_indices_buf,
  561. const float** ord_values, const int** sorted_indices, int* sample_indices_buf );
  562. virtual int get_child_buf_idx( CvDTreeNode* n );
  563. ////////////////////////////////////
  564. virtual bool set_params( const CvDTreeParams& params );
  565. virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
  566. int storage_idx, int offset );
  567. virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
  568. int split_point, int inversed, float quality );
  569. virtual CvDTreeSplit* new_split_cat( int vi, float quality );
  570. virtual void free_node_data( CvDTreeNode* node );
  571. virtual void free_train_data();
  572. virtual void free_node( CvDTreeNode* node );
  573. int sample_count, var_all, var_count, max_c_count;
  574. int ord_var_count, cat_var_count, work_var_count;
  575. bool have_labels, have_priors;
  576. bool is_classifier;
  577. int tflag;
  578. const CvMat* train_data;
  579. const CvMat* responses;
  580. CvMat* responses_copy; // used in Boosting
  581. int buf_count, buf_size; // buf_size is obsolete, please do not use it, use expression ((int64)buf->rows * (int64)buf->cols / buf_count) instead
  582. bool shared;
  583. int is_buf_16u;
  584. CvMat* cat_count;
  585. CvMat* cat_ofs;
  586. CvMat* cat_map;
  587. CvMat* counts;
  588. CvMat* buf;
  589. inline size_t get_length_subbuf() const
  590. {
  591. size_t res = (size_t)(work_var_count + 1) * (size_t)sample_count;
  592. return res;
  593. }
  594. CvMat* direction;
  595. CvMat* split_buf;
  596. CvMat* var_idx;
  597. CvMat* var_type; // i-th element =
  598. // k<0 - ordered
  599. // k>=0 - categorical, see k-th element of cat_* arrays
  600. CvMat* priors;
  601. CvMat* priors_mult;
  602. CvDTreeParams params;
  603. CvMemStorage* tree_storage;
  604. CvMemStorage* temp_storage;
  605. CvDTreeNode* data_root;
  606. CvSet* node_heap;
  607. CvSet* split_heap;
  608. CvSet* cv_heap;
  609. CvSet* nv_heap;
  610. cv::RNG* rng;
  611. };
  612. class CvDTree;
  613. class CvForestTree;
  614. namespace cv
  615. {
  616. struct DTreeBestSplitFinder;
  617. struct ForestTreeBestSplitFinder;
  618. }
  619. class CvDTree : public CvStatModel
  620. {
  621. public:
  622. CV_WRAP CvDTree();
  623. virtual ~CvDTree();
  624. virtual bool train( const CvMat* trainData, int tflag,
  625. const CvMat* responses, const CvMat* varIdx=0,
  626. const CvMat* sampleIdx=0, const CvMat* varType=0,
  627. const CvMat* missingDataMask=0,
  628. CvDTreeParams params=CvDTreeParams() );
  629. virtual bool train( CvMLData* trainData, CvDTreeParams params=CvDTreeParams() );
  630. // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
  631. virtual float calc_error( CvMLData* trainData, int type, std::vector<float> *resp = 0 );
  632. virtual bool train( CvDTreeTrainData* trainData, const CvMat* subsampleIdx );
  633. virtual CvDTreeNode* predict( const CvMat* sample, const CvMat* missingDataMask=0,
  634. bool preprocessedInput=false ) const;
  635. CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
  636. const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
  637. const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
  638. const cv::Mat& missingDataMask=cv::Mat(),
  639. CvDTreeParams params=CvDTreeParams() );
  640. CV_WRAP virtual CvDTreeNode* predict( const cv::Mat& sample, const cv::Mat& missingDataMask=cv::Mat(),
  641. bool preprocessedInput=false ) const;
  642. CV_WRAP virtual cv::Mat getVarImportance();
  643. virtual const CvMat* get_var_importance();
  644. CV_WRAP virtual void clear();
  645. virtual void read( const cv::FileNode& node );
  646. virtual void write( cv::FileStorage& fs, const char* name ) const;
  647. // special read & write methods for trees in the tree ensembles
  648. virtual void read( const cv::FileNode& node, CvDTreeTrainData* data );
  649. virtual void write( cv::FileStorage& fs ) const;
  650. const CvDTreeNode* get_root() const;
  651. int get_pruned_tree_idx() const;
  652. CvDTreeTrainData* get_data();
  653. protected:
  654. friend struct cv::DTreeBestSplitFinder;
  655. virtual bool do_train( const CvMat* _subsample_idx );
  656. virtual void try_split_node( CvDTreeNode* n );
  657. virtual void split_node_data( CvDTreeNode* n );
  658. virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
  659. virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
  660. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  661. virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
  662. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  663. virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
  664. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  665. virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
  666. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  667. virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
  668. virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
  669. virtual double calc_node_dir( CvDTreeNode* node );
  670. virtual void complete_node_dir( CvDTreeNode* node );
  671. virtual void cluster_categories( const int* vectors, int vector_count,
  672. int var_count, int* sums, int k, int* cluster_labels );
  673. virtual void calc_node_value( CvDTreeNode* node );
  674. virtual void prune_cv();
  675. virtual double update_tree_rnc( int T, int fold );
  676. virtual int cut_tree( int T, int fold, double min_alpha );
  677. virtual void free_prune_data(bool cut_tree);
  678. virtual void free_tree();
  679. virtual void write_node( cv::FileStorage& fs, CvDTreeNode* node ) const;
  680. virtual void write_split( cv::FileStorage& fs, CvDTreeSplit* split ) const;
  681. virtual CvDTreeNode* read_node( const cv::FileNode& node, CvDTreeNode* parent );
  682. virtual CvDTreeSplit* read_split( const cv::FileNode& node );
  683. virtual void write_tree_nodes( cv::FileStorage& fs ) const;
  684. virtual void read_tree_nodes( const cv::FileNode& node );
  685. CvDTreeNode* root;
  686. CvMat* var_importance;
  687. CvDTreeTrainData* data;
  688. CvMat train_data_hdr, responses_hdr;
  689. cv::Mat train_data_mat, responses_mat;
  690. public:
  691. int pruned_tree_idx;
  692. };
  693. /****************************************************************************************\
  694. * Random Trees Classifier *
  695. \****************************************************************************************/
  696. class CvRTrees;
  697. class CvForestTree: public CvDTree
  698. {
  699. public:
  700. CvForestTree();
  701. virtual ~CvForestTree();
  702. virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx, CvRTrees* forest );
  703. virtual int get_var_count() const {return data ? data->var_count : 0;}
  704. virtual void read( cv::FileStorage& fs, cv::FileNode& node, CvRTrees* forest, CvDTreeTrainData* _data );
  705. /* dummy methods to avoid warnings: BEGIN */
  706. virtual bool train( const CvMat* trainData, int tflag,
  707. const CvMat* responses, const CvMat* varIdx=0,
  708. const CvMat* sampleIdx=0, const CvMat* varType=0,
  709. const CvMat* missingDataMask=0,
  710. CvDTreeParams params=CvDTreeParams() );
  711. virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
  712. virtual void read( cv::FileStorage& fs, cv::FileNode& node );
  713. virtual void read( cv::FileStorage& fs, cv::FileNode& node,
  714. CvDTreeTrainData* data );
  715. /* dummy methods to avoid warnings: END */
  716. protected:
  717. friend struct cv::ForestTreeBestSplitFinder;
  718. virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
  719. CvRTrees* forest;
  720. };
  721. struct CvRTParams : public CvDTreeParams
  722. {
  723. //Parameters for the forest
  724. CV_PROP_RW bool calc_var_importance; // true <=> RF processes variable importance
  725. CV_PROP_RW int nactive_vars;
  726. CV_PROP_RW CvTermCriteria term_crit;
  727. CvRTParams();
  728. CvRTParams( int max_depth, int min_sample_count,
  729. float regression_accuracy, bool use_surrogates,
  730. int max_categories, const float* priors, bool calc_var_importance,
  731. int nactive_vars, int max_num_of_trees_in_the_forest,
  732. float forest_accuracy, int termcrit_type );
  733. };
  734. class CvRTrees : public CvStatModel
  735. {
  736. public:
  737. CV_WRAP CvRTrees();
  738. virtual ~CvRTrees();
  739. virtual bool train( const CvMat* trainData, int tflag,
  740. const CvMat* responses, const CvMat* varIdx=0,
  741. const CvMat* sampleIdx=0, const CvMat* varType=0,
  742. const CvMat* missingDataMask=0,
  743. CvRTParams params=CvRTParams() );
  744. virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
  745. virtual float predict( const CvMat* sample, const CvMat* missing = 0 ) const;
  746. virtual float predict_prob( const CvMat* sample, const CvMat* missing = 0 ) const;
  747. CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
  748. const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
  749. const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
  750. const cv::Mat& missingDataMask=cv::Mat(),
  751. CvRTParams params=CvRTParams() );
  752. CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
  753. CV_WRAP virtual float predict_prob( const cv::Mat& sample, const cv::Mat& missing = cv::Mat() ) const;
  754. CV_WRAP virtual cv::Mat getVarImportance();
  755. CV_WRAP virtual void clear();
  756. virtual const CvMat* get_var_importance();
  757. virtual float get_proximity( const CvMat* sample1, const CvMat* sample2,
  758. const CvMat* missing1 = 0, const CvMat* missing2 = 0 ) const;
  759. virtual float calc_error( CvMLData* data, int type , std::vector<float>* resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
  760. virtual float get_train_error();
  761. virtual void read( cv::FileStorage& fs, cv::FileNode& node );
  762. virtual void write( cv::FileStorage& fs, const char* name ) const;
  763. CvMat* get_active_var_mask();
  764. CvRNG* get_rng();
  765. int get_tree_count() const;
  766. CvForestTree* get_tree(int i) const;
  767. protected:
  768. virtual cv::String getName() const;
  769. virtual bool grow_forest( const CvTermCriteria term_crit );
  770. // array of the trees of the forest
  771. CvForestTree** trees;
  772. CvDTreeTrainData* data;
  773. CvMat train_data_hdr, responses_hdr;
  774. cv::Mat train_data_mat, responses_mat;
  775. int ntrees;
  776. int nclasses;
  777. double oob_error;
  778. CvMat* var_importance;
  779. int nsamples;
  780. cv::RNG* rng;
  781. CvMat* active_var_mask;
  782. };
  783. /****************************************************************************************\
  784. * Extremely randomized trees Classifier *
  785. \****************************************************************************************/
  786. struct CvERTreeTrainData : public CvDTreeTrainData
  787. {
  788. virtual void set_data( const CvMat* trainData, int tflag,
  789. const CvMat* responses, const CvMat* varIdx=0,
  790. const CvMat* sampleIdx=0, const CvMat* varType=0,
  791. const CvMat* missingDataMask=0,
  792. const CvDTreeParams& params=CvDTreeParams(),
  793. bool _shared=false, bool _add_labels=false,
  794. bool _update_data=false );
  795. virtual void get_ord_var_data( CvDTreeNode* n, int vi, float* ord_values_buf, int* missing_buf,
  796. const float** ord_values, const int** missing, int* sample_buf = 0 );
  797. virtual const int* get_sample_indices( CvDTreeNode* n, int* indices_buf );
  798. virtual const int* get_cv_labels( CvDTreeNode* n, int* labels_buf );
  799. virtual const int* get_cat_var_data( CvDTreeNode* n, int vi, int* cat_values_buf );
  800. virtual void get_vectors( const CvMat* _subsample_idx, float* values, uchar* missing,
  801. float* responses, bool get_class_idx=false );
  802. virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );
  803. const CvMat* missing_mask;
  804. };
  805. class CvForestERTree : public CvForestTree
  806. {
  807. protected:
  808. virtual double calc_node_dir( CvDTreeNode* node );
  809. virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
  810. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  811. virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
  812. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  813. virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
  814. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  815. virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
  816. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  817. virtual void split_node_data( CvDTreeNode* n );
  818. };
  819. class CvERTrees : public CvRTrees
  820. {
  821. public:
  822. CV_WRAP CvERTrees();
  823. virtual ~CvERTrees();
  824. virtual bool train( const CvMat* trainData, int tflag,
  825. const CvMat* responses, const CvMat* varIdx=0,
  826. const CvMat* sampleIdx=0, const CvMat* varType=0,
  827. const CvMat* missingDataMask=0,
  828. CvRTParams params=CvRTParams());
  829. CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
  830. const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
  831. const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
  832. const cv::Mat& missingDataMask=cv::Mat(),
  833. CvRTParams params=CvRTParams());
  834. virtual bool train( CvMLData* data, CvRTParams params=CvRTParams() );
  835. protected:
  836. virtual cv::String getName() const;
  837. virtual bool grow_forest( const CvTermCriteria term_crit );
  838. };
  839. /****************************************************************************************\
  840. * Boosted tree classifier *
  841. \****************************************************************************************/
  842. struct CvBoostParams : public CvDTreeParams
  843. {
  844. CV_PROP_RW int boost_type;
  845. CV_PROP_RW int weak_count;
  846. CV_PROP_RW int split_criteria;
  847. CV_PROP_RW double weight_trim_rate;
  848. CvBoostParams();
  849. CvBoostParams( int boost_type, int weak_count, double weight_trim_rate,
  850. int max_depth, bool use_surrogates, const float* priors );
  851. };
  852. class CvBoost;
  853. class CvBoostTree: public CvDTree
  854. {
  855. public:
  856. CvBoostTree();
  857. virtual ~CvBoostTree();
  858. virtual bool train( CvDTreeTrainData* trainData,
  859. const CvMat* subsample_idx, CvBoost* ensemble );
  860. virtual void scale( double s );
  861. virtual void read( const cv::FileNode& node,
  862. CvBoost* ensemble, CvDTreeTrainData* _data );
  863. virtual void clear();
  864. /* dummy methods to avoid warnings: BEGIN */
  865. virtual bool train( const CvMat* trainData, int tflag,
  866. const CvMat* responses, const CvMat* varIdx=0,
  867. const CvMat* sampleIdx=0, const CvMat* varType=0,
  868. const CvMat* missingDataMask=0,
  869. CvDTreeParams params=CvDTreeParams() );
  870. virtual bool train( CvDTreeTrainData* trainData, const CvMat* _subsample_idx );
  871. virtual void read( cv::FileNode& node );
  872. virtual void read( cv::FileNode& node, CvDTreeTrainData* data );
  873. /* dummy methods to avoid warnings: END */
  874. protected:
  875. virtual void try_split_node( CvDTreeNode* n );
  876. virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
  877. virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi, uchar* ext_buf = 0 );
  878. virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi,
  879. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  880. virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi,
  881. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  882. virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi,
  883. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  884. virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi,
  885. float init_quality = 0, CvDTreeSplit* _split = 0, uchar* ext_buf = 0 );
  886. virtual void calc_node_value( CvDTreeNode* n );
  887. virtual double calc_node_dir( CvDTreeNode* n );
  888. CvBoost* ensemble;
  889. };
  890. class CvBoost : public CvStatModel
  891. {
  892. public:
  893. // Boosting type
  894. enum { DISCRETE=0, REAL=1, LOGIT=2, GENTLE=3 };
  895. // Splitting criteria
  896. enum { DEFAULT=0, GINI=1, MISCLASS=3, SQERR=4 };
  897. CV_WRAP CvBoost();
  898. virtual ~CvBoost();
  899. CvBoost( const CvMat* trainData, int tflag,
  900. const CvMat* responses, const CvMat* varIdx=0,
  901. const CvMat* sampleIdx=0, const CvMat* varType=0,
  902. const CvMat* missingDataMask=0,
  903. CvBoostParams params=CvBoostParams() );
  904. virtual bool train( const CvMat* trainData, int tflag,
  905. const CvMat* responses, const CvMat* varIdx=0,
  906. const CvMat* sampleIdx=0, const CvMat* varType=0,
  907. const CvMat* missingDataMask=0,
  908. CvBoostParams params=CvBoostParams(),
  909. bool update=false );
  910. virtual bool train( CvMLData* data,
  911. CvBoostParams params=CvBoostParams(),
  912. bool update=false );
  913. virtual float predict( const CvMat* sample, const CvMat* missing=0,
  914. CvMat* weak_responses=0, CvSlice slice=CV_WHOLE_SEQ,
  915. bool raw_mode=false, bool return_sum=false ) const;
  916. CV_WRAP CvBoost( const cv::Mat& trainData, int tflag,
  917. const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
  918. const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
  919. const cv::Mat& missingDataMask=cv::Mat(),
  920. CvBoostParams params=CvBoostParams() );
  921. CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
  922. const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
  923. const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
  924. const cv::Mat& missingDataMask=cv::Mat(),
  925. CvBoostParams params=CvBoostParams(),
  926. bool update=false );
  927. CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
  928. const cv::Range& slice=cv::Range::all(), bool rawMode=false,
  929. bool returnSum=false ) const;
  930. virtual float calc_error( CvMLData* _data, int type , std::vector<float> *resp = 0 ); // type in {CV_TRAIN_ERROR, CV_TEST_ERROR}
  931. CV_WRAP virtual void prune( CvSlice slice );
  932. CV_WRAP virtual void clear();
  933. virtual void write( cv::FileStorage& storage, const char* name ) const;
  934. virtual void read( cv::FileNode& node );
  935. virtual const CvMat* get_active_vars(bool absolute_idx=true);
  936. CvSeq* get_weak_predictors();
  937. CvMat* get_weights();
  938. CvMat* get_subtree_weights();
  939. CvMat* get_weak_response();
  940. const CvBoostParams& get_params() const;
  941. const CvDTreeTrainData* get_data() const;
  942. protected:
  943. virtual bool set_params( const CvBoostParams& params );
  944. virtual void update_weights( CvBoostTree* tree );
  945. virtual void trim_weights();
  946. virtual void write_params( cv::FileStorage & fs ) const;
  947. virtual void read_params( cv::FileNode& node );
  948. virtual void initialize_weights(double (&p)[2]);
  949. CvDTreeTrainData* data;
  950. CvMat train_data_hdr, responses_hdr;
  951. cv::Mat train_data_mat, responses_mat;
  952. CvBoostParams params;
  953. CvSeq* weak;
  954. CvMat* active_vars;
  955. CvMat* active_vars_abs;
  956. bool have_active_cat_vars;
  957. CvMat* orig_response;
  958. CvMat* sum_response;
  959. CvMat* weak_eval;
  960. CvMat* subsample_mask;
  961. CvMat* weights;
  962. CvMat* subtree_weights;
  963. bool have_subsample;
  964. };
  965. /****************************************************************************************\
  966. * Gradient Boosted Trees *
  967. \****************************************************************************************/
  968. // DataType: STRUCT CvGBTreesParams
  969. // Parameters of GBT (Gradient Boosted trees model), including single
  970. // tree settings and ensemble parameters.
  971. //
  972. // weak_count - count of trees in the ensemble
  973. // loss_function_type - loss function used for ensemble training
  974. // subsample_portion - portion of whole training set used for
  975. // every single tree training.
  976. // subsample_portion value is in (0.0, 1.0].
  977. // subsample_portion == 1.0 when whole dataset is
  978. // used on each step. Count of sample used on each
  979. // step is computed as
  980. // int(total_samples_count * subsample_portion).
  981. // shrinkage - regularization parameter.
  982. // Each tree prediction is multiplied on shrinkage value.
  983. struct CvGBTreesParams : public CvDTreeParams
  984. {
  985. CV_PROP_RW int weak_count;
  986. CV_PROP_RW int loss_function_type;
  987. CV_PROP_RW float subsample_portion;
  988. CV_PROP_RW float shrinkage;
  989. CvGBTreesParams();
  990. CvGBTreesParams( int loss_function_type, int weak_count, float shrinkage,
  991. float subsample_portion, int max_depth, bool use_surrogates );
  992. };
  993. // DataType: CLASS CvGBTrees
  994. // Gradient Boosting Trees (GBT) algorithm implementation.
  995. //
  996. // data - training dataset
  997. // params - parameters of the CvGBTrees
  998. // weak - array[0..(class_count-1)] of CvSeq
  999. // for storing tree ensembles
  1000. // orig_response - original responses of the training set samples
  1001. // sum_response - predictions of the current model on the training dataset.
  1002. // this matrix is updated on every iteration.
  1003. // sum_response_tmp - predictions of the model on the training set on the next
  1004. // step. On every iteration values of sum_responses_tmp are
  1005. // computed via sum_responses values. When the current
  1006. // step is complete sum_response values become equal to
  1007. // sum_responses_tmp.
  1008. // sampleIdx - indices of samples used for training the ensemble.
  1009. // CvGBTrees training procedure takes a set of samples
  1010. // (train_data) and a set of responses (responses).
  1011. // Only pairs (train_data[i], responses[i]), where i is
  1012. // in sample_idx are used for training the ensemble.
  1013. // subsample_train - indices of samples used for training a single decision
  1014. // tree on the current step. This indices are countered
  1015. // relatively to the sample_idx, so that pairs
  1016. // (train_data[sample_idx[i]], responses[sample_idx[i]])
  1017. // are used for training a decision tree.
  1018. // Training set is randomly splited
  1019. // in two parts (subsample_train and subsample_test)
  1020. // on every iteration accordingly to the portion parameter.
  1021. // subsample_test - relative indices of samples from the training set,
  1022. // which are not used for training a tree on the current
  1023. // step.
  1024. // missing - mask of the missing values in the training set. This
  1025. // matrix has the same size as train_data. 1 - missing
  1026. // value, 0 - not a missing value.
  1027. // class_labels - output class labels map.
  1028. // rng - random number generator. Used for splitting the
  1029. // training set.
  1030. // class_count - count of output classes.
  1031. // class_count == 1 in the case of regression,
  1032. // and > 1 in the case of classification.
  1033. // delta - Huber loss function parameter.
  1034. // base_value - start point of the gradient descent procedure.
  1035. // model prediction is
  1036. // f(x) = f_0 + sum_{i=1..weak_count-1}(f_i(x)), where
  1037. // f_0 is the base value.
  1038. class CvGBTrees : public CvStatModel
  1039. {
  1040. public:
  1041. /*
  1042. // DataType: ENUM
  1043. // Loss functions implemented in CvGBTrees.
  1044. //
  1045. // SQUARED_LOSS
  1046. // problem: regression
  1047. // loss = (x - x')^2
  1048. //
  1049. // ABSOLUTE_LOSS
  1050. // problem: regression
  1051. // loss = abs(x - x')
  1052. //
  1053. // HUBER_LOSS
  1054. // problem: regression
  1055. // loss = delta*( abs(x - x') - delta/2), if abs(x - x') > delta
  1056. // 1/2*(x - x')^2, if abs(x - x') <= delta,
  1057. // where delta is the alpha-quantile of pseudo responses from
  1058. // the training set.
  1059. //
  1060. // DEVIANCE_LOSS
  1061. // problem: classification
  1062. //
  1063. */
  1064. enum {SQUARED_LOSS=0, ABSOLUTE_LOSS, HUBER_LOSS=3, DEVIANCE_LOSS};
  1065. /*
  1066. // Default constructor. Creates a model only (without training).
  1067. // Should be followed by one form of the train(...) function.
  1068. //
  1069. // API
  1070. // CvGBTrees();
  1071. // INPUT
  1072. // OUTPUT
  1073. // RESULT
  1074. */
  1075. CV_WRAP CvGBTrees();
  1076. /*
  1077. // Full form constructor. Creates a gradient boosting model and does the
  1078. // train.
  1079. //
  1080. // API
  1081. // CvGBTrees( const CvMat* trainData, int tflag,
  1082. const CvMat* responses, const CvMat* varIdx=0,
  1083. const CvMat* sampleIdx=0, const CvMat* varType=0,
  1084. const CvMat* missingDataMask=0,
  1085. CvGBTreesParams params=CvGBTreesParams() );
  1086. // INPUT
  1087. // trainData - a set of input feature vectors.
  1088. // size of matrix is
  1089. // <count of samples> x <variables count>
  1090. // or <variables count> x <count of samples>
  1091. // depending on the tflag parameter.
  1092. // matrix values are float.
  1093. // tflag - a flag showing how do samples stored in the
  1094. // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
  1095. // or column by column (tflag=CV_COL_SAMPLE).
  1096. // responses - a vector of responses corresponding to the samples
  1097. // in trainData.
  1098. // varIdx - indices of used variables. zero value means that all
  1099. // variables are active.
  1100. // sampleIdx - indices of used samples. zero value means that all
  1101. // samples from trainData are in the training set.
  1102. // varType - vector of <variables count> length. gives every
  1103. // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
  1104. // varType = 0 means all variables are numerical.
  1105. // missingDataMask - a mask of misiing values in trainData.
  1106. // missingDataMask = 0 means that there are no missing
  1107. // values.
  1108. // params - parameters of GTB algorithm.
  1109. // OUTPUT
  1110. // RESULT
  1111. */
  1112. CvGBTrees( const CvMat* trainData, int tflag,
  1113. const CvMat* responses, const CvMat* varIdx=0,
  1114. const CvMat* sampleIdx=0, const CvMat* varType=0,
  1115. const CvMat* missingDataMask=0,
  1116. CvGBTreesParams params=CvGBTreesParams() );
  1117. /*
  1118. // Destructor.
  1119. */
  1120. virtual ~CvGBTrees();
  1121. /*
  1122. // Gradient tree boosting model training
  1123. //
  1124. // API
  1125. // virtual bool train( const CvMat* trainData, int tflag,
  1126. const CvMat* responses, const CvMat* varIdx=0,
  1127. const CvMat* sampleIdx=0, const CvMat* varType=0,
  1128. const CvMat* missingDataMask=0,
  1129. CvGBTreesParams params=CvGBTreesParams(),
  1130. bool update=false );
  1131. // INPUT
  1132. // trainData - a set of input feature vectors.
  1133. // size of matrix is
  1134. // <count of samples> x <variables count>
  1135. // or <variables count> x <count of samples>
  1136. // depending on the tflag parameter.
  1137. // matrix values are float.
  1138. // tflag - a flag showing how do samples stored in the
  1139. // trainData matrix row by row (tflag=CV_ROW_SAMPLE)
  1140. // or column by column (tflag=CV_COL_SAMPLE).
  1141. // responses - a vector of responses corresponding to the samples
  1142. // in trainData.
  1143. // varIdx - indices of used variables. zero value means that all
  1144. // variables are active.
  1145. // sampleIdx - indices of used samples. zero value means that all
  1146. // samples from trainData are in the training set.
  1147. // varType - vector of <variables count> length. gives every
  1148. // variable type CV_VAR_CATEGORICAL or CV_VAR_ORDERED.
  1149. // varType = 0 means all variables are numerical.
  1150. // missingDataMask - a mask of misiing values in trainData.
  1151. // missingDataMask = 0 means that there are no missing
  1152. // values.
  1153. // params - parameters of GTB algorithm.
  1154. // update - is not supported now. (!)
  1155. // OUTPUT
  1156. // RESULT
  1157. // Error state.
  1158. */
  1159. virtual bool train( const CvMat* trainData, int tflag,
  1160. const CvMat* responses, const CvMat* varIdx=0,
  1161. const CvMat* sampleIdx=0, const CvMat* varType=0,
  1162. const CvMat* missingDataMask=0,
  1163. CvGBTreesParams params=CvGBTreesParams(),
  1164. bool update=false );
  1165. /*
  1166. // Gradient tree boosting model training
  1167. //
  1168. // API
  1169. // virtual bool train( CvMLData* data,
  1170. CvGBTreesParams params=CvGBTreesParams(),
  1171. bool update=false ) {return false;}
  1172. // INPUT
  1173. // data - training set.
  1174. // params - parameters of GTB algorithm.
  1175. // update - is not supported now. (!)
  1176. // OUTPUT
  1177. // RESULT
  1178. // Error state.
  1179. */
  1180. virtual bool train( CvMLData* data,
  1181. CvGBTreesParams params=CvGBTreesParams(),
  1182. bool update=false );
  1183. /*
  1184. // Response value prediction
  1185. //
  1186. // API
  1187. // virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
  1188. CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
  1189. int k=-1 ) const;
  1190. // INPUT
  1191. // sample - input sample of the same type as in the training set.
  1192. // missing - missing values mask. missing=0 if there are no
  1193. // missing values in sample vector.
  1194. // weak_responses - predictions of all of the trees.
  1195. // not implemented (!)
  1196. // slice - part of the ensemble used for prediction.
  1197. // slice = CV_WHOLE_SEQ when all trees are used.
  1198. // k - number of ensemble used.
  1199. // k is in {-1,0,1,..,<count of output classes-1>}.
  1200. // in the case of classification problem
  1201. // <count of output classes-1> ensembles are built.
  1202. // If k = -1 ordinary prediction is the result,
  1203. // otherwise function gives the prediction of the
  1204. // k-th ensemble only.
  1205. // OUTPUT
  1206. // RESULT
  1207. // Predicted value.
  1208. */
  1209. virtual float predict_serial( const CvMat* sample, const CvMat* missing=0,
  1210. CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
  1211. int k=-1 ) const;
  1212. /*
  1213. // Response value prediction.
  1214. // Parallel version (in the case of TBB existence)
  1215. //
  1216. // API
  1217. // virtual float predict( const CvMat* sample, const CvMat* missing=0,
  1218. CvMat* weak_responses=0, CvSlice slice = CV_WHOLE_SEQ,
  1219. int k=-1 ) const;
  1220. // INPUT
  1221. // sample - input sample of the same type as in the training set.
  1222. // missing - missing values mask. missing=0 if there are no
  1223. // missing values in sample vector.
  1224. // weak_responses - predictions of all of the trees.
  1225. // not implemented (!)
  1226. // slice - part of the ensemble used for prediction.
  1227. // slice = CV_WHOLE_SEQ when all trees are used.
  1228. // k - number of ensemble used.
  1229. // k is in {-1,0,1,..,<count of output classes-1>}.
  1230. // in the case of classification problem
  1231. // <count of output classes-1> ensembles are built.
  1232. // If k = -1 ordinary prediction is the result,
  1233. // otherwise function gives the prediction of the
  1234. // k-th ensemble only.
  1235. // OUTPUT
  1236. // RESULT
  1237. // Predicted value.
  1238. */
  1239. virtual float predict( const CvMat* sample, const CvMat* missing=0,
  1240. CvMat* weakResponses=0, CvSlice slice = CV_WHOLE_SEQ,
  1241. int k=-1 ) const;
  1242. /*
  1243. // Deletes all the data.
  1244. //
  1245. // API
  1246. // virtual void clear();
  1247. // INPUT
  1248. // OUTPUT
  1249. // delete data, weak, orig_response, sum_response,
  1250. // weak_eval, subsample_train, subsample_test,
  1251. // sample_idx, missing, lass_labels
  1252. // delta = 0.0
  1253. // RESULT
  1254. */
  1255. CV_WRAP virtual void clear();
  1256. /*
  1257. // Compute error on the train/test set.
  1258. //
  1259. // API
  1260. // virtual float calc_error( CvMLData* _data, int type,
  1261. // std::vector<float> *resp = 0 );
  1262. //
  1263. // INPUT
  1264. // data - dataset
  1265. // type - defines which error is to compute: train (CV_TRAIN_ERROR) or
  1266. // test (CV_TEST_ERROR).
  1267. // OUTPUT
  1268. // resp - vector of predictions
  1269. // RESULT
  1270. // Error value.
  1271. */
  1272. virtual float calc_error( CvMLData* _data, int type,
  1273. std::vector<float> *resp = 0 );
  1274. /*
  1275. //
  1276. // Write parameters of the gtb model and data. Write learned model.
  1277. //
  1278. // API
  1279. // virtual void write( cv::FileStorage& fs, const char* name ) const;
  1280. //
  1281. // INPUT
  1282. // fs - file storage to read parameters from.
  1283. // name - model name.
  1284. // OUTPUT
  1285. // RESULT
  1286. */
  1287. virtual void write( cv::FileStorage& fs, const char* name ) const;
  1288. /*
  1289. //
  1290. // Read parameters of the gtb model and data. Read learned model.
  1291. //
  1292. // API
  1293. // virtual void read( cv::FileStorage& fs, cv::FileNode& node );
  1294. //
  1295. // INPUT
  1296. // fs - file storage to read parameters from.
  1297. // node - file node.
  1298. // OUTPUT
  1299. // RESULT
  1300. */
  1301. virtual void read( cv::FileStorage& fs, cv::FileNode& node );
  1302. // new-style C++ interface
  1303. CV_WRAP CvGBTrees( const cv::Mat& trainData, int tflag,
  1304. const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
  1305. const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
  1306. const cv::Mat& missingDataMask=cv::Mat(),
  1307. CvGBTreesParams params=CvGBTreesParams() );
  1308. CV_WRAP virtual bool train( const cv::Mat& trainData, int tflag,
  1309. const cv::Mat& responses, const cv::Mat& varIdx=cv::Mat(),
  1310. const cv::Mat& sampleIdx=cv::Mat(), const cv::Mat& varType=cv::Mat(),
  1311. const cv::Mat& missingDataMask=cv::Mat(),
  1312. CvGBTreesParams params=CvGBTreesParams(),
  1313. bool update=false );
  1314. CV_WRAP virtual float predict( const cv::Mat& sample, const cv::Mat& missing=cv::Mat(),
  1315. const cv::Range& slice = cv::Range::all(),
  1316. int k=-1 ) const;
  1317. protected:
  1318. /*
  1319. // Compute the gradient vector components.
  1320. //
  1321. // API
  1322. // virtual void find_gradient( const int k = 0);
  1323. // INPUT
  1324. // k - used for classification problem, determining current
  1325. // tree ensemble.
  1326. // OUTPUT
  1327. // changes components of data->responses
  1328. // which correspond to samples used for training
  1329. // on the current step.
  1330. // RESULT
  1331. */
  1332. virtual void find_gradient( const int k = 0);
  1333. /*
  1334. //
  1335. // Change values in tree leaves according to the used loss function.
  1336. //
  1337. // API
  1338. // virtual void change_values(CvDTree* tree, const int k = 0);
  1339. //
  1340. // INPUT
  1341. // tree - decision tree to change.
  1342. // k - used for classification problem, determining current
  1343. // tree ensemble.
  1344. // OUTPUT
  1345. // changes 'value' fields of the trees' leaves.
  1346. // changes sum_response_tmp.
  1347. // RESULT
  1348. */
  1349. virtual void change_values(CvDTree* tree, const int k = 0);
  1350. /*
  1351. //
  1352. // Find optimal constant prediction value according to the used loss
  1353. // function.
  1354. // The goal is to find a constant which gives the minimal summary loss
  1355. // on the _Idx samples.
  1356. //
  1357. // API
  1358. // virtual float find_optimal_value( const CvMat* _Idx );
  1359. //
  1360. // INPUT
  1361. // _Idx - indices of the samples from the training set.
  1362. // OUTPUT
  1363. // RESULT
  1364. // optimal constant value.
  1365. */
  1366. virtual float find_optimal_value( const CvMat* _Idx );
  1367. /*
  1368. //
  1369. // Randomly split the whole training set in two parts according
  1370. // to params.portion.
  1371. //
  1372. // API
  1373. // virtual void do_subsample();
  1374. //
  1375. // INPUT
  1376. // OUTPUT
  1377. // subsample_train - indices of samples used for training
  1378. // subsample_test - indices of samples used for test
  1379. // RESULT
  1380. */
  1381. virtual void do_subsample();
  1382. /*
  1383. //
  1384. // Internal recursive function giving an array of subtree tree leaves.
  1385. //
  1386. // API
  1387. // void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
  1388. //
  1389. // INPUT
  1390. // node - current leaf.
  1391. // OUTPUT
  1392. // count - count of leaves in the subtree.
  1393. // leaves - array of pointers to leaves.
  1394. // RESULT
  1395. */
  1396. void leaves_get( CvDTreeNode** leaves, int& count, CvDTreeNode* node );
  1397. /*
  1398. //
  1399. // Get leaves of the tree.
  1400. //
  1401. // API
  1402. // CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
  1403. //
  1404. // INPUT
  1405. // dtree - decision tree.
  1406. // OUTPUT
  1407. // len - count of the leaves.
  1408. // RESULT
  1409. // CvDTreeNode** - array of pointers to leaves.
  1410. */
  1411. CvDTreeNode** GetLeaves( const CvDTree* dtree, int& len );
  1412. /*
  1413. //
  1414. // Is it a regression or a classification.
  1415. //
  1416. // API
  1417. // bool problem_type();
  1418. //
  1419. // INPUT
  1420. // OUTPUT
  1421. // RESULT
  1422. // false if it is a classification problem,
  1423. // true - if regression.
  1424. */
  1425. virtual bool problem_type() const;
  1426. /*
  1427. //
  1428. // Write parameters of the gtb model.
  1429. //
  1430. // API
  1431. // virtual void write_params( cv::FileStorage& fs ) const;
  1432. //
  1433. // INPUT
  1434. // fs - file storage to write parameters to.
  1435. // OUTPUT
  1436. // RESULT
  1437. */
  1438. virtual void write_params( cv::FileStorage& fs ) const;
  1439. /*
  1440. //
  1441. // Read parameters of the gtb model and data.
  1442. //
  1443. // API
  1444. // virtual void read_params( const cv::FileStorage& fs );
  1445. //
  1446. // INPUT
  1447. // fs - file storage to read parameters from.
  1448. // OUTPUT
  1449. // params - parameters of the gtb model.
  1450. // data - contains information about the structure
  1451. // of the data set (count of variables,
  1452. // their types, etc.).
  1453. // class_labels - output class labels map.
  1454. // RESULT
  1455. */
  1456. virtual void read_params( cv::FileStorage& fs, cv::FileNode& fnode );
  1457. int get_len(const CvMat* mat) const;
  1458. CvDTreeTrainData* data;
  1459. CvGBTreesParams params;
  1460. CvSeq** weak;
  1461. CvMat* orig_response;
  1462. CvMat* sum_response;
  1463. CvMat* sum_response_tmp;
  1464. CvMat* sample_idx;
  1465. CvMat* subsample_train;
  1466. CvMat* subsample_test;
  1467. CvMat* missing;
  1468. CvMat* class_labels;
  1469. cv::RNG* rng;
  1470. int class_count;
  1471. float delta;
  1472. float base_value;
  1473. };
  1474. /****************************************************************************************\
  1475. * Artificial Neural Networks (ANN) *
  1476. \****************************************************************************************/
  1477. /////////////////////////////////// Multi-Layer Perceptrons //////////////////////////////
  1478. struct CvANN_MLP_TrainParams
  1479. {
  1480. CvANN_MLP_TrainParams();
  1481. CvANN_MLP_TrainParams( CvTermCriteria term_crit, int train_method,
  1482. double param1, double param2=0 );
  1483. ~CvANN_MLP_TrainParams();
  1484. enum { BACKPROP=0, RPROP=1 };
  1485. CV_PROP_RW CvTermCriteria term_crit;
  1486. CV_PROP_RW int train_method;
  1487. // backpropagation parameters
  1488. CV_PROP_RW double bp_dw_scale, bp_moment_scale;
  1489. // rprop parameters
  1490. CV_PROP_RW double rp_dw0, rp_dw_plus, rp_dw_minus, rp_dw_min, rp_dw_max;
  1491. };
  1492. class CvANN_MLP : public CvStatModel
  1493. {
  1494. public:
  1495. CV_WRAP CvANN_MLP();
  1496. CvANN_MLP( const CvMat* layerSizes,
  1497. int activateFunc=CvANN_MLP::SIGMOID_SYM,
  1498. double fparam1=0, double fparam2=0 );
  1499. virtual ~CvANN_MLP();
  1500. virtual void create( const CvMat* layerSizes,
  1501. int activateFunc=CvANN_MLP::SIGMOID_SYM,
  1502. double fparam1=0, double fparam2=0 );
  1503. virtual int train( const CvMat* inputs, const CvMat* outputs,
  1504. const CvMat* sampleWeights, const CvMat* sampleIdx=0,
  1505. CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
  1506. int flags=0 );
  1507. virtual float predict( const CvMat* inputs, CV_OUT CvMat* outputs ) const;
  1508. CV_WRAP CvANN_MLP( const cv::Mat& layerSizes,
  1509. int activateFunc=CvANN_MLP::SIGMOID_SYM,
  1510. double fparam1=0, double fparam2=0 );
  1511. CV_WRAP virtual void create( const cv::Mat& layerSizes,
  1512. int activateFunc=CvANN_MLP::SIGMOID_SYM,
  1513. double fparam1=0, double fparam2=0 );
  1514. CV_WRAP virtual int train( const cv::Mat& inputs, const cv::Mat& outputs,
  1515. const cv::Mat& sampleWeights, const cv::Mat& sampleIdx=cv::Mat(),
  1516. CvANN_MLP_TrainParams params = CvANN_MLP_TrainParams(),
  1517. int flags=0 );
  1518. CV_WRAP virtual float predict( const cv::Mat& inputs, CV_OUT cv::Mat& outputs ) const;
  1519. CV_WRAP virtual void clear();
  1520. // possible activation functions
  1521. enum { IDENTITY = 0, SIGMOID_SYM = 1, GAUSSIAN = 2 };
  1522. // available training flags
  1523. enum { UPDATE_WEIGHTS = 1, NO_INPUT_SCALE = 2, NO_OUTPUT_SCALE = 4 };
  1524. virtual void read( cv::FileStorage& fs, cv::FileNode& node );
  1525. virtual void write( cv::FileStorage& storage, const char* name ) const;
  1526. int get_layer_count() { return layer_sizes ? layer_sizes->cols : 0; }
  1527. const CvMat* get_layer_sizes() { return layer_sizes; }
  1528. double* get_weights(int layer)
  1529. {
  1530. return layer_sizes && weights &&
  1531. (unsigned)layer <= (unsigned)layer_sizes->cols ? weights[layer] : 0;
  1532. }
  1533. virtual void calc_activ_func_deriv( CvMat* xf, CvMat* deriv, const double* bias ) const;
  1534. protected:
  1535. virtual bool prepare_to_train( const CvMat* _inputs, const CvMat* _outputs,
  1536. const CvMat* _sample_weights, const CvMat* sampleIdx,
  1537. CvVectors* _ivecs, CvVectors* _ovecs, double** _sw, int _flags );
  1538. // sequential random backpropagation
  1539. virtual int train_backprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
  1540. // RPROP algorithm
  1541. virtual int train_rprop( CvVectors _ivecs, CvVectors _ovecs, const double* _sw );
  1542. virtual void calc_activ_func( CvMat* xf, const double* bias ) const;
  1543. virtual void set_activ_func( int _activ_func=SIGMOID_SYM,
  1544. double _f_param1=0, double _f_param2=0 );
  1545. virtual void init_weights();
  1546. virtual void scale_input( const CvMat* _src, CvMat* _dst ) const;
  1547. virtual void scale_output( const CvMat* _src, CvMat* _dst ) const;
  1548. virtual void calc_input_scale( const CvVectors* vecs, int flags );
  1549. virtual void calc_output_scale( const CvVectors* vecs, int flags );
  1550. virtual void write_params( cv::FileStorage& fs ) const;
  1551. virtual void read_params( cv::FileStorage& fs, cv::FileNode& node );
  1552. CvMat* layer_sizes;
  1553. CvMat* wbuf;
  1554. CvMat* sample_weights;
  1555. double** weights;
  1556. double f_param1, f_param2;
  1557. double min_val, max_val, min_val1, max_val1;
  1558. int activ_func;
  1559. int max_count, max_buf_sz;
  1560. CvANN_MLP_TrainParams params;
  1561. cv::RNG* rng;
  1562. };
  1563. /****************************************************************************************\
  1564. * Data *
  1565. \****************************************************************************************/
  1566. #define CV_COUNT 0
  1567. #define CV_PORTION 1
  1568. struct CvTrainTestSplit
  1569. {
  1570. CvTrainTestSplit();
  1571. CvTrainTestSplit( int train_sample_count, bool mix = true);
  1572. CvTrainTestSplit( float train_sample_portion, bool mix = true);
  1573. union
  1574. {
  1575. int count;
  1576. float portion;
  1577. } train_sample_part;
  1578. int train_sample_part_mode;
  1579. bool mix;
  1580. };
  1581. class CvMLData
  1582. {
  1583. public:
  1584. CvMLData();
  1585. virtual ~CvMLData();
  1586. // returns:
  1587. // 0 - OK
  1588. // -1 - file can not be opened or is not correct
  1589. int read_csv( const char* filename );
  1590. const CvMat* get_values() const;
  1591. const CvMat* get_responses();
  1592. const CvMat* get_missing() const;
  1593. void set_header_lines_number( int n );
  1594. int get_header_lines_number() const;
  1595. void set_response_idx( int idx ); // old response become predictors, new response_idx = idx
  1596. // if idx < 0 there will be no response
  1597. int get_response_idx() const;
  1598. void set_train_test_split( const CvTrainTestSplit * spl );
  1599. const CvMat* get_train_sample_idx() const;
  1600. const CvMat* get_test_sample_idx() const;
  1601. void mix_train_and_test_idx();
  1602. const CvMat* get_var_idx();
  1603. void chahge_var_idx( int vi, bool state ); // misspelled (saved for back compitability),
  1604. // use change_var_idx
  1605. void change_var_idx( int vi, bool state ); // state == true to set vi-variable as predictor
  1606. const CvMat* get_var_types();
  1607. int get_var_type( int var_idx ) const;
  1608. // following 2 methods enable to change vars type
  1609. // use these methods to assign CV_VAR_CATEGORICAL type for categorical variable
  1610. // with numerical labels; in the other cases var types are correctly determined automatically
  1611. void set_var_types( const char* str ); // str examples:
  1612. // "ord[0-17],cat[18]", "ord[0,2,4,10-12], cat[1,3,5-9,13,14]",
  1613. // "cat", "ord" (all vars are categorical/ordered)
  1614. void change_var_type( int var_idx, int type); // type in { CV_VAR_ORDERED, CV_VAR_CATEGORICAL }
  1615. void set_delimiter( char ch );
  1616. char get_delimiter() const;
  1617. void set_miss_ch( char ch );
  1618. char get_miss_ch() const;
  1619. const std::map<cv::String, int>& get_class_labels_map() const;
  1620. protected:
  1621. virtual void clear();
  1622. void str_to_flt_elem( const char* token, float& flt_elem, int& type);
  1623. void free_train_test_idx();
  1624. char delimiter;
  1625. char miss_ch;
  1626. //char flt_separator;
  1627. CvMat* values;
  1628. CvMat* missing;
  1629. CvMat* var_types;
  1630. CvMat* var_idx_mask;
  1631. CvMat* response_out; // header
  1632. CvMat* var_idx_out; // mat
  1633. CvMat* var_types_out; // mat
  1634. int header_lines_number;
  1635. int response_idx;
  1636. int train_sample_count;
  1637. bool mix;
  1638. int total_class_count;
  1639. std::map<cv::String, int> class_map;
  1640. CvMat* train_sample_idx;
  1641. CvMat* test_sample_idx;
  1642. int* sample_idx; // data of train_sample_idx and test_sample_idx
  1643. cv::RNG* rng;
  1644. };
  1645. namespace cv
  1646. {
  1647. typedef CvStatModel StatModel;
  1648. typedef CvParamGrid ParamGrid;
  1649. typedef CvNormalBayesClassifier NormalBayesClassifier;
  1650. typedef CvKNearest KNearest;
  1651. typedef CvSVMParams SVMParams;
  1652. typedef CvSVMKernel SVMKernel;
  1653. typedef CvSVMSolver SVMSolver;
  1654. typedef CvSVM SVM;
  1655. typedef CvDTreeParams DTreeParams;
  1656. typedef CvMLData TrainData;
  1657. typedef CvDTree DecisionTree;
  1658. typedef CvForestTree ForestTree;
  1659. typedef CvRTParams RandomTreeParams;
  1660. typedef CvRTrees RandomTrees;
  1661. typedef CvERTreeTrainData ERTreeTRainData;
  1662. typedef CvForestERTree ERTree;
  1663. typedef CvERTrees ERTrees;
  1664. typedef CvBoostParams BoostParams;
  1665. typedef CvBoostTree BoostTree;
  1666. typedef CvBoost Boost;
  1667. typedef CvANN_MLP_TrainParams ANN_MLP_TrainParams;
  1668. typedef CvANN_MLP NeuralNet_MLP;
  1669. typedef CvGBTreesParams GradientBoostingTreeParams;
  1670. typedef CvGBTrees GradientBoostingTrees;
  1671. template<> struct DefaultDeleter<CvDTreeSplit>{ void operator ()(CvDTreeSplit* obj) const; };
  1672. }
  1673. #endif // __cplusplus
  1674. #endif // OPENCV_OLD_ML_HPP
  1675. /* End of file. */