mnist_train.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # -*- coding: utf8 -*-
  2. import os
  3. import tensorflow as tf
  4. from tensorflow.examples.tutorials.mnist import input_data
  5. import mnist_inference
  6. # define input, output, batch and training params
  7. BATCH_SIZE = 1000
  8. LEARNING_RATE_BASE = 0.8
  9. LEARNING_RATE_DECAY = 0.99
  10. REGULARIZATION_RATE = 0.0001
  11. TRAINING_STEPS = 30000
  12. MOVING_AVERAGE_DECAY = 0.99
  13. MODEL_SAVE_PATH = "model/"
  14. MODEL_NAME = "model.ckpt"
  15. # train a fully connected neural network
  16. def train(mnist):
  17. x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
  18. y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
  19. regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
  20. y = mnist_inference.inference(x, regularizer)
  21. global_step = tf.Variable(0, trainable=False)
  22. # moving average, cross entropy, loss function with regularization and learning rate
  23. variable_average = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
  24. variable_average_op = variable_average.apply(tf.trainable_variables())
  25. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
  26. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  27. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  28. learning_rate = tf.train.exponential_decay(
  29. LEARNING_RATE_BASE,
  30. global_step,
  31. mnist.train.num_examples / BATCH_SIZE,
  32. LEARNING_RATE_DECAY
  33. )
  34. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  35. with tf.control_dependencies([train_step, variable_average_op]):
  36. train_op = tf.no_op(name='train')
  37. # initialize persistence class
  38. saver = tf.train.Saver()
  39. config = tf.ConfigProto(allow_soft_placement=True)
  40. gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
  41. config.gpu_options.allow_growth = True
  42. with tf.Session(config=config) as sess:
  43. sess.run(tf.global_variables_initializer())
  44. # create directory
  45. try:
  46. os.mkdir(MODEL_SAVE_PATH)
  47. except:
  48. print("directory already exist")
  49. for i in range(TRAINING_STEPS):
  50. xs, ys = mnist.train.next_batch(BATCH_SIZE)
  51. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  52. if i % 2500 == 0:
  53. print("after %d training step(s), loss on training batch is %g " % (step, loss_value))
  54. saver.save(
  55. sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step
  56. )
  57. def main(argv=None):
  58. print("start")
  59. mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
  60. print("start")
  61. train(mnist)
  62. if __name__ == '__main__':
  63. tf.app.run()