tf_text_graph_faster_rcnn.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. import argparse
  2. import numpy as np
  3. from tf_text_graph_common import *
  4. def createFasterRCNNGraph(modelPath, configPath, outputPath):
  5. scopesToKeep = ('FirstStageFeatureExtractor', 'Conv',
  6. 'FirstStageBoxPredictor/BoxEncodingPredictor',
  7. 'FirstStageBoxPredictor/ClassPredictor',
  8. 'CropAndResize',
  9. 'MaxPool2D',
  10. 'SecondStageFeatureExtractor',
  11. 'SecondStageBoxPredictor',
  12. 'Preprocessor/sub',
  13. 'Preprocessor/mul',
  14. 'image_tensor')
  15. scopesToIgnore = ('FirstStageFeatureExtractor/Assert',
  16. 'FirstStageFeatureExtractor/Shape',
  17. 'FirstStageFeatureExtractor/strided_slice',
  18. 'FirstStageFeatureExtractor/GreaterEqual',
  19. 'FirstStageFeatureExtractor/LogicalAnd')
  20. # Load a config file.
  21. config = readTextMessage(configPath)
  22. config = config['model'][0]['faster_rcnn'][0]
  23. num_classes = int(config['num_classes'][0])
  24. grid_anchor_generator = config['first_stage_anchor_generator'][0]['grid_anchor_generator'][0]
  25. scales = [float(s) for s in grid_anchor_generator['scales']]
  26. aspect_ratios = [float(ar) for ar in grid_anchor_generator['aspect_ratios']]
  27. width_stride = float(grid_anchor_generator['width_stride'][0])
  28. height_stride = float(grid_anchor_generator['height_stride'][0])
  29. feature_extractor = config['feature_extractor'][0]
  30. if 'type' in feature_extractor and feature_extractor['type'][0] == 'faster_rcnn_nas':
  31. features_stride = 16.0
  32. else:
  33. features_stride = float(feature_extractor['first_stage_features_stride'][0])
  34. first_stage_nms_iou_threshold = float(config['first_stage_nms_iou_threshold'][0])
  35. first_stage_max_proposals = int(config['first_stage_max_proposals'][0])
  36. print('Number of classes: %d' % num_classes)
  37. print('Scales: %s' % str(scales))
  38. print('Aspect ratios: %s' % str(aspect_ratios))
  39. print('Width stride: %f' % width_stride)
  40. print('Height stride: %f' % height_stride)
  41. print('Features stride: %f' % features_stride)
  42. # Read the graph.
  43. writeTextGraph(modelPath, outputPath, ['num_detections', 'detection_scores', 'detection_boxes', 'detection_classes'])
  44. graph_def = parseTextGraph(outputPath)
  45. removeIdentity(graph_def)
  46. nodesToKeep = []
  47. def to_remove(name, op):
  48. if name in nodesToKeep:
  49. return False
  50. return op == 'Const' or name.startswith(scopesToIgnore) or not name.startswith(scopesToKeep) or \
  51. (name.startswith('CropAndResize') and op != 'CropAndResize')
  52. # Fuse atrous convolutions (with dilations).
  53. nodesMap = {node.name: node for node in graph_def.node}
  54. for node in reversed(graph_def.node):
  55. if node.op == 'BatchToSpaceND':
  56. del node.input[2]
  57. conv = nodesMap[node.input[0]]
  58. spaceToBatchND = nodesMap[conv.input[0]]
  59. # Extract paddings
  60. stridedSlice = nodesMap[spaceToBatchND.input[2]]
  61. assert(stridedSlice.op == 'StridedSlice')
  62. pack = nodesMap[stridedSlice.input[0]]
  63. assert(pack.op == 'Pack')
  64. padNodeH = nodesMap[nodesMap[pack.input[0]].input[0]]
  65. padNodeW = nodesMap[nodesMap[pack.input[1]].input[0]]
  66. padH = int(padNodeH.attr['value']['tensor'][0]['int_val'][0])
  67. padW = int(padNodeW.attr['value']['tensor'][0]['int_val'][0])
  68. paddingsNode = NodeDef()
  69. paddingsNode.name = conv.name + '/paddings'
  70. paddingsNode.op = 'Const'
  71. paddingsNode.addAttr('value', [padH, padH, padW, padW])
  72. graph_def.node.insert(graph_def.node.index(spaceToBatchND), paddingsNode)
  73. nodesToKeep.append(paddingsNode.name)
  74. spaceToBatchND.input[2] = paddingsNode.name
  75. removeUnusedNodesAndAttrs(to_remove, graph_def)
  76. # Connect input node to the first layer
  77. assert(graph_def.node[0].op == 'Placeholder')
  78. graph_def.node[1].input.insert(0, graph_def.node[0].name)
  79. # Temporarily remove top nodes.
  80. topNodes = []
  81. while True:
  82. node = graph_def.node.pop()
  83. topNodes.append(node)
  84. if node.op == 'CropAndResize':
  85. break
  86. addReshape('FirstStageBoxPredictor/ClassPredictor/BiasAdd',
  87. 'FirstStageBoxPredictor/ClassPredictor/reshape_1', [0, -1, 2], graph_def)
  88. addSoftMax('FirstStageBoxPredictor/ClassPredictor/reshape_1',
  89. 'FirstStageBoxPredictor/ClassPredictor/softmax', graph_def) # Compare with Reshape_4
  90. addFlatten('FirstStageBoxPredictor/ClassPredictor/softmax',
  91. 'FirstStageBoxPredictor/ClassPredictor/softmax/flatten', graph_def)
  92. # Compare with FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd
  93. addFlatten('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd',
  94. 'FirstStageBoxPredictor/BoxEncodingPredictor/flatten', graph_def)
  95. proposals = NodeDef()
  96. proposals.name = 'proposals' # Compare with ClipToWindow/Gather/Gather (NOTE: normalized)
  97. proposals.op = 'PriorBox'
  98. proposals.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/BiasAdd')
  99. proposals.input.append(graph_def.node[0].name) # image_tensor
  100. proposals.addAttr('flip', False)
  101. proposals.addAttr('clip', True)
  102. proposals.addAttr('step', features_stride)
  103. proposals.addAttr('offset', 0.0)
  104. proposals.addAttr('variance', [0.1, 0.1, 0.2, 0.2])
  105. widths = []
  106. heights = []
  107. for a in aspect_ratios:
  108. for s in scales:
  109. ar = np.sqrt(a)
  110. heights.append((height_stride**2) * s / ar)
  111. widths.append((width_stride**2) * s * ar)
  112. proposals.addAttr('width', widths)
  113. proposals.addAttr('height', heights)
  114. graph_def.node.extend([proposals])
  115. # Compare with Reshape_5
  116. detectionOut = NodeDef()
  117. detectionOut.name = 'detection_out'
  118. detectionOut.op = 'DetectionOutput'
  119. detectionOut.input.append('FirstStageBoxPredictor/BoxEncodingPredictor/flatten')
  120. detectionOut.input.append('FirstStageBoxPredictor/ClassPredictor/softmax/flatten')
  121. detectionOut.input.append('proposals')
  122. detectionOut.addAttr('num_classes', 2)
  123. detectionOut.addAttr('share_location', True)
  124. detectionOut.addAttr('background_label_id', 0)
  125. detectionOut.addAttr('nms_threshold', first_stage_nms_iou_threshold)
  126. detectionOut.addAttr('top_k', 6000)
  127. detectionOut.addAttr('code_type', "CENTER_SIZE")
  128. detectionOut.addAttr('keep_top_k', first_stage_max_proposals)
  129. detectionOut.addAttr('clip', False)
  130. graph_def.node.extend([detectionOut])
  131. addConstNode('clip_by_value/lower', [0.0], graph_def)
  132. addConstNode('clip_by_value/upper', [1.0], graph_def)
  133. clipByValueNode = NodeDef()
  134. clipByValueNode.name = 'detection_out/clip_by_value'
  135. clipByValueNode.op = 'ClipByValue'
  136. clipByValueNode.input.append('detection_out')
  137. clipByValueNode.input.append('clip_by_value/lower')
  138. clipByValueNode.input.append('clip_by_value/upper')
  139. graph_def.node.extend([clipByValueNode])
  140. # Save as text.
  141. for node in reversed(topNodes):
  142. graph_def.node.extend([node])
  143. addSoftMax('SecondStageBoxPredictor/Reshape_1', 'SecondStageBoxPredictor/Reshape_1/softmax', graph_def)
  144. addSlice('SecondStageBoxPredictor/Reshape_1/softmax',
  145. 'SecondStageBoxPredictor/Reshape_1/slice',
  146. [0, 0, 1], [-1, -1, -1], graph_def)
  147. addReshape('SecondStageBoxPredictor/Reshape_1/slice',
  148. 'SecondStageBoxPredictor/Reshape_1/Reshape', [1, -1], graph_def)
  149. # Replace Flatten subgraph onto a single node.
  150. cropAndResizeNodeName = ''
  151. for i in reversed(range(len(graph_def.node))):
  152. if graph_def.node[i].op == 'CropAndResize':
  153. graph_def.node[i].input.insert(1, 'detection_out/clip_by_value')
  154. cropAndResizeNodeName = graph_def.node[i].name
  155. if graph_def.node[i].name == 'SecondStageBoxPredictor/Reshape':
  156. addConstNode('SecondStageBoxPredictor/Reshape/shape2', [1, -1, 4], graph_def)
  157. graph_def.node[i].input.pop()
  158. graph_def.node[i].input.append('SecondStageBoxPredictor/Reshape/shape2')
  159. if graph_def.node[i].name in ['SecondStageBoxPredictor/Flatten/flatten/Shape',
  160. 'SecondStageBoxPredictor/Flatten/flatten/strided_slice',
  161. 'SecondStageBoxPredictor/Flatten/flatten/Reshape/shape',
  162. 'SecondStageBoxPredictor/Flatten_1/flatten/Shape',
  163. 'SecondStageBoxPredictor/Flatten_1/flatten/strided_slice',
  164. 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape/shape']:
  165. del graph_def.node[i]
  166. for node in graph_def.node:
  167. if node.name == 'SecondStageBoxPredictor/Flatten/flatten/Reshape' or \
  168. node.name == 'SecondStageBoxPredictor/Flatten_1/flatten/Reshape':
  169. node.op = 'Flatten'
  170. node.input.pop()
  171. if node.name in ['FirstStageBoxPredictor/BoxEncodingPredictor/Conv2D',
  172. 'SecondStageBoxPredictor/BoxEncodingPredictor/MatMul']:
  173. node.addAttr('loc_pred_transposed', True)
  174. if node.name.startswith('MaxPool2D'):
  175. assert(node.op == 'MaxPool')
  176. assert(cropAndResizeNodeName)
  177. node.input = [cropAndResizeNodeName]
  178. ################################################################################
  179. ### Postprocessing
  180. ################################################################################
  181. addSlice('detection_out/clip_by_value', 'detection_out/slice', [0, 0, 0, 3], [-1, -1, -1, 4], graph_def)
  182. variance = NodeDef()
  183. variance.name = 'proposals/variance'
  184. variance.op = 'Const'
  185. variance.addAttr('value', [0.1, 0.1, 0.2, 0.2])
  186. graph_def.node.extend([variance])
  187. varianceEncoder = NodeDef()
  188. varianceEncoder.name = 'variance_encoded'
  189. varianceEncoder.op = 'Mul'
  190. varianceEncoder.input.append('SecondStageBoxPredictor/Reshape')
  191. varianceEncoder.input.append(variance.name)
  192. varianceEncoder.addAttr('axis', 2)
  193. graph_def.node.extend([varianceEncoder])
  194. addReshape('detection_out/slice', 'detection_out/slice/reshape', [1, 1, -1], graph_def)
  195. addFlatten('variance_encoded', 'variance_encoded/flatten', graph_def)
  196. detectionOut = NodeDef()
  197. detectionOut.name = 'detection_out_final'
  198. detectionOut.op = 'DetectionOutput'
  199. detectionOut.input.append('variance_encoded/flatten')
  200. detectionOut.input.append('SecondStageBoxPredictor/Reshape_1/Reshape')
  201. detectionOut.input.append('detection_out/slice/reshape')
  202. detectionOut.addAttr('num_classes', num_classes)
  203. detectionOut.addAttr('share_location', False)
  204. detectionOut.addAttr('background_label_id', num_classes + 1)
  205. detectionOut.addAttr('nms_threshold', 0.6)
  206. detectionOut.addAttr('code_type', "CENTER_SIZE")
  207. detectionOut.addAttr('keep_top_k', 100)
  208. detectionOut.addAttr('clip', True)
  209. detectionOut.addAttr('variance_encoded_in_target', True)
  210. graph_def.node.extend([detectionOut])
  211. def getUnconnectedNodes():
  212. unconnected = [node.name for node in graph_def.node]
  213. for node in graph_def.node:
  214. for inp in node.input:
  215. if inp in unconnected:
  216. unconnected.remove(inp)
  217. return unconnected
  218. while True:
  219. unconnectedNodes = getUnconnectedNodes()
  220. unconnectedNodes.remove(detectionOut.name)
  221. if not unconnectedNodes:
  222. break
  223. for name in unconnectedNodes:
  224. for i in range(len(graph_def.node)):
  225. if graph_def.node[i].name == name:
  226. del graph_def.node[i]
  227. break
  228. # Save as text.
  229. graph_def.save(outputPath)
  230. if __name__ == "__main__":
  231. parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
  232. 'Faster-RCNN model from TensorFlow Object Detection API. '
  233. 'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
  234. parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
  235. parser.add_argument('--output', required=True, help='Path to output text graph.')
  236. parser.add_argument('--config', required=True, help='Path to a *.config file is used for training.')
  237. args = parser.parse_args()
  238. createFasterRCNNGraph(args.input, args.config, args.output)