tflearn_usage.py 936 B

1234567891011121314151617181920212223
  1. # -*- coding:utf-8 -*-
  2. import tflearn
  3. from tflearn.layers.core import input_data, fully_connected
  4. from tflearn.layers.conv import conv_2d, max_pool_2d
  5. from tflearn.layers.estimator import regression
  6. import tflearn.datasets.mnist as mnist
  7. trainX, trainY, testX, testY = mnist.load_data(data_dir='../MNIST_data', one_hot=True)
  8. trainX = trainX.reshape([-1, 28, 28, 1])
  9. testX = testX.reshape([-1, 28, 28, 1])
  10. net = input_data(shape=[None, 28, 28, 1], name='input')
  11. net = conv_2d(net, 6, 5, activation='relu')
  12. net = max_pool_2d(net, 2)
  13. net = conv_2d(net, 16, 5, activation='relu')
  14. net = max_pool_2d(net, 2)
  15. net = fully_connected(net, 500, activation='relu')
  16. net = fully_connected(net, 10, activation='relu')
  17. net = regression(net, optimizer='sgd', learning_rate=0.01, loss='categorical_crossentropy')
  18. model = tflearn.DNN(net, tensorboard_verbose=0)
  19. model.fit(trainX, trainY, n_epoch=20, validation_set=([testX, testY]), show_metric=True)