#環境
- Ubuntu 18.04
- Jupyter Notebook
- Python 3.6.8
- Tensorflow 1.12.0
解決したい問題
TFRecordに学習データを保存したが、それをパースできない
詳細
画像データをTFRecordに保存しようと考え、以下のようなコードを実行しました。
python
1writer = tf.python_io.TFRecordWriter('training.tfrecord') 2for image_path, label in zip(X_train, y_train): # (X_train, y_train) = (画像ファイルのパス, 整数のラベル) 3 image = cv2.imread(image_path) 4 image = cv2.resize(image, (150, 150)) / 255.0 # 画像サイズを揃えて、正規化してtfrecordに入れる 5 ex = tf.train.Example( 6 features = tf.train.Features( 7 feature={ 8 'image' : tf.train.Feature(float_list = tf.train.FloatList(value=image.ravel())), 9 'label' : tf.train.Feature(int64_list = tf.train.Int64List(value=[label])) 10 } 11 ) 12 ) 13 writer.write(ex.SerializeToString()) 14writer.close()
これを最も簡単な方法として、次のような方法で取り出したところうまく行きました。
python
1for record in tf.python_io.tf_record_iterator('test.tfrecord'): 2 example = tf.train.Example() 3 example.ParseFromString(record) 4 5 img = example.features.feature['image'].float_list.value 6 label = example.features.feature['label'].int64_list.value[0]
最後に取り出されたデータを画像に戻して表示すると
のようにうまく行っています。
しかしながら、これをモデルに入れるためにDataset APIを用いる形にするとうまく行きません。具体的には、
python
1# 関数をいくつか定義しておく 2def _parse_function(example_proto): 3 features = { 4 'label' : tf.FixedLenFeature((), tf.int64), 5 'image' : tf.FixedLenFeature((), tf.float32) 6 } 7 parsed_features = tf.parse_single_example(example_proto, features) 8 9 return parsed_features['image'], parsed_features['label'] 10 11def read_image(images, labels): 12 label = tf.cast(labels, tf.int32) 13 images = tf.cast(images, tf.float32) 14 image = tf.reshape(images, [150, 150, 3]) 15 16# データの読み込み 17dataset = tf.data.TFRecordDataset('training.tfrecord') 18dataset = dataset.map(_parse_function) 19dataset = dataset.map(read_image) # <- ここでエラー
read_imageを適用する段階でエラーが出ます。
bash
1ValueError: Cannot reshape a tensor with 1 elements to shape [150,150,3] (67500 elements) for 'Reshape' (op: 'Reshape') with input shapes: [], [3] and with input tensors computed as partial shapes: input[1] = [150,150,3].
これは画像をベクトルから[150, 150, 3]に変えられないということを言っていると思ったので、_parse_functionを適用直後の段階でdatasetの中身を確認してみると
bash
1<MapDataset shapes: ((), ()), types: (tf.float32, tf.int64)>
となっており、中身が全く入っていないことがわかりました。データをシャッフルしたりバッチ化して読み込みたいので、うまく行っている方ではなくこちらの方法を用いたいと考えているのですが、どのようにしたらデータを読み込むことができるでしょうか?
よろしくおねがいいたします。
あなたの回答
tips
プレビュー