123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182 |
- # -*- utf-8 -*-
- import glob
- import os.path
- import time
- import matplotlib.pyplot as plt
- import numpy as np
- import tensorflow as tf
- from tensorflow.python.platform import gfile
- import tensorflow.contrib.slim as slim
- import tensorflow.contrib.slim.python.slim.nets.inception_v3 as inception_v3
- INPUT_DATA = 'preprocess/flower_processed_data.npy'
- TRAIN_FILE = 'model/'
- CKPT_FILE = '../../dataset/inception_v3.ckpt'
- # params
- LEARNING_RATE = 0.0001
- STEPS = 1000
- BATCH = 32
- N_CLASSES = 5
- # lasers don't load from ckpt, i.e. the last fc layer
- CHECKPOINT_EXCLUDE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
- TRAINABLE_SCOPES = 'InceptionV3/Logits,InceptionV3/AuxLogits'
- TRAINING = False
- flower_label = ["daisy雏菊", "roses玫瑰", "tulips郁金香", "sunflowers向日葵", "dandelion蒲公英"]
- def get_tuned_variables():
- exclusions = [scope.strip() for scope in CHECKPOINT_EXCLUDE_SCOPES.split(',')]
- variables_to_restore = []
- # enumerate params in v3 model, check if it need to be loaded
- for var in slim.get_model_variables():
- excluded = False
- for exclusion in exclusions:
- if var.op.name.startswith(exclusion):
- excluded = True
- break
- if not excluded:
- variables_to_restore.append(var)
- return variables_to_restore
- def get_trainable_variables():
- scopes = [scope.strip() for scope in TRAINABLE_SCOPES.split(',')]
- variables_to_train = []
- for scope in scopes:
- variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope)
- variables_to_train.extend(variables)
- return variables_to_train
- def main():
- # processed_data = np.load("preprocess/test_flower.npy", allow_pickle=True)
- # test_images = processed_data[0]
- # test_labels = processed_data[1]
- # load preprocessed data
- processed_data = np.load(INPUT_DATA, allow_pickle=True)
- training_images = processed_data[0]
- n_training_example = len(training_images)
- training_labels = processed_data[1]
- # np.save("preprocess/training_flower.npy", np.asarray([training_images, training_labels]))
- validation_images = processed_data[2]
- validation_labels = processed_data[3]
- # np.save("preprocess/validation_flower.npy", np.asarray([validation_images, validation_labels]))
- test_images = processed_data[4]
- test_labels = processed_data[5]
- # np.save("preprocess/test_flower.npy", np.asarray([test_images, test_labels]))
- print("%d training examples, %d validation examples and %d testing examples." % (
- n_training_example, len(validation_labels), len(test_labels)))
- # define inputs
- images = tf.placeholder(
- tf.float32, [None, 299, 299, 3], name='input_images')
- labels = tf.placeholder(tf.int64, [None], name='labels')
- # define model
- with slim.arg_scope(inception_v3.inception_v3_arg_scope()):
- logits, _ = inception_v3.inception_v3(images, num_classes=N_CLASSES, is_training=False)
- # get trainable variable
- trainable_variables = get_trainable_variables()
- # define cross entropy
- tf.losses.softmax_cross_entropy(tf.one_hot(labels, N_CLASSES), logits, weights=1.0)
- train_step = tf.train.RMSPropOptimizer(LEARNING_RATE).minimize(tf.losses.get_total_loss())
- # calc accuracy
- with tf.name_scope('evaluation'):
- prediction = tf.argmax(logits, 1)
- correct_answer = labels
- correct_prediction = tf.equal(tf.argmax(logits, 1), labels)
- evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
- # define func to load model
- load_fn = slim.assign_from_checkpoint_fn(
- CKPT_FILE,
- get_tuned_variables(),
- ignore_missing_vars=True
- )
- # define saver
- saver = tf.train.Saver()
- config = tf.ConfigProto(allow_soft_placement=True)
- # gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.25)
- config.gpu_options.allow_growth = True
- with tf.Session(config=config) as sess:
- # init
- init = tf.global_variables_initializer()
- sess.run(init)
- ckpt = tf.train.get_checkpoint_state(
- TRAIN_FILE
- )
- if ckpt and ckpt.model_checkpoint_path:
- saver.restore(sess, ckpt.model_checkpoint_path)
- else:
- # load origin model
- print('loading tuned variables from %s' % CKPT_FILE)
- load_fn(sess)
- start = 0
- end = BATCH
- if TRAINING:
- for i in range(STEPS):
- sess.run(train_step, feed_dict={
- images: training_images[start:end],
- labels: training_labels[start:end]
- })
- if i % 20 == 0 or i + 1 == STEPS:
- saver.save(sess, TRAIN_FILE, global_step=i)
- validation_accuracy = sess.run(evaluation_step, feed_dict={
- images: validation_images,
- labels: validation_labels
- })
- print('step %d: validation accuracy = %.1f%%' % (i, validation_accuracy * 100.0))
- start = end
- if start == n_training_example:
- start = 0
- end = start + BATCH
- if end > n_training_example:
- end = n_training_example
- # test accuracy
- test_acccuracy = sess.run(evaluation_step, feed_dict={
- images: test_images,
- labels: test_labels
- })
- print('final test accuracy = %.1f%%' % (test_acccuracy * 100.0))
- else:
- while True:
- index = np.random.randint(0, len(test_labels) - 2)
- # test accuracy
- prediction_score, correct_answer_score = sess.run([prediction, correct_answer], feed_dict={
- images: test_images[index:index+1],
- labels: test_labels[index:index+1]
- })
- result = [(flower_label[x]+str(x)) for x in prediction_score]
- answer = [(flower_label[x]+str(x)) for x in correct_answer_score]
- # print(result)
- # print(answer)
- plt.imshow(test_images[index])
- print('test result: %s, correct answer: %s' % (
- result, answer))
- plt.show()
- time.sleep(3)
- if __name__ == '__main__':
- main()
|