fast_template.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. #!/usr/bin/env python
  2. """ Based on the original C++ code referenced below """
  3. """***************************************************************************
  4. * FastMatchTemplate.cpp
  5. *
  6. *
  7. * Copyright 2010 Tristen Georgiou
  8. * tristen_georgiou@hotmail.com
  9. ****************************************************************************/
  10. /*
  11. * This program is free software; you can redistribute it and/or modify
  12. * it under the terms of the GNU General Public License as published by
  13. * the Free Software Foundation; either version 2 of the License, or
  14. * (at your option) any later version.
  15. *
  16. * This program is distributed in the hope that it will be useful,
  17. * but WITHOUT ANY WARRANTY; without even the implied warranty of
  18. * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  19. * GNU General Public License for more details.
  20. *
  21. * You should have received a copy of the GNU General Public License
  22. * along with this program; if not, write to the Free Software
  23. * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA.
  24. */
  25. """
  26. from rbx1_vision.ros2opencv2 import ROS2OpenCV2
  27. from sensor_msgs.msg import Image, RegionOfInterest
  28. from geometry_msgs.msg import Point
  29. import sys
  30. import cv2.cv as cv
  31. import cv2
  32. import numpy as np
  33. from math import pow
  34. """=============================================================================
  35. // Assumes that source image exists and numDownPyrs > 1, no ROIs for either
  36. // image, and both images have the same depth and number of channels
  37. """
  38. class FastMatchTemplate(ROS2OpenCV2):
  39. def __init__(self, node_name):
  40. ROS2OpenCV2.__init__(self, node_name)
  41. self.matchPercentage = rospy.get_param("~matchPercentage", 70)
  42. self.findMultipleTargets = rospy.get_param("~findMultipleTargets", False)
  43. self.numMaxima = rospy.get_param("~numMaxima", 1)
  44. self.numDownPyrs = rospy.get_param("~numDownPyrs", 2)
  45. self.searchExpansion = rospy.get_param("~searchExpansion", 15)
  46. self.use_depth_for_detection = rospy.get_param("~use_depth_for_detection", False)
  47. self.fov_width = rospy.get_param("~fov_width", 1.094)
  48. self.fov_height = rospy.get_param("~fov_height", 1.094)
  49. self.max_object_size = rospy.get_param("~max_object_size", 0.28)
  50. self.foundPointsList = list()
  51. self.confidencesList = list()
  52. # Intialize the detection box
  53. self.detect_box = None
  54. # Initialize a couple of intermediate image variables
  55. self.grey = None
  56. self.small_image = None
  57. # What kind of detector do we want to load
  58. self.detector_type = "template"
  59. self.detector_loaded = False
  60. rospy.loginfo("Waiting for video topics to become available...")
  61. # Wait until the image topics are ready before starting
  62. rospy.wait_for_message("input_rgb_image", Image)
  63. if self.use_depth_for_detection:
  64. rospy.wait_for_message("input_depth_image", Image)
  65. rospy.loginfo("Ready.")
  66. def process_image(self, cv_image):
  67. #return_image = cv.CreateMat(self.frame_size[1], self.frame_size[0], cv.CV_8UC3)
  68. #cv.Copy(cv.fromarray(cv_image), return_image)
  69. # STEP 1. Load a detector if one is specified
  70. if self.detector_type and not self.detector_loaded:
  71. self.detector_loaded = self.load_detector(self.detector_type)
  72. # STEP 2: Detect the object
  73. self.detect_box = self.detect_roi(self.detector_type, cv_image)
  74. return cv_image
  75. def load_detector(self, detector):
  76. if detector == "template":
  77. #try:
  78. """ Read in the template image """
  79. template_file = rospy.get_param("~template_file", "")
  80. #self.template = cv2.equalizeHist(cv2.cvtColor(cv2.imread(template_file, cv.CV_LOAD_IMAGE_COLOR), cv2.COLOR_BGR2GRAY))
  81. #self.template = cv2.imread(template_file, cv.CV_LOAD_IMAGE_GRAYSCALE)
  82. #self.template = cv2.Sobel(self.template_image, cv.CV_32F, 1, 1)
  83. self.template = cv2.imread(template_file, cv.CV_LOAD_IMAGE_COLOR)
  84. cv2.imshow("Template", self.template)
  85. return True
  86. #except:
  87. #rospy.loginfo("Exception loading face detector!")
  88. #return False
  89. else:
  90. return False
  91. def detect_roi(self, detector, cv_image):
  92. if detector == "template":
  93. detect_box = self.match_template(cv_image)
  94. return detect_box
  95. def match_template(self, cv_image):
  96. frame = np.array(cv_image, dtype=np.uint8)
  97. W,H = frame.shape[1], frame.shape[0]
  98. w,h = self.template.shape[1], self.template.shape[0]
  99. width = W - w + 1
  100. height = H - h + 1
  101. # Make sure that the template image is smaller than the source
  102. if W < w or H < h:
  103. rospy.loginfo( "Template image must be smaller than video frame." )
  104. return False
  105. if frame.dtype != self.template.dtype:
  106. rospy.loginfo("Template and video frame must have same depth and number of channels.")
  107. return False
  108. # Create copies of the images to modify
  109. frame_copy = frame.copy()
  110. template_copy = self.template.copy()
  111. # Down pyramid the images
  112. for k in range(self.numDownPyrs):
  113. # Start with the source image
  114. W = (W + 1) / 2
  115. H = (H + 1) / 2
  116. frame_small = np.array([H, W], dtype=frame.dtype)
  117. frame_small = cv2.pyrDown(frame_copy)
  118. # frame_window = "PyrDown " + str(k)
  119. # cv.NamedWindow(frame_window, cv.CV_NORMAL)
  120. # cv.ShowImage(frame_window, cv.fromarray(frame_small))
  121. # cv.ResizeWindow(frame_window, 640, 480)
  122. # Prepare for next loop, if any
  123. frame_copy = frame_small.copy()
  124. #Next, do the target
  125. w = (w + 1) / 2
  126. h = (h + 1) / 2
  127. template_small = np.array([h, w], dtype=self.template.dtype)
  128. template_small = cv2.pyrDown(template_copy)
  129. # template_window = "Template PyrDown " + str(k)
  130. # cv.NamedWindow(template_window, cv.CV_NORMAL)
  131. # cv.ShowImage(template_window, cv.fromarray(template_small))
  132. # cv.ResizeWindow(template_window, 640, 480)
  133. # Prepare for next loop, if any
  134. template_copy = template_small.copy()
  135. # Perform the match on the shrunken images
  136. small_frame_width = frame_copy.shape[1]
  137. small_frame_height = frame_copy.shape[0]
  138. small_template_width = template_copy.shape[1]
  139. small_template_height = template_copy.shape[0]
  140. result_width = small_frame_width - small_template_width + 1
  141. result_height = small_frame_height - small_template_height + 1
  142. result_mat = cv.CreateMat(result_height, result_width, cv.CV_32FC1)
  143. result = np.array(result_mat, dtype = np.float32)
  144. cv2.matchTemplate(frame_copy, template_copy, cv.CV_TM_CCOEFF_NORMED, result)
  145. cv2.imshow("Result", result)
  146. return (0, 0, 100, 100)
  147. # # Find the best match location
  148. # (minValue, maxValue, minLoc, maxLoc) = cv2.minMaxLoc(result)
  149. #
  150. # # Transform point back to original image
  151. # target_location = Point()
  152. # target_location.x, target_location.y = maxLoc
  153. #
  154. # return (target_location.x, target_location.y, w, h)
  155. # Find the top match locations
  156. locations = self.MultipleMaxLoc(result, self.numMaxima)
  157. foundPointsList = list()
  158. confidencesList = list()
  159. W,H = frame.shape[1], frame.shape[0]
  160. w,h = self.template.shape[1], self.template.shape[0]
  161. # Search the large images at the returned locations
  162. for currMax in range(self.numMaxima):
  163. # Transform the point to its corresponding point in the larger image
  164. #locations[currMax].x *= int(pow(2.0, self.numDownPyrs))
  165. #locations[currMax].y *= int(pow(2.0, self.numDownPyrs))
  166. locations[currMax].x += w / 2
  167. locations[currMax].y += h / 2
  168. searchPoint = locations[currMax]
  169. print "Search Point", searchPoint
  170. # If we are searching for multiple targets and we have found a target or
  171. # multiple targets, we don't want to search in the same location(s) again
  172. # if self.findMultipleTargets and len(foundPointsList) != 0:
  173. # thisTargetFound = False
  174. # numPoints = len(foundPointsList)
  175. #
  176. # for currPoint in range(numPoints):
  177. # foundPoint = foundPointsList[currPoint]
  178. # if (abs(searchPoint.x - foundPoint.x) <= self.searchExpansion * 2) and (abs(searchPoint.y - foundPoint.y) <= self.searchExpansion * 2):
  179. # thisTargetFound = True
  180. # break
  181. #
  182. # # If the current target has been found, continue onto the next point
  183. # if thisTargetFound:
  184. # continue
  185. # Set the source image's ROI to slightly larger than the target image,
  186. # centred at the current point
  187. searchRoi = RegionOfInterest()
  188. searchRoi.x_offset = searchPoint.x - w / 2 - self.searchExpansion
  189. searchRoi.y_offset = searchPoint.y - h / 2 - self.searchExpansion
  190. searchRoi.width = w + self.searchExpansion * 2
  191. searchRoi.height = h + self.searchExpansion * 2
  192. #print (searchRoi.x_offset, searchRoi.y_offset, searchRoi.width, searchRoi.height)
  193. # Make sure ROI doesn't extend outside of image
  194. if searchRoi.x_offset < 0:
  195. searchRoi.x_offset = 0
  196. if searchRoi.y_offset < 0:
  197. searchRoi.y_offset = 0
  198. if (searchRoi.x_offset + searchRoi.width) > (W - 1):
  199. numPixelsOver = (searchRoi.x_offset + searchRoi.width) - (W - 1)
  200. print "NUM PIXELS OVER", numPixelsOver
  201. searchRoi.width -= numPixelsOver
  202. if (searchRoi.y_offset + searchRoi.height) > (H - 1):
  203. numPixelsOver = (searchRoi.y_offset + searchRoi.height) - (H - 1)
  204. searchRoi.height -= numPixelsOver
  205. mask = (searchRoi.x_offset, searchRoi.y_offset, searchRoi.width, searchRoi.height)
  206. frame_mat = cv.fromarray(frame)
  207. searchImage = cv.CreateMat(searchRoi.height, searchRoi.width, cv.CV_8UC3)
  208. searchImage = cv.GetSubRect(frame_mat, mask)
  209. searchArray = np.array(searchImage, dtype=np.uint8)
  210. # Perform the search on the large images
  211. result_width = searchRoi.width - w + 1
  212. result_height = searchRoi.height - h + 1
  213. result_mat = cv.CreateMat(result_height, result_width, cv.CV_32FC1)
  214. result = np.array(result_mat, dtype = np.float32)
  215. cv2.matchTemplate(searchArray, self.template, cv.CV_TM_CCOEFF_NORMED, result)
  216. # Find the best match location
  217. (minValue, maxValue, minLoc, maxLoc) = cv2.minMaxLoc(result)
  218. maxValue *= 100
  219. # Transform point back to original image
  220. target_location = Point()
  221. target_location.x, target_location.y = maxLoc
  222. target_location.x += searchRoi.x_offset - w / 2 + self.searchExpansion
  223. target_location.y += searchRoi.y_offset - h / 2 + self.searchExpansion
  224. if maxValue >= self.matchPercentage:
  225. # Add the point to the list
  226. foundPointsList.append(maxLoc)
  227. confidencesList.append(maxValue)
  228. # If we are only looking for a single target, we have found it, so we
  229. # can return
  230. if not self.findMultipleTargets:
  231. break
  232. if len(foundPointsList) == 0:
  233. rospy.loginfo("Target was not found to required confidence")
  234. return (target_location.x, target_location.y, w, h)
  235. def MultipleMaxLoc(self, result, numMaxima):
  236. # Initialize input variable locations
  237. #locations = np.empty((numMaxima, 2), dtype=np.uint8)
  238. locations = [Point()]*numMaxima
  239. # Create array for tracking maxima
  240. maxima = [0.0]*numMaxima
  241. result_width = result.shape[1]
  242. result_height = result.shape[0]
  243. # Extract the raw data for analysis
  244. for y in range(result_height):
  245. for x in range(result_width):
  246. data = result[y, x]
  247. # Insert the data value into the array if it is greater than any of the
  248. # other array values, and bump the other values below it, down
  249. for j in range(numMaxima):
  250. # Require at least 50% confidence on the sub-sampled image
  251. # in order to make this as fast as possible
  252. if data > 0.5 and data > maxima[j]:
  253. # Move the maxima down
  254. k = numMaxima - 1
  255. while k > j:
  256. maxima[k] = maxima[k-1]
  257. locations[k] = locations[k-1]
  258. k = k - 1
  259. # Insert the value
  260. maxima[j] = data
  261. locations[j].x = x
  262. locations[j].y = y
  263. break
  264. return locations
  265. #//=============================================================================
  266. #
  267. #void
  268. #DrawFoundTargets(Mat* image,
  269. # const Size& size,
  270. # const vector<Point>& pointsList,
  271. # const vector<double>& confidencesList,
  272. # int red,
  273. # int green,
  274. # int blue)
  275. #{
  276. # int numPoints = pointsList.size();
  277. # for(int currPoint = 0; currPoint < numPoints; currPoint++)
  278. # {
  279. # const Point& point = pointsList[currPoint];
  280. #
  281. # // write the confidences to stdout
  282. # rospy.loginfo("\nTarget found at (%d, %d), with confidence = %3.3f %%.\n",
  283. # point.x,
  284. # point.y,
  285. # confidencesList[currPoint]);
  286. #
  287. # // draw a circle at the center
  288. # circle(*image, point, 2, CV_RGB(red, green, blue));
  289. #
  290. # // draw a rectangle around the found target
  291. # Point topLeft;
  292. # topLeft.x = point.x - size.width / 2;
  293. # topLeft.y = point.y - size.height / 2;
  294. #
  295. # Point bottomRight;
  296. # bottomRight.x = point.x + size.width / 2;
  297. # bottomRight.y = point.y + size.height / 2;
  298. #
  299. # rectangle(*image, topLeft, bottomRight, CV_RGB(red, green, blue));
  300. # }
  301. #}
  302. def main(args):
  303. FMT = FastMatchTemplate("fast_match_template")
  304. try:
  305. rospy.spin()
  306. except KeyboardInterrupt:
  307. print "Shutting down fast match template node."
  308. cv.DestroyAllWindows()
  309. if __name__ == '__main__':
  310. main(sys.argv)