mnist_eval.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # -*- coding: utf8 -*-
  2. import time
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import matplotlib.pyplot as plt
  6. import mnist_inference
  7. import mnist_train
  8. from numpy.random import RandomState
  9. import os
  10. # generate new random dataset for test in 3 secs after close figure window manually
  11. EVAL_INTERVAL_SECS = 3
  12. NUMBER_OF_SAMPLES = 36
  13. FIG_ROWS = 3
  14. # display images and recognition result rather than accuracy diagram
  15. def evaluation(mnist):
  16. with tf.Graph().as_default() as g:
  17. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='input-x')
  18. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='input-y')
  19. # move sample picking into each cycle
  20. # rdm = RandomState(int(time.time()))
  21. # sample_index = rdm.randint(0, mnist.validation.num_examples)
  22. # validation_feed = {
  23. # x: mnist.validation.images[sample_index:sample_index + 6],
  24. # y_: mnist.validation.labels[sample_index:sample_index + 6]}
  25. # replace accuracy with actual recognition result
  26. y = mnist_inference.inference(x, None)
  27. indices = tf.argmax(y, 1)
  28. correct_indices = tf.argmax(y_, 1)
  29. # correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  30. # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  31. variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
  32. variables_to_restore = variable_averages.variables_to_restore()
  33. saver = tf.train.Saver(variables_to_restore)
  34. while True:
  35. # configure TF to allocate mem properly, rather than consume all GPU mem
  36. config = tf.ConfigProto(allow_soft_placement=True)
  37. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
  38. config.gpu_options.allow_growth = True
  39. with tf.Session(config=config) as sess:
  40. ckpt = tf.train.get_checkpoint_state(
  41. mnist_train.MODEL_SAVE_PATH
  42. )
  43. if ckpt and ckpt.model_checkpoint_path:
  44. saver.restore(sess, ckpt.model_checkpoint_path)
  45. rdm = RandomState(int(time.time()))
  46. sample_index = rdm.randint(0, mnist.validation.num_examples - NUMBER_OF_SAMPLES)
  47. validation_feed = {
  48. x: mnist.validation.images[sample_index:sample_index + NUMBER_OF_SAMPLES],
  49. y_: mnist.validation.labels[sample_index:sample_index + NUMBER_OF_SAMPLES]}
  50. # get global step from file name
  51. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  52. indices_score, correct_indices_score = sess.run(
  53. [indices, correct_indices], feed_dict=validation_feed)
  54. # accuracy_score = sess.run(accuracy, feed_dict=validation_feed)
  55. # print "after %s training step(s), validation accuracy = %g" % (global_step, accuracy_score)
  56. print("after %s training step(s), validation result = \n%s\n, correct answer: \n%s" \
  57. % (global_step, indices_score, correct_indices_score))
  58. fig = plt.figure(1)
  59. fig.set_size_inches(15,6)
  60. for n in range(1, NUMBER_OF_SAMPLES + 1):
  61. fig.add_subplot(FIG_ROWS, (NUMBER_OF_SAMPLES / FIG_ROWS + 1), n)
  62. plt.title("predict: [%s]\nanswer: [%s]"
  63. % (indices_score[n - 1], correct_indices_score[n - 1]))
  64. plt.imshow(mnist.validation.images[sample_index + n - 1].reshape(28, 28))
  65. # fig.add_subplot(2, 3, 1)
  66. # plt.imshow(mnist.validation.images[sample_index].reshape(28, 28))
  67. # fig.add_subplot(2, 3, 2)
  68. # plt.imshow(mnist.validation.images[sample_index + 1].reshape(28, 28))
  69. # fig.add_subplot(2, 3, 3)
  70. # plt.imshow(mnist.validation.images[sample_index + 2].reshape(28, 28))
  71. # fig.add_subplot(2, 3, 4)
  72. # plt.imshow(mnist.validation.images[sample_index + 3].reshape(28, 28))
  73. # fig.add_subplot(2, 3, 5)
  74. # plt.imshow(mnist.validation.images[sample_index + 4].reshape(28, 28))
  75. # fig.add_subplot(2, 3, 6)
  76. # plt.imshow(mnist.validation.images[sample_index + 5].reshape(28, 28))
  77. plt.subplots_adjust(
  78. top=0.95, bottom=0.05, left=0.05, right=0.95, hspace=0.5, wspace=0.55)
  79. try:
  80. os.mkdir('images/')
  81. except:
  82. print("directory already exist")
  83. plt.savefig('images/mnist_result_evaluation.jpg', format='jpg')
  84. plt.show()
  85. else:
  86. print("no checkpoint file found")
  87. return
  88. time.sleep(EVAL_INTERVAL_SECS)
  89. def main(argv=None):
  90. mnist = input_data.read_data_sets('../MNIST_data', one_hot=True)
  91. evaluation(mnist)
  92. if __name__ == '__main__':
  93. tf.app.run()