inception_preprocessing.py 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # -*- utf-8 -*-
  2. import glob
  3. import os.path
  4. import tensorflow as tf
  5. import numpy as np
  6. from tensorflow.python.platform import gfile
  7. INPUT_DATA = "../../dataset/flower_photos"
  8. OUTPUT_FILE = "preprocess/flower_processed_data.npy"
  9. # test and validation ratio
  10. VALIDATION_PERCENTAGE = 10
  11. TEST_PERCENTAGE = 10
  12. def create_image_lists(sess, testing_percentage, validation_percentage):
  13. # '../../dataset/flower_photos', '../../dataset/flower_photos/daisy', '../../dataset/flower_photos/tulips',
  14. # '../../dataset/flower_photos/dandelion', '../../dataset/flower_photos/sunflowers',
  15. # '../../dataset/flower_photos/roses']
  16. subdirs = [x[0] for x in os.walk(INPUT_DATA)]
  17. # print(subdirs)
  18. is_root_dir = True
  19. count = 0
  20. # init datasets
  21. training_images = []
  22. training_labels = []
  23. testing_images = []
  24. testing_labels = []
  25. validation_images = []
  26. validation_labels = []
  27. current_label = 0
  28. # read all subdirs
  29. for sub_dir in subdirs:
  30. if is_root_dir:
  31. is_root_dir = False
  32. continue
  33. extensions = ['jpg', 'jpeg', 'JPG', 'JPEG']
  34. file_list = []
  35. dir_name = os.path.basename(sub_dir)
  36. # print(dir_name)
  37. for extension in extensions:
  38. # find all images in sub_dir
  39. file_glob = os.path.join(INPUT_DATA, dir_name, '*.' + extension)
  40. file_list.extend(glob.glob(file_glob))
  41. if not file_list:
  42. continue
  43. # deal with images
  44. for file_name in file_list:
  45. print(str(current_label) + file_name + "\t\t" + str(count))
  46. count += 1
  47. image_raw_data = gfile.FastGFile(file_name, 'rb').read()
  48. image = tf.image.decode_jpeg(image_raw_data)
  49. if image.dtype != tf.float32:
  50. image = tf.image.convert_image_dtype(image, dtype=tf.float32)
  51. image = tf.image.resize_images(image, [299, 299])
  52. image_value = sess.run(image)
  53. # split dataset randomly
  54. chance = np.random.randint(100)
  55. if chance < validation_percentage:
  56. validation_images.append(image_value)
  57. validation_labels.append(current_label)
  58. elif chance < (validation_percentage + testing_percentage):
  59. testing_images.append(image_value)
  60. testing_labels.append(current_label)
  61. else:
  62. training_images.append(image_value)
  63. training_labels.append(current_label)
  64. current_label += 1
  65. state = np.random.get_state()
  66. np.random.shuffle(training_images)
  67. np.random.set_state(state)
  68. np.random.shuffle(training_labels)
  69. return np.asarray(
  70. [training_images, training_labels, validation_images, validation_labels, testing_images, testing_labels])
  71. def main():
  72. with tf.Session() as sess:
  73. processed_data = create_image_lists(sess, TEST_PERCENTAGE, VALIDATION_PERCENTAGE)
  74. np.save(OUTPUT_FILE, processed_data)
  75. if __name__ == '__main__':
  76. main()
  77. # a = [1, 2, 3, 4, 5, 6, 7, 8, 9]
  78. # b = [9, 8, 7, 6, 5, 4, 3, 2, 1]
  79. # state = np.random.get_state()
  80. # np.random.shuffle(a)
  81. # np.random.set_state(state)
  82. # np.random.shuffle(b)
  83. # print(a)
  84. # print(b)