tfRecordExample.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import time
  6. # tfRecord defined by tf.train.Example (Protocol Buffer)
  7. # message Example{ Features features=1;}
  8. # message Features{ Map<string, Feature> feature=1 }
  9. # message Feature{oneof kind{ BytesList bytes_list=1; FloatList float_list=1; Int64List int64_list=1;}}
  10. def _int64_feature(value):
  11. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  12. def _bytes_feature(value):
  13. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  14. def save_mnist_record(dataset=0, output_filename="record/output_mnist.tfrecords"):
  15. mnist = input_data.read_data_sets("../MNIST_data", dtype=tf.uint8, one_hot=True)
  16. images = []
  17. labels = []
  18. num_examples=0
  19. if dataset == 0:
  20. images = mnist.train.images
  21. labels = mnist.train.labels
  22. num_examples = mnist.train.num_examples
  23. elif dataset == 1:
  24. images = mnist.validation.images
  25. labels = mnist.validation.labels
  26. num_examples = mnist.validation.num_examples
  27. elif dataset == 2:
  28. images = mnist.test.images
  29. labels = mnist.test.labels
  30. num_examples = mnist.test.num_examples
  31. print(num_examples)
  32. # define resolution
  33. # pixels = images.shape[1]
  34. # print(images[0].shape)
  35. writer = tf.python_io.TFRecordWriter(output_filename)
  36. for index in range(num_examples):
  37. # convert img to str
  38. image_raw = images[index].tostring()
  39. # create Example Protocol Buffer
  40. # example = tf.train.Example(features=tf.train.Features(feature={
  41. # 'pixels': _int64_feature(pixels),
  42. # 'label': _int64_feature(np.argmax(labels[index])),
  43. # 'image_raw': _bytes_feature(image_raw)
  44. # }))
  45. example = tf.train.Example(features=tf.train.Features(feature={
  46. 'image': _bytes_feature(image_raw),
  47. 'label': _int64_feature(np.argmax(labels[index])),
  48. 'height': _int64_feature(28),
  49. 'width': _int64_feature(28),
  50. 'channels': _int64_feature(1),
  51. }))
  52. writer.write(example.SerializeToString())
  53. writer.close()
  54. def read_mnist_record(input_filename="output_mnist.tfrecords"):
  55. reader = tf.TFRecordReader()
  56. filename_queue = tf.train.string_input_producer([input_filename])
  57. # read an example
  58. _, serialized_example = reader.read(filename_queue)
  59. # resolve the example
  60. features = tf.parse_single_example(
  61. serialized_example,
  62. features={
  63. # tf.FixedLenFeature return a Tensor
  64. # tf.VarLenFeature return a SparseTensor
  65. 'pixels': tf.FixedLenFeature([], tf.int64),
  66. 'label': tf.FixedLenFeature([], tf.int64),
  67. 'image_raw': tf.FixedLenFeature([], tf.string)
  68. }
  69. )
  70. # convert from str to img
  71. image = tf.decode_raw(features['image_raw'], tf.uint8)
  72. label = tf.cast(features['label'], tf.int32)
  73. pixels = tf.cast(features['pixels'], tf.int32)
  74. sess = tf.Session()
  75. coord = tf.train.Coordinator()
  76. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  77. for i in range(10):
  78. img, answer, num_pixels = sess.run([image, label, pixels])
  79. print("answer: %d, num of pixels: %d" % (answer, num_pixels))
  80. plt.imshow(img.reshape(28, 28))
  81. plt.show()
  82. time.sleep(3)
  83. def main():
  84. save_mnist_record(0, "record/mnist_train.tfrecord")
  85. save_mnist_record(1, "record/mnist_validation.tfrecord")
  86. save_mnist_record(2, "record/mnist_test.tfrecord")
  87. # read_mnist_record()
  88. if __name__ == '__main__':
  89. main()