virtual_try_on.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472
  1. #!/usr/bin/env python3
  2. '''
  3. You can download the Geometric Matching Module model from https://www.dropbox.com/s/tyhc73xa051grjp/cp_vton_gmm.onnx?dl=0
  4. You can download the Try-On Module model from https://www.dropbox.com/s/q2x97ve2h53j66k/cp_vton_tom.onnx?dl=0
  5. You can download the cloth segmentation model from https://www.dropbox.com/s/qag9vzambhhkvxr/lip_jppnet_384.pb?dl=0
  6. You can find the OpenPose proto in opencv_extra/testdata/dnn/openpose_pose_coco.prototxt
  7. and get .caffemodel using opencv_extra/testdata/dnn/download_models.py
  8. '''
  9. import argparse
  10. import os.path
  11. import numpy as np
  12. import cv2 as cv
  13. from numpy import linalg
  14. from common import findFile
  15. from human_parsing import parse_human
  16. backends = (cv.dnn.DNN_BACKEND_DEFAULT, cv.dnn.DNN_BACKEND_HALIDE, cv.dnn.DNN_BACKEND_INFERENCE_ENGINE, cv.dnn.DNN_BACKEND_OPENCV,
  17. cv.dnn.DNN_BACKEND_VKCOM, cv.dnn.DNN_BACKEND_CUDA)
  18. targets = (cv.dnn.DNN_TARGET_CPU, cv.dnn.DNN_TARGET_OPENCL, cv.dnn.DNN_TARGET_OPENCL_FP16, cv.dnn.DNN_TARGET_MYRIAD, cv.dnn.DNN_TARGET_HDDL,
  19. cv.dnn.DNN_TARGET_VULKAN, cv.dnn.DNN_TARGET_CUDA, cv.dnn.DNN_TARGET_CUDA_FP16)
  20. parser = argparse.ArgumentParser(description='Use this script to run virtial try-on using CP-VTON',
  21. formatter_class=argparse.ArgumentDefaultsHelpFormatter)
  22. parser.add_argument('--input_image', '-i', required=True, help='Path to image with person.')
  23. parser.add_argument('--input_cloth', '-c', required=True, help='Path to target cloth image')
  24. parser.add_argument('--gmm_model', '-gmm', default='cp_vton_gmm.onnx', help='Path to Geometric Matching Module .onnx model.')
  25. parser.add_argument('--tom_model', '-tom', default='cp_vton_tom.onnx', help='Path to Try-On Module .onnx model.')
  26. parser.add_argument('--segmentation_model', default='lip_jppnet_384.pb', help='Path to cloth segmentation .pb model.')
  27. parser.add_argument('--openpose_proto', default='openpose_pose_coco.prototxt', help='Path to OpenPose .prototxt model was trained on COCO dataset.')
  28. parser.add_argument('--openpose_model', default='openpose_pose_coco.caffemodel', help='Path to OpenPose .caffemodel model was trained on COCO dataset.')
  29. parser.add_argument('--backend', choices=backends, default=cv.dnn.DNN_BACKEND_DEFAULT, type=int,
  30. help="Choose one of computation backends: "
  31. "%d: automatically (by default), "
  32. "%d: Halide language (http://halide-lang.org/), "
  33. "%d: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
  34. "%d: OpenCV implementation, "
  35. "%d: VKCOM, "
  36. "%d: CUDA" % backends)
  37. parser.add_argument('--target', choices=targets, default=cv.dnn.DNN_TARGET_CPU, type=int,
  38. help='Choose one of target computation devices: '
  39. '%d: CPU target (by default), '
  40. '%d: OpenCL, '
  41. '%d: OpenCL fp16 (half-float precision), '
  42. '%d: NCS2 VPU, '
  43. '%d: HDDL VPU, '
  44. '%d: Vulkan, '
  45. '%d: CUDA, '
  46. '%d: CUDA fp16 (half-float preprocess)'% targets)
  47. args, _ = parser.parse_known_args()
  48. def get_pose_map(image, proto_path, model_path, backend, target, height=256, width=192):
  49. radius = 5
  50. inp = cv.dnn.blobFromImage(image, 1.0 / 255, (width, height))
  51. net = cv.dnn.readNet(proto_path, model_path)
  52. net.setPreferableBackend(backend)
  53. net.setPreferableTarget(target)
  54. net.setInput(inp)
  55. out = net.forward()
  56. threshold = 0.1
  57. _, out_c, out_h, out_w = out.shape
  58. pose_map = np.zeros((height, width, out_c - 1))
  59. # last label: Background
  60. for i in range(0, out.shape[1] - 1):
  61. heatMap = out[0, i, :, :]
  62. keypoint = np.full((height, width), -1)
  63. _, conf, _, point = cv.minMaxLoc(heatMap)
  64. x = width * point[0] // out_w
  65. y = height * point[1] // out_h
  66. if conf > threshold and x > 0 and y > 0:
  67. keypoint[y - radius:y + radius, x - radius:x + radius] = 1
  68. pose_map[:, :, i] = keypoint
  69. pose_map = pose_map.transpose(2, 0, 1)
  70. return pose_map
  71. class BilinearFilter(object):
  72. """
  73. PIL bilinear resize implementation
  74. image = image.resize((image_width // 16, image_height // 16), Image.BILINEAR)
  75. """
  76. def _precompute_coeffs(self, inSize, outSize):
  77. filterscale = max(1.0, inSize / outSize)
  78. ksize = int(np.ceil(filterscale)) * 2 + 1
  79. kk = np.zeros(shape=(outSize * ksize, ), dtype=np.float32)
  80. bounds = np.empty(shape=(outSize * 2, ), dtype=np.int32)
  81. centers = (np.arange(outSize) + 0.5) * filterscale + 0.5
  82. bounds[::2] = np.where(centers - filterscale < 0, 0, centers - filterscale)
  83. bounds[1::2] = np.where(centers + filterscale > inSize, inSize, centers + filterscale) - bounds[::2]
  84. xmins = bounds[::2] - centers + 1
  85. points = np.array([np.arange(row) + xmins[i] for i, row in enumerate(bounds[1::2])]) / filterscale
  86. for xx in range(0, outSize):
  87. point = points[xx]
  88. bilinear = np.where(point < 1.0, 1.0 - abs(point), 0.0)
  89. ww = np.sum(bilinear)
  90. kk[xx * ksize : xx * ksize + bilinear.size] = np.where(ww == 0.0, bilinear, bilinear / ww)
  91. return bounds, kk, ksize
  92. def _resample_horizontal(self, out, img, ksize, bounds, kk):
  93. for yy in range(0, out.shape[0]):
  94. for xx in range(0, out.shape[1]):
  95. xmin = bounds[xx * 2 + 0]
  96. xmax = bounds[xx * 2 + 1]
  97. k = kk[xx * ksize : xx * ksize + xmax]
  98. out[yy, xx] = np.round(np.sum(img[yy, xmin : xmin + xmax] * k))
  99. def _resample_vertical(self, out, img, ksize, bounds, kk):
  100. for yy in range(0, out.shape[0]):
  101. ymin = bounds[yy * 2 + 0]
  102. ymax = bounds[yy * 2 + 1]
  103. k = kk[yy * ksize: yy * ksize + ymax]
  104. out[yy] = np.round(np.sum(img[ymin : ymin + ymax, 0:out.shape[1]] * k[:, np.newaxis], axis=0))
  105. def imaging_resample(self, img, xsize, ysize):
  106. height, width = img.shape[0:2]
  107. bounds_horiz, kk_horiz, ksize_horiz = self._precompute_coeffs(width, xsize)
  108. bounds_vert, kk_vert, ksize_vert = self._precompute_coeffs(height, ysize)
  109. out_hor = np.empty((img.shape[0], xsize), dtype=np.uint8)
  110. self._resample_horizontal(out_hor, img, ksize_horiz, bounds_horiz, kk_horiz)
  111. out = np.empty((ysize, xsize), dtype=np.uint8)
  112. self._resample_vertical(out, out_hor, ksize_vert, bounds_vert, kk_vert)
  113. return out
  114. class CpVton(object):
  115. def __init__(self, gmm_model, tom_model, backend, target):
  116. super(CpVton, self).__init__()
  117. self.gmm_net = cv.dnn.readNet(gmm_model)
  118. self.tom_net = cv.dnn.readNet(tom_model)
  119. self.gmm_net.setPreferableBackend(backend)
  120. self.gmm_net.setPreferableTarget(target)
  121. self.tom_net.setPreferableBackend(backend)
  122. self.tom_net.setPreferableTarget(target)
  123. def prepare_agnostic(self, segm_image, input_image, pose_map, height=256, width=192):
  124. palette = {
  125. 'Background' : (0, 0, 0),
  126. 'Hat' : (128, 0, 0),
  127. 'Hair' : (255, 0, 0),
  128. 'Glove' : (0, 85, 0),
  129. 'Sunglasses' : (170, 0, 51),
  130. 'UpperClothes' : (255, 85, 0),
  131. 'Dress' : (0, 0, 85),
  132. 'Coat' : (0, 119, 221),
  133. 'Socks' : (85, 85, 0),
  134. 'Pants' : (0, 85, 85),
  135. 'Jumpsuits' : (85, 51, 0),
  136. 'Scarf' : (52, 86, 128),
  137. 'Skirt' : (0, 128, 0),
  138. 'Face' : (0, 0, 255),
  139. 'Left-arm' : (51, 170, 221),
  140. 'Right-arm' : (0, 255, 255),
  141. 'Left-leg' : (85, 255, 170),
  142. 'Right-leg' : (170, 255, 85),
  143. 'Left-shoe' : (255, 255, 0),
  144. 'Right-shoe' : (255, 170, 0)
  145. }
  146. color2label = {val: key for key, val in palette.items()}
  147. head_labels = ['Hat', 'Hair', 'Sunglasses', 'Face', 'Pants', 'Skirt']
  148. segm_image = cv.cvtColor(segm_image, cv.COLOR_BGR2RGB)
  149. phead = np.zeros((1, height, width), dtype=np.float32)
  150. pose_shape = np.zeros((height, width), dtype=np.uint8)
  151. for r in range(height):
  152. for c in range(width):
  153. pixel = tuple(segm_image[r, c])
  154. if tuple(pixel) in color2label:
  155. if color2label[pixel] in head_labels:
  156. phead[0, r, c] = 1
  157. if color2label[pixel] != 'Background':
  158. pose_shape[r, c] = 255
  159. input_image = cv.dnn.blobFromImage(input_image, 1.0 / 127.5, (width, height), mean=(127.5, 127.5, 127.5), swapRB=True)
  160. input_image = input_image.squeeze(0)
  161. img_head = input_image * phead - (1 - phead)
  162. downsample = BilinearFilter()
  163. down = downsample.imaging_resample(pose_shape, width // 16, height // 16)
  164. res_shape = cv.resize(down, (width, height), cv.INTER_LINEAR)
  165. res_shape = cv.dnn.blobFromImage(res_shape, 1.0 / 127.5, mean=(127.5, 127.5, 127.5), swapRB=True)
  166. res_shape = res_shape.squeeze(0)
  167. agnostic = np.concatenate((res_shape, img_head, pose_map), axis=0)
  168. agnostic = np.expand_dims(agnostic, axis=0)
  169. return agnostic.astype(np.float32)
  170. def get_warped_cloth(self, cloth_img, agnostic, height=256, width=192):
  171. cloth = cv.dnn.blobFromImage(cloth_img, 1.0 / 127.5, (width, height), mean=(127.5, 127.5, 127.5), swapRB=True)
  172. self.gmm_net.setInput(agnostic, "input.1")
  173. self.gmm_net.setInput(cloth, "input.18")
  174. theta = self.gmm_net.forward()
  175. grid = self._generate_grid(theta)
  176. warped_cloth = self._bilinear_sampler(cloth, grid).astype(np.float32)
  177. return warped_cloth
  178. def get_tryon(self, agnostic, warp_cloth):
  179. inp = np.concatenate([agnostic, warp_cloth], axis=1)
  180. self.tom_net.setInput(inp)
  181. out = self.tom_net.forward()
  182. p_rendered, m_composite = np.split(out, [3], axis=1)
  183. p_rendered = np.tanh(p_rendered)
  184. m_composite = 1 / (1 + np.exp(-m_composite))
  185. p_tryon = warp_cloth * m_composite + p_rendered * (1 - m_composite)
  186. rgb_p_tryon = cv.cvtColor(p_tryon.squeeze(0).transpose(1, 2, 0), cv.COLOR_BGR2RGB)
  187. rgb_p_tryon = (rgb_p_tryon + 1) / 2
  188. return rgb_p_tryon
  189. def _compute_L_inverse(self, X, Y):
  190. N = X.shape[0]
  191. Xmat = np.tile(X, (1, N))
  192. Ymat = np.tile(Y, (1, N))
  193. P_dist_squared = np.power(Xmat - Xmat.transpose(1, 0), 2) + np.power(Ymat - Ymat.transpose(1, 0), 2)
  194. P_dist_squared[P_dist_squared == 0] = 1
  195. K = np.multiply(P_dist_squared, np.log(P_dist_squared))
  196. O = np.ones([N, 1], dtype=np.float32)
  197. Z = np.zeros([3, 3], dtype=np.float32)
  198. P = np.concatenate([O, X, Y], axis=1)
  199. first = np.concatenate((K, P), axis=1)
  200. second = np.concatenate((P.transpose(1, 0), Z), axis=1)
  201. L = np.concatenate((first, second), axis=0)
  202. Li = linalg.inv(L)
  203. return Li
  204. def _prepare_to_transform(self, out_h=256, out_w=192, grid_size=5):
  205. grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
  206. grid_X = np.expand_dims(np.expand_dims(grid_X, axis=0), axis=3)
  207. grid_Y = np.expand_dims(np.expand_dims(grid_Y, axis=0), axis=3)
  208. axis_coords = np.linspace(-1, 1, grid_size)
  209. N = grid_size ** 2
  210. P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
  211. P_X = np.reshape(P_X,(-1, 1))
  212. P_Y = np.reshape(P_Y,(-1, 1))
  213. P_X = np.expand_dims(np.expand_dims(np.expand_dims(P_X, axis=2), axis=3), axis=4).transpose(4, 1, 2, 3, 0)
  214. P_Y = np.expand_dims(np.expand_dims(np.expand_dims(P_Y, axis=2), axis=3), axis=4).transpose(4, 1, 2, 3, 0)
  215. return grid_X, grid_Y, N, P_X, P_Y
  216. def _expand_torch(self, X, shape):
  217. if len(X.shape) != len(shape):
  218. return X.flatten().reshape(shape)
  219. else:
  220. axis = [1 if src == dst else dst for src, dst in zip(X.shape, shape)]
  221. return np.tile(X, axis)
  222. def _apply_transformation(self, theta, points, N, P_X, P_Y):
  223. if len(theta.shape) == 2:
  224. theta = np.expand_dims(np.expand_dims(theta, axis=2), axis=3)
  225. batch_size = theta.shape[0]
  226. P_X_base = np.copy(P_X)
  227. P_Y_base = np.copy(P_Y)
  228. Li = self._compute_L_inverse(np.reshape(P_X, (N, -1)), np.reshape(P_Y, (N, -1)))
  229. Li = np.expand_dims(Li, axis=0)
  230. # split theta into point coordinates
  231. Q_X = np.squeeze(theta[:, :N, :, :], axis=3)
  232. Q_Y = np.squeeze(theta[:, N:, :, :], axis=3)
  233. Q_X += self._expand_torch(P_X_base, Q_X.shape)
  234. Q_Y += self._expand_torch(P_Y_base, Q_Y.shape)
  235. points_b = points.shape[0]
  236. points_h = points.shape[1]
  237. points_w = points.shape[2]
  238. P_X = self._expand_torch(P_X, (1, points_h, points_w, 1, N))
  239. P_Y = self._expand_torch(P_Y, (1, points_h, points_w, 1, N))
  240. W_X = self._expand_torch(Li[:,:N,:N], (batch_size, N, N)) @ Q_X
  241. W_Y = self._expand_torch(Li[:,:N,:N], (batch_size, N, N)) @ Q_Y
  242. W_X = np.expand_dims(np.expand_dims(W_X, axis=3), axis=4).transpose(0, 4, 2, 3, 1)
  243. W_X = np.repeat(W_X, points_h, axis=1)
  244. W_X = np.repeat(W_X, points_w, axis=2)
  245. W_Y = np.expand_dims(np.expand_dims(W_Y, axis=3), axis=4).transpose(0, 4, 2, 3, 1)
  246. W_Y = np.repeat(W_Y, points_h, axis=1)
  247. W_Y = np.repeat(W_Y, points_w, axis=2)
  248. A_X = self._expand_torch(Li[:, N:, :N], (batch_size, 3, N)) @ Q_X
  249. A_Y = self._expand_torch(Li[:, N:, :N], (batch_size, 3, N)) @ Q_Y
  250. A_X = np.expand_dims(np.expand_dims(A_X, axis=3), axis=4).transpose(0, 4, 2, 3, 1)
  251. A_X = np.repeat(A_X, points_h, axis=1)
  252. A_X = np.repeat(A_X, points_w, axis=2)
  253. A_Y = np.expand_dims(np.expand_dims(A_Y, axis=3), axis=4).transpose(0, 4, 2, 3, 1)
  254. A_Y = np.repeat(A_Y, points_h, axis=1)
  255. A_Y = np.repeat(A_Y, points_w, axis=2)
  256. points_X_for_summation = np.expand_dims(np.expand_dims(points[:, :, :, 0], axis=3), axis=4)
  257. points_X_for_summation = self._expand_torch(points_X_for_summation, points[:, :, :, 0].shape + (1, N))
  258. points_Y_for_summation = np.expand_dims(np.expand_dims(points[:, :, :, 1], axis=3), axis=4)
  259. points_Y_for_summation = self._expand_torch(points_Y_for_summation, points[:, :, :, 0].shape + (1, N))
  260. if points_b == 1:
  261. delta_X = points_X_for_summation - P_X
  262. delta_Y = points_Y_for_summation - P_Y
  263. else:
  264. delta_X = points_X_for_summation - self._expand_torch(P_X, points_X_for_summation.shape)
  265. delta_Y = points_Y_for_summation - self._expand_torch(P_Y, points_Y_for_summation.shape)
  266. dist_squared = np.power(delta_X, 2) + np.power(delta_Y, 2)
  267. dist_squared[dist_squared == 0] = 1
  268. U = np.multiply(dist_squared, np.log(dist_squared))
  269. points_X_batch = np.expand_dims(points[:,:,:,0], axis=3)
  270. points_Y_batch = np.expand_dims(points[:,:,:,1], axis=3)
  271. if points_b == 1:
  272. points_X_batch = self._expand_torch(points_X_batch, (batch_size, ) + points_X_batch.shape[1:])
  273. points_Y_batch = self._expand_torch(points_Y_batch, (batch_size, ) + points_Y_batch.shape[1:])
  274. points_X_prime = A_X[:,:,:,:,0]+ \
  275. np.multiply(A_X[:,:,:,:,1], points_X_batch) + \
  276. np.multiply(A_X[:,:,:,:,2], points_Y_batch) + \
  277. np.sum(np.multiply(W_X, self._expand_torch(U, W_X.shape)), 4)
  278. points_Y_prime = A_Y[:,:,:,:,0]+ \
  279. np.multiply(A_Y[:,:,:,:,1], points_X_batch) + \
  280. np.multiply(A_Y[:,:,:,:,2], points_Y_batch) + \
  281. np.sum(np.multiply(W_Y, self._expand_torch(U, W_Y.shape)), 4)
  282. return np.concatenate((points_X_prime, points_Y_prime), 3)
  283. def _generate_grid(self, theta):
  284. grid_X, grid_Y, N, P_X, P_Y = self._prepare_to_transform()
  285. warped_grid = self._apply_transformation(theta, np.concatenate((grid_X, grid_Y), axis=3), N, P_X, P_Y)
  286. return warped_grid
  287. def _bilinear_sampler(self, img, grid):
  288. x, y = grid[:,:,:,0], grid[:,:,:,1]
  289. H = img.shape[2]
  290. W = img.shape[3]
  291. max_y = H - 1
  292. max_x = W - 1
  293. # rescale x and y to [0, W-1/H-1]
  294. x = 0.5 * (x + 1.0) * (max_x - 1)
  295. y = 0.5 * (y + 1.0) * (max_y - 1)
  296. # grab 4 nearest corner points for each (x_i, y_i)
  297. x0 = np.floor(x).astype(int)
  298. x1 = x0 + 1
  299. y0 = np.floor(y).astype(int)
  300. y1 = y0 + 1
  301. # calculate deltas
  302. wa = (x1 - x) * (y1 - y)
  303. wb = (x1 - x) * (y - y0)
  304. wc = (x - x0) * (y1 - y)
  305. wd = (x - x0) * (y - y0)
  306. # clip to range [0, H-1/W-1] to not violate img boundaries
  307. x0 = np.clip(x0, 0, max_x)
  308. x1 = np.clip(x1, 0, max_x)
  309. y0 = np.clip(y0, 0, max_y)
  310. y1 = np.clip(y1, 0, max_y)
  311. # get pixel value at corner coords
  312. img = img.reshape(-1, H, W)
  313. Ia = img[:, y0, x0].swapaxes(0, 1)
  314. Ib = img[:, y1, x0].swapaxes(0, 1)
  315. Ic = img[:, y0, x1].swapaxes(0, 1)
  316. Id = img[:, y1, x1].swapaxes(0, 1)
  317. wa = np.expand_dims(wa, axis=0)
  318. wb = np.expand_dims(wb, axis=0)
  319. wc = np.expand_dims(wc, axis=0)
  320. wd = np.expand_dims(wd, axis=0)
  321. # compute output
  322. out = wa*Ia + wb*Ib + wc*Ic + wd*Id
  323. return out
  324. class CorrelationLayer(object):
  325. def __init__(self, params, blobs):
  326. super(CorrelationLayer, self).__init__()
  327. def getMemoryShapes(self, inputs):
  328. fetureAShape = inputs[0]
  329. b, _, h, w = fetureAShape
  330. return [[b, h * w, h, w]]
  331. def forward(self, inputs):
  332. feature_A, feature_B = inputs
  333. b, c, h, w = feature_A.shape
  334. feature_A = feature_A.transpose(0, 1, 3, 2)
  335. feature_A = np.reshape(feature_A, (b, c, h * w))
  336. feature_B = np.reshape(feature_B, (b, c, h * w))
  337. feature_B = feature_B.transpose(0, 2, 1)
  338. feature_mul = feature_B @ feature_A
  339. feature_mul= np.reshape(feature_mul, (b, h, w, h * w))
  340. feature_mul = feature_mul.transpose(0, 1, 3, 2)
  341. correlation_tensor = feature_mul.transpose(0, 2, 1, 3)
  342. correlation_tensor = np.ascontiguousarray(correlation_tensor)
  343. return [correlation_tensor]
  344. if __name__ == "__main__":
  345. if not os.path.isfile(args.gmm_model):
  346. raise OSError("GMM model not exist")
  347. if not os.path.isfile(args.tom_model):
  348. raise OSError("TOM model not exist")
  349. if not os.path.isfile(args.segmentation_model):
  350. raise OSError("Segmentation model not exist")
  351. if not os.path.isfile(findFile(args.openpose_proto)):
  352. raise OSError("OpenPose proto not exist")
  353. if not os.path.isfile(findFile(args.openpose_model)):
  354. raise OSError("OpenPose model not exist")
  355. person_img = cv.imread(args.input_image)
  356. ratio = 256 / 192
  357. inp_h, inp_w, _ = person_img.shape
  358. current_ratio = inp_h / inp_w
  359. if current_ratio > ratio:
  360. center_h = inp_h // 2
  361. out_h = inp_w * ratio
  362. start = int(center_h - out_h // 2)
  363. end = int(center_h + out_h // 2)
  364. person_img = person_img[start:end, ...]
  365. else:
  366. center_w = inp_w // 2
  367. out_w = inp_h / ratio
  368. start = int(center_w - out_w // 2)
  369. end = int(center_w + out_w // 2)
  370. person_img = person_img[:, start:end, :]
  371. cloth_img = cv.imread(args.input_cloth)
  372. pose = get_pose_map(person_img, findFile(args.openpose_proto),
  373. findFile(args.openpose_model), args.backend, args.target)
  374. segm_image = parse_human(person_img, args.segmentation_model)
  375. segm_image = cv.resize(segm_image, (192, 256), cv.INTER_LINEAR)
  376. cv.dnn_registerLayer('Correlation', CorrelationLayer)
  377. model = CpVton(args.gmm_model, args.tom_model, args.backend, args.target)
  378. agnostic = model.prepare_agnostic(segm_image, person_img, pose)
  379. warped_cloth = model.get_warped_cloth(cloth_img, agnostic)
  380. output = model.get_tryon(agnostic, warped_cloth)
  381. cv.dnn_unregisterLayer('Correlation')
  382. winName = 'Virtual Try-On'
  383. cv.namedWindow(winName, cv.WINDOW_AUTOSIZE)
  384. cv.imshow(winName, output)
  385. cv.waitKey()