123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081 |
- # -*- coding: utf8 -*-
- import os
- import tensorflow as tf
- from tensorflow.examples.tutorials.mnist import input_data
- import mnist_inference
- # define input, output, batch and training params
- BATCH_SIZE = 1000
- LEARNING_RATE_BASE = 0.8
- LEARNING_RATE_DECAY = 0.99
- REGULARIZATION_RATE = 0.0001
- TRAINING_STEPS = 30000
- MOVING_AVERAGE_DECAY = 0.99
- MODEL_SAVE_PATH = "model/"
- MODEL_NAME = "model.ckpt"
- # train a fully connected neural network
- def train(mnist):
- x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
- y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
- regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
- y = mnist_inference.inference(x, regularizer)
- global_step = tf.Variable(0, trainable=False)
- # moving average, cross entropy, loss function with regularization and learning rate
- variable_average = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step)
- variable_average_op = variable_average.apply(tf.trainable_variables())
- cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))
- cross_entropy_mean = tf.reduce_mean(cross_entropy)
- loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
- learning_rate = tf.train.exponential_decay(
- LEARNING_RATE_BASE,
- global_step,
- mnist.train.num_examples / BATCH_SIZE,
- LEARNING_RATE_DECAY
- )
- train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
- with tf.control_dependencies([train_step, variable_average_op]):
- train_op = tf.no_op(name='train')
- # initialize persistence class
- saver = tf.train.Saver()
- config = tf.ConfigProto(allow_soft_placement=True)
- gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4)
- config.gpu_options.allow_growth = True
- with tf.Session(config=config) as sess:
- sess.run(tf.global_variables_initializer())
- # create directory
- try:
- os.mkdir(MODEL_SAVE_PATH)
- except:
- print("directory already exist")
- for i in range(TRAINING_STEPS):
- xs, ys = mnist.train.next_batch(BATCH_SIZE)
- _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
- if i % 2500 == 0:
- print("after %d training step(s), loss on training batch is %g " % (step, loss_value))
- saver.save(
- sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step
- )
- def main(argv=None):
- print("start")
- mnist = input_data.read_data_sets("../MNIST_data", one_hot=True)
- print("start")
- train(mnist)
- if __name__ == '__main__':
- tf.app.run()
|