質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

Q&A

0回答

2071閲覧

expand_dimsにより、(width, height, depth) から (batch, width, height, depth)に配列を変更する理由

退会済みユーザー

退会済みユーザー

総合スコア0

Python 3.x

Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

0グッド

0クリップ

投稿2017/04/20 02:43

expand_dimsにより、(width, height, depth) から (batch, width, height, depth)に配列を変更する理由がわかりません。
現在、以下のサイトを見て勉強しています。
http://www.buildinsider.net/small/booktensorflow/0204

コードは

FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_integer('epoch', 30, "訓練するEpoch数") tf.app.flags.DEFINE_string('data_dir', './data/', "訓練データのディレクトリ") tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoints/', "チェックポイントを保存するディレクトリ") tf.app.flags.DEFINE_string('test_data', None, "テストデータのパス") def main(argv=None): global_step = tf.Variable(0, trainable=False) train_placeholder = tf.placeholder(tf.float32, shape=[32, 32, 3], name='input_image') label_placeholder = tf.placeholder(tf.int32, shape=[1], name='label') # (width, height, depth) -> (batch, width, height, depth) image_node = tf.expand_dims(train_placeholder, 0) logits = model.inference(image_node) total_loss = _loss(logits, label_placeholder) train_op = _train(total_loss, global_step) top_k_op = tf.nn.in_top_k(logits, label_placeholder, 1) with tf.Session() as sess: sess.run(tf.initialize_all_variables()) total_duration = 0 for epoch in range(1, FLAGS.epoch + 1): start_time = time.time() for file_index in range(5): print('Epoch %d: %s' % (epoch, filenames[file_index])) reader = Cifar10Reader(filenames[file_index]) for index in range(10000): image = reader.read(index) _, loss_value = sess.run([train_op, total_loss], feed_dict={ train_placeholder: image.byte_array, label_placeholder: image.label }) assert not np.isnan(loss_value), \ 'Model diverged with loss = NaN' reader.close() duration = time.time() - start_time total_duration += duration prediction = _eval(sess, top_k_op, train_placeholder, label_placeholder) print('epoch %d duration = %d sec, prediction = %.3f' % (epoch, duration, prediction)) tf.train.SummaryWriter(FLAGS.checkpoint_dir, sess.graph) print('Total duration = %d sec' % total_duration)

のようになっています。
その中に

# (width, height, depth) -> (batch, width, height, depth) image_node = tf.expand_dims(train_placeholder, 0)

と書かれていて、最初 (width, height, depth) の3つの要素しか持っていなかったのにそこにbatchも追加している理由がわかりません。
batchは一度に何枚の画像を学習させるかを決める要素で、画像の読み込みの情報(縦・横・高さ)とは別のカテゴリーではと思います。
でもここでは(batch, width, height, depth)の要素を4つまとめて渡していて、batchと画像の読み込みの情報を一緒の要素に持たせることが一般的なやり方なのでしょうか?
なぜここでexpand_dimsにより、(width, height, depth) から (batch, width, height, depth)に配列を変更しているのでしょうか?

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問