cityscapes_semsegm_test_enet.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import numpy as np
  2. import sys
  3. import os
  4. import fnmatch
  5. import argparse
  6. try:
  7. import cv2 as cv
  8. except ImportError:
  9. raise ImportError('Can\'t find OpenCV Python module. If you\'ve built it from sources without installation, '
  10. 'configure environment variable PYTHONPATH to "opencv_build_dir/lib" directory (with "python3" subdirectory if required)')
  11. try:
  12. import torch
  13. except ImportError:
  14. raise ImportError('Can\'t find pytorch. Please install it by following instructions on the official site')
  15. from torch.utils.serialization import load_lua
  16. from pascal_semsegm_test_fcn import eval_segm_result, get_conf_mat, get_metrics, DatasetImageFetch, SemSegmEvaluation
  17. from imagenet_cls_test_alexnet import Framework, DnnCaffeModel
  18. class NormalizePreproc:
  19. def __init__(self):
  20. pass
  21. @staticmethod
  22. def process(img):
  23. image_data = np.array(img).transpose(2, 0, 1).astype(np.float32)
  24. image_data = np.expand_dims(image_data, 0)
  25. image_data /= 255.0
  26. return image_data
  27. class CityscapesDataFetch(DatasetImageFetch):
  28. img_dir = ''
  29. segm_dir = ''
  30. segm_files = []
  31. colors = []
  32. i = 0
  33. def __init__(self, img_dir, segm_dir, preproc):
  34. self.img_dir = img_dir
  35. self.segm_dir = segm_dir
  36. self.segm_files = sorted([img for img in self.locate('*_color.png', segm_dir)])
  37. self.colors = self.get_colors()
  38. self.data_prepoc = preproc
  39. self.i = 0
  40. @staticmethod
  41. def get_colors():
  42. result = []
  43. colors_list = (
  44. (0, 0, 0), (128, 64, 128), (244, 35, 232), (70, 70, 70), (102, 102, 156), (190, 153, 153), (153, 153, 153),
  45. (250, 170, 30), (220, 220, 0), (107, 142, 35), (152, 251, 152), (70, 130, 180), (220, 20, 60), (255, 0, 0),
  46. (0, 0, 142), (0, 0, 70), (0, 60, 100), (0, 80, 100), (0, 0, 230), (119, 11, 32))
  47. for c in colors_list:
  48. result.append(DatasetImageFetch.pix_to_c(c))
  49. return result
  50. def __iter__(self):
  51. return self
  52. def next(self):
  53. if self.i < len(self.segm_files):
  54. segm_file = self.segm_files[self.i]
  55. segm = cv.imread(segm_file, cv.IMREAD_COLOR)[:, :, ::-1]
  56. segm = cv.resize(segm, (1024, 512), interpolation=cv.INTER_NEAREST)
  57. img_file = self.rreplace(self.img_dir + segm_file[len(self.segm_dir):], 'gtFine_color', 'leftImg8bit')
  58. assert os.path.exists(img_file)
  59. img = cv.imread(img_file, cv.IMREAD_COLOR)[:, :, ::-1]
  60. img = cv.resize(img, (1024, 512))
  61. self.i += 1
  62. gt = self.color_to_gt(segm, self.colors)
  63. img = self.data_prepoc.process(img)
  64. return img, gt
  65. else:
  66. self.i = 0
  67. raise StopIteration
  68. def get_num_classes(self):
  69. return len(self.colors)
  70. @staticmethod
  71. def locate(pattern, root_path):
  72. for path, dirs, files in os.walk(os.path.abspath(root_path)):
  73. for filename in fnmatch.filter(files, pattern):
  74. yield os.path.join(path, filename)
  75. @staticmethod
  76. def rreplace(s, old, new, occurrence=1):
  77. li = s.rsplit(old, occurrence)
  78. return new.join(li)
  79. class TorchModel(Framework):
  80. net = object
  81. def __init__(self, model_file):
  82. self.net = load_lua(model_file)
  83. def get_name(self):
  84. return 'Torch'
  85. def get_output(self, input_blob):
  86. tensor = torch.FloatTensor(input_blob)
  87. out = self.net.forward(tensor).numpy()
  88. return out
  89. class DnnTorchModel(DnnCaffeModel):
  90. net = cv.dnn.Net()
  91. def __init__(self, model_file):
  92. self.net = cv.dnn.readNetFromTorch(model_file)
  93. def get_output(self, input_blob):
  94. self.net.setBlob("", input_blob)
  95. self.net.forward()
  96. return self.net.getBlob(self.net.getLayerNames()[-1])
  97. if __name__ == "__main__":
  98. parser = argparse.ArgumentParser()
  99. parser.add_argument("--imgs_dir", help="path to Cityscapes validation images dir, imgsfine/leftImg8bit/val")
  100. parser.add_argument("--segm_dir", help="path to Cityscapes dir with segmentation, gtfine/gtFine/val")
  101. parser.add_argument("--model", help="path to torch model, download it here: "
  102. "https://www.dropbox.com/sh/dywzk3gyb12hpe5/AAD5YkUa8XgMpHs2gCRgmCVCa")
  103. parser.add_argument("--log", help="path to logging file")
  104. args = parser.parse_args()
  105. prep = NormalizePreproc()
  106. df = CityscapesDataFetch(args.imgs_dir, args.segm_dir, prep)
  107. fw = [TorchModel(args.model),
  108. DnnTorchModel(args.model)]
  109. segm_eval = SemSegmEvaluation(args.log)
  110. segm_eval.process(fw, df)