mnist_inference.py 1.1 KB

123456789101112131415161718192021222324252627282930313233
  1. # -*- coding: utf8 -*-
  2. import tensorflow as tf
  3. # define basic params
  4. INPUT_NODE = 784
  5. OUTPUT_NODE = 10
  6. LAYER1_NODE = 500
  7. def get_weight_variable(shape, regularizer):
  8. # init and get weights
  9. weights = tf.get_variable("weights", shape, initializer=tf.truncated_normal_initializer(stddev=0.1))
  10. # save regularization loss
  11. if regularizer is not None:
  12. tf.add_to_collection('losses', regularizer(weights))
  13. return weights
  14. def inference(input_tensor, regularizer):
  15. # define layer1 forward propagation
  16. with tf.variable_scope('layer1'):
  17. weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
  18. biases = tf.get_variable("biases", [LAYER1_NODE], initializer=tf.constant_initializer(0.0))
  19. layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
  20. # define layer2 forward propagation
  21. with tf.variable_scope('layer2'):
  22. weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
  23. biases = tf.get_variable("biases", [OUTPUT_NODE], initializer=tf.constant_initializer(0.0))
  24. layer2 = tf.matmul(layer1, weights) + biases
  25. return layer2