機械学習の勉強をしています
線形回帰が進む様子をmatplotlibのanimationにしてみています.
epoch数をタイトルにいれたいのですが,表示されません.
具体的には以下のコードです.
import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt import matplotlib.animation as animation from matplotlib.animation import PillowWriter # hyper parameters input_size = 1 output_size = 1 num_epochs = 200 learning_rate = 0.001 # toy dataset # 15 samples, 1 features x_train = np.array([3.3, 4.4, 5.5, 6.71, 6.93, 4.168, 9.779, 6.182, 7.59, 2.167, 7.042, 10.791, 5.313, 7.997, 3.1], dtype=np.float32) y_train = np.array([1.7, 2.76, 2.09, 3.19, 1.694, 1.573, 3.366, 2.596, 2.53, 1.221, 2.827, 3.465, 1.65, 2.904, 1.3], dtype=np.float32) x_train = x_train.reshape(15,1) y_train = y_train.reshape(15,1) # linear regression model class LinearRegression(nn.Module): def __init__(self, input_size, output_size): super(LinearRegression, self).__init__() self.linear = nn.Linear(input_size, output_size) def forward(self, x): out = self.linear(x) return out model = LinearRegression(input_size, output_size) # loss and optimizer # loss function mean squared error # optimizer stochastic gradient descent criterion = nn.MSELoss() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) # Figure Setting fig = plt.figure() ims = [] # train the model for epoch in range(num_epochs): inputs = torch.from_numpy(x_train) targets = torch.from_numpy(y_train) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, targets) loss.backward() optimizer.step() if (epoch + 1) % 10 == 0: print('Epoch [%d/%d], Loss: %.4f' % (epoch + 1, num_epochs, loss.item())) predicted = model(torch.from_numpy(x_train)).detach().numpy() line, = plt.plot(x_train,predicted,"skyblue") tm = plt.title("Epoch = {0}".format(epoch+1)) ims.append([line,tm]) # save the model torch.save(model.state_dict(), "model.pkl") od, = plt.plot(x_train,y_train,"ro") #plt.title("test") plt.legend([od,line],["Original Data","Fitted Line"]) ani = animation.ArtistAnimation(fig,ims,interval=50,blit=True,repeat_delay=1000) plt.show() print("Save Animation? [y/n]") should_save_animation = str(input()) if should_save_animation == "y": anim.save("anim.gif",writer="pillow")
ArtistAnimationに渡すArtistのリストに問題があると思うのですが,解決策が分かりません.
よろしくお願いいたします.

回答1件
あなたの回答
tips
プレビュー
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2018/11/20 07:33
2018/11/20 08:39 編集
2018/11/20 08:41