inception_transfer.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. # -*- utf-8 -*-
  2. import glob
  3. import os.path
  4. import time
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import tensorflow as tf
  8. from tensorflow.python.platform import gfile
  9. import tensorflow.contrib.slim as slim
  10. import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
  11. INPUT_DATA = 'preprocess/flower_processed_data.npy'
  12. TRAIN_FILE = 'model/'
  13. CKPT_FILE = '../../dataset/inception_v3.ckpt'
  14. # params
  15. LEARNING_RATE = 0.0001
  16. STEPS = 1000
  17. BATCH = 32
  18. N_CLASSES = 5
  19. # lasers don't load from ckpt, i.e. the last fc layer
  20. CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
  21. TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
  22. TRAINING = False
  23. flower_label = ["daisy雏菊", "roses玫瑰", "tulips郁金香", "sunflowers向日葵", "dandelion蒲公英"]
  24. def get_tuned_variables():
  25. exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
  26. variables_to_restore = []
  27. # enumerate params in v3 model, check if it need to be loaded
  28. for var in slim.get_model_variables():
  29. excluded = False
  30. for exclusion in exclusions:
  31. if var.op.name.startswith(exclusion):
  32. excluded = True
  33. break
  34. if not excluded:
  35. variables_to_restore.append(var)
  36. return variables_to_restore
  37. def get_trainable_variables():
  38. scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
  39. variables_to_train = []
  40. for scope in scopes:
  41. variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
  42. variables_to_train.extend(variables)
  43. return variables_to_train
  44. def main():
  45. # processed_data = np.load("preprocess/test_flower.npy", allow_pickle=True)
  46. # test_images = processed_data[0]
  47. # test_labels = processed_data[1]
  48. # load preprocessed data
  49. processed_data = np.load(INPUT_DATA, allow_pickle=True)
  50. training_images = processed_data[0]
  51. n_training_example = len(training_images)
  52. training_labels = processed_data[1]
  53. # np.save("preprocess/training_flower.npy", np.asarray([training_images, training_labels]))
  54. validation_images = processed_data[2]
  55. validation_labels = processed_data[3]
  56. # np.save("preprocess/validation_flower.npy", np.asarray([validation_images, validation_labels]))
  57. test_images = processed_data[4]
  58. test_labels = processed_data[5]
  59. # np.save("preprocess/test_flower.npy", np.asarray([test_images, test_labels]))
  60. print("%d training examples, %d validation examples and %d testing examples." % (
  61. n_training_example, len(validation_labels), len(test_labels)))
  62. # define inputs
  63. images = tf.placeholder(
  64. tf.float32, [None, 299, 299, 3], name='input_images')
  65. labels = tf.placeholder(tf.int64, [None], name='labels')
  66. # define model
  67. with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
  68. logits, _ = inception_v3.inception_v3(images, num_classes=N_CLASSES, is_training=False)
  69. # get trainable variable
  70. trainable_variables = get_trainable_variables()
  71. # define cross entropy
  72. tf.losses.softmax_cross_entropy(tf.one_hot(labels, N_CLASSES), logits, weights=1.0)
  73. train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())
  74. # calc accuracy
  75. with tf.name_scope('evaluation'):
  76. prediction = tf.argmax(logits, 1)
  77. correct_answer = labels
  78. correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
  79. evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  80. # define func to load model
  81. load_fn = slim.assign_from_checkpoint_fn(
  82. CKPT_FILE,
  83. get_tuned_variables(),
  84. ignore_missing_vars=True
  85. )
  86. # define saver
  87. saver = tf.train.Saver()
  88. config = tf.ConfigProto(allow_soft_placement=True)
  89. # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
  90. config.gpu_options.allow_growth = True
  91. with tf.Session(config=config) as sess:
  92. # init
  93. init = tf.global_variables_initializer()
  94. sess.run(init)
  95. ckpt = tf.train.get_checkpoint_state(
  96. TRAIN_FILE
  97. )
  98. if ckpt and ckpt.model_checkpoint_path:
  99. saver.restore(sess, ckpt.model_checkpoint_path)
  100. else:
  101. # load origin model
  102. print('loading tuned variables from %s' % CKPT_FILE)
  103. load_fn(sess)
  104. start = 0
  105. end = BATCH
  106. if TRAINING:
  107. for i in range(STEPS):
  108. sess.run(train_step, feed_dict={
  109. images: training_images[start:end],
  110. labels: training_labels[start:end]
  111. })
  112. if i % 20 == 0 or i + 1 == STEPS:
  113. saver.save(sess, TRAIN_FILE, global_step=i)
  114. validation_accuracy = sess.run(evaluation_step, feed_dict={
  115. images: validation_images,
  116. labels: validation_labels
  117. })
  118. print('step %d: validation accuracy = %.1f%%' % (i, validation_accuracy * 100.0))
  119. start = end
  120. if start == n_training_example:
  121. start = 0
  122. end = start + BATCH
  123. if end > n_training_example:
  124. end = n_training_example
  125. # test accuracy
  126. test_acccuracy = sess.run(evaluation_step, feed_dict={
  127. images: test_images,
  128. labels: test_labels
  129. })
  130. print('final test accuracy = %.1f%%' % (test_acccuracy * 100.0))
  131. else:
  132. while True:
  133. index = np.random.randint(0, len(test_labels) - 2)
  134. # test accuracy
  135. prediction_score, correct_answer_score = sess.run([prediction, correct_answer], feed_dict={
  136. images: test_images[index:index+1],
  137. labels: test_labels[index:index+1]
  138. })
  139. result = [(flower_label[x]+str(x)) for x in prediction_score]
  140. answer = [(flower_label[x]+str(x)) for x in correct_answer_score]
  141. # print(result)
  142. # print(answer)
  143. plt.imshow(test_images[index])
  144. print('test result: %s, correct answer: %s' % (
  145. result, answer))
  146. plt.show()
  147. time.sleep(3)
  148. if __name__ == '__main__':
  149. main()