train_svmsgd.cpp 5.9 KB


  1. #include "opencv2/core.hpp"
  2. #include "opencv2/video/tracking.hpp"
  3. #include "opencv2/imgproc.hpp"
  4. #include "opencv2/highgui.hpp"
  5. #include "opencv2/ml.hpp"
  6. using namespace cv;
  7. using namespace cv::ml;
  8. struct Data
  9. {
  10. Mat img;
  11. Mat samples; //Set of train samples. Contains points on image
  12. Mat responses; //Set of responses for train samples
  13. Data()
  14. {
  15. const int WIDTH = 841;
  16. const int HEIGHT = 594;
  17. img = Mat::zeros(HEIGHT, WIDTH, CV_8UC3);
  18. imshow("Train svmsgd", img);
  19. }
  20. };
  21. //Train with SVMSGD algorithm
  22. //(samples, responses) is a train set
  23. //weights is a required vector for decision function of SVMSGD algorithm
  24. bool doTrain(const Mat samples, const Mat responses, Mat &weights, float &shift);
  25. //function finds two points for drawing line (wx = 0)
  26. bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height);
  27. // function finds cross point of line (wx = 0) and segment ( (y = HEIGHT, 0 <= x <= WIDTH) or (x = WIDTH, 0 <= y <= HEIGHT) )
  28. bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint);
  29. //segments' initialization ( (y = HEIGHT, 0 <= x <= WIDTH) and (x = WIDTH, 0 <= y <= HEIGHT) )
  30. void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height);
  31. //redraw points' set and line (wx = 0)
  32. void redraw(Data data, const Point points[2]);
  33. //add point in train set, train SVMSGD algorithm and draw results on image
  34. void addPointRetrainAndRedraw(Data &data, int x, int y, int response);
  35. bool doTrain( const Mat samples, const Mat responses, Mat &weights, float &shift)
  36. {
  37. cv::Ptr<SVMSGD> svmsgd = SVMSGD::create();
  38. cv::Ptr<TrainData> trainData = TrainData::create(samples, cv::ml::ROW_SAMPLE, responses);
  39. svmsgd->train( trainData );
  40. if (svmsgd->isTrained())
  41. {
  42. weights = svmsgd->getWeights();
  43. shift = svmsgd->getShift();
  44. return true;
  45. }
  46. return false;
  47. }
  48. void fillSegments(std::vector<std::pair<Point,Point> > &segments, int width, int height)
  49. {
  50. std::pair<Point,Point> currentSegment;
  51. currentSegment.first = Point(width, 0);
  52. currentSegment.second = Point(width, height);
  53. segments.push_back(currentSegment);
  54. currentSegment.first = Point(0, height);
  55. currentSegment.second = Point(width, height);
  56. segments.push_back(currentSegment);
  57. currentSegment.first = Point(0, 0);
  58. currentSegment.second = Point(width, 0);
  59. segments.push_back(currentSegment);
  60. currentSegment.first = Point(0, 0);
  61. currentSegment.second = Point(0, height);
  62. segments.push_back(currentSegment);
  63. }
  64. bool findCrossPointWithBorders(const Mat &weights, float shift, const std::pair<Point,Point> &segment, Point &crossPoint)
  65. {
  66. int x = 0;
  67. int y = 0;
  68. int xMin = std::min(segment.first.x, segment.second.x);
  69. int xMax = std::max(segment.first.x, segment.second.x);
  70. int yMin = std::min(segment.first.y, segment.second.y);
  71. int yMax = std::max(segment.first.y, segment.second.y);
  72. CV_Assert(weights.type() == CV_32FC1);
  73. CV_Assert(xMin == xMax || yMin == yMax);
  74. if (xMin == xMax && weights.at<float>(1) != 0)
  75. {
  76. x = xMin;
  77. y = static_cast<int>(std::floor( - (weights.at<float>(0) * x + shift) / weights.at<float>(1)));
  78. if (y >= yMin && y <= yMax)
  79. {
  80. crossPoint.x = x;
  81. crossPoint.y = y;
  82. return true;
  83. }
  84. }
  85. else if (yMin == yMax && weights.at<float>(0) != 0)
  86. {
  87. y = yMin;
  88. x = static_cast<int>(std::floor( - (weights.at<float>(1) * y + shift) / weights.at<float>(0)));
  89. if (x >= xMin && x <= xMax)
  90. {
  91. crossPoint.x = x;
  92. crossPoint.y = y;
  93. return true;
  94. }
  95. }
  96. return false;
  97. }
  98. bool findPointsForLine(const Mat &weights, float shift, Point points[2], int width, int height)
  99. {
  100. if (weights.empty())
  101. {
  102. return false;
  103. }
  104. int foundPointsCount = 0;
  105. std::vector<std::pair<Point,Point> > segments;
  106. fillSegments(segments, width, height);
  107. for (uint i = 0; i < segments.size(); i++)
  108. {
  109. if (findCrossPointWithBorders(weights, shift, segments[i], points[foundPointsCount]))
  110. foundPointsCount++;
  111. if (foundPointsCount >= 2)
  112. break;
  113. }
  114. return true;
  115. }
  116. void redraw(Data data, const Point points[2])
  117. {
  118. data.img.setTo(0);
  119. Point center;
  120. int radius = 3;
  121. Scalar color;
  122. CV_Assert((data.samples.type() == CV_32FC1) && (data.responses.type() == CV_32FC1));
  123. for (int i = 0; i < data.samples.rows; i++)
  124. {
  125. center.x = static_cast<int>(data.samples.at<float>(i,0));
  126. center.y = static_cast<int>(data.samples.at<float>(i,1));
  127. color = (data.responses.at<float>(i) > 0) ? Scalar(128,128,0) : Scalar(0,128,128);
  128. circle(data.img, center, radius, color, 5);
  129. }
  130. line(data.img, points[0], points[1],cv::Scalar(1,255,1));
  131. imshow("Train svmsgd", data.img);
  132. }
  133. void addPointRetrainAndRedraw(Data &data, int x, int y, int response)
  134. {
  135. Mat currentSample(1, 2, CV_32FC1);
  136. currentSample.at<float>(0,0) = (float)x;
  137. currentSample.at<float>(0,1) = (float)y;
  138. data.samples.push_back(currentSample);
  139. data.responses.push_back(static_cast<float>(response));
  140. Mat weights(1, 2, CV_32FC1);
  141. float shift = 0;
  142. if (doTrain(data.samples, data.responses, weights, shift))
  143. {
  144. Point points[2];
  145. findPointsForLine(weights, shift, points, data.img.cols, data.img.rows);
  146. redraw(data, points);
  147. }
  148. }
  149. static void onMouse( int event, int x, int y, int, void* pData)
  150. {
  151. Data &data = *(Data*)pData;
  152. switch( event )
  153. {
  154. case EVENT_LBUTTONUP:
  155. addPointRetrainAndRedraw(data, x, y, 1);
  156. break;
  157. case EVENT_RBUTTONDOWN:
  158. addPointRetrainAndRedraw(data, x, y, -1);
  159. break;
  160. }
  161. }
  162. int main()
  163. {
  164. Data data;
  165. setMouseCallback( "Train svmsgd", onMouse, &data );
  166. waitKey();
  167. return 0;
  168. }