tfRecordExample.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. import tensorflow as tf
  2. from tensorflow.examples.tutorials.mnist import input_data
  3. import numpy as np
  4. import matplotlib.pyplot as plt
  5. import time
  6. # tfRecord defined by tf.train.Example (Protocol Buffer)
  7. # message Example{ Features features=1;}
  8. # message Features{ Map<string, Feature> feature=1 }
  9. # message Feature{oneof kind{ BytesList bytes_list=1; FloatList float_list=1; Int64List int64_list=1;}}
  10. def _int64_feature(value):
  11. return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
  12. def _bytes_feature(value):
  13. return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
  14. def save_mnist_record(output_filename="output_mnist.tfrecords"):
  15. mnist = input_data.read_data_sets("../MNIST_data", dtype=tf.uint8, one_hot=True)
  16. images = mnist.train.images
  17. labels = mnist.train.labels
  18. # define resolution
  19. pixels = images.shape[1]
  20. num_examples = mnist.train.num_examples
  21. writer = tf.python_io.TFRecordWriter(output_filename)
  22. for index in range(num_examples):
  23. # convert img to str
  24. image_raw = images[index].tostring()
  25. # create Example Protocol Buffer
  26. example = tf.train.Example(features=tf.train.Features(feature={
  27. 'pixels': _int64_feature(pixels),
  28. 'label': _int64_feature(np.argmax(labels[index])),
  29. 'image_raw': _bytes_feature(image_raw)
  30. }))
  31. writer.write(example.SerializeToString())
  32. writer.close()
  33. def read_mnist_record(input_filename="output_mnist.tfrecords"):
  34. reader = tf.TFRecordReader()
  35. filename_queue = tf.train.string_input_producer([input_filename])
  36. # read an example
  37. _, serialized_example = reader.read(filename_queue)
  38. # resolve the example
  39. features = tf.parse_single_example(
  40. serialized_example,
  41. features={
  42. # tf.FixedLenFeature return a Tensor
  43. # tf.VarLenFeature return a SparseTensor
  44. 'pixels': tf.FixedLenFeature([], tf.int64),
  45. 'label': tf.FixedLenFeature([], tf.int64),
  46. 'image_raw': tf.FixedLenFeature([], tf.string)
  47. }
  48. )
  49. # convert from str to img
  50. image = tf.decode_raw(features['image_raw'], tf.uint8)
  51. label = tf.cast(features['label'], tf.int32)
  52. pixels = tf.cast(features['pixels'], tf.int32)
  53. sess = tf.Session()
  54. coord = tf.train.Coordinator()
  55. threads = tf.train.start_queue_runners(sess=sess, coord=coord)
  56. for i in range(10):
  57. img, answer, num_pixels = sess.run([image, label, pixels])
  58. print("answer: %d, num of pixels: %d" % (answer, num_pixels))
  59. plt.imshow(img.reshape(28, 28))
  60. plt.show()
  61. time.sleep(3)
  62. def main():
  63. # save_mnist_record()
  64. read_mnist_record()
  65. if __name__ == '__main__':
  66. main()