kmeans.cpp 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. #include "opencv2/highgui.hpp"
  2. #include "opencv2/core.hpp"
  3. #include "opencv2/imgproc.hpp"
  4. #include <iostream>
  5. using namespace cv;
  6. using namespace std;
  7. // static void help()
  8. // {
  9. // cout << "\nThis program demonstrates kmeans clustering.\n"
  10. // "It generates an image with random points, then assigns a random number of cluster\n"
  11. // "centers and uses kmeans to move those cluster centers to their representitive location\n"
  12. // "Call\n"
  13. // "./kmeans\n" << endl;
  14. // }
  15. int main( int /*argc*/, char** /*argv*/ )
  16. {
  17. const int MAX_CLUSTERS = 5;
  18. Scalar colorTab[] =
  19. {
  20. Scalar(0, 0, 255),
  21. Scalar(0,255,0),
  22. Scalar(255,100,100),
  23. Scalar(255,0,255),
  24. Scalar(0,255,255)
  25. };
  26. Mat img(500, 500, CV_8UC3);
  27. RNG rng(12345);
  28. for(;;)
  29. {
  30. int k, clusterCount = rng.uniform(2, MAX_CLUSTERS+1);
  31. int i, sampleCount = rng.uniform(1, 1001);
  32. Mat points(sampleCount, 1, CV_32FC2), labels;
  33. clusterCount = MIN(clusterCount, sampleCount);
  34. std::vector<Point2f> centers;
  35. /* generate random sample from multigaussian distribution */
  36. for( k = 0; k < clusterCount; k++ )
  37. {
  38. Point center;
  39. center.x = rng.uniform(0, img.cols);
  40. center.y = rng.uniform(0, img.rows);
  41. Mat pointChunk = points.rowRange(k*sampleCount/clusterCount,
  42. k == clusterCount - 1 ? sampleCount :
  43. (k+1)*sampleCount/clusterCount);
  44. rng.fill(pointChunk, RNG::NORMAL, Scalar(center.x, center.y), Scalar(img.cols*0.05, img.rows*0.05));
  45. }
  46. randShuffle(points, 1, &rng);
  47. double compactness = kmeans(points, clusterCount, labels,
  48. TermCriteria( TermCriteria::EPS+TermCriteria::COUNT, 10, 1.0),
  49. 3, KMEANS_PP_CENTERS, centers);
  50. img = Scalar::all(0);
  51. for( i = 0; i < sampleCount; i++ )
  52. {
  53. int clusterIdx = labels.at<int>(i);
  54. Point ipt = points.at<Point2f>(i);
  55. circle( img, ipt, 2, colorTab[clusterIdx], FILLED, LINE_AA );
  56. }
  57. for (i = 0; i < (int)centers.size(); ++i)
  58. {
  59. Point2f c = centers[i];
  60. circle( img, c, 40, colorTab[i], 1, LINE_AA );
  61. }
  62. cout << "Compactness: " << compactness << endl;
  63. imshow("clusters", img);
  64. char key = (char)waitKey();
  65. if( key == 27 || key == 'q' || key == 'Q' ) // 'ESC'
  66. break;
  67. }
  68. return 0;
  69. }