mnist_eval.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # -*- coding: utf8 -*-
  2. import time
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import matplotlib.pyplot as plt
  6. import numpy as np
  7. import mnist_inference
  8. import mnist_train
  9. from numpy.random import RandomState
  10. import os
  11. # generate new random dataset for test in 3 secs after close figure window manually
  12. EVAL_INTERVAL_SECS = 3
  13. NUMBER_OF_SAMPLES = 36
  14. FIG_ROWS = 3
  15. # display images and recognition result rather than accuracy diagram
  16. def evaluation(mnist):
  17. with tf.Graph().as_default() as g:
  18. x = tf.placeholder(tf.float32, [NUMBER_OF_SAMPLES,
  19. mnist_inference.IMAGE_SIZE,
  20. mnist_inference.IMAGE_SIZE,
  21. mnist_inference.NUM_CHANNELS], name='x-input')
  22. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='input-y')
  23. # replace accuracy with actual recognition result
  24. y = mnist_inference.inference(x, False, None)
  25. indices = tf.argmax(y, 1)
  26. correct_indices = tf.argmax(y_, 1)
  27. # correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  28. # accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  29. variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
  30. variables_to_restore = variable_averages.variables_to_restore()
  31. saver = tf.train.Saver(variables_to_restore)
  32. while True:
  33. # configure TF to allocate mem properly, rather than consume all GPU mem
  34. config = tf.ConfigProto(allow_soft_placement=True)
  35. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
  36. config.gpu_options.allow_growth = True
  37. with tf.Session(config=config) as sess:
  38. ckpt = tf.train.get_checkpoint_state(
  39. mnist_train.MODEL_SAVE_PATH
  40. )
  41. if ckpt and ckpt.model_checkpoint_path:
  42. saver.restore(sess, ckpt.model_checkpoint_path)
  43. rdm = RandomState(int(time.time()))
  44. sample_index = rdm.randint(0, mnist.validation.num_examples - NUMBER_OF_SAMPLES)
  45. xs = mnist.validation.images[sample_index:sample_index + NUMBER_OF_SAMPLES]
  46. validation_feed = {
  47. x: np.reshape(xs, (NUMBER_OF_SAMPLES,
  48. mnist_inference.IMAGE_SIZE,
  49. mnist_inference.IMAGE_SIZE,
  50. mnist_inference.NUM_CHANNELS)),
  51. y_: mnist.validation.labels[sample_index:sample_index + NUMBER_OF_SAMPLES]}
  52. # txs = mnist.test.images
  53. # test_feed = {
  54. # x: np.reshape(txs, (mnist.test.num_examples,
  55. # mnist_inference.IMAGE_SIZE,
  56. # mnist_inference.IMAGE_SIZE,
  57. # mnist_inference.NUM_CHANNELS)),
  58. # y_: mnist.test.labels}
  59. # # define accuracy score, generate image
  60. # accuracy_score = sess.run(accuracy, feed_dict=test_feed)
  61. # get global step from file name
  62. global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
  63. # print("after %s training step(s), validation accuracy = %g" % (global_step, accuracy_score))
  64. indices_score, correct_indices_score = sess.run(
  65. [indices, correct_indices], feed_dict=validation_feed)
  66. print("after %s training step(s), validation result = \n%s\n, correct answer: \n%s" \
  67. % (global_step, indices_score, correct_indices_score))
  68. fig = plt.figure(1)
  69. fig.set_size_inches(15, 6)
  70. for n in range(1, NUMBER_OF_SAMPLES + 1):
  71. fig.add_subplot(FIG_ROWS, (NUMBER_OF_SAMPLES / FIG_ROWS + 1), n)
  72. plt.title("predict: [%s]\nanswer: [%s]"
  73. % (indices_score[n - 1], correct_indices_score[n - 1]))
  74. plt.imshow(mnist.validation.images[sample_index + n - 1].reshape(28, 28))
  75. # fig.add_subplot(2, 3, 1)
  76. # plt.imshow(mnist.validation.images[sample_index].reshape(28, 28))
  77. # fig.add_subplot(2, 3, 2)
  78. # plt.imshow(mnist.validation.images[sample_index + 1].reshape(28, 28))
  79. # fig.add_subplot(2, 3, 3)
  80. # plt.imshow(mnist.validation.images[sample_index + 2].reshape(28, 28))
  81. # fig.add_subplot(2, 3, 4)
  82. # plt.imshow(mnist.validation.images[sample_index + 3].reshape(28, 28))
  83. # fig.add_subplot(2, 3, 5)
  84. # plt.imshow(mnist.validation.images[sample_index + 4].reshape(28, 28))
  85. # fig.add_subplot(2, 3, 6)
  86. # plt.imshow(mnist.validation.images[sample_index + 5].reshape(28, 28))
  87. plt.subplots_adjust(
  88. top=0.95, bottom=0.05, left=0.05, right=0.95, hspace=0.35, wspace=0.6)
  89. try:
  90. os.mkdir('images/')
  91. except:
  92. print("directory already exist")
  93. plt.savefig('images/mnist_result_evaluation.jpg', format='jpg')
  94. plt.show()
  95. else:
  96. print("no checkpoint file found")
  97. return
  98. time.sleep(EVAL_INTERVAL_SECS)
  99. def main(argv=None):
  100. mnist = input_data.read_data_sets('../MNIST_data', one_hot=True)
  101. print("basic information of mnist dataset")
  102. print("mnist training data size: ", mnist.train.num_examples)
  103. print("mnist validating data size: ", mnist.validation.num_examples)
  104. print("mnist testing data size: ", mnist.test.num_examples)
  105. # print("mnist example training data: ", mnist.train.images[0])
  106. # print("mnist example training data label", mnist.train.labels[0])
  107. evaluation(mnist)
  108. if __name__ == '__main__':
  109. tf.app.run()