#前提・実現したいこと
Variational Autoencoderを使った画像の異常検知 後編 (塩尻MLもくもく会#7)
上記のURLのサイト様のコードを参考にVAEとオリジナルデータセットを用いて異常検知を行いたいと考えております。
#発生している問題・エラーメッセージ
Python
1------------------------------------------------------------------------- 2ValueError Traceback (most recent call last) 3<ipython-input-16-68779fe8e4c7> in <module> 4 5 5 6 x_train_b = np.array(x_train_b) 6----> 7 x_train_b = cut_img(x_train_b, 80) 7 8 print("train data:",len(x_train_b)) 8 9<ipython-input-14-cf808f629f5c> in cut_img(x, number, height, width) 10 5 11 6 for i in range(number): 12----> 7 shape_0 = np.random.randint(0,x_shape[0]) 13 8 shape_1 = np.random.randint(0,x_shape[1]-height) 14 9 shape_2 = np.random.randint(0,x_shape[2]-width) 15 16mtrand.pyx in numpy.random.mtrand.RandomState.randint() 17 18_bounded_integers.pyx in numpy.random._bounded_integers._rand_int64() 19 20ValueError: low >= high
#コード
Python
1from __future__ import absolute_import 2from __future__ import division 3from __future__ import print_function 4 5from keras.layers import Lambda, Input, Dense, Reshape 6from keras.models import Model 7from keras.losses import mse 8from keras.utils import plot_model 9from keras import backend as K 10from keras.layers import BatchNormalization, Activation, Flatten 11from keras.layers.convolutional import Conv2DTranspose, Conv2D 12 13import numpy as np 14import matplotlib.pyplot as plt 15import matplotlib.colors as colors 16import os 17from sklearn import metrics 18 19os.chdir('/Users/user_name/desktop/VAE') 20os.getcwd() 21 22def result_score(model, x, name, height=80, width=80, move=2): 23 score = [] 24 25 for k in range(len(x)): 26 max_score = -1000000000 27 if k%100 == 0: 28 print(k) 29 30 for i in range(int((x.shape[1]-height)/move)+1): 31 for j in range(int((x.shape[2]-width)/move)+1): 32 x_sub = x[k, i*move:i*move+height, j*move:j*move+width, 0] 33 x_sub = x_sub.reshape(1, height, width, 1) 34 35 #従来手法 36 if name == "old_": 37 #スコア 38 temp_score = model.evaluate(x_sub, batch_size=1, verbose=0) 39 if temp_score > max_score: 40 max_score = temp_score 41 42 #提案手法 43 else: 44 #スコア 45 mu, sigma = model.predict(x_sub, batch_size=1, verbose=0) 46 loss = 0 47 for o in range(height): 48 for l in range(width): 49 loss += 0.5 * (x_sub[0,o,l,0] - mu[0,o,l,0])**2 / sigma[0,o,l,0] 50 if loss > max_score: 51 max_score = loss 52 53 score.append(max_score) 54 55 return(score) 56 57def cut_img(x, number, height=80, width=80): 58 print("cutting images ...") 59 x_out = [] 60 x_shape = x.shape 61 62 for i in range(number): 63 shape_0 = np.random.randint(0,x_shape[0]) 64 shape_1 = np.random.randint(0,x_shape[1]-height) 65 shape_2 = np.random.randint(0,x_shape[2]-width) 66 temp = x[shape_0, shape_1:shape_1+height, shape_2:shape_2+width, 0] 67 x_out.append(temp.reshape((height, width, x_shape[3]))) 68 69 print("Complete.") 70 x_out = np.array(x_out) 71 72 return x_out 73 74# reparameterization trick 75# instead of sampling from Q(z|X), sample eps = N(0,I) 76# z = z_mean + sqrt(var)*eps 77def sampling(args): 78 z_mean, z_log_var = args 79 batch = K.shape(z_mean)[0] 80 dim = K.int_shape(z_mean)[1] 81 # by default, random_normal has mean=0 and std=1.0 82 epsilon = K.random_normal(shape=(batch, dim)) 83 return z_mean + K.exp(0.5 * z_log_var) * epsilon 84 85# dataset 86from bcn_dataset import BCN_Dataset2 87(x_train, y_train), (x_test, y_test) = BCN_Dataset2.create_bcn() 88 89x_train = x_train.reshape(x_train.shape[0], 224, 224, 3) 90x_test = x_test.reshape(x_test.shape[0], 224, 224, 3) 91 92x_train = x_train.astype('float32') / 255 93x_test = x_test.astype('float32') / 255 94 95x_train_b = [] 96x_test_b = [] 97x_test_n = [] 98 99x_train_shape = x_train.shape 100 101#以下のコードを実行した際にエラーが発生します 102for i in range(len(x_train)): 103 if y_train[i] == 1:#スニーカーは7 104 temp = x_train[i,:,:,:] 105 np.append(x_train_b,temp.reshape((x_train_shape[1],x_train_shape[2],x_train_shape[3]))) 106 107x_train_b = np.array(x_train_b) 108x_train_b = cut_img(x_train_b, 50) 109print("train data:",len(x_train_b))
#試していること
x_trainの中身をprint()で確認致しましたところ、データが入力されていることが確認できたため、現在、x_train_bにデータが入力されていない(データ数が0になっている)ことが原因で上記のエラーメッセージが発生しているのではないかと考えております。しかしなぜ、データが入力されないのかがわからず頭を抱えております。元のサイト様はmnistとFasin-mnistデータセットを使っており、私はオリジナルのデータセット(トレーニングデータ数は60)を使用しているためエラーが発生しているのではないか、という仮説を立てデバッグ作業を行っておりますがうまくいってない状況です。Pythonに詳しい方がいらっしゃいましたら、ご助言をいただけたら幸いです。
#補足
使っているPCはmacOS Catalina バージョン10.15.5
Pythonのバージョンは3.6.5です
Jupiter notebookを使用しています
入力:print(x_train_b)
出力:[]
となっております。
回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2020/10/24 05:12