tf_text_graph_efficientdet.py 9.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # This file is a part of OpenCV project.
  2. # It is a subject to the license terms in the LICENSE file found in the top-level directory
  3. # of this distribution and at http://opencv.org/license.html.
  4. #
  5. # Copyright (C) 2020, Intel Corporation, all rights reserved.
  6. # Third party copyrights are property of their respective owners.
  7. #
  8. # Use this script to get the text graph representation (.pbtxt) of EfficientDet
  9. # deep learning network trained in https://github.com/google/automl.
  10. # Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
  11. # See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
  12. import argparse
  13. import re
  14. from math import sqrt
  15. from tf_text_graph_common import *
  16. class AnchorGenerator:
  17. def __init__(self, min_level, aspect_ratios, num_scales, anchor_scale):
  18. self.min_level = min_level
  19. self.aspect_ratios = aspect_ratios
  20. self.anchor_scale = anchor_scale
  21. self.scales = [2**(float(s) / num_scales) for s in range(num_scales)]
  22. def get(self, layer_id):
  23. widths = []
  24. heights = []
  25. for s in self.scales:
  26. for a in self.aspect_ratios:
  27. base_anchor_size = 2**(self.min_level + layer_id) * self.anchor_scale
  28. heights.append(base_anchor_size * s * a[1])
  29. widths.append(base_anchor_size * s * a[0])
  30. return widths, heights
  31. def createGraph(modelPath, outputPath, min_level, aspect_ratios, num_scales,
  32. anchor_scale, num_classes, image_width, image_height):
  33. print('Min level: %d' % min_level)
  34. print('Anchor scale: %f' % anchor_scale)
  35. print('Num scales: %d' % num_scales)
  36. print('Aspect ratios: %s' % str(aspect_ratios))
  37. print('Number of classes: %d' % num_classes)
  38. print('Input image size: %dx%d' % (image_width, image_height))
  39. # Read the graph.
  40. _inpNames = ['image_arrays']
  41. outNames = ['detections']
  42. writeTextGraph(modelPath, outputPath, outNames)
  43. graph_def = parseTextGraph(outputPath)
  44. def getUnconnectedNodes():
  45. unconnected = []
  46. for node in graph_def.node:
  47. if node.op == 'Const':
  48. continue
  49. unconnected.append(node.name)
  50. for inp in node.input:
  51. if inp in unconnected:
  52. unconnected.remove(inp)
  53. return unconnected
  54. nodesToKeep = ['truediv'] # Keep preprocessing nodes
  55. removeIdentity(graph_def)
  56. scopesToKeep = ('image_arrays', 'efficientnet', 'resample_p6', 'resample_p7',
  57. 'fpn_cells', 'class_net', 'box_net', 'Reshape', 'concat')
  58. addConstNode('scale_w', [2.0], graph_def)
  59. addConstNode('scale_h', [2.0], graph_def)
  60. nodesToKeep += ['scale_w', 'scale_h']
  61. for node in graph_def.node:
  62. if re.match('efficientnet-(.*)/blocks_\d+/se/mul_1', node.name):
  63. node.input[0], node.input[1] = node.input[1], node.input[0]
  64. if re.match('fpn_cells/cell_\d+/fnode\d+/resample(.*)/nearest_upsampling/Reshape_1$', node.name):
  65. node.op = 'ResizeNearestNeighbor'
  66. node.input[1] = 'scale_w'
  67. node.input.append('scale_h')
  68. for inpNode in graph_def.node:
  69. if inpNode.name == node.name[:node.name.rfind('_')]:
  70. node.input[0] = inpNode.input[0]
  71. if re.match('box_net/box-predict(_\d)*/separable_conv2d$', node.name):
  72. node.addAttr('loc_pred_transposed', True)
  73. # Replace RealDiv to Mul with inversed scale for compatibility
  74. if node.op == 'RealDiv':
  75. for inpNode in graph_def.node:
  76. if inpNode.name != node.input[1] or not 'value' in inpNode.attr:
  77. continue
  78. tensor = inpNode.attr['value']['tensor'][0]
  79. if not 'float_val' in tensor:
  80. continue
  81. scale = float(inpNode.attr['value']['tensor'][0]['float_val'][0])
  82. addConstNode(inpNode.name + '/inv', [1.0 / scale], graph_def)
  83. nodesToKeep.append(inpNode.name + '/inv')
  84. node.input[1] = inpNode.name + '/inv'
  85. node.op = 'Mul'
  86. break
  87. def to_remove(name, op):
  88. if name in nodesToKeep:
  89. return False
  90. return op == 'Const' or not name.startswith(scopesToKeep)
  91. removeUnusedNodesAndAttrs(to_remove, graph_def)
  92. # Attach unconnected preprocessing
  93. assert(graph_def.node[1].name == 'truediv' and graph_def.node[1].op == 'RealDiv')
  94. graph_def.node[1].input.insert(0, 'image_arrays')
  95. graph_def.node[2].input.insert(0, 'truediv')
  96. priors_generator = AnchorGenerator(min_level, aspect_ratios, num_scales, anchor_scale)
  97. priorBoxes = []
  98. for i in range(5):
  99. inpName = ''
  100. for node in graph_def.node:
  101. if node.name == 'Reshape_%d' % (i * 2 + 1):
  102. inpName = node.input[0]
  103. break
  104. priorBox = NodeDef()
  105. priorBox.name = 'PriorBox_%d' % i
  106. priorBox.op = 'PriorBox'
  107. priorBox.input.append(inpName)
  108. priorBox.input.append(graph_def.node[0].name) # image_tensor
  109. priorBox.addAttr('flip', False)
  110. priorBox.addAttr('clip', False)
  111. widths, heights = priors_generator.get(i)
  112. priorBox.addAttr('width', widths)
  113. priorBox.addAttr('height', heights)
  114. priorBox.addAttr('variance', [1.0, 1.0, 1.0, 1.0])
  115. graph_def.node.extend([priorBox])
  116. priorBoxes.append(priorBox.name)
  117. addConstNode('concat/axis_flatten', [-1], graph_def)
  118. def addConcatNode(name, inputs, axisNodeName):
  119. concat = NodeDef()
  120. concat.name = name
  121. concat.op = 'ConcatV2'
  122. for inp in inputs:
  123. concat.input.append(inp)
  124. concat.input.append(axisNodeName)
  125. graph_def.node.extend([concat])
  126. addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
  127. sigmoid = NodeDef()
  128. sigmoid.name = 'concat/sigmoid'
  129. sigmoid.op = 'Sigmoid'
  130. sigmoid.input.append('concat')
  131. graph_def.node.extend([sigmoid])
  132. addFlatten(sigmoid.name, sigmoid.name + '/Flatten', graph_def)
  133. addFlatten('concat_1', 'concat_1/Flatten', graph_def)
  134. detectionOut = NodeDef()
  135. detectionOut.name = 'detection_out'
  136. detectionOut.op = 'DetectionOutput'
  137. detectionOut.input.append('concat_1/Flatten')
  138. detectionOut.input.append(sigmoid.name + '/Flatten')
  139. detectionOut.input.append('PriorBox/concat')
  140. detectionOut.addAttr('num_classes', num_classes)
  141. detectionOut.addAttr('share_location', True)
  142. detectionOut.addAttr('background_label_id', num_classes + 1)
  143. detectionOut.addAttr('nms_threshold', 0.6)
  144. detectionOut.addAttr('confidence_threshold', 0.2)
  145. detectionOut.addAttr('top_k', 100)
  146. detectionOut.addAttr('keep_top_k', 100)
  147. detectionOut.addAttr('code_type', "CENTER_SIZE")
  148. graph_def.node.extend([detectionOut])
  149. graph_def.node[0].attr['shape'] = {
  150. 'shape': {
  151. 'dim': [
  152. {'size': -1},
  153. {'size': image_height},
  154. {'size': image_width},
  155. {'size': 3}
  156. ]
  157. }
  158. }
  159. while True:
  160. unconnectedNodes = getUnconnectedNodes()
  161. unconnectedNodes.remove(detectionOut.name)
  162. if not unconnectedNodes:
  163. break
  164. for name in unconnectedNodes:
  165. for i in range(len(graph_def.node)):
  166. if graph_def.node[i].name == name:
  167. del graph_def.node[i]
  168. break
  169. # Save as text
  170. graph_def.save(outputPath)
  171. if __name__ == "__main__":
  172. parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
  173. 'SSD model from TensorFlow Object Detection API. '
  174. 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
  175. parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
  176. parser.add_argument('--output', required=True, help='Path to output text graph.')
  177. parser.add_argument('--min_level', default=3, type=int, help='Parameter from training config')
  178. parser.add_argument('--num_scales', default=3, type=int, help='Parameter from training config')
  179. parser.add_argument('--anchor_scale', default=4.0, type=float, help='Parameter from training config')
  180. parser.add_argument('--aspect_ratios', default=[1.0, 1.0, 1.4, 0.7, 0.7, 1.4],
  181. nargs='+', type=float, help='Parameter from training config')
  182. parser.add_argument('--num_classes', default=90, type=int, help='Number of classes to detect')
  183. parser.add_argument('--width', default=512, type=int, help='Network input width')
  184. parser.add_argument('--height', default=512, type=int, help='Network input height')
  185. args = parser.parse_args()
  186. ar = args.aspect_ratios
  187. assert(len(ar) % 2 == 0)
  188. ar = list(zip(ar[::2], ar[1::2]))
  189. createGraph(args.input, args.output, args.min_level, ar, args.num_scales,
  190. args.anchor_scale, args.num_classes, args.width, args.height)