mnist_optimization.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import matplotlib.pyplot as plt
  4. import os
  5. import numpy as np
  6. # mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  7. # print "basic information of mnist dataset"
  8. # print "mnist training data size: ", mnist.train.num_examples
  9. # print "mnist validating data size: ", mnist.validation.num_examples
  10. # print "mnist testing data size: ", mnist.test.num_examples
  11. # print "mnist example training data: ", mnist.train.images[0]
  12. # print "mnist example training data label", mnist.train.labels[0]
  13. # define input and output data size
  14. INPUT_NODE = 784
  15. OUTPUT_NODE = 10
  16. # params for neural network
  17. LAYER1_NODE = 500
  18. BATCH_SIZE = 1000
  19. LEARNING_RATE_BASE = 0.8
  20. LEARNING_RATE_DECAY = 0.999
  21. REGULARIZATION_RATE = 0.0001
  22. TRAINING_STEPS = 200000
  23. MOVING_AVERAGE_DECAY = 0.99
  24. def inference(input_tensor, avg_class, reuse=False):
  25. if avg_class is None:
  26. with tf.variable_scope("layer1", reuse=reuse):
  27. weights = tf.get_variable("weights", [INPUT_NODE, LAYER1_NODE],
  28. initializer=tf.truncated_normal_initializer(stddev=0.1))
  29. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.1))
  30. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
  31. with tf.variable_scope("layer2", reuse=reuse):
  32. weights = tf.get_variable("weights", [LAYER1_NODE, OUTPUT_NODE],
  33. initializer=tf.truncated_normal_initializer(stddev=0.1))
  34. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.1))
  35. layer2 = tf.matmul(layer1, weights) + biases
  36. return layer2
  37. else:
  38. with tf.variable_scope("layer1", reuse=reuse):
  39. weights = tf.get_variable("weights", [INPUT_NODE, LAYER1_NODE],
  40. initializer=tf.truncated_normal_initializer(stddev=0.1))
  41. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.1))
  42. layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights)) + avg_class.average(biases))
  43. with tf.variable_scope("layer2", reuse=reuse):
  44. weights = tf.get_variable("weights", [LAYER1_NODE, OUTPUT_NODE],
  45. initializer=tf.truncated_normal_initializer(stddev=0.1))
  46. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.1))
  47. layer2 = tf.matmul(layer1, avg_class.average(weights)) + avg_class.average(biases)
  48. return layer2
  49. # training process
  50. def train(mnist):
  51. x = tf.placeholder(tf.float32, [None, INPUT_NODE], name="x-input")
  52. # y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name="y-input")
  53. y_ = inference(x, None)
  54. y = inference(x, None, True)
  55. # used to store training cycles
  56. global_step = tf.Variable(0, trainable=False)
  57. # define EMA function to increase robustness when predict
  58. variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  59. variable_averages_op = variable_averages.apply(tf.trainable_variables())
  60. # # forward propagation with moving average function
  61. # average_y = inference(x, variable_averages, weight1, biases1, weight2, biases2)
  62. average_y = inference(x, variable_averages, True)
  63. # cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
  64. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.arg_max(y_, 1))
  65. # calc cross_entropy mean for current batch
  66. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  67. # calc L2 regularization loss function
  68. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  69. with tf.variable_scope("", reuse=True):
  70. regularization = regularizer(tf.get_variable("layer1/weights", [INPUT_NODE, LAYER1_NODE])) + regularizer(
  71. tf.get_variable("layer2/weights", [LAYER1_NODE, OUTPUT_NODE]))
  72. loss = cross_entropy_mean + regularization
  73. # learning rate = learning rate * LEARNING_RATE_DECAY ^ (global_step / decay_step)
  74. learning_rate = tf.train.exponential_decay(
  75. LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
  76. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  77. # combine backward propagation and EMA value modification
  78. with tf.control_dependencies([train_step, variable_averages_op]):
  79. train_op = tf.no_op(name="train")
  80. # correct_prediction = tf.equal(tf.arg_max(average_y, 1), tf.arg_max(y_, 1))
  81. correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
  82. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  83. with tf.Session() as sess:
  84. sess.run(tf.global_variables_initializer())
  85. # prepare validation dataset to stop optimization
  86. validation_feed = {x: mnist.validation.images,
  87. y_: mnist.validation.labels}
  88. # define test dataset for final evaluation
  89. test_feed = {x: mnist.test.images,
  90. y_: mnist.test.labels}
  91. validation_result = range(TRAINING_STEPS / 1000)
  92. test_result = range(TRAINING_STEPS / 1000)
  93. for i in range(TRAINING_STEPS):
  94. if i % 1000 == 0:
  95. validate_acc = sess.run(accuracy, feed_dict=validation_feed)
  96. validation_result[i / 1000] = validate_acc
  97. # print "after %d training step(s), validation accuracy using average model is %g " % (i, validate_acc)
  98. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  99. sess.run(train_op, feed_dict={x: xs, y_: ys})
  100. test_acc = sess.run(accuracy, feed_dict=test_feed)
  101. test_result[i / 1000] = test_acc
  102. # print "after %d training step(s), test accuracy using average model is %g " % (i, test_acc)
  103. print validation_result
  104. print test_result
  105. saver = tf.train.Saver()
  106. saver.export_meta_graph("model.ckpt.meda.json", as_text=True)
  107. dispImg(validation_result, test_result, "with EMA")
  108. # img_vector = mnist.train.images[5]
  109. # img_length = int(np.sqrt(INPUT_NODE))
  110. # img = np.ndarray([img_length, img_length])
  111. # # print "image size: ", img_length, "*", img_length
  112. # for c in range(INPUT_NODE):
  113. # # print "image indices: ", c / img_length, "*", c % img_length
  114. # img[c / img_length][c % img_length] = img_vector[c]
  115. # plt.figure(num=2, figsize=(15, 8))
  116. # plt.imshow(img)
  117. plt.show()
  118. def dispImg(validation_result, test_result, filename):
  119. # draw a graph of accuracy using matplotlib
  120. iteration_count = range(0, TRAINING_STEPS, 1000)
  121. plt.figure(num=1, figsize=(15, 8))
  122. plt.title("Plot accuracy", size=20)
  123. plt.xlabel("iteration count", size=14)
  124. plt.ylabel("accuracy/%", size=14)
  125. validation_note = [TRAINING_STEPS - 1000, validation_result[TRAINING_STEPS / 1000 - 1]]
  126. test_note = [TRAINING_STEPS - 1000, test_result[TRAINING_STEPS / 1000 - 1]]
  127. plt.annotate('validate-' + str(validation_note), xy=(test_note[0], test_note[1]),
  128. xytext=(test_note[0] - 1000, test_note[1] - 0.1), arrowprops=dict(facecolor='black', shrink=0.05))
  129. plt.annotate('test-' + str(test_note), xy=(test_note[0], test_note[1]),
  130. xytext=(test_note[0] + 1000, test_note[1] - 0.07), arrowprops=dict(facecolor='black', shrink=0.05))
  131. plt.grid(True)
  132. plt.plot(iteration_count, validation_result, color='b', linestyle='-', marker='o', label='validation data')
  133. plt.plot(iteration_count, test_result, linestyle='-.', marker='X', label='test data')
  134. plt.legend(loc="upper left")
  135. try:
  136. os.mkdir('images/')
  137. except:
  138. print("directory already exist")
  139. plt.savefig('images/%s.png' % filename, format='png')
  140. def main(argv=None):
  141. mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
  142. print "basic information of mnist dataset"
  143. print "mnist training data size: ", mnist.train.num_examples
  144. print "mnist validating data size: ", mnist.validation.num_examples
  145. print "mnist testing data size: ", mnist.test.num_examples
  146. train(mnist)
  147. if __name__ == '__main__':
  148. tf.app.run()