multiThread.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. import tensorflow as tf
  2. import numpy as np
  3. import threading
  4. import time
  5. import os
  6. import preprocessing
  7. import mnist_inference
  8. import matplotlib.pyplot as plt
  9. # ********** queue operation ***********
  10. def queue_op():
  11. # FIFOQueue & RandomShuffleQueue
  12. # maximum 2 int elements
  13. q = tf.FIFOQueue(2, "int32")
  14. init = q.enqueue_many(([0, 10],))
  15. x = q.dequeue()
  16. y = x + 1
  17. q_inc = q.enqueue([y])
  18. with tf.Session() as sess:
  19. init.run()
  20. for _ in range(5):
  21. # including dequeue, add 1, enqueue
  22. v, _ = sess.run([x, q_inc])
  23. # print(v)
  24. # tf.train.Coordinator enable thread synchronization
  25. # request_stop, should_stop, join
  26. def MyLoop(coord, worker_id):
  27. while not coord.should_stop():
  28. if np.random.rand() < 0.1:
  29. print("Stoping from id: %d" % worker_id)
  30. coord.request_stop()
  31. else:
  32. time.sleep(0.5)
  33. print("Working on id: %d" % worker_id)
  34. time.sleep(1)
  35. # coord = tf.train.Coordinator()
  36. # threads = [
  37. # threading.Thread(target=MyLoop, args=(coord, i), ) for i in range(5)
  38. # ]
  39. # # start all threads
  40. # for t in threads:
  41. # t.start()
  42. # # wait for all threads to stop
  43. # coord.join(threads)
  44. # ******** tf.QueueRunner **********
  45. def threads_mgmt():
  46. queue = tf.FIFOQueue(100, 'float')
  47. enqueue_op = queue.enqueue([tf.random_normal([1])])
  48. # create 5 threads
  49. qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
  50. # added to default collection tf.GraphKeys.QUEUE_RUNNERS,
  51. # start_queue_runner() will start all threads in the specified collection
  52. tf.train.add_queue_runner(qr)
  53. out_tensor = queue.dequeue()
  54. with tf.Session() as sess:
  55. coord = tf.train.Coordinator()
  56. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  57. for _ in range(15):
  58. print(sess.run(out_tensor)[0])
  59. time.sleep(0.2)
  60. coord.request_stop()
  61. coord.join(threads)
  62. def _int64_feature(value):
  63. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  64. def _bytes_feature(value):
  65. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  66. # simulate big data situation
  67. def generate_files():
  68. # how many files to write
  69. num_shard = 3
  70. # how much data in a file
  71. instances_per_shard = 6
  72. record_path = "record/"
  73. try:
  74. os.mkdir(record_path)
  75. except:
  76. print("directory already exist")
  77. # data 0000n-of-0000m, n means file No., m means how many files the data has been stored as
  78. for i in range(num_shard):
  79. filename = (os.path.join(record_path, "data.tfrecords-%.5d-of-%.5d" % (i, num_shard)))
  80. writer = tf.python_io.TFRecordWriter(filename)
  81. for j in range(instances_per_shard):
  82. example = tf.train.Example(features=tf.train.Features(feature={
  83. 'i': _int64_feature(i),
  84. 'j': _int64_feature(j)
  85. }))
  86. writer.write(example.SerializeToString())
  87. writer.close()
  88. def read_files():
  89. # 获取文件列表
  90. record_path = "record/"
  91. files = tf.train.match_filenames_once(os.path.join(record_path, "data.tfrecords-*"))
  92. # 1 epochs means 1 cycle
  93. filename_queue = tf.train.string_input_producer(files, num_epochs=1, shuffle=True)
  94. reader = tf.TFRecordReader()
  95. _, serialized_example = reader.read(filename_queue)
  96. features = tf.parse_single_example(
  97. serialized_example,
  98. features={
  99. 'i': tf.FixedLenFeature([], tf.int64),
  100. 'j': tf.FixedLenFeature([], tf.int64),
  101. }
  102. )
  103. with tf.Session() as sess:
  104. # match_filename_once() needs to be initialized
  105. tf.local_variables_initializer().run()
  106. print(sess.run(files))
  107. coord = tf.train.Coordinator()
  108. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  109. for i in range(18):
  110. print(sess.run([features['i'], features['j']]))
  111. coord.request_stop()
  112. coord.join(threads)
  113. return features
  114. def batch_example():
  115. features = read_files()
  116. print("____ end of read files _____")
  117. example, label = features['i'], features['j']
  118. batch_size = 3
  119. # queue capacity, larger means more memory usage, smaller means can be blocked and less efficient
  120. capacity = 1000 + 3 * batch_size
  121. # example_batch, label_batch = tf.train.batch([example, label], batch_size=batch_size, capacity=capacity)
  122. # min_after_dequeue represent the num of data needed for dequeue operation which is blocked when the num inadequate
  123. example_batch, label_batch = tf.train.shuffle_batch([example, label], batch_size=batch_size, capacity=capacity,
  124. min_after_dequeue=6)
  125. with tf.Session() as sess:
  126. tf.local_variables_initializer().run()
  127. coord = tf.train.Coordinator()
  128. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  129. # combine
  130. for i in range(6):
  131. curr_exp_b, curr_lab_b = sess.run([example_batch, label_batch])
  132. print(curr_exp_b, curr_lab_b, "lll")
  133. coord.request_stop()
  134. coord.join(threads)
  135. # ************* use inceptionV3 data to generate data for training **************
  136. def write_record(name, image, label):
  137. writer = tf.python_io.TFRecordWriter(name)
  138. for index in range(len(image)):
  139. # convert img to str
  140. image_raw = image[index].tobytes()
  141. print(label[index])
  142. print(image[index].shape[0])
  143. print(image[index].shape[1])
  144. print(image[index].shape[2])
  145. # create Example Protocol Buffer
  146. example = tf.train.Example(features=tf.train.Features(feature={
  147. 'image': _bytes_feature(image_raw),
  148. 'label': _int64_feature(label[index]),
  149. 'height': _int64_feature(image[index].shape[0]),
  150. 'width': _int64_feature(image[index].shape[1]),
  151. 'channels': _int64_feature(image[index].shape[2]),
  152. }))
  153. writer.write(example.SerializeToString())
  154. writer.close()
  155. def generate_record(output_filename="output_flower.tfrecords"):
  156. input_data = "../inceptionv3/preprocess/validation_flower.npy"
  157. processed_data = np.load(input_data, allow_pickle=True)
  158. training_images = processed_data[0]
  159. training_labels = processed_data[1]
  160. input_data = "../inceptionv3/preprocess/test_flower.npy"
  161. processed_data = np.load(input_data, allow_pickle=True)
  162. validation_images = processed_data[0]
  163. validation_labels = processed_data[1]
  164. write_record("output_flower_train.tfrecord", training_images, training_labels)
  165. write_record("output_flower_validation.tfrecord", validation_images, validation_labels)
  166. print("training_images: " + str(len(training_labels)))
  167. print("validation_images: " + str(len(validation_labels)))
  168. def read_record(file_regex="record/output_flower_*.tfrecord"):
  169. files = tf.train.match_filenames_once(file_regex)
  170. filename_queue = tf.train.string_input_producer(files, shuffle=False)
  171. reader = tf.TFRecordReader()
  172. _, serialized_example = reader.read(filename_queue)
  173. features = tf.parse_single_example(
  174. serialized_example,
  175. features={
  176. 'image': tf.FixedLenFeature([], tf.string),
  177. 'label': tf.FixedLenFeature([], tf.int64),
  178. 'height': tf.FixedLenFeature([], tf.int64),
  179. 'width': tf.FixedLenFeature([], tf.int64),
  180. 'channels': tf.FixedLenFeature([], tf.int64)
  181. })
  182. image, label = features['image'], tf.cast(features['label'], tf.int32)
  183. height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
  184. channels = tf.cast(features['channels'], tf.int32)
  185. # image decoding
  186. decoded_img = tf.decode_raw(image, tf.float32)
  187. # decoded_img.set_shape(268203)
  188. decoded_img = tf.reshape(decoded_img,
  189. shape=[height, width, channels])
  190. return decoded_img, label
  191. def tfrecord_parser(record):
  192. features = tf.parse_single_example(
  193. record,
  194. features={
  195. 'image': tf.FixedLenFeature([], tf.string),
  196. 'label': tf.FixedLenFeature([], tf.int64),
  197. 'height': tf.FixedLenFeature([], tf.int64),
  198. 'width': tf.FixedLenFeature([], tf.int64),
  199. 'channels': tf.FixedLenFeature([], tf.int64)
  200. })
  201. image, label = features['image'], tf.cast(features['label'], tf.int32)
  202. height, width = tf.cast(features['height'], tf.int32), tf.cast(features['width'], tf.int32)
  203. channels = tf.cast(features['channels'], tf.int32)
  204. # image decoding
  205. decoded_img = tf.decode_raw(image, tf.uint8)
  206. # decoded_img.set_shape(268203)
  207. # decoded_img.set_shape([height, width, channels])
  208. decoded_img = tf.reshape(decoded_img,
  209. shape=[height, width, channels])
  210. return decoded_img, label
  211. # ** wrong image dtype may cause " Input to reshape is a tensor with xxx values, but the requested shape has xxx "
  212. # such as uint8 and float32, float32 is usually used for training, whereas uint8 more likely used for image storage
  213. # ** must have channel 3 but has channels 1 problem is caused by image preprocessing
  214. def process_data(doTrain=True):
  215. image_size = 28
  216. num_channels = 1
  217. num_of_labels = 10
  218. min_after_dequeue = 2000
  219. shuffle_buffer = 10000
  220. num_epochs = 50 # same effect as training_rounds
  221. batch_size = 500
  222. training_rounds = 5000
  223. training_images = 55000 # 362
  224. validation_images = 5000 # 367
  225. test_images = 10000
  226. train_files = tf.train.match_filenames_once("record/mnist_train.tfrecord")
  227. validation_files = tf.train.match_filenames_once("record/mnist_validation.tfrecord")
  228. test_files = tf.train.match_filenames_once("record/mnist_test.tfrecord")
  229. # ********** define neural network structure and forward propagation **********
  230. learning_rate_base = 0.8
  231. learning_rate_decay = 0.99
  232. regularization_rate = 0.0001
  233. moving_average_decay = 0.99
  234. x = tf.placeholder(tf.float32, [None,
  235. image_size,
  236. image_size,
  237. num_channels], name='x-input')
  238. y_ = tf.placeholder(tf.float32, [None], name='y-input')
  239. regularizer = tf.contrib.layers.l2_regularizer(regularization_rate)
  240. y = mnist_inference.inference(x, True, regularizer)
  241. global_step = tf.Variable(0, trainable=False)
  242. # moving average, cross entropy, loss function with regularization and learning rate
  243. variable_average = tf.train.ExponentialMovingAverage(moving_average_decay, global_step)
  244. variable_average_op = variable_average.apply(tf.trainable_variables())
  245. # calc loss
  246. cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.cast(y_, tf.int32))
  247. cross_entropy_mean = tf.reduce_mean(cross_entropy)
  248. loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
  249. learning_rate = tf.train.exponential_decay(
  250. learning_rate_base,
  251. global_step,
  252. training_images / batch_size,
  253. learning_rate_decay
  254. )
  255. train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step)
  256. with tf.control_dependencies([train_step, variable_average_op]):
  257. train_op = tf.no_op(name='train')
  258. # define accuracy
  259. prediction = tf.argmax(y, 1)
  260. answer = tf.cast(y_, tf.int64)
  261. correct_prediction = tf.equal(tf.argmax(y, 1), tf.cast(y_, tf.int64))
  262. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  263. # test_result = list(range(int(training_rounds / 500)))
  264. # # ********** original tfrecord data operator **********
  265. # decoded_img, label = read_record("record/mnist_train.tfrecord")
  266. # # img preprocessing
  267. # # distorted_img = tf.image.resize_images(decoded_img, [image_size, image_size], method=0)
  268. # distorted_img = preprocessing.process_for_train(decoded_img, image_size, image_size, None, 1)
  269. # distorted_img.set_shape([image_size, image_size, num_channels])
  270. # # print(distorted_img.shape)
  271. #
  272. # # create batch
  273. # total_sample = training_images + validation_images
  274. # capacity = min_after_dequeue + batch_size * 3
  275. # image_batch, label_batch = tf.train.shuffle_batch([distorted_img, label], batch_size=batch_size,
  276. # capacity=capacity, num_threads=64,
  277. # min_after_dequeue=min_after_dequeue)
  278. # ********** tfrecord dataset **********
  279. dataset = tf.data.TFRecordDataset(train_files)
  280. dataset = dataset.map(tfrecord_parser)
  281. dataset = dataset.map(
  282. lambda image, label: (
  283. preprocessing.process_for_train(tf.image.convert_image_dtype(image, dtype=tf.float32), image_size,
  284. image_size, None, 1), label
  285. # tf.image.resize_images(tf.image.convert_image_dtype(image, dtype=tf.float32), [image_size, image_size]), label
  286. ))
  287. dataset = dataset.shuffle(shuffle_buffer).batch(batch_size)
  288. dataset = dataset.repeat(num_epochs)
  289. # match_filename_once has similar mechanism as placeholder
  290. iterator = dataset.make_initializable_iterator()
  291. image_batch, label_batch = iterator.get_next()
  292. # ********** validation dataset **********
  293. validation_dataset = tf.data.TFRecordDataset(validation_files)
  294. validation_dataset = validation_dataset.map(tfrecord_parser).map(
  295. lambda image, label: (
  296. tf.image.resize_images(tf.image.convert_image_dtype(image, dtype=tf.float32), [image_size, image_size]),
  297. label
  298. ))
  299. validation_dataset = validation_dataset.batch(validation_images)
  300. validation_dataset = validation_dataset.repeat(None)
  301. validation_iterator = validation_dataset.make_initializable_iterator()
  302. validation_image_batch, validation_label_batch = validation_iterator.get_next()
  303. # ********** test dataset **********
  304. test_dataset = tf.data.TFRecordDataset(test_files)
  305. test_dataset = test_dataset.map(tfrecord_parser).map(
  306. lambda image, label: (
  307. tf.image.resize_images(tf.image.convert_image_dtype(image, dtype=tf.float32), [image_size, image_size]),
  308. label
  309. ))
  310. test_dataset = test_dataset.batch(test_images)
  311. test_iterator = test_dataset.make_initializable_iterator()
  312. test_image_batch, test_label_batch = test_iterator.get_next()
  313. # logit = inference(image_batch)
  314. # loss = calc_loss(logit, label_batch)
  315. # train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss)
  316. # initialize persistence class
  317. saver = tf.train.Saver()
  318. config = tf.ConfigProto(allow_soft_placement=True)
  319. config.gpu_options.allow_growth = True
  320. with tf.Session(config=config) as sess:
  321. sess.run(tf.global_variables_initializer())
  322. sess.run(tf.local_variables_initializer())
  323. # print(sess.run(tf.cast(features['label'], tf.int32)))
  324. coord = tf.train.Coordinator()
  325. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  326. print("start training........")
  327. # for i in range(training_rounds):
  328. i = 0
  329. step = 0
  330. if doTrain:
  331. sess.run(iterator.initializer)
  332. sess.run(validation_iterator.initializer)
  333. while True:
  334. i += 1
  335. try:
  336. # img = sess.run(distorted_img)
  337. # plt.imshow(img)
  338. # plt.show()
  339. xs, ys = sess.run([image_batch, label_batch])
  340. # print(xs.shape)
  341. # print(ys.shape)
  342. _, loss_value, step = sess.run([train_op, loss, global_step], feed_dict={x: xs, y_: ys})
  343. if i % 200 == 0:
  344. vxs, vys = sess.run([validation_image_batch, validation_label_batch])
  345. p, a, accuracy_score = sess.run([prediction, answer, accuracy], feed_dict={x: vxs, y_: vys})
  346. print("prediction: \t%s, \nanswer: \t\t%s" % (p[0:10], a[0:10]))
  347. print("after %d steps, loss: %.3f, accuracy: %.3f" % (step, loss_value, accuracy_score))
  348. except tf.errors.OutOfRangeError:
  349. # i = step
  350. break
  351. sess.run(test_iterator.initializer)
  352. tp = []
  353. ta = []
  354. while True:
  355. try:
  356. txs, tys = sess.run([test_image_batch, test_label_batch])
  357. p, a = sess.run([prediction, answer], feed_dict={x: txs, y_: tys})
  358. tp.extend(p)
  359. ta.extend(a)
  360. except tf.errors.OutOfRangeError:
  361. break
  362. correct = [float(y == y_) for (y, y_) in zip(tp, ta)]
  363. accuracy_score = sum(correct) / len(correct)
  364. print("in total %d steps, total accuracy: %.3f" % (i, accuracy_score))
  365. try:
  366. os.mkdir("model/")
  367. except:
  368. print("directory already exist")
  369. saver.save(
  370. sess, os.path.join("model/", "model.ckpt"), global_step=global_step
  371. )
  372. else:
  373. ckpt = tf.train.get_checkpoint_state("model/")
  374. if ckpt and ckpt.model_checkpoint_path:
  375. sess.run(test_iterator.initializer)
  376. saver.restore(sess, ckpt.model_checkpoint_path)
  377. start = np.random.randint(int(test_images/3), int(test_images/2))
  378. length = 10
  379. txs, tys = sess.run([test_image_batch, test_label_batch])
  380. p, a = sess.run([prediction, answer], feed_dict={x: txs[start:start+length], y_: tys[start:start+length]})
  381. print("prediction: \t%s, \nanswer: \t\t%s" % (p, a))
  382. else:
  383. print("model not exist")
  384. coord.request_stop()
  385. coord.join(threads)
  386. # ************* dataset operation **************
  387. def parser(record):
  388. features = tf.parse_single_example(
  389. record,
  390. features={
  391. 'feat1': tf.FixedLenFeature([], tf.int64),
  392. 'feat2': tf.FixedLenFeature([], tf.int64),
  393. })
  394. return features['feat1'], features['feat2']
  395. def dataset_basic_test():
  396. # 从tensor构建数据集
  397. input_data = [1, 2, 3, 5, 8]
  398. dataset = tf.data.Dataset.from_tensor_slices(input_data)
  399. # traverse dataset
  400. iterator = dataset.make_one_shot_iterator()
  401. x = iterator.get_next()
  402. y = x * x
  403. # 从文本构建数据集
  404. # input_files = ["file1", "file2"]
  405. # dataset = tf.data.TextLineDataset(input_files)
  406. # 从tfrecord构建数据集
  407. input_files = ["file1", "file2"]
  408. dataset = tf.data.TFRecordDataset(input_files)
  409. # call parser and replace each element with returned value
  410. dataset = dataset.map(parser)
  411. # make_one_shot_iterator 所有参数必须确定, 使用placeholder需使用initializable_iterator
  412. # reinitializable_iterator, initialize multiple times for different data source
  413. # feedable_iterator, use feed_dict to assign iterators to run
  414. iterator = dataset.make_one_shot_iterator()
  415. feat1, feat2 = iterator.get_next()
  416. with tf.Session() as sess:
  417. # for i in range(len(input_data)):
  418. # print(sess.run(y))
  419. for i in range(10):
  420. f1, f2 = sess.run([feat1, feat2])
  421. # 从tfrecord构建数据集, placeholder
  422. input_files = tf.placeholder(tf.string)
  423. dataset = tf.data.TFRecordDataset(input_files)
  424. dataset = dataset.map(parser)
  425. iterator = dataset.make_initializable_iterator()
  426. feat1, feat2 = iterator.get_next()
  427. with tf.Session() as sess:
  428. sess.run(iterator.initializer, feed_dict={
  429. input_files: ["file1", "file2"]
  430. })
  431. while True:
  432. try:
  433. sess.run([feat1, feat2])
  434. except tf.errors.OutOfRangeError:
  435. break
  436. # dataset high level API
  437. image_size = 299
  438. buffer_size = 1000 # min_after_dequeue
  439. batch_size = 100
  440. N = 10 # num_epoch
  441. dataset = dataset.map(
  442. lambda x: preprocessing.process_for_train(x, image_size, image_size, None)
  443. )
  444. dataset = dataset.shuffle(buffer_size=buffer_size)
  445. dataset = dataset.batch(batch_size=batch_size)
  446. dataset = dataset.repeat(N)
  447. if __name__ == '__main__':
  448. # threads_mgmt()
  449. # generate_files()
  450. # read_files()
  451. # batch_example()
  452. # process_data()
  453. # generate_record()
  454. process_data(doTrain=False)
  455. # dataset_basic_test()