image data generatorを利用してデータ拡張を実施しています。
その出力したデータをnumpy形式に保存したいと思って以下のコードを記載しました。
以下で保存した場合、5次元配列(15000,1,28,28,1)=(要素数、元のnumpyファイルのバッチ数、縦ピクセル、横ピクセル、チャネル数)になってしまいます。
これを15000,28,28,1の4次元配列で保存させたいのですが、方法ございますでしょうか?
Python
1def onehot_to_str(label): 2 """ 3 ワンホットベクトル形式のラベルをカタカナ文字に変換する 4 """ 5 dic_katakana = {"a":0,"i":1,"u":2,"e":3,"o":4,"ka":5,"ki":6,"ku":7,"ke":8,"ko":9,"sa":10,"si":11,"su":12,"se":13,"so":14} 6 label_int = np.argmax(label) 7 for key, value in dic_katakana.items(): 8 if value==label_int: 9 return key 10X=[] 11y=[] 12data = np.load("../1_data/katakana/ImageDataGenerator/original/train_data.npy") 13for i in range(len(data)): 14 # 画像読み込み 15 data = np.load("../1_data/katakana/ImageDataGenerator/original/train_data.npy") # パスは適宜変更すること 16 label = np.load("../1_data/katakana/ImageDataGenerator/original/train_label.npy") # パス 17 data = data[i:i+1] 18 label = label[i:i+1] 19 label_katakana = onehot_to_str(label) 20 21 # 軸をN,H,W,Cに入れ替え 22 data = data.transpose(0,2,3,1) 23 24 # ImageDataGeneratorのオブジェクト生成 25 datagen = ImageDataGenerator( 26 rescale=1./255, 27 width_shift_range=0.2, 28 height_shift_range=0.2, 29 zoom_range = 0.1, 30 shear_range = 0.2, 31 ) 32 33 # 生成後枚数 34 num_image = 5 35 36 # 生成 37 g = datagen.flow(data, save_to_dir="../1_data/imagedatagenerator/" +str(label_katakana), 38 save_format='png', save_prefix='out_%s_from_npy_'%label_katakana) 39 for i in range(num_image): 40 batches = g.next() 41 X.append(batches) 42 y.append(label) 43 print(batches.shape) 44X=np.array(X) 45Y=np.array(Y)
回答1件
あなたの回答
tips
プレビュー