test_grabcut.py 2.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071
  1. #!/usr/bin/env python
  2. '''
  3. ===============================================================================
  4. Interactive Image Segmentation using GrabCut algorithm.
  5. ===============================================================================
  6. '''
  7. # Python 2/3 compatibility
  8. from __future__ import print_function
  9. import numpy as np
  10. import cv2 as cv
  11. import sys
  12. from tests_common import NewOpenCVTests
  13. class grabcut_test(NewOpenCVTests):
  14. def verify(self, mask, exp):
  15. maxDiffRatio = 0.02
  16. expArea = np.count_nonzero(exp)
  17. nonIntersectArea = np.count_nonzero(mask != exp)
  18. curRatio = float(nonIntersectArea) / expArea
  19. return curRatio < maxDiffRatio
  20. def scaleMask(self, mask):
  21. return np.where((mask==cv.GC_FGD) + (mask==cv.GC_PR_FGD),255,0).astype('uint8')
  22. def test_grabcut(self):
  23. img = self.get_sample('cv/shared/airplane.png')
  24. mask_prob = self.get_sample("cv/grabcut/mask_probpy.png", 0)
  25. exp_mask1 = self.get_sample("cv/grabcut/exp_mask1py.png", 0)
  26. exp_mask2 = self.get_sample("cv/grabcut/exp_mask2py.png", 0)
  27. if img is None:
  28. self.assertTrue(False, 'Missing test data')
  29. rect = (24, 126, 459, 168)
  30. mask = np.zeros(img.shape[:2], dtype = np.uint8)
  31. bgdModel = np.zeros((1,65),np.float64)
  32. fgdModel = np.zeros((1,65),np.float64)
  33. cv.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv.GC_INIT_WITH_RECT)
  34. cv.grabCut(img, mask, rect, bgdModel, fgdModel, 2, cv.GC_EVAL)
  35. if mask_prob is None:
  36. mask_prob = mask.copy()
  37. cv.imwrite(self.extraTestDataPath + '/cv/grabcut/mask_probpy.png', mask_prob)
  38. if exp_mask1 is None:
  39. exp_mask1 = self.scaleMask(mask)
  40. cv.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask1py.png', exp_mask1)
  41. self.assertEqual(self.verify(self.scaleMask(mask), exp_mask1), True)
  42. mask = mask_prob
  43. bgdModel = np.zeros((1,65),np.float64)
  44. fgdModel = np.zeros((1,65),np.float64)
  45. cv.grabCut(img, mask, rect, bgdModel, fgdModel, 0, cv.GC_INIT_WITH_MASK)
  46. cv.grabCut(img, mask, rect, bgdModel, fgdModel, 1, cv.GC_EVAL)
  47. if exp_mask2 is None:
  48. exp_mask2 = self.scaleMask(mask)
  49. cv.imwrite(self.extraTestDataPath + '/cv/grabcut/exp_mask2py.png', exp_mask2)
  50. self.assertEqual(self.verify(self.scaleMask(mask), exp_mask2), True)
  51. if __name__ == '__main__':
  52. NewOpenCVTests.bootstrap()