質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.50%

  • Python 3.x

    9816questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。

expand_dimsにより、(width, height, depth) から (batch, width, height, depth)に配列を変更する理由

受付中

回答 0

投稿

  • 評価
  • クリップ 0
  • VIEW 938
退会済みユーザー

退会済みユーザー

expand_dimsにより、(width, height, depth) から (batch, width, height, depth)に配列を変更する理由がわかりません。
現在、以下のサイトを見て勉強しています。
http://www.buildinsider.net/small/booktensorflow/0204

コードは

FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('epoch', 30, "訓練するEpoch数")
tf.app.flags.DEFINE_string('data_dir', './data/', "訓練データのディレクトリ")
tf.app.flags.DEFINE_string('checkpoint_dir', './checkpoints/',
                          "チェックポイントを保存するディレクトリ")
tf.app.flags.DEFINE_string('test_data', None, "テストデータのパス")

def main(argv=None):
  global_step = tf.Variable(0, trainable=False)

  train_placeholder = tf.placeholder(tf.float32,
                                     shape=[32, 32, 3],
                                     name='input_image')
  label_placeholder = tf.placeholder(tf.int32, shape=[1], name='label')

  # (width, height, depth) -> (batch, width, height, depth)
  image_node = tf.expand_dims(train_placeholder, 0)

  logits = model.inference(image_node)
  total_loss = _loss(logits, label_placeholder)
  train_op = _train(total_loss, global_step)

  top_k_op = tf.nn.in_top_k(logits, label_placeholder, 1)

  with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    total_duration = 0

    for epoch in range(1, FLAGS.epoch + 1):
      start_time = time.time()

      for file_index in range(5):
        print('Epoch %d: %s' % (epoch, filenames[file_index]))
        reader = Cifar10Reader(filenames[file_index])

        for index in range(10000):
          image = reader.read(index)

          _, loss_value = sess.run([train_op, total_loss],
                                   feed_dict={
                                     train_placeholder: image.byte_array,
                                     label_placeholder: image.label
                                   })

          assert not np.isnan(loss_value), \
            'Model diverged with loss = NaN'

        reader.close()

      duration = time.time() - start_time
      total_duration += duration

      prediction = _eval(sess, top_k_op,
                         train_placeholder, label_placeholder)
      print('epoch %d duration = %d sec, prediction = %.3f'
            % (epoch, duration, prediction))

      tf.train.SummaryWriter(FLAGS.checkpoint_dir, sess.graph)

    print('Total duration = %d sec' % total_duration)


のようになっています。
その中に

  # (width, height, depth) -> (batch, width, height, depth)
  image_node = tf.expand_dims(train_placeholder, 0)


と書かれていて、最初 (width, height, depth) の3つの要素しか持っていなかったのにそこにbatchも追加している理由がわかりません。
batchは一度に何枚の画像を学習させるかを決める要素で、画像の読み込みの情報(縦・横・高さ)とは別のカテゴリーではと思います。
でもここでは(batch, width, height, depth)の要素を4つまとめて渡していて、batchと画像の読み込みの情報を一緒の要素に持たせることが一般的なやり方なのでしょうか?
なぜここでexpand_dimsにより、(width, height, depth) から (batch, width, height, depth)に配列を変更しているのでしょうか?

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

まだ回答がついていません

同じタグがついた質問を見る

  • Python 3.x

    9816questions

    Python 3はPythonプログラミング言語の最新バージョンであり、2008年12月3日にリリースされました。