質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

ただいまの
回答率

90.11%

batch(32)はデータ配列に対してどのようなことを行っているのでしょうか。

受付中

回答 0

投稿

  • 評価
  • クリップ 0
  • VIEW 267

dendenmushi

score 55

前提・実現したいこと

csvデータをpythonでライブラリnumpyやテンソルフローデータセットなどでパースしたい。
例えばbatch(4)であれば4つずつデータを分けていくはずなのですが、なぜbatch(32)としているのか、またそれをコメントアウトすることでなぜ以下のようなエラーが出るのか理由が明確にわからなくアドバイス頂けないでしょうか。

発生している問題・エラーメッセージ

tf.Tensor(7.7, shape=(), dtype=float32)
tf.Tensor(2.6, shape=(), dtype=float32)
tf.Tensor(2.3, shape=(), dtype=float32)

---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-287-3286b6a94e0f> in <module>()
      3 print(features[1])
      4 print(features[3])
----> 5 print(label[0])

C:\ProgramData\Anaconda3\envs\tensorflow_hajimete\lib\site-packages\tensorflow\python\ops\array_ops.py in _slice_helper(tensor, slice_spec, var)
    523         ellipsis_mask=ellipsis_mask,
    524         var=var,
--> 525         name=name)
    526 
    527 

C:\ProgramData\Anaconda3\envs\tensorflow_hajimete\lib\site-packages\tensorflow\python\ops\array_ops.py in strided_slice(input_, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, var, name)
    689       ellipsis_mask=ellipsis_mask,
    690       new_axis_mask=new_axis_mask,
--> 691       shrink_axis_mask=shrink_axis_mask)
    692 
    693   parent_name = name

C:\ProgramData\Anaconda3\envs\tensorflow_hajimete\lib\site-packages\tensorflow\python\ops\gen_array_ops.py in strided_slice(input, begin, end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask, shrink_axis_mask, name)
  10555       else:
  10556         message = e.message
> 10557       _six.raise_from(_core._status_to_exception(e.code, message), None)
  10558 
  10559 

C:\ProgramData\Anaconda3\envs\tensorflow_hajimete\lib\site-packages\six.py in raise_from(value, from_value)

InvalidArgumentError: Index out of range using input dim 0; input has only 0 dims [Op:StridedSlice] name: strided_slice/

該当のソースコード

#IRIS prediction by TensorFlow 
#https://www.tensorflow.org/get_started/eager
from __future__ import absolute_import, division, print_function

import os
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.contrib.eager as tfe


## TensorFlowのバージョンチェック
tf.enable_eager_execution()

print("TensorFlow version: {}".format(tf.VERSION))
print("Eager execution: {}".format(tf.executing_eagerly()))



## 訓練データ(CSV)を指定URLからダウンロード
train_dataset_url = "http://download.tensorflow.org/data/iris_training.csv"

train_dataset_fp = tf.keras.utils.get_file(fname=os.path.basename(train_dataset_url),origin=train_dataset_url)

print("Local copy of the dataset file: {}".format(train_dataset_fp))


## ダウンロードしたデータを整形
def parse_csv(line):
  example_defaults = [[0.], [0.], [0.], [0.], [0]]  # sets field types
  parsed_line = tf.decode_csv(line, example_defaults)
  # First 4 fields are features, combine into single tensor
  features = tf.reshape(parsed_line[:-1], shape=(4,))
  # Last field is the label
  label = tf.reshape(parsed_line[-1], shape=())
  return features, label


## Create the training tf.data.Dataset
train_dataset = tf.data.TextLineDataset(train_dataset_fp)
train_dataset = train_dataset.skip(1)             # skip the first header row
train_dataset = train_dataset.map(parse_csv)      # parse each row
train_dataset = train_dataset.shuffle(buffer_size=1000)  # randomize
#train_dataset = train_dataset.batch(32)

## View a single example entry from a batch
features, label = tfe.Iterator(train_dataset).next()
print(features[0])
print(features[1])
print(features[3])
print(label[0])

試したこと

.batch(32)の行のコメントアウトをはずした場合の出力は以下でした。

tf.Tensor([6.7 3.1 5.6 2.4], shape=(4,), dtype=float32)
tf.Tensor([5.4 3.9 1.3 0.4], shape=(4,), dtype=float32)
tf.Tensor([4.8 3.  1.4 0.3], shape=(4,), dtype=float32)
tf.Tensor(2, shape=(), dtype=int32)


また、以下別サイト説明からも32区切りになるとは思うのですが、なぜ32なのか疑問です。
ここでは4つの葉の長さとそのラベルとして1つを足した5つのデータごとに区切り1つのレコードになっているはずです。
(別サイト説明)
イメージ説明

補足情報(FW/ツールのバージョンなど)

win10
python 3.6
TensorFlow 1.12.0

  • 気になる質問をクリップする

    クリップした質問は、後からいつでもマイページで確認できます。

    またクリップした質問に回答があった際、通知やメールを受け取ることができます。

    クリップを取り消します

  • 良い質問の評価を上げる

    以下のような質問は評価を上げましょう

    • 質問内容が明確
    • 自分も答えを知りたい
    • 質問者以外のユーザにも役立つ

    評価が高い質問は、TOPページの「注目」タブのフィードに表示されやすくなります。

    質問の評価を上げたことを取り消します

  • 評価を下げられる数の上限に達しました

    評価を下げることができません

    • 1日5回まで評価を下げられます
    • 1日に1ユーザに対して2回まで評価を下げられます

    質問の評価を下げる

    teratailでは下記のような質問を「具体的に困っていることがない質問」、「サイトポリシーに違反する質問」と定義し、推奨していません。

    • プログラミングに関係のない質問
    • やってほしいことだけを記載した丸投げの質問
    • 問題・課題が含まれていない質問
    • 意図的に内容が抹消された質問
    • 広告と受け取られるような投稿

    評価が下がると、TOPページの「アクティブ」「注目」タブのフィードに表示されにくくなります。

    質問の評価を下げたことを取り消します

    この機能は開放されていません

    評価を下げる条件を満たしてません

    評価を下げる理由を選択してください

    詳細な説明はこちら

    上記に当てはまらず、質問内容が明確になっていない質問には「情報の追加・修正依頼」機能からコメントをしてください。

    質問の評価を下げる機能の利用条件

    この機能を利用するためには、以下の事項を行う必要があります。

まだ回答がついていません

15分調べてもわからないことは、teratailで質問しよう!

  • ただいまの回答率 90.11%
  • 質問をまとめることで、思考を整理して素早く解決
  • テンプレート機能で、簡単に質問をまとめられる

同じタグがついた質問を見る