test_kmeans.py 1.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. #!/usr/bin/env python
  2. '''
  3. K-means clusterization test
  4. '''
  5. # Python 2/3 compatibility
  6. from __future__ import print_function
  7. import numpy as np
  8. import cv2 as cv
  9. from numpy import random
  10. import sys
  11. PY3 = sys.version_info[0] == 3
  12. if PY3:
  13. xrange = range
  14. from tests_common import NewOpenCVTests
  15. def make_gaussians(cluster_n, img_size):
  16. points = []
  17. ref_distrs = []
  18. sizes = []
  19. for _ in xrange(cluster_n):
  20. mean = (0.1 + 0.8*random.rand(2)) * img_size
  21. a = (random.rand(2, 2)-0.5)*img_size*0.1
  22. cov = np.dot(a.T, a) + img_size*0.05*np.eye(2)
  23. n = 100 + random.randint(900)
  24. pts = random.multivariate_normal(mean, cov, n)
  25. points.append( pts )
  26. ref_distrs.append( (mean, cov) )
  27. sizes.append(n)
  28. points = np.float32( np.vstack(points) )
  29. return points, ref_distrs, sizes
  30. def getMainLabelConfidence(labels, nLabels):
  31. n = len(labels)
  32. labelsDict = dict.fromkeys(range(nLabels), 0)
  33. labelsConfDict = dict.fromkeys(range(nLabels))
  34. for i in range(n):
  35. labelsDict[labels[i][0]] += 1
  36. for i in range(nLabels):
  37. labelsConfDict[i] = float(labelsDict[i]) / n
  38. return max(labelsConfDict.values())
  39. class kmeans_test(NewOpenCVTests):
  40. def test_kmeans(self):
  41. np.random.seed(10)
  42. cluster_n = 5
  43. img_size = 512
  44. points, _, clusterSizes = make_gaussians(cluster_n, img_size)
  45. term_crit = (cv.TERM_CRITERIA_EPS, 30, 0.1)
  46. _ret, labels, centers = cv.kmeans(points, cluster_n, None, term_crit, 10, 0)
  47. self.assertEqual(len(centers), cluster_n)
  48. offset = 0
  49. for i in range(cluster_n):
  50. confidence = getMainLabelConfidence(labels[offset : (offset + clusterSizes[i])], cluster_n)
  51. offset += clusterSizes[i]
  52. self.assertGreater(confidence, 0.9)
  53. if __name__ == '__main__':
  54. NewOpenCVTests.bootstrap()