google colaboratory環境において
1.Univariate LSTM Modelsの
1-2.Vanilla LSTMを利用し
なのでこのVanilla LSTMのサンプルソースをone-hot化して、実行するように変更してみたのですが
1# univariate lstm example 2from numpy import array 3from numpy import argmax 4 5from keras.models import Sequential 6from keras.layers import LSTM 7from keras.layers import Dense 8 9 10# one hot encode 11def one_hot_encode(sequence, n_unique=1000): 12 encoding = list() 13 for value in sequence: 14 vector = [0 for _ in range(n_unique)] 15 vector[value] = 1 16 encoding.append(vector) 17 return array(encoding) 18 19 20# one hot encode sequence 21def one_hot_encode2(sequence, n_unique=1000): 22 encoding = list() 23 24 #for value in sequence: 25 vector = [0 for _ in range(n_unique)] 26 vector[sequence] = 1 27 encoding.append(vector) 28 a = array(encoding) 29 a = a.reshape(1000) 30 31 return a 32 #return array(encoding) 33 34 35# decode a one hot encoded string 36def one_hot_decode(encoded_seq): 37 return [argmax(vector) for vector in encoded_seq] 38 39 40# split a univariate sequence into samples 41def split_sequence(sequence, n_steps): 42 X, y = list(), list() 43 for i in range(len(sequence)): 44 # find the end of this pattern 45 end_ix = i + n_steps 46 # check if we are beyond the sequence 47 if end_ix > len(sequence)-1: 48 break 49 # gather input and output parts of the pattern 50 seq_x, seq_y = sequence[i:end_ix], sequence[end_ix] 51 X.append(one_hot_encode(seq_x)) 52 y.append(one_hot_encode2(seq_y)) 53 return array(X), array(y) 54 55 56# define input sequence 57raw_seq = [10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 110, 120, 130, 140, 150, 160,170, 180,190,200,210] 58# choose a number of time steps 59n_steps = 3 60 61 62# split into samples 63X, y = split_sequence(raw_seq, n_steps) 64 65# reshape from [samples, timesteps] into [samples, timesteps, features] 66n_features = 1000 67X = X.reshape((X.shape[0], X.shape[1], n_features)) ########## (1) 68 69print(y) ########## (2) 70y = y.reshape(len(y),1000) ########## (3) 71 72 73# define model 74model = Sequential() 75model.add(LSTM(50, activation='relu', batch_input_shape=(X.shape[0], X.shape[1], n_features,),dropout=0.15)) 76model.add(Dense(1000, activation="softmax")) 77model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) 78 79 80# fit model 81history =, y, epochs=50, batch_size=2, verbose=1) 82 83 84# demonstrate prediction 85x_input = one_hot_encode(array([190,200,210])) 86 87x_input = x_input.reshape((1, n_steps, n_features)) ########## (4) 88 89 90yhat = model.predict(x_input, verbose=0) 91
# データ内容 [[[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] ... [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]]] #データ形式 (18, 3, 1000)
#データ内容 [[[0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0]] ... [[0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0]] [[0 0 0 ... 0 0 0]]] #データ形式 (18, 1, 1000)
# データ内容 [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]] #データ形式 (18, 1000)
#データ内容 [[[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]]] #データ形式 (1, 3, 1000)
--------------------------------------------------------------------------- InvalidArgumentError Traceback (most recent call last) <ipython-input-54-8021a9079097> in <module>() 11 print(x_input.shape) 12 ---> 13 yhat = model.predict(x_input, batch_size=2, verbose=0) 14 15 #print(yhat) 6 frames /usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/ in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 58 ctx.ensure_initialized() 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, ---> 60 inputs, attrs, num_outputs) 61 except core._NotOkStatusException as e: 62 if name is not None: InvalidArgumentError: Specified a list with shape [18,1000] from a tensor with shape [1,1000] [[node sequential_2/lstm_2/TensorArrayUnstack/TensorListFromTensor (defined at <ipython-input-44-c1377d111e30>:13) ]] [Op:__inference_predict_function_9471] Function call stack: predict_function