こちらの強化学習のコードをcolabで動かしています。一から学習させるために、モデルの一部(VAE)を訓練しようとしているのですが、以下のエラーが発生しています。colabのコードはこちらです(エラーが発生しているのは17番のコードです)。Valueエラーの意味が調べてもいまいち理解できないのですが、何が問題なのか教えていただけないでしょうか?
走らせようとしたコード:
#Training VAE !python3 -m worldmodels.vision.train_vae --load_model 0 --data local
結果&エラー:
local files that contain episode in controller-rollouts found 0 files 2020-10-11 04:53:35.446580: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA 2020-10-11 04:53:35.450185: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2200000000 Hz 2020-10-11 04:53:35.450371: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x1da6a00 executing computations on platform Host. Devices: 2020-10-11 04:53:35.450401: I tensorflow/compiler/xla/service/service.cc:175] StreamExecutor device (0): <undefined>, <undefined> results_dir /content/world-models-experiments/vae-training/vae-training cli ------ Namespace(data='local', dataset='random-rollouts', epochs=10, load_model='0', log_every=100, save_every=1000) training params ------ {'model': <worldmodels.vision.vae.VAE object at 0x7f1e19900dd8>, 'epochs': 10, 'batch_size': 256, 'log_every': 100, 'save_every': 1000, 'records': []} vision params ------ {'latent_dim': 32, 'learning_rate': 0.0001, 'load_model': False, 'results_dir': '/content/world-models-experiments/vae-training'} Traceback (most recent call last): File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main "__main__", mod_spec) File "/usr/lib/python3.6/runpy.py", line 85, in _run_code exec(code, run_globals) File "/content/world-models/worldmodels/vision/train_vae.py", line 117, in <module> train(**training_params) File "/content/world-models/worldmodels/vision/train_vae.py", line 21, in train dataset = shuffle_samples(parse_episode, records, batch_size) File "/content/world-models/worldmodels/data/tf_records.py", line 71, in shuffle_samples cycle_length=num_cpu File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 1256, in interleave block_length, num_parallel_calls) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 3367, in __init__ map_func, self._transformation_name(), dataset=input_dataset) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 2591, in __init__ self._function = wrapper_fn._get_concrete_function_internal() File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1366, in _get_concrete_function_internal *args, **kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1360, in _get_concrete_function_internal_garbage_collected graph_function, _, _ = self._maybe_define_function(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1648, in _maybe_define_function graph_function = self._create_graph_function(args, kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1541, in _create_graph_function capture_by_value=self._capture_by_value), File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/func_graph.py", line 716, in func_graph_from_py_func func_outputs = python_func(*func_args, **func_kwargs) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 2585, in wrapper_fn ret = _wrapper_helper(*args) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 2530, in _wrapper_helper ret = func(*nested_args) File "/content/world-models/worldmodels/data/tf_records.py", line 69, in <lambda> lambda x: tf.data.TFRecordDataset(x), File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/readers.py", line 295, in __init__ filenames = _create_or_validate_filenames_dataset(filenames) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/readers.py", line 57, in _create_or_validate_filenames_dataset filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1100, in convert_to_tensor return convert_to_tensor_v2(value, dtype, preferred_dtype, name) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1158, in convert_to_tensor_v2 as_ref=False) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1237, in internal_convert_to_tensor ret = conversion_func(value, dtype=dtype, name=name, as_ref=as_ref) File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 1036, in _TensorTensorConversionFunction (dtype.name, t.dtype.name, str(t))) ValueError: Tensor conversion requested dtype string for Tensor with dtype float32: 'Tensor("args_0:0", shape=(), dtype=float32)'
あなたの回答
tips
プレビュー