mnist_optimization.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import matplotlib.pyplot as plt
  4. import os
  5. import numpy as np
  6. # mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  7. # print "basic information of mnist dataset"
  8. # print "mnist training data size: ", mnist.train.num_examples
  9. # print "mnist validating data size: ", mnist.validation.num_examples
  10. # print "mnist testing data size: ", mnist.test.num_examples
  11. # print "mnist example training data: ", mnist.train.images[0]
  12. # print "mnist example training data label", mnist.train.labels[0]
  13. # define input and output data size
  14. INPUT_NODE = 784
  15. OUTPUT_NODE = 10
  16. # params for neural network
  17. LAYER1_NODE = 500
  18. BATCH_SIZE = 1000
  19. LEARNING_RATE_BASE = 0.8
  20. LEARNING_RATE_DECAY = 0.999
  21. REGULARIZATION_RATE = 0.0001
  22. TRAINING_STEPS = 100000
  23. MOVING_AVERAGE_DECAY = 0.99
  24. def inference(input_tensor, avg_class, reuse=False):
  25. if avg_class is None:
  26. with tf.variable_scope("layer1", reuse=reuse):
  27. weights = tf.get_variable("weights", [INPUT_NODE, LAYER1_NODE],
  28. initializer=tf.truncated_normal_initializer(stddev=0.1))
  29. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.1))
  30. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
  31. with tf.variable_scope("layer2", reuse=reuse):
  32. weights = tf.get_variable("weights", [LAYER1_NODE, OUTPUT_NODE],
  33. initializer=tf.truncated_normal_initializer(stddev=0.1))
  34. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.1))
  35. layer2 = tf.matmul(layer1, weights) + biases
  36. return layer2
  37. else:
  38. with tf.variable_scope("layer1", reuse=reuse):
  39. weights = tf.get_variable("weights", [INPUT_NODE, LAYER1_NODE],
  40. initializer=tf.truncated_normal_initializer(stddev=0.1))
  41. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.1))
  42. layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights)) + avg_class.average(biases))
  43. with tf.variable_scope("layer2", reuse=reuse):
  44. weights = tf.get_variable("weights", [LAYER1_NODE, OUTPUT_NODE],
  45. initializer=tf.truncated_normal_initializer(stddev=0.1))
  46. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.1))
  47. layer2 = tf.matmul(layer1, avg_class.average(weights)) + avg_class.average(biases)
  48. return layer2
  49. # training process
  50. def train(mnist):
  51. x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x-input")
  52. # y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name="y-input")
  53. y_ = inference(x, None)
  54. y = inference(x, None, True)
  55. # used to store training cycles
  56. global_step = tf.Variable(0, trainable=False)
  57. # define EMA function to increase robustness when predict
  58. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  59. variable_averages_op = variable_averages.apply(tf.trainable_variables())
  60. # # forward propagation with moving average function
  61. # average_y = inference(x, variable_averages, weight1, biases1, weight2, biases2)
  62. average_y = inference(x, variable_averages, True)
  63. # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
  64. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
  65. # calc cross_entropy mean for current batch
  66. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  67. # calc L2 regularization loss function
  68. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  69. with tf.variable_scope("", reuse=True):
  70. regularization = regularizer(tf.get_variable("layer1/weights", [INPUT_NODE, LAYER1_NODE])) + regularizer(
  71. tf.get_variable("layer2/weights", [LAYER1_NODE, OUTPUT_NODE]))
  72. loss = cross_entropy_mean + regularization
  73. # learning rate = learning rate * LEARNING_RATE_DECAY ^ (global_step / decay_step)
  74. learning_rate = tf.train.exponential_decay(
  75. LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
  76. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  77. # combine backward propagation and EMA value modification
  78. with tf.control_dependencies([train_step, variable_averages_op]):
  79. train_op = tf.no_op(name="train")
  80. correct_prediction = tf.equal(tf.arg_max(average_y, 1), tf.arg_max(y_, 1))
  81. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  82. with tf.Session() as sess:
  83. sess.run(tf.global_variables_initializer())
  84. # prepare validation dataset to stop optimization
  85. validation_feed = {x: mnist.validation.images,
  86. y_: mnist.validation.labels}
  87. # define test dataset for final evaluation
  88. test_feed = {x: mnist.test.images,
  89. y_: mnist.test.labels}
  90. validation_result = range(TRAINING_STEPS / 1000)
  91. test_result = range(TRAINING_STEPS / 1000)
  92. for i in range(TRAINING_STEPS):
  93. if i % 1000 == 0:
  94. validate_acc = sess.run(accuracy, feed_dict=validation_feed)
  95. validation_result[i / 1000] = validate_acc
  96. # print "after %d training step(s), validation accuracy using average model is %g " % (i, validate_acc)
  97. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  98. sess.run(train_op, feed_dict={x: xs, y_: ys})
  99. test_acc = sess.run(accuracy, feed_dict=test_feed)
  100. test_result[i / 1000] = test_acc
  101. # print "after %d training step(s), test accuracy using average model is %g " % (i, test_acc)
  102. print validation_result
  103. print test_result
  104. saver = tf.train.Saver()
  105. saver.export_meta_graph("model.ckpt.meda.json", as_text=True)
  106. # draw a graph of accuracy using matplotlib
  107. iteration_count = range(0, TRAINING_STEPS, 1000)
  108. plt.figure(num=1, figsize=(15, 8))
  109. plt.title("Plot accuracy", size=20)
  110. plt.xlabel("iteration count", size=14)
  111. plt.ylabel("accuracy/%", size=14)
  112. validation_note = [TRAINING_STEPS - 1000, validation_result[TRAINING_STEPS / 1000 - 1]]
  113. test_note = [TRAINING_STEPS - 1000, test_result[TRAINING_STEPS / 1000 - 1]]
  114. plt.annotate('validate-' + str(validation_note), xy=(test_note[0], test_note[1]),
  115. xytext=(test_note[0] - 1000, test_note[1] - 0.1), arrowprops=dict(facecolor='black', shrink=0.05))
  116. plt.annotate('test-' + str(test_note), xy=(test_note[0], test_note[1]),
  117. xytext=(test_note[0] + 1000, test_note[1] - 0.07), arrowprops=dict(facecolor='black', shrink=0.05))
  118. plt.grid(True)
  119. plt.plot(iteration_count, validation_result, color='b', linestyle='-', marker='o', label='validation data')
  120. plt.plot(iteration_count, test_result, linestyle='-.', marker='X', label='test data')
  121. plt.legend(loc="upper left")
  122. try:
  123. os.mkdir('images/')
  124. except:
  125. print("directory already exist")
  126. plt.savefig('images/mnist_accuracy_evaluation.png', format='png')
  127. img_vector = mnist.train.images[5]
  128. img_length = int(np.sqrt(INPUT_NODE))
  129. img = np.ndarray([img_length, img_length])
  130. # print "image size: ", img_length, "*", img_length
  131. for c in range(INPUT_NODE):
  132. # print "image indices: ", c / img_length, "*", c % img_length
  133. img[c / img_length][c % img_length] = img_vector[c]
  134. plt.figure(num=2, figsize=(15, 8))
  135. plt.imshow(img)
  136. plt.show()
  137. def main(argv=None):
  138. mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  139. print "basic information of mnist dataset"
  140. print "mnist training data size: ", mnist.train.num_examples
  141. print "mnist validating data size: ", mnist.validation.num_examples
  142. print "mnist testing data size: ", mnist.test.num_examples
  143. train(mnist)
  144. if __name__ == '__main__':
  145. tf.app.run()