train_HOG.cpp 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392
  1. #include "opencv2/imgproc.hpp"
  2. #include "opencv2/highgui.hpp"
  3. #include "opencv2/ml.hpp"
  4. #include "opencv2/objdetect.hpp"
  5. #include "opencv2/videoio.hpp"
  6. #include <iostream>
  7. #include <time.h>
  8. using namespace cv;
  9. using namespace cv::ml;
  10. using namespace std;
  11. vector< float > get_svm_detector( const Ptr< SVM >& svm );
  12. void convert_to_ml( const std::vector< Mat > & train_samples, Mat& trainData );
  13. void load_images( const String & dirname, vector< Mat > & img_lst, bool showImages );
  14. void sample_neg( const vector< Mat > & full_neg_lst, vector< Mat > & neg_lst, const Size & size );
  15. void computeHOGs( const Size wsize, const vector< Mat > & img_lst, vector< Mat > & gradient_lst, bool use_flip );
  16. void test_trained_detector( String obj_det_filename, String test_dir, String videofilename );
  17. vector< float > get_svm_detector( const Ptr< SVM >& svm )
  18. {
  19. // get the support vectors
  20. Mat sv = svm->getSupportVectors();
  21. const int sv_total = sv.rows;
  22. // get the decision function
  23. Mat alpha, svidx;
  24. double rho = svm->getDecisionFunction( 0, alpha, svidx );
  25. CV_Assert( alpha.total() == 1 && svidx.total() == 1 && sv_total == 1 );
  26. CV_Assert( (alpha.type() == CV_64F && alpha.at<double>(0) == 1.) ||
  27. (alpha.type() == CV_32F && alpha.at<float>(0) == 1.f) );
  28. CV_Assert( sv.type() == CV_32F );
  29. vector< float > hog_detector( sv.cols + 1 );
  30. memcpy( &hog_detector[0], sv.ptr(), sv.cols*sizeof( hog_detector[0] ) );
  31. hog_detector[sv.cols] = (float)-rho;
  32. return hog_detector;
  33. }
  34. /*
  35. * Convert training/testing set to be used by OpenCV Machine Learning algorithms.
  36. * TrainData is a matrix of size (#samples x max(#cols,#rows) per samples), in 32FC1.
  37. * Transposition of samples are made if needed.
  38. */
  39. void convert_to_ml( const vector< Mat > & train_samples, Mat& trainData )
  40. {
  41. //--Convert data
  42. const int rows = (int)train_samples.size();
  43. const int cols = (int)std::max( train_samples[0].cols, train_samples[0].rows );
  44. Mat tmp( 1, cols, CV_32FC1 ); //< used for transposition if needed
  45. trainData = Mat( rows, cols, CV_32FC1 );
  46. for( size_t i = 0 ; i < train_samples.size(); ++i )
  47. {
  48. CV_Assert( train_samples[i].cols == 1 || train_samples[i].rows == 1 );
  49. if( train_samples[i].cols == 1 )
  50. {
  51. transpose( train_samples[i], tmp );
  52. tmp.copyTo( trainData.row( (int)i ) );
  53. }
  54. else if( train_samples[i].rows == 1 )
  55. {
  56. train_samples[i].copyTo( trainData.row( (int)i ) );
  57. }
  58. }
  59. }
  60. void load_images( const String & dirname, vector< Mat > & img_lst, bool showImages = false )
  61. {
  62. vector< String > files;
  63. glob( dirname, files );
  64. for ( size_t i = 0; i < files.size(); ++i )
  65. {
  66. Mat img = imread( files[i] ); // load the image
  67. if ( img.empty() )
  68. {
  69. cout << files[i] << " is invalid!" << endl; // invalid image, skip it.
  70. continue;
  71. }
  72. if ( showImages )
  73. {
  74. imshow( "image", img );
  75. waitKey( 1 );
  76. }
  77. img_lst.push_back( img );
  78. }
  79. }
  80. void sample_neg( const vector< Mat > & full_neg_lst, vector< Mat > & neg_lst, const Size & size )
  81. {
  82. Rect box;
  83. box.width = size.width;
  84. box.height = size.height;
  85. srand( (unsigned int)time( NULL ) );
  86. for ( size_t i = 0; i < full_neg_lst.size(); i++ )
  87. if ( full_neg_lst[i].cols > box.width && full_neg_lst[i].rows > box.height )
  88. {
  89. box.x = rand() % ( full_neg_lst[i].cols - box.width );
  90. box.y = rand() % ( full_neg_lst[i].rows - box.height );
  91. Mat roi = full_neg_lst[i]( box );
  92. neg_lst.push_back( roi.clone() );
  93. }
  94. }
  95. void computeHOGs( const Size wsize, const vector< Mat > & img_lst, vector< Mat > & gradient_lst, bool use_flip )
  96. {
  97. HOGDescriptor hog;
  98. hog.winSize = wsize;
  99. Mat gray;
  100. vector< float > descriptors;
  101. for( size_t i = 0 ; i < img_lst.size(); i++ )
  102. {
  103. if ( img_lst[i].cols >= wsize.width && img_lst[i].rows >= wsize.height )
  104. {
  105. Rect r = Rect(( img_lst[i].cols - wsize.width ) / 2,
  106. ( img_lst[i].rows - wsize.height ) / 2,
  107. wsize.width,
  108. wsize.height);
  109. cvtColor( img_lst[i](r), gray, COLOR_BGR2GRAY );
  110. hog.compute( gray, descriptors, Size( 8, 8 ), Size( 0, 0 ) );
  111. gradient_lst.push_back( Mat( descriptors ).clone() );
  112. if ( use_flip )
  113. {
  114. flip( gray, gray, 1 );
  115. hog.compute( gray, descriptors, Size( 8, 8 ), Size( 0, 0 ) );
  116. gradient_lst.push_back( Mat( descriptors ).clone() );
  117. }
  118. }
  119. }
  120. }
  121. void test_trained_detector( String obj_det_filename, String test_dir, String videofilename )
  122. {
  123. cout << "Testing trained detector..." << endl;
  124. HOGDescriptor hog;
  125. hog.load( obj_det_filename );
  126. vector< String > files;
  127. glob( test_dir, files );
  128. int delay = 0;
  129. VideoCapture cap;
  130. if ( videofilename != "" )
  131. {
  132. if ( videofilename.size() == 1 && isdigit( videofilename[0] ) )
  133. cap.open( videofilename[0] - '0' );
  134. else
  135. cap.open( videofilename );
  136. }
  137. obj_det_filename = "testing " + obj_det_filename;
  138. namedWindow( obj_det_filename, WINDOW_NORMAL );
  139. for( size_t i=0;; i++ )
  140. {
  141. Mat img;
  142. if ( cap.isOpened() )
  143. {
  144. cap >> img;
  145. delay = 1;
  146. }
  147. else if( i < files.size() )
  148. {
  149. img = imread( files[i] );
  150. }
  151. if ( img.empty() )
  152. {
  153. return;
  154. }
  155. vector< Rect > detections;
  156. vector< double > foundWeights;
  157. hog.detectMultiScale( img, detections, foundWeights );
  158. for ( size_t j = 0; j < detections.size(); j++ )
  159. {
  160. Scalar color = Scalar( 0, foundWeights[j] * foundWeights[j] * 200, 0 );
  161. rectangle( img, detections[j], color, img.cols / 400 + 1 );
  162. }
  163. imshow( obj_det_filename, img );
  164. if( waitKey( delay ) == 27 )
  165. {
  166. return;
  167. }
  168. }
  169. }
  170. int main( int argc, char** argv )
  171. {
  172. const char* keys =
  173. {
  174. "{help h| | show help message}"
  175. "{pd | | path of directory contains positive images}"
  176. "{nd | | path of directory contains negative images}"
  177. "{td | | path of directory contains test images}"
  178. "{tv | | test video file name}"
  179. "{dw | | width of the detector}"
  180. "{dh | | height of the detector}"
  181. "{f |false| indicates if the program will generate and use mirrored samples or not}"
  182. "{d |false| train twice}"
  183. "{t |false| test a trained detector}"
  184. "{v |false| visualize training steps}"
  185. "{fn |my_detector.yml| file name of trained SVM}"
  186. };
  187. CommandLineParser parser( argc, argv, keys );
  188. if ( parser.has( "help" ) )
  189. {
  190. parser.printMessage();
  191. exit( 0 );
  192. }
  193. String pos_dir = parser.get< String >( "pd" );
  194. String neg_dir = parser.get< String >( "nd" );
  195. String test_dir = parser.get< String >( "td" );
  196. String obj_det_filename = parser.get< String >( "fn" );
  197. String videofilename = parser.get< String >( "tv" );
  198. int detector_width = parser.get< int >( "dw" );
  199. int detector_height = parser.get< int >( "dh" );
  200. bool test_detector = parser.get< bool >( "t" );
  201. bool train_twice = parser.get< bool >( "d" );
  202. bool visualization = parser.get< bool >( "v" );
  203. bool flip_samples = parser.get< bool >( "f" );
  204. if ( test_detector )
  205. {
  206. test_trained_detector( obj_det_filename, test_dir, videofilename );
  207. exit( 0 );
  208. }
  209. if( pos_dir.empty() || neg_dir.empty() )
  210. {
  211. parser.printMessage();
  212. cout << "Wrong number of parameters.\n\n"
  213. << "Example command line:\n" << argv[0] << " -dw=64 -dh=128 -pd=/INRIAPerson/96X160H96/Train/pos -nd=/INRIAPerson/neg -td=/INRIAPerson/Test/pos -fn=HOGpedestrian64x128.xml -d\n"
  214. << "\nExample command line for testing trained detector:\n" << argv[0] << " -t -fn=HOGpedestrian64x128.xml -td=/INRIAPerson/Test/pos";
  215. exit( 1 );
  216. }
  217. vector< Mat > pos_lst, full_neg_lst, neg_lst, gradient_lst;
  218. vector< int > labels;
  219. clog << "Positive images are being loaded..." ;
  220. load_images( pos_dir, pos_lst, visualization );
  221. if ( pos_lst.size() > 0 )
  222. {
  223. clog << "...[done] " << pos_lst.size() << " files." << endl;
  224. }
  225. else
  226. {
  227. clog << "no image in " << pos_dir <<endl;
  228. return 1;
  229. }
  230. Size pos_image_size = pos_lst[0].size();
  231. if ( detector_width && detector_height )
  232. {
  233. pos_image_size = Size( detector_width, detector_height );
  234. }
  235. else
  236. {
  237. for ( size_t i = 0; i < pos_lst.size(); ++i )
  238. {
  239. if( pos_lst[i].size() != pos_image_size )
  240. {
  241. cout << "All positive images should be same size!" << endl;
  242. exit( 1 );
  243. }
  244. }
  245. pos_image_size = pos_image_size / 8 * 8;
  246. }
  247. clog << "Negative images are being loaded...";
  248. load_images( neg_dir, full_neg_lst, visualization );
  249. clog << "...[done] " << full_neg_lst.size() << " files." << endl;
  250. clog << "Negative images are being processed...";
  251. sample_neg( full_neg_lst, neg_lst, pos_image_size );
  252. clog << "...[done] " << neg_lst.size() << " files." << endl;
  253. clog << "Histogram of Gradients are being calculated for positive images...";
  254. computeHOGs( pos_image_size, pos_lst, gradient_lst, flip_samples );
  255. size_t positive_count = gradient_lst.size();
  256. labels.assign( positive_count, +1 );
  257. clog << "...[done] ( positive images count : " << positive_count << " )" << endl;
  258. clog << "Histogram of Gradients are being calculated for negative images...";
  259. computeHOGs( pos_image_size, neg_lst, gradient_lst, flip_samples );
  260. size_t negative_count = gradient_lst.size() - positive_count;
  261. labels.insert( labels.end(), negative_count, -1 );
  262. CV_Assert( positive_count < labels.size() );
  263. clog << "...[done] ( negative images count : " << negative_count << " )" << endl;
  264. Mat train_data;
  265. convert_to_ml( gradient_lst, train_data );
  266. clog << "Training SVM...";
  267. Ptr< SVM > svm = SVM::create();
  268. /* Default values to train SVM */
  269. svm->setCoef0( 0.0 );
  270. svm->setDegree( 3 );
  271. svm->setTermCriteria( TermCriteria(TermCriteria::MAX_ITER + TermCriteria::EPS, 1000, 1e-3 ) );
  272. svm->setGamma( 0 );
  273. svm->setKernel( SVM::LINEAR );
  274. svm->setNu( 0.5 );
  275. svm->setP( 0.1 ); // for EPSILON_SVR, epsilon in loss function?
  276. svm->setC( 0.01 ); // From paper, soft classifier
  277. svm->setType( SVM::EPS_SVR ); // C_SVC; // EPSILON_SVR; // may be also NU_SVR; // do regression task
  278. svm->train( train_data, ROW_SAMPLE, labels );
  279. clog << "...[done]" << endl;
  280. if ( train_twice )
  281. {
  282. clog << "Testing trained detector on negative images. This might take a few minutes...";
  283. HOGDescriptor my_hog;
  284. my_hog.winSize = pos_image_size;
  285. // Set the trained svm to my_hog
  286. my_hog.setSVMDetector( get_svm_detector( svm ) );
  287. vector< Rect > detections;
  288. vector< double > foundWeights;
  289. for ( size_t i = 0; i < full_neg_lst.size(); i++ )
  290. {
  291. if ( full_neg_lst[i].cols >= pos_image_size.width && full_neg_lst[i].rows >= pos_image_size.height )
  292. my_hog.detectMultiScale( full_neg_lst[i], detections, foundWeights );
  293. else
  294. detections.clear();
  295. for ( size_t j = 0; j < detections.size(); j++ )
  296. {
  297. Mat detection = full_neg_lst[i]( detections[j] ).clone();
  298. resize( detection, detection, pos_image_size, 0, 0, INTER_LINEAR_EXACT);
  299. neg_lst.push_back( detection );
  300. }
  301. if ( visualization )
  302. {
  303. for ( size_t j = 0; j < detections.size(); j++ )
  304. {
  305. rectangle( full_neg_lst[i], detections[j], Scalar( 0, 255, 0 ), 2 );
  306. }
  307. imshow( "testing trained detector on negative images", full_neg_lst[i] );
  308. waitKey( 5 );
  309. }
  310. }
  311. clog << "...[done]" << endl;
  312. gradient_lst.clear();
  313. clog << "Histogram of Gradients are being calculated for positive images...";
  314. computeHOGs( pos_image_size, pos_lst, gradient_lst, flip_samples );
  315. positive_count = gradient_lst.size();
  316. clog << "...[done] ( positive count : " << positive_count << " )" << endl;
  317. clog << "Histogram of Gradients are being calculated for negative images...";
  318. computeHOGs( pos_image_size, neg_lst, gradient_lst, flip_samples );
  319. negative_count = gradient_lst.size() - positive_count;
  320. clog << "...[done] ( negative count : " << negative_count << " )" << endl;
  321. labels.clear();
  322. labels.assign(positive_count, +1);
  323. labels.insert(labels.end(), negative_count, -1);
  324. clog << "Training SVM again...";
  325. convert_to_ml( gradient_lst, train_data );
  326. svm->train( train_data, ROW_SAMPLE, labels );
  327. clog << "...[done]" << endl;
  328. }
  329. HOGDescriptor hog;
  330. hog.winSize = pos_image_size;
  331. hog.setSVMDetector( get_svm_detector( svm ) );
  332. hog.save( obj_det_filename );
  333. test_trained_detector( obj_det_filename, test_dir, videofilename );
  334. return 0;
  335. }