前提・実現したいこと
GoogleColaboratoryでpix2pixを導入し、自作のデータセットを用いて機械学習を実現させようと取り組んでいます。あらかじめ作られているデータセットをダウンロードして学習させることはできたのですが、自作のデータセットで行うと以下のようなエラーが発生しました。このエラーの原因と対策について考えていきたいです。
自作のデータセットは、アプリケーションを用いてある動画を1フレーム毎区切りってモノクロにし、編集ソフトを用いて元の画像(jpg,256×256)とモノクロの画像(jpg,256×256)をつなぎ合わせて1枚の画像(jpg,512×256)にしたものです。
発生している問題・エラーメッセージ
100% 1/1 [00:00<00:00, 2.02it/s] Traceback (most recent call last): File "/usr/local/lib/python3.6/dist-packages/numpy/lib/shape_base.py", line 748, in array_split Nsections = len(indices_or_sections) + 1 TypeError: object of type 'float' has no len() During handling of the above exception, another exception occurred: Traceback (most recent call last): File "make_dataset.py", line 134, in <module> size=args.img_size) File "make_dataset.py", line 76, in build_HDF5 arr_chunks = np.array_split(np.arange(num_files), num_chunks) File "/usr/local/lib/python3.6/dist-packages/numpy/lib/shape_base.py", line 754, in array_split raise ValueError('number sections must be larger than 0.') ValueError: number sections must be larger than 0.
該当のソースコード
GoogleColaboratory
1!python make_dataset.py ../../datasets/facades/ 3 --img_size 256
※ソースコードを追加しました。(make_dataset.py)
Python
1import os 2import cv2 3import h5py 4import parmap 5import argparse 6import numpy as np 7from pathlib import Path 8from tqdm import tqdm as tqdm 9import matplotlib.pylab as plt 10 11 12def format_image(img_path, size, nb_channels): 13 """ 14 Load img with opencv and reshape 15 """ 16 17 if nb_channels == 1: 18 img = cv2.imread(img_path, 0) 19 img = np.expand_dims(img, axis=-1) 20 else: 21 img = cv2.imread(img_path) 22 img = img[:, :, ::-1] # GBR to RGB 23 24 w = img.shape[1] 25 26 # Slice image in 2 to get both parts 27 img_full = img[:, :w // 2, :] 28 img_sketch = img[:, w // 2:, :] 29 30 img_full = cv2.resize(img_full, (size, size), interpolation=cv2.INTER_AREA) 31 img_sketch = cv2.resize(img_sketch, (size, size), interpolation=cv2.INTER_AREA) 32 33 if nb_channels == 1: 34 img_full = np.expand_dims(img_full, -1) 35 img_sketch = np.expand_dims(img_sketch, -1) 36 37 img_full = np.expand_dims(img_full, 0).transpose(0, 3, 1, 2) 38 img_sketch = np.expand_dims(img_sketch, 0).transpose(0, 3, 1, 2) 39 40 return img_full, img_sketch 41 42 43def build_HDF5(jpeg_dir, nb_channels, data_dir, size=256): 44 """ 45 Gather the data in a single HDF5 file. 46 """ 47 48 data_dir = os.path.join(data_dir, 'processed') 49 50 # Put train data in HDF5 51 file_name = os.path.basename(jpeg_dir.rstrip("/")) 52 hdf5_file = os.path.join(data_dir, "%s_data.h5" % file_name) 53 with h5py.File(hdf5_file, "w") as hfw: 54 55 for dset_type in ["train", "test", "val"]: 56 57 list_img = [img for img in Path(jpeg_dir).glob('%s/*.jpg' % dset_type)] 58 list_img = [str(img) for img in list_img] 59 list_img.extend(list(Path(jpeg_dir).glob('%s/*.png' % dset_type))) 60 list_img = list(map(str, list_img)) 61 list_img = np.array(list_img) 62 63 data_full = hfw.create_dataset("%s_data_full" % dset_type, 64 (0, nb_channels, size, size), 65 maxshape=(None, 3, size, size), 66 dtype=np.uint8) 67 68 data_sketch = hfw.create_dataset("%s_data_sketch" % dset_type, 69 (0, nb_channels, size, size), 70 maxshape=(None, 3, size, size), 71 dtype=np.uint8) 72 73 num_files = len(list_img) 74 chunk_size = 100 75 num_chunks = num_files / chunk_size 76 arr_chunks = np.array_split(np.arange(num_files), num_chunks) 77 78 for chunk_idx in tqdm(arr_chunks): 79 80 list_img_path = list_img[chunk_idx].tolist() 81 output = parmap.map(format_image, list_img_path, size, nb_channels, pm_parallel=False) 82 83 arr_img_full = np.concatenate([o[0] for o in output], axis=0) 84 arr_img_sketch = np.concatenate([o[1] for o in output], axis=0) 85 86 # Resize HDF5 dataset 87 data_full.resize(data_full.shape[0] + arr_img_full.shape[0], axis=0) 88 data_sketch.resize(data_sketch.shape[0] + arr_img_sketch.shape[0], axis=0) 89 90 data_full[-arr_img_full.shape[0]:] = arr_img_full.astype(np.uint8) 91 data_sketch[-arr_img_sketch.shape[0]:] = arr_img_sketch.astype(np.uint8) 92 93def check_HDF5(jpeg_dir, nb_channels): 94 """ 95 Plot images with landmarks to check the processing 96 """ 97 98 # Get hdf5 file 99 file_name = os.path.basename(jpeg_dir.rstrip("/")) 100 hdf5_file = os.path.join(data_dir, "%s_data.h5" % file_name) 101 102 with h5py.File(hdf5_file, "r") as hf: 103 data_full = hf["train_data_full"] 104 data_sketch = hf["train_data_sketch"] 105 for i in range(data_full.shape[0]): 106 plt.figure() 107 img = data_full[i, :, :, :].transpose(1,2,0) 108 img2 = data_sketch[i, :, :, :].transpose(1,2,0) 109 img = np.concatenate((img, img2), axis=1) 110 if nb_channels == 1: 111 plt.imshow(img[:, :, 0], cmap="gray") 112 else: 113 plt.imshow(img) 114 plt.show() 115 plt.clf() 116 plt.close() 117 118 119if __name__ == '__main__': 120 121 parser = argparse.ArgumentParser(description='Build dataset') 122 parser.add_argument('jpeg_dir', type=str, help='path to jpeg images') 123 parser.add_argument('nb_channels', type=int, help='number of image channels') 124 parser.add_argument('--img_size', default=256, type=int, 125 help='Desired Width == Height') 126 parser.add_argument('--do_plot', action="store_true", 127 help='Plot the images to make sure the data processing went OK') 128 parser.add_argument('--data_dir', default='../../data', type=str, help='Data directory') 129 args = parser.parse_args() 130 131 build_HDF5(args.jpeg_dir, 132 args.nb_channels, 133 args.data_dir, 134 size=args.img_size) 135 136 if args.do_plot: 137 check_HDF5(args.jpeg_dir, args.nb_channels) 138 139
試したこと
データセットの作り方に問題があると考え調べてみたところ、データセットの画像には何かしら意味付けがされているのではないか、単に編集アプリで作るだけではいけないのではないかと考えました。が、理解が浅く、どうすればよいのか分からない状態です。
補足情報(FW/ツールのバージョンなど)
自分はPythonと機械学習、共に初心者です。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2019/10/09 13:07
2019/10/10 07:05 編集
退会済みユーザー
2019/10/10 11:47
2019/10/10 14:26