mnist_train.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # -*- coding: utf8 -*-
  2. import os
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. import mnist_inference
  8. # define input, output, batch and training params
  9. BATCH_SIZE = 50
  10. LEARNING_RATE_BASE = 0.8
  11. LEARNING_RATE_DECAY = 0.99
  12. REGULARIZATION_RATE = 0.0001
  13. TRAINING_STEPS = 10000
  14. MOVING_AVERAGE_DECAY = 0.99
  15. MODEL_SAVE_PATH = "model/"
  16. MODEL_NAME = "model.ckpt"
  17. score_filename = "accuracy_score_cnn.txt"
  18. # train a convolutional neural network
  19. def train(mnist, continue_train=False):
  20. x = tf.placeholder(tf.float32, [BATCH_SIZE,
  21. mnist_inference.IMAGE_SIZE,
  22. mnist_inference.IMAGE_SIZE,
  23. mnist_inference.NUM_CHANNELS], name='x-input')
  24. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  25. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  26. y = mnist_inference.inference(x, True, regularizer)
  27. global_step = tf.Variable(0, trainable=False)
  28. # moving average, cross entropy, loss function with regularization and learning rate
  29. variable_average = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  30. variable_average_op = variable_average.apply(tf.trainable_variables())
  31. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  32. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  33. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  34. learning_rate = tf.train.exponential_decay(
  35. LEARNING_RATE_BASE,
  36. global_step,
  37. mnist.train.num_examples / BATCH_SIZE,
  38. LEARNING_RATE_DECAY
  39. )
  40. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  41. with tf.control_dependencies([train_step, variable_average_op]):
  42. train_op = tf.no_op(name='train')
  43. # initialize persistence class
  44. saver = tf.train.Saver()
  45. config = tf.ConfigProto(allow_soft_placement=True)
  46. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
  47. config.gpu_options.allow_growth = True
  48. with tf.Session(config=config) as sess:
  49. if continue_train:
  50. ckpt = tf.train.get_checkpoint_state(
  51. MODEL_SAVE_PATH
  52. )
  53. if ckpt and ckpt.model_checkpoint_path:
  54. saver.restore(sess, ckpt.model_checkpoint_path)
  55. else:
  56. sess.run(tf.global_variables_initializer())
  57. # create directory
  58. try:
  59. os.mkdir(MODEL_SAVE_PATH)
  60. except:
  61. print("directory already exist")
  62. # define accuracy
  63. correct_prediction = tf.equal(tf.arg_max(y, 1), tf.arg_max(y_, 1))
  64. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  65. test_result = list(range(int(TRAINING_STEPS / 1000)))
  66. for i in range(TRAINING_STEPS):
  67. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  68. reshaped_xs = np.reshape(xs, (
  69. BATCH_SIZE,
  70. mnist_inference.IMAGE_SIZE,
  71. mnist_inference.IMAGE_SIZE,
  72. mnist_inference.NUM_CHANNELS))
  73. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: reshaped_xs, y_: ys})
  74. txs = mnist.test.images[0:BATCH_SIZE]
  75. test_feed = {
  76. x: np.reshape(txs, (BATCH_SIZE,
  77. mnist_inference.IMAGE_SIZE,
  78. mnist_inference.IMAGE_SIZE,
  79. mnist_inference.NUM_CHANNELS)),
  80. y_: mnist.test.labels[0:BATCH_SIZE]}
  81. accuracy_score = sess.run(accuracy, feed_dict=test_feed)
  82. test_result[int(i / 1000)] = accuracy_score
  83. if i % 1000 == 0:
  84. print("after %d training step(s), loss on training batch is %g , validation accuracy = %g" % (
  85. step, loss_value, accuracy_score))
  86. saver.save(
  87. sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step
  88. )
  89. # np.savetxt(score_filename, test_result, fmt="%0.4f")
  90. #
  91. # dispImg(test_result, 'accuracy_score')
  92. # plt.show()
  93. def dispImg(test_result, filename):
  94. # draw a graph of accuracy using matplotlib
  95. iteration_count = range(0, TRAINING_STEPS, 1000)
  96. plt.figure(num=1, figsize=(15, 8))
  97. plt.title("Plot accuracy", size=20)
  98. plt.xlabel("iteration count", size=14)
  99. plt.ylabel("accuracy/%", size=14)
  100. test_note = [TRAINING_STEPS - 1000, test_result[TRAINING_STEPS / 1000 - 1]]
  101. plt.annotate('test-' + str(test_note), xy=(test_note[0], test_note[1]),
  102. xytext=(test_note[0] + 1000, test_note[1] - 0.07), arrowprops=dict(facecolor='black', shrink=0.05))
  103. plt.grid(True)
  104. plt.plot(iteration_count, test_result, linestyle='-.', marker='X', label='test data')
  105. plt.legend(loc="upper left")
  106. try:
  107. os.mkdir('images/')
  108. except:
  109. print("directory already exist")
  110. plt.savefig('images/%s.png' % filename, format='png')
  111. def main(argv=None):
  112. mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
  113. print("start")
  114. train(mnist, True)
  115. if __name__ == '__main__':
  116. tf.app.run()