# -*- coding: utf8 -*- import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import mnist_inference # define input, output, batch and training params BATCH_SIZE = 1000 LEARNING_RATE_BASE = 0.8 LEARNING_RATE_DECAY = 0.99 REGULARIZATION_RATE = 0.0001 TRAINING_STEPS = 30000 MOVING_AVERAGE_DECAY = 0.99 MODEL_SAVE_PATH = "model/" MODEL_NAME = "model.ckpt" # train a fully connected neural network def train(mnist): x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) y = mnist_inference.inference(x, regularizer) global_step = tf.Variable(0, trainable=False) # moving average, cross entropy, loss function with regularization and learning rate variable_average = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY, global_step) variable_average_op = variable_average.apply(tf.trainable_variables()) cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1)) cross_entropy_mean = tf.reduce_mean(cross_entropy) loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses')) learning_rate = tf.train.exponential_decay( LEARNING_RATE_BASE, global_step, mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY ) train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) with tf.control_dependencies([train_step, variable_average_op]): train_op = tf.no_op(name='train') # initialize persistence class saver = tf.train.Saver() config = tf.ConfigProto(allow_soft_placement=True) gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.4) config.gpu_options.allow_growth = True with tf.Session(config=config) as sess: sess.run(tf.global_variables_initializer()) # create directory try: os.mkdir(MODEL_SAVE_PATH) except: print("directory already exist") for i in range(TRAINING_STEPS): xs, ys = mnist.train.next_batch(BATCH_SIZE) _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys}) if i % 2500 == 0: print("after %d training step(s), loss on training batch is %g " % (step, loss_value)) saver.save( sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME), global_step=global_step ) def main(argv=None): print("start") mnist = input_data.read_data_sets("../MNIST_data", one_hot=True) print("start") train(mnist) if __name__ == '__main__': tf.app.run()