mnist_dataset.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import matplotlib.pyplot as plt
  4. import os
  5. import numpy as np
  6. # mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  7. # print "basic information of mnist dataset"
  8. # print "mnist training data size: ", mnist.train.num_examples
  9. # print "mnist validating data size: ", mnist.validation.num_examples
  10. # print "mnist testing data size: ", mnist.test.num_examples
  11. # print "mnist example training data: ", mnist.train.images[0]
  12. # print "mnist example training data label", mnist.train.labels[0]
  13. # define input and output data size
  14. INPUT_NODE = 784
  15. OUTPUT_NODE = 10
  16. # params for neural network
  17. LAYER1_NODE = 500
  18. BATCH_SIZE = 1000
  19. LEARNING_RATE_BASE = 0.8
  20. LEARNING_RATE_DECAY = 0.999
  21. REGULARIZATION_RATE = 0.0001
  22. TRAINING_STEPS = 100000
  23. MOVING_AVERAGE_DECAY = 0.99
  24. # calc the result of forward propagation,
  25. # ***original method***
  26. def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
  27. # use current value when there's no moving average model
  28. if avg_class is None:
  29. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
  30. return tf.matmul(layer1, weights2) + biases2
  31. else:
  32. layer1 = tf.nn.relu(
  33. tf.matmul(input_tensor, avg_class.average(weights1)) + avg_class.average(biases1))
  34. return tf.matmul(layer1, avg_class.average(weights2)) + avg_class.average(biases2)
  35. # training process
  36. def train(mnist):
  37. x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x-input")
  38. y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name="y-input")
  39. # generate hidden layer params
  40. weight1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE], stddev=0.1))
  41. biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
  42. # generate output layer params
  43. weight2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE], stddev=0.1))
  44. biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
  45. # forward propagation
  46. y = inference(x, None, weight1, biases1, weight2, biases2)
  47. # used to store training cycles
  48. global_step = tf.Variable(0, trainable=False)
  49. # define EMA function to increase robustness when predict
  50. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  51. variable_averages_op = variable_averages.apply(tf.trainable_variables())
  52. # # forward propagation with moving average function
  53. # average_y = inference(x, variable_averages, weight1, biases1, weight2, biases2)
  54. average_y = inference(x, variable_averages, weight1, biases1, weight2, biases2)
  55. # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
  56. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
  57. # calc cross_entropy mean for current batch
  58. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  59. # calc L2 regularization loss function
  60. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  61. regularization = regularizer(weight1) + regularizer(weight2)
  62. loss = cross_entropy_mean + regularization
  63. # learning rate = learning rate * LEARNING_RATE_DECAY ^ (global_step / decay_step)
  64. learning_rate = tf.train.exponential_decay(
  65. LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
  66. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  67. # combine backward propagation and EMA value modification
  68. with tf.control_dependencies([train_step, variable_averages_op]):
  69. train_op = tf.no_op(name="train")
  70. correct_prediction = tf.equal(tf.arg_max(average_y, 1), tf.arg_max(y_, 1))
  71. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  72. with tf.Session() as sess:
  73. sess.run(tf.global_variables_initializer())
  74. # prepare validation dataset to stop optimization
  75. validation_feed = {x: mnist.validation.images,
  76. y_: mnist.validation.labels}
  77. # define test dataset for final evaluation
  78. test_feed = {x: mnist.test.images,
  79. y_: mnist.test.labels}
  80. validation_result = range(TRAINING_STEPS / 1000)
  81. test_result = range(TRAINING_STEPS / 1000)
  82. for i in range(TRAINING_STEPS):
  83. if i % 1000 == 0:
  84. # print "average_y: ", average_y, sess.run(average_y, feed_dict=validation_feed)
  85. # print "y_: ", y_, sess.run(y_, feed_dict=validation_feed)
  86. validate_acc = sess.run(accuracy, feed_dict=validation_feed)
  87. validation_result[i / 1000] = validate_acc
  88. # print "after %d training step(s), validation accuracy using average model is %g " % (i, validate_acc)
  89. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  90. sess.run(train_op, feed_dict={x: xs, y_: ys})
  91. test_acc = sess.run(accuracy, feed_dict=test_feed)
  92. test_result[i / 1000] = test_acc
  93. # print "after %d training step(s), test accuracy using average model is %g " % (i, test_acc)
  94. print validation_result
  95. print test_result
  96. # draw a graph of accuracy using matplotlib
  97. iteration_count = range(0, TRAINING_STEPS, 1000)
  98. plt.figure(num=1, figsize=(15, 8))
  99. plt.title("Plot accuracy", size=20)
  100. plt.xlabel("iteration count", size=14)
  101. plt.ylabel("accuracy/%", size=14)
  102. validation_note = [TRAINING_STEPS - 1000, validation_result[TRAINING_STEPS / 1000 - 1]]
  103. test_note = [TRAINING_STEPS - 1000, test_result[TRAINING_STEPS / 1000 - 1]]
  104. plt.annotate('validate-' + str(validation_note), xy=(test_note[0], test_note[1]),
  105. xytext=(test_note[0] - 1000, test_note[1] - 0.1), arrowprops=dict(facecolor='black', shrink=0.05))
  106. plt.annotate('test-' + str(test_note), xy=(test_note[0], test_note[1]),
  107. xytext=(test_note[0] + 1000, test_note[1] - 0.07), arrowprops=dict(facecolor='black', shrink=0.05))
  108. plt.grid(True)
  109. plt.plot(iteration_count, validation_result, color='b', linestyle='-', marker='o', label='validation data')
  110. plt.plot(iteration_count, test_result, linestyle='-.', marker='X', label='test data')
  111. plt.legend(loc="upper left")
  112. try:
  113. os.mkdir('images/')
  114. except:
  115. print("directory already exist")
  116. plt.savefig('images/mnist_accuracy_evaluation.png', format='png')
  117. img_vector = mnist.train.images[5]
  118. img_length = int(np.sqrt(INPUT_NODE))
  119. img = np.ndarray([img_length, img_length])
  120. # print "image size: ", img_length, "*", img_length
  121. for c in range(INPUT_NODE):
  122. # print "image indices: ", c / img_length, "*", c % img_length
  123. img[c / img_length][c % img_length] = img_vector[c]
  124. plt.figure(num=2, figsize=(15, 8))
  125. plt.imshow(img)
  126. plt.show()
  127. def main(argv=None):
  128. mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  129. print "basic information of mnist dataset"
  130. print "mnist training data size: ", mnist.train.num_examples
  131. print "mnist validating data size: ", mnist.validation.num_examples
  132. print "mnist testing data size: ", mnist.test.num_examples
  133. train(mnist)
  134. if __name__ == '__main__':
  135. tf.app.run()