multiThread.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567
  1. import tensorflow as tf
  2. import numpy as np
  3. import threading
  4. import time
  5. import os
  6. import preprocessing
  7. import mnist_inference
  8. import matplotlib.pyplot as plt
  9. # ********** queue operation ***********
  10. def queue_op():
  11. # FIFOQueue & RandomShuffleQueue
  12. # maximum 2 int elements
  13. q = tf.FIFOQueue(2, "int32")
  14. init = q.enqueue_many(([0, 10],))
  15. x = q.dequeue()
  16. y = x + 1
  17. q_inc = q.enqueue([y])
  18. with tf.Session() as sess:
  19. init.run()
  20. for _ in range(5):
  21. # including dequeue, add 1, enqueue
  22. v, _ = sess.run([x, q_inc])
  23. # print(v)
  24. # tf.train.Coordinator enable thread synchronization
  25. # request_stop, should_stop, join
  26. def MyLoop(coord, worker_id):
  27. while not coord.should_stop():
  28. if np.random.rand() < 0.1:
  29. print("Stoping from id: %d" % worker_id)
  30. coord.request_stop()
  31. else:
  32. time.sleep(0.5)
  33. print("Working on id: %d" % worker_id)
  34. time.sleep(1)
  35. # coord = tf.train.Coordinator()
  36. # threads = [
  37. # threading.Thread(target=MyLoop, args=(coord, i), ) for i in range(5)
  38. # ]
  39. # # start all threads
  40. # for t in threads:
  41. # t.start()
  42. # # wait for all threads to stop
  43. # coord.join(threads)
  44. # ******** tf.QueueRunner **********
  45. def threads_mgmt():
  46. queue = tf.FIFOQueue(100, 'float')
  47. enqueue_op = queue.enqueue([tf.random_normal([1])])
  48. # create 5 threads
  49. qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
  50. # added to default collection tf.GraphKeys.QUEUE_RUNNERS,
  51. # start_queue_runner() will start all threads in the specified collection
  52. tf.train.add_queue_runner(qr)
  53. out_tensor = queue.dequeue()
  54. with tf.Session() as sess:
  55. coord = tf.train.Coordinator()
  56. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  57. for _ in range(15):
  58. print(sess.run(out_tensor)[0])
  59. time.sleep(0.2)
  60. coord.request_stop()
  61. coord.join(threads)
  62. def _int64_feature(value):
  63. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  64. def _bytes_feature(value):
  65. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  66. # simulate big data situation
  67. def generate_files():
  68. # how many files to write
  69. num_shard = 3
  70. # how much data in a file
  71. instances_per_shard = 6
  72. record_path = "record/"
  73. try:
  74. os.mkdir(record_path)
  75. except:
  76. print("directory already exist")
  77. # data 0000n-of-0000m, n means file No., m means how many files the data has been stored as
  78. for i in range(num_shard):
  79. filename = (os.path.join(record_path, "data.tfrecords-%.5d-of-%.5d" % (i, num_shard)))
  80. writer = tf.python_io.TFRecordWriter(filename)
  81. for j in range(instances_per_shard):
  82. example = tf.train.Example(features=tf.train.Features(feature={
  83. 'i': _int64_feature(i),
  84. 'j': _int64_feature(j)
  85. }))
  86. writer.write(example.SerializeToString())
  87. writer.close()
  88. def read_files():
  89. # 获取文件列表
  90. record_path = "record/"
  91. files = tf.train.match_filenames_once(os.path.join(record_path, "data.tfrecords-*"))
  92. # 1 epochs means 1 cycle
  93. filename_queue = tf.train.string_input_producer(files, num_epochs=1, shuffle=True)
  94. reader = tf.TFRecordReader()
  95. _, serialized_example = reader.read(filename_queue)
  96. features = tf.parse_single_example(
  97. serialized_example,
  98. features={
  99. 'i': tf.FixedLenFeature([], tf.int64),
  100. 'j': tf.FixedLenFeature([], tf.int64),
  101. }
  102. )
  103. with tf.Session() as sess:
  104. # match_filename_once() needs to be initialized
  105. tf.local_variables_initializer().run()
  106. print(sess.run(files))
  107. coord = tf.train.Coordinator()
  108. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  109. for i in range(18):
  110. print(sess.run([features['i'], features['j']]))
  111. coord.request_stop()
  112. coord.join(threads)
  113. return features
  114. def batch_example():
  115. features = read_files()
  116. print("____ end of read files _____")
  117. example, label = features['i'], features['j']
  118. batch_size = 3
  119. # queue capacity, larger means more memory usage, smaller means can be blocked and less efficient
  120. capacity = 1000 + 3 * batch_size
  121. # example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
  122. # min_after_dequeue represent the num of data needed for dequeue operation which is blocked when the num inadequate
  123. example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity,
  124. min_after_dequeue=6)
  125. with tf.Session() as sess:
  126. tf.local_variables_initializer().run()
  127. coord = tf.train.Coordinator()
  128. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  129. # combine
  130. for i in range(6):
  131. curr_exp_b, curr_lab_b = sess.run([example_batch, label_batch])
  132. print(curr_exp_b, curr_lab_b, "lll")
  133. coord.request_stop()
  134. coord.join(threads)
  135. # ************* use inceptionV3 data to generate data for training **************
  136. def write_record(name, image, label):
  137. writer = tf.python_io.TFRecordWriter(name)
  138. for index in range(len(image)):
  139. # convert img to str
  140. image_raw = image[index].tobytes()
  141. print(label[index])
  142. print(image[index].shape[0])
  143. print(image[index].shape[1])
  144. print(image[index].shape[2])
  145. # create Example Protocol Buffer
  146. example = tf.train.Example(features=tf.train.Features(feature={
  147. 'image': _bytes_feature(image_raw),
  148. 'label': _int64_feature(label[index]),
  149. 'height': _int64_feature(image[index].shape[0]),
  150. 'width': _int64_feature(image[index].shape[1]),
  151. 'channels': _int64_feature(image[index].shape[2]),
  152. }))
  153. writer.write(example.SerializeToString())
  154. writer.close()
  155. def generate_record(output_filename="output_flower.tfrecords"):
  156. input_data = "../inceptionv3/preprocess/validation_flower.npy"
  157. processed_data = np.load(input_data, allow_pickle=True)
  158. training_images = processed_data[0]
  159. training_labels = processed_data[1]
  160. input_data = "../inceptionv3/preprocess/test_flower.npy"
  161. processed_data = np.load(input_data, allow_pickle=True)
  162. validation_images = processed_data[0]
  163. validation_labels = processed_data[1]
  164. write_record("output_flower_train.tfrecord", training_images, training_labels)
  165. write_record("output_flower_validation.tfrecord", validation_images, validation_labels)
  166. print("training_images: " + str(len(training_labels)))
  167. print("validation_images: " + str(len(validation_labels)))
  168. def read_record(file_regex="record/output_flower_*.tfrecord"):
  169. files = tf.train.match_filenames_once(file_regex)
  170. filename_queue = tf.train.string_input_producer(files, shuffle=False)
  171. reader = tf.TFRecordReader()
  172. _, serialized_example = reader.read(filename_queue)
  173. features = tf.parse_single_example(
  174. serialized_example,
  175. features={
  176. 'image': tf.FixedLenFeature([], tf.string),
  177. 'label': tf.FixedLenFeature([], tf.int64),
  178. 'height': tf.FixedLenFeature([], tf.int64),
  179. 'width': tf.FixedLenFeature([], tf.int64),
  180. 'channels': tf.FixedLenFeature([], tf.int64)
  181. })
  182. image, label = features['image'], tf.cast(features['label'], tf.int32)
  183. height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
  184. channels = tf.cast(features['channels'], tf.int32)
  185. # image decoding
  186. decoded_img = tf.decode_raw(image, tf.float32)
  187. # decoded_img.set_shape(268203)
  188. decoded_img = tf.reshape(decoded_img,
  189. shape=[height, width, channels])
  190. return decoded_img, label
  191. def tfrecord_parser(record):
  192. features = tf.parse_single_example(
  193. record,
  194. features={
  195. 'image': tf.FixedLenFeature([], tf.string),
  196. 'label': tf.FixedLenFeature([], tf.int64),
  197. 'height': tf.FixedLenFeature([], tf.int64),
  198. 'width': tf.FixedLenFeature([], tf.int64),
  199. 'channels': tf.FixedLenFeature([], tf.int64)
  200. })
  201. image, label = features['image'], tf.cast(features['label'], tf.int32)
  202. height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
  203. channels = tf.cast(features['channels'], tf.int32)
  204. # image decoding
  205. decoded_img = tf.decode_raw(image, tf.uint8)
  206. # decoded_img.set_shape(268203)
  207. # decoded_img.set_shape([height, width, channels])
  208. decoded_img = tf.reshape(decoded_img,
  209. shape=[height, width, channels])
  210. return decoded_img, label
  211. # ** wrong image dtype may cause " Input to reshape is a tensor with xxx values, but the requested shape has xxx "
  212. # such as uint8 and float32, float32 is usually used for training, whereas uint8 more likely used for image storage
  213. # ** must have channel 3 but has channels 1 problem is caused by image preprocessing
  214. def process_data(doTrain=True):
  215. image_size = 28
  216. num_channels = 1
  217. num_of_labels = 10
  218. min_after_dequeue = 2000
  219. shuffle_buffer = 10000
  220. num_epochs = 50 # same effect as training_rounds
  221. batch_size = 1000
  222. training_rounds = 5000
  223. training_images = 55000 # 362
  224. validation_images = 5000 # 367
  225. test_images = 10000
  226. train_files = tf.train.match_filenames_once("record/mnist_train.tfrecord")
  227. validation_files = tf.train.match_filenames_once("record/mnist_validation.tfrecord")
  228. test_files = tf.train.match_filenames_once("record/mnist_test.tfrecord")
  229. # ********** define neural network structure and forward propagation **********
  230. learning_rate_base = 0.8
  231. learning_rate_decay = 0.99
  232. regularization_rate = 0.0001
  233. moving_average_decay = 0.99
  234. x = tf.placeholder(tf.float32, [None,
  235. image_size,
  236. image_size,
  237. num_channels], name='x-input')
  238. y_ = tf.placeholder(tf.float32, [None], name='y-input')
  239. regularizer = tf.contrib.layers.l2_regularizer(regularization_rate)
  240. y = mnist_inference.inference(x, True, regularizer)
  241. global_step = tf.Variable(0, trainable=False)
  242. # moving average, cross entropy, loss function with regularization and learning rate
  243. with tf.name_scope("moving_average"):
  244. variable_average = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
  245. variable_average_op = variable_average.apply(tf.trainable_variables())
  246. # calc loss
  247. with tf.name_scope("loss_function"):
  248. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.cast(y_, tf.int32))
  249. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  250. tf.summary.scalar('cross_entropy', cross_entropy_mean)
  251. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  252. with tf.name_scope("train_step"):
  253. learning_rate = tf.train.exponential_decay(
  254. learning_rate_base,
  255. global_step,
  256. training_images / batch_size,
  257. learning_rate_decay
  258. )
  259. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  260. with tf.control_dependencies([train_step, variable_average_op]):
  261. train_op = tf.no_op(name='train')
  262. # define accuracy
  263. with tf.name_scope("accuracy_calc"):
  264. prediction = tf.argmax(y, 1)
  265. answer = tf.cast(y_, tf.int64)
  266. correct_prediction = tf.equal(tf.argmax(y, 1), tf.cast(y_, tf.int64))
  267. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  268. tf.summary.scalar('accuracy', accuracy)
  269. # test_result = list(range(int(training_rounds / 500)))
  270. # # ********** original tfrecord data operator **********
  271. # decoded_img, label = read_record("record/mnist_train.tfrecord")
  272. # # img preprocessing
  273. # # distorted_img = tf.image.resize_images(decoded_img, [image_size, image_size], method=0)
  274. # distorted_img = preprocessing.process_for_train(decoded_img, image_size, image_size, None, 1)
  275. # distorted_img.set_shape([image_size, image_size, num_channels])
  276. # # print(distorted_img.shape)
  277. #
  278. # # create batch
  279. # total_sample = training_images + validation_images
  280. # capacity = min_after_dequeue + batch_size * 3
  281. # image_batch, label_batch = tf.train.shuffle_batch([distorted_img, label], batch_size=batch_size,
  282. # capacity=capacity, num_threads=64,
  283. # min_after_dequeue=min_after_dequeue)
  284. # ********** tfrecord dataset **********
  285. dataset = tf.data.TFRecordDataset(train_files)
  286. dataset = dataset.map(tfrecord_parser)
  287. dataset = dataset.map(
  288. lambda image, label: (
  289. preprocessing.process_for_train(tf.image.convert_image_dtype(image, dtype=tf.float32), image_size,
  290. image_size, None, 1), label
  291. # tf.image.resize_images(tf.image.convert_image_dtype(image, dtype=tf.float32), [image_size, image_size]), label
  292. ))
  293. dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)
  294. dataset = dataset.repeat(num_epochs)
  295. # match_filename_once has similar mechanism as placeholder
  296. iterator = dataset.make_initializable_iterator()
  297. image_batch, label_batch = iterator.get_next()
  298. # ********** validation dataset **********
  299. validation_dataset = tf.data.TFRecordDataset(validation_files)
  300. validation_dataset = validation_dataset.map(tfrecord_parser).map(
  301. lambda image, label: (
  302. tf.image.resize_images(tf.image.convert_image_dtype(image, dtype=tf.float32), [image_size, image_size]),
  303. label
  304. ))
  305. validation_dataset = validation_dataset.batch(validation_images)
  306. validation_dataset = validation_dataset.repeat(None)
  307. validation_iterator = validation_dataset.make_initializable_iterator()
  308. validation_image_batch, validation_label_batch = validation_iterator.get_next()
  309. # ********** test dataset **********
  310. test_dataset = tf.data.TFRecordDataset(test_files)
  311. test_dataset = test_dataset.map(tfrecord_parser).map(
  312. lambda image, label: (
  313. tf.image.resize_images(tf.image.convert_image_dtype(image, dtype=tf.float32), [image_size, image_size]),
  314. label
  315. ))
  316. test_dataset = test_dataset.batch(test_images)
  317. test_iterator = test_dataset.make_initializable_iterator()
  318. test_image_batch, test_label_batch = test_iterator.get_next()
  319. # logit = inference(image_batch)
  320. # loss = calc_loss(logit, label_batch)
  321. # train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
  322. # initialize persistence class
  323. saver = tf.train.Saver()
  324. config = tf.ConfigProto(allow_soft_placement=True)
  325. config.gpu_options.allow_growth = True
  326. merged = tf.summary.merge_all()
  327. with tf.Session(config=config) as sess:
  328. writer = tf.summary.FileWriter("log", sess.graph)
  329. sess.run(tf.global_variables_initializer())
  330. sess.run(tf.local_variables_initializer())
  331. # print(sess.run(tf.cast(features['label'], tf.int32)))
  332. coord = tf.train.Coordinator()
  333. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  334. print("start training........")
  335. # for i in range(training_rounds):
  336. i = 0
  337. step = 0
  338. if doTrain:
  339. sess.run(iterator.initializer)
  340. sess.run(validation_iterator.initializer)
  341. while True:
  342. i += 1
  343. try:
  344. # img = sess.run(distorted_img)
  345. # plt.imshow(img)
  346. # plt.show()
  347. xs, ys = sess.run([image_batch, label_batch])
  348. # print(xs.shape)
  349. # print(ys.shape)
  350. if i % 200 == 0:
  351. # config necessary info when training
  352. run_options = tf.RunOptions(
  353. trace_level=tf.RunOptions.FULL_TRACE
  354. )
  355. # record proto when training
  356. run_metadata = tf.RunMetadata()
  357. summary, _, loss_value, step = sess.run([merged, train_op, loss, global_step],
  358. feed_dict={x: xs, y_: ys},
  359. options=run_options, run_metadata=run_metadata)
  360. writer.add_run_metadata(run_metadata, 'step%03d' % i)
  361. writer.add_summary(summary, i)
  362. vxs, vys = sess.run([validation_image_batch, validation_label_batch])
  363. p, a, accuracy_score = sess.run([prediction, answer, accuracy], feed_dict={x: vxs, y_: vys})
  364. print("prediction: \t%s, \nanswer: \t\t%s" % (p[0:10], a[0:10]))
  365. print("after %d steps, loss: %.3f, accuracy: %.3f" % (step, loss_value, accuracy_score))
  366. else:
  367. summary, _, loss_value, step = sess.run([merged, train_op, loss, global_step],
  368. feed_dict={x: xs, y_: ys})
  369. writer.add_summary(summary, i)
  370. except tf.errors.OutOfRangeError:
  371. # i = step
  372. break
  373. sess.run(test_iterator.initializer)
  374. tp = []
  375. ta = []
  376. while True:
  377. try:
  378. txs, tys = sess.run([test_image_batch, test_label_batch])
  379. p, a = sess.run([prediction, answer], feed_dict={x: txs, y_: tys})
  380. tp.extend(p)
  381. ta.extend(a)
  382. except tf.errors.OutOfRangeError:
  383. break
  384. correct = [float(y == y_) for (y, y_) in zip(tp, ta)]
  385. accuracy_score = sum(correct) / len(correct)
  386. print("in total %d steps, total accuracy: %.3f" % (i, accuracy_score))
  387. try:
  388. os.mkdir("model/")
  389. except:
  390. print("directory already exist")
  391. saver.save(
  392. sess, os.path.join("model/", "model.ckpt"), global_step=global_step
  393. )
  394. else:
  395. ckpt = tf.train.get_checkpoint_state("model/")
  396. if ckpt and ckpt.model_checkpoint_path:
  397. sess.run(test_iterator.initializer)
  398. saver.restore(sess, ckpt.model_checkpoint_path)
  399. start = np.random.randint(int(test_images / 3), int(test_images / 2))
  400. length = 10
  401. txs, tys = sess.run([test_image_batch, test_label_batch])
  402. p, a = sess.run([prediction, answer],
  403. feed_dict={x: txs[start:start + length], y_: tys[start:start + length]})
  404. print("prediction: \t%s, \nanswer: \t\t%s" % (p, a))
  405. else:
  406. print("model not exist")
  407. coord.request_stop()
  408. coord.join(threads)
  409. # writer = tf.summary.FileWriter("log", tf.get_default_graph())
  410. writer.close()
  411. # ************* dataset operation **************
  412. def parser(record):
  413. features = tf.parse_single_example(
  414. record,
  415. features={
  416. 'feat1': tf.FixedLenFeature([], tf.int64),
  417. 'feat2': tf.FixedLenFeature([], tf.int64),
  418. })
  419. return features['feat1'], features['feat2']
  420. def dataset_basic_test():
  421. # 从tensor构建数据集
  422. input_data = [1, 2, 3, 5, 8]
  423. dataset = tf.data.Dataset.from_tensor_slices(input_data)
  424. # traverse dataset
  425. iterator = dataset.make_one_shot_iterator()
  426. x = iterator.get_next()
  427. y = x * x
  428. # 从文本构建数据集
  429. # input_files = ["file1", "file2"]
  430. # dataset = tf.data.TextLineDataset(input_files)
  431. # 从tfrecord构建数据集
  432. input_files = ["file1", "file2"]
  433. dataset = tf.data.TFRecordDataset(input_files)
  434. # call parser and replace each element with returned value
  435. dataset = dataset.map(parser)
  436. # make_one_shot_iterator 所有参数必须确定, 使用placeholder需使用initializable_iterator
  437. # reinitializable_iterator, initialize multiple times for different data source
  438. # feedable_iterator, use feed_dict to assign iterators to run
  439. iterator = dataset.make_one_shot_iterator()
  440. feat1, feat2 = iterator.get_next()
  441. with tf.Session() as sess:
  442. # for i in range(len(input_data)):
  443. # print(sess.run(y))
  444. for i in range(10):
  445. f1, f2 = sess.run([feat1, feat2])
  446. # 从tfrecord构建数据集, placeholder
  447. input_files = tf.placeholder(tf.string)
  448. dataset = tf.data.TFRecordDataset(input_files)
  449. dataset = dataset.map(parser)
  450. iterator = dataset.make_initializable_iterator()
  451. feat1, feat2 = iterator.get_next()
  452. with tf.Session() as sess:
  453. sess.run(iterator.initializer, feed_dict={
  454. input_files: ["file1", "file2"]
  455. })
  456. while True:
  457. try:
  458. sess.run([feat1, feat2])
  459. except tf.errors.OutOfRangeError:
  460. break
  461. # dataset high level API
  462. image_size = 299
  463. buffer_size = 1000 # min_after_dequeue
  464. batch_size = 100
  465. N = 10 # num_epoch
  466. dataset = dataset.map(
  467. lambda x: preprocessing.process_for_train(x, image_size, image_size, None)
  468. )
  469. dataset = dataset.shuffle(buffer_size=buffer_size)
  470. dataset = dataset.batch(batch_size=batch_size)
  471. dataset = dataset.repeat(N)
  472. # open tensorboard cmd:
  473. # tensorboard --logdir=/path/to/log --port=6006
  474. if __name__ == '__main__':
  475. # threads_mgmt()
  476. # generate_files()
  477. # read_files()
  478. # batch_example()
  479. # process_data()
  480. # generate_record()
  481. process_data(doTrain=True)
  482. # dataset_basic_test()