siamrpnpp.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. import argparse
  2. import cv2 as cv
  3. import numpy as np
  4. import os
  5. """
  6. Link to original paper : https://arxiv.org/abs/1812.11703
  7. Link to original repo : https://github.com/STVIR/pysot
  8. You can download the pre-trained weights of the Tracker Model from https://drive.google.com/file/d/11bwgPFVkps9AH2NOD1zBDdpF_tQghAB-/view?usp=sharing
  9. You can download the target net (target branch of SiamRPN++) from https://drive.google.com/file/d/1dw_Ne3UMcCnFsaD6xkZepwE4GEpqq7U_/view?usp=sharing
  10. You can download the search net (search branch of SiamRPN++) from https://drive.google.com/file/d/1Lt4oE43ZSucJvze3Y-Z87CVDreO-Afwl/view?usp=sharing
  11. You can download the head model (RPN Head) from https://drive.google.com/file/d/1zT1yu12mtj3JQEkkfKFJWiZ71fJ-dQTi/view?usp=sharing
  12. """
  13. class ModelBuilder():
  14. """ This class generates the SiamRPN++ Tracker Model by using Imported ONNX Nets
  15. """
  16. def __init__(self, target_net, search_net, rpn_head):
  17. super(ModelBuilder, self).__init__()
  18. # Build the target branch
  19. self.target_net = target_net
  20. # Build the search branch
  21. self.search_net = search_net
  22. # Build RPN_Head
  23. self.rpn_head = rpn_head
  24. def template(self, z):
  25. """ Takes the template of size (1, 1, 127, 127) as an input to generate kernel
  26. """
  27. self.target_net.setInput(z)
  28. outNames = self.target_net.getUnconnectedOutLayersNames()
  29. self.zfs_1, self.zfs_2, self.zfs_3 = self.target_net.forward(outNames)
  30. def track(self, x):
  31. """ Takes the search of size (1, 1, 255, 255) as an input to generate classification score and bounding box regression
  32. """
  33. self.search_net.setInput(x)
  34. outNames = self.search_net.getUnconnectedOutLayersNames()
  35. xfs_1, xfs_2, xfs_3 = self.search_net.forward(outNames)
  36. self.rpn_head.setInput(np.stack([self.zfs_1, self.zfs_2, self.zfs_3]), 'input_1')
  37. self.rpn_head.setInput(np.stack([xfs_1, xfs_2, xfs_3]), 'input_2')
  38. outNames = self.rpn_head.getUnconnectedOutLayersNames()
  39. cls, loc = self.rpn_head.forward(outNames)
  40. return {'cls': cls, 'loc': loc}
  41. class Anchors:
  42. """ This class generate anchors.
  43. """
  44. def __init__(self, stride, ratios, scales, image_center=0, size=0):
  45. self.stride = stride
  46. self.ratios = ratios
  47. self.scales = scales
  48. self.image_center = image_center
  49. self.size = size
  50. self.anchor_num = len(self.scales) * len(self.ratios)
  51. self.anchors = self.generate_anchors()
  52. def generate_anchors(self):
  53. """
  54. generate anchors based on predefined configuration
  55. """
  56. anchors = np.zeros((self.anchor_num, 4), dtype=np.float32)
  57. size = self.stride**2
  58. count = 0
  59. for r in self.ratios:
  60. ws = int(np.sqrt(size * 1. / r))
  61. hs = int(ws * r)
  62. for s in self.scales:
  63. w = ws * s
  64. h = hs * s
  65. anchors[count][:] = [-w * 0.5, -h * 0.5, w * 0.5, h * 0.5][:]
  66. count += 1
  67. return anchors
  68. class SiamRPNTracker:
  69. def __init__(self, model):
  70. super(SiamRPNTracker, self).__init__()
  71. self.anchor_stride = 8
  72. self.anchor_ratios = [0.33, 0.5, 1, 2, 3]
  73. self.anchor_scales = [8]
  74. self.track_base_size = 8
  75. self.track_context_amount = 0.5
  76. self.track_exemplar_size = 127
  77. self.track_instance_size = 255
  78. self.track_lr = 0.4
  79. self.track_penalty_k = 0.04
  80. self.track_window_influence = 0.44
  81. self.score_size = (self.track_instance_size - self.track_exemplar_size) // \
  82. self.anchor_stride + 1 + self.track_base_size
  83. self.anchor_num = len(self.anchor_ratios) * len(self.anchor_scales)
  84. hanning = np.hanning(self.score_size)
  85. window = np.outer(hanning, hanning)
  86. self.window = np.tile(window.flatten(), self.anchor_num)
  87. self.anchors = self.generate_anchor(self.score_size)
  88. self.model = model
  89. def get_subwindow(self, im, pos, model_sz, original_sz, avg_chans):
  90. """
  91. Args:
  92. im: bgr based input image frame
  93. pos: position of the center of the frame
  94. model_sz: exemplar / target image size
  95. s_z: original / search image size
  96. avg_chans: channel average
  97. Return:
  98. im_patch: sub_windows for the given image input
  99. """
  100. if isinstance(pos, float):
  101. pos = [pos, pos]
  102. sz = original_sz
  103. im_h, im_w, im_d = im.shape
  104. c = (original_sz + 1) / 2
  105. cx, cy = pos
  106. context_xmin = np.floor(cx - c + 0.5)
  107. context_xmax = context_xmin + sz - 1
  108. context_ymin = np.floor(cy - c + 0.5)
  109. context_ymax = context_ymin + sz - 1
  110. left_pad = int(max(0., -context_xmin))
  111. top_pad = int(max(0., -context_ymin))
  112. right_pad = int(max(0., context_xmax - im_w + 1))
  113. bottom_pad = int(max(0., context_ymax - im_h + 1))
  114. context_xmin += left_pad
  115. context_xmax += left_pad
  116. context_ymin += top_pad
  117. context_ymax += top_pad
  118. if any([top_pad, bottom_pad, left_pad, right_pad]):
  119. size = (im_h + top_pad + bottom_pad, im_w + left_pad + right_pad, im_d)
  120. te_im = np.zeros(size, np.uint8)
  121. te_im[top_pad:top_pad + im_h, left_pad:left_pad + im_w, :] = im
  122. if top_pad:
  123. te_im[0:top_pad, left_pad:left_pad + im_w, :] = avg_chans
  124. if bottom_pad:
  125. te_im[im_h + top_pad:, left_pad:left_pad + im_w, :] = avg_chans
  126. if left_pad:
  127. te_im[:, 0:left_pad, :] = avg_chans
  128. if right_pad:
  129. te_im[:, im_w + left_pad:, :] = avg_chans
  130. im_patch = te_im[int(context_ymin):int(context_ymax + 1),
  131. int(context_xmin):int(context_xmax + 1), :]
  132. else:
  133. im_patch = im[int(context_ymin):int(context_ymax + 1),
  134. int(context_xmin):int(context_xmax + 1), :]
  135. if not np.array_equal(model_sz, original_sz):
  136. im_patch = cv.resize(im_patch, (model_sz, model_sz))
  137. im_patch = im_patch.transpose(2, 0, 1)
  138. im_patch = im_patch[np.newaxis, :, :, :]
  139. im_patch = im_patch.astype(np.float32)
  140. return im_patch
  141. def generate_anchor(self, score_size):
  142. """
  143. Args:
  144. im: bgr based input image frame
  145. pos: position of the center of the frame
  146. model_sz: exemplar / target image size
  147. s_z: original / search image size
  148. avg_chans: channel average
  149. Return:
  150. anchor: anchors for pre-determined values of stride, ratio, and scale
  151. """
  152. anchors = Anchors(self.anchor_stride, self.anchor_ratios, self.anchor_scales)
  153. anchor = anchors.anchors
  154. x1, y1, x2, y2 = anchor[:, 0], anchor[:, 1], anchor[:, 2], anchor[:, 3]
  155. anchor = np.stack([(x1 + x2) * 0.5, (y1 + y2) * 0.5, x2 - x1, y2 - y1], 1)
  156. total_stride = anchors.stride
  157. anchor_num = anchors.anchor_num
  158. anchor = np.tile(anchor, score_size * score_size).reshape((-1, 4))
  159. ori = - (score_size // 2) * total_stride
  160. xx, yy = np.meshgrid([ori + total_stride * dx for dx in range(score_size)],
  161. [ori + total_stride * dy for dy in range(score_size)])
  162. xx, yy = np.tile(xx.flatten(), (anchor_num, 1)).flatten(), \
  163. np.tile(yy.flatten(), (anchor_num, 1)).flatten()
  164. anchor[:, 0], anchor[:, 1] = xx.astype(np.float32), yy.astype(np.float32)
  165. return anchor
  166. def _convert_bbox(self, delta, anchor):
  167. """
  168. Args:
  169. delta: localisation
  170. anchor: anchor of pre-determined anchor size
  171. Return:
  172. delta: prediction of bounding box
  173. """
  174. delta_transpose = np.transpose(delta, (1, 2, 3, 0))
  175. delta_contig = np.ascontiguousarray(delta_transpose)
  176. delta = delta_contig.reshape(4, -1)
  177. delta[0, :] = delta[0, :] * anchor[:, 2] + anchor[:, 0]
  178. delta[1, :] = delta[1, :] * anchor[:, 3] + anchor[:, 1]
  179. delta[2, :] = np.exp(delta[2, :]) * anchor[:, 2]
  180. delta[3, :] = np.exp(delta[3, :]) * anchor[:, 3]
  181. return delta
  182. def _softmax(self, x):
  183. """
  184. Softmax in the direction of the depth of the layer
  185. """
  186. x = x.astype(dtype=np.float32)
  187. x_max = x.max(axis=1)[:, np.newaxis]
  188. e_x = np.exp(x-x_max)
  189. div = np.sum(e_x, axis=1)[:, np.newaxis]
  190. y = e_x / div
  191. return y
  192. def _convert_score(self, score):
  193. """
  194. Args:
  195. cls: score
  196. Return:
  197. cls: score for cls
  198. """
  199. score_transpose = np.transpose(score, (1, 2, 3, 0))
  200. score_con = np.ascontiguousarray(score_transpose)
  201. score_view = score_con.reshape(2, -1)
  202. score = np.transpose(score_view, (1, 0))
  203. score = self._softmax(score)
  204. return score[:,1]
  205. def _bbox_clip(self, cx, cy, width, height, boundary):
  206. """
  207. Adjusting the bounding box
  208. """
  209. bbox_h, bbox_w = boundary
  210. cx = max(0, min(cx, bbox_w))
  211. cy = max(0, min(cy, bbox_h))
  212. width = max(10, min(width, bbox_w))
  213. height = max(10, min(height, bbox_h))
  214. return cx, cy, width, height
  215. def init(self, img, bbox):
  216. """
  217. Args:
  218. img(np.ndarray): bgr based input image frame
  219. bbox: (x, y, w, h): bounding box
  220. """
  221. x, y, w, h = bbox
  222. self.center_pos = np.array([x + (w - 1) / 2, y + (h - 1) / 2])
  223. self.h = h
  224. self.w = w
  225. w_z = self.w + self.track_context_amount * np.add(h, w)
  226. h_z = self.h + self.track_context_amount * np.add(h, w)
  227. s_z = round(np.sqrt(w_z * h_z))
  228. self.channel_average = np.mean(img, axis=(0, 1))
  229. z_crop = self.get_subwindow(img, self.center_pos, self.track_exemplar_size, s_z, self.channel_average)
  230. self.model.template(z_crop)
  231. def track(self, img):
  232. """
  233. Args:
  234. img(np.ndarray): BGR image
  235. Return:
  236. bbox(list):[x, y, width, height]
  237. """
  238. w_z = self.w + self.track_context_amount * np.add(self.w, self.h)
  239. h_z = self.h + self.track_context_amount * np.add(self.w, self.h)
  240. s_z = np.sqrt(w_z * h_z)
  241. scale_z = self.track_exemplar_size / s_z
  242. s_x = s_z * (self.track_instance_size / self.track_exemplar_size)
  243. x_crop = self.get_subwindow(img, self.center_pos, self.track_instance_size, round(s_x), self.channel_average)
  244. outputs = self.model.track(x_crop)
  245. score = self._convert_score(outputs['cls'])
  246. pred_bbox = self._convert_bbox(outputs['loc'], self.anchors)
  247. def change(r):
  248. return np.maximum(r, 1. / r)
  249. def sz(w, h):
  250. pad = (w + h) * 0.5
  251. return np.sqrt((w + pad) * (h + pad))
  252. # scale penalty
  253. s_c = change(sz(pred_bbox[2, :], pred_bbox[3, :]) /
  254. (sz(self.w * scale_z, self.h * scale_z)))
  255. # aspect ratio penalty
  256. r_c = change((self.w / self.h) /
  257. (pred_bbox[2, :] / pred_bbox[3, :]))
  258. penalty = np.exp(-(r_c * s_c - 1) * self.track_penalty_k)
  259. pscore = penalty * score
  260. # window penalty
  261. pscore = pscore * (1 - self.track_window_influence) + \
  262. self.window * self.track_window_influence
  263. best_idx = np.argmax(pscore)
  264. bbox = pred_bbox[:, best_idx] / scale_z
  265. lr = penalty[best_idx] * score[best_idx] * self.track_lr
  266. cpx, cpy = self.center_pos
  267. x,y,w,h = bbox
  268. cx = x + cpx
  269. cy = y + cpy
  270. # smooth bbox
  271. width = self.w * (1 - lr) + w * lr
  272. height = self.h * (1 - lr) + h * lr
  273. # clip boundary
  274. cx, cy, width, height = self._bbox_clip(cx, cy, width, height, img.shape[:2])
  275. # update state
  276. self.center_pos = np.array([cx, cy])
  277. self.w = width
  278. self.h = height
  279. bbox = [cx - width / 2, cy - height / 2, width, height]
  280. best_score = score[best_idx]
  281. return {'bbox': bbox, 'best_score': best_score}
  282. def get_frames(video_name):
  283. """
  284. Args:
  285. Path to input video frame
  286. Return:
  287. Frame
  288. """
  289. cap = cv.VideoCapture(video_name if video_name else 0)
  290. while True:
  291. ret, frame = cap.read()
  292. if ret:
  293. yield frame
  294. else:
  295. break
  296. def main():
  297. """ Sample SiamRPN Tracker
  298. """
  299. # Computation backends supported by layers
  300. backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV,
  301. cv.dnn.DNN_BACKEND_VKCOM, cv.dnn.DNN_BACKEND_CUDA)
  302. # Target Devices for computation
  303. targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD,
  304. cv.dnn.DNN_TARGET_VULKAN, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16)
  305. parser = argparse.ArgumentParser(description='Use this script to run SiamRPN++ Visual Tracker',
  306. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  307. parser.add_argument('--input_video', type=str, help='Path to input video file. Skip this argument to capture frames from a camera.')
  308. parser.add_argument('--target_net', type=str, default='target_net.onnx', help='Path to part of SiamRPN++ ran on target frame.')
  309. parser.add_argument('--search_net', type=str, default='search_net.onnx', help='Path to part of SiamRPN++ ran on search frame.')
  310. parser.add_argument('--rpn_head', type=str, default='rpn_head.onnx', help='Path to RPN Head ONNX model.')
  311. parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
  312. help="Select a computation backend: "
  313. "%d: automatically (by default), "
  314. "%d: Halide, "
  315. "%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  316. "%d: OpenCV Implementation, "
  317. "%d: VKCOM, "
  318. "%d: CUDA" % backends)
  319. parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
  320. help='Select a target device: '
  321. '%d: CPU target (by default), '
  322. '%d: OpenCL, '
  323. '%d: OpenCL FP16, '
  324. '%d: Myriad, '
  325. '%d: Vulkan, '
  326. '%d: CUDA, '
  327. '%d: CUDA fp16 (half-float preprocess)' % targets)
  328. args, _ = parser.parse_known_args()
  329. if args.input_video and not os.path.isfile(args.input_video):
  330. raise OSError("Input video file does not exist")
  331. if not os.path.isfile(args.target_net):
  332. raise OSError("Target Net does not exist")
  333. if not os.path.isfile(args.search_net):
  334. raise OSError("Search Net does not exist")
  335. if not os.path.isfile(args.rpn_head):
  336. raise OSError("RPN Head Net does not exist")
  337. #Load the Networks
  338. target_net = cv.dnn.readNetFromONNX(args.target_net)
  339. target_net.setPreferableBackend(args.backend)
  340. target_net.setPreferableTarget(args.target)
  341. search_net = cv.dnn.readNetFromONNX(args.search_net)
  342. search_net.setPreferableBackend(args.backend)
  343. search_net.setPreferableTarget(args.target)
  344. rpn_head = cv.dnn.readNetFromONNX(args.rpn_head)
  345. rpn_head.setPreferableBackend(args.backend)
  346. rpn_head.setPreferableTarget(args.target)
  347. model = ModelBuilder(target_net, search_net, rpn_head)
  348. tracker = SiamRPNTracker(model)
  349. first_frame = True
  350. cv.namedWindow('SiamRPN++ Tracker', cv.WINDOW_AUTOSIZE)
  351. for frame in get_frames(args.input_video):
  352. if first_frame:
  353. try:
  354. init_rect = cv.selectROI('SiamRPN++ Tracker', frame, False, False)
  355. except:
  356. exit()
  357. tracker.init(frame, init_rect)
  358. first_frame = False
  359. else:
  360. outputs = tracker.track(frame)
  361. bbox = list(map(int, outputs['bbox']))
  362. x,y,w,h = bbox
  363. cv.rectangle(frame, (x, y), (x+w, y+h), (0, 255, 0), 3)
  364. cv.imshow('SiamRPN++ Tracker', frame)
  365. key = cv.waitKey(1)
  366. if key == ord("q"):
  367. break
  368. if __name__ == '__main__':
  369. main()