実現したいこと
MNIST画像を生成・識別するGANのモデルを作りたい
発生している問題・分からないこと
簡易的なGANのモデルを作成しているのですが、generatorによって生成した画像が、次のようなmnist画像にかすりもしないような画像しか生成されません。どうすればmnist画像が生成されるようになるのでしょうか。
上がデータローダーから取り出したMNIST,下が生成画像です。
エラーメッセージ
error
1エラーが表示されてはいませんが、出力画像が異なっています
該当のソースコード
python
1前準備 2 3import PIL 4PIL.PILLOW_VERSION = PIL.__version__ 5 6import torch, time, os, pickle 7import numpy as np 8import torch.nn as nn 9import torch.optim as optim 10from torch.utils.data import DataLoader 11from torch.nn import Parameter 12from torchvision import datasets, transforms 13import matplotlib.pyplot as plt 14from torch.autograd import grad 15from IPython import display 16import pylab as pl 17import time 18import torch.nn.functional as F 19 20%matplotlib inline 21 22import os 23os.environ['CUDA_VISIBLE_DEVICES '] = '0' 24 25input_size=32 26seed=42 27batch_size=128 28epochs=20 29device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 31def initialize(m): 32 if(type(m)==nn.Linear or type(m)==nn.ConvTranspose2d or type(m)==nn.Conv2d): 33 nn.init.kaiming_normal_(m.weight) 34 #m.bias.data.fill_(0.0) 35 36 37transform=transforms.Compose( [transforms.Resize(input_size), 38 transforms.ToTensor(), 39 transforms.Normalize(mean=(0.5,),std=(0.5,)), 40 ]) 41 42 43pr_dataset=datasets.MNIST("data/mnist",train=True,download=True,transform=transform) 44test_dataset=datasets.MNIST("data/mnist",train=False,download=True,transform=transform) 45 46train_dataset,valid_dataset=torch.utils.data.random_split(pr_dataset,[int(len(pr_dataset)*0.8),len(pr_dataset)-int(len(pr_dataset)*0.8)],generator=torch.Generator().manual_seed(seed)) 47 48train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size) 49valid_loader=DataLoader(valid_dataset,shuffle=False,batch_size=batch_size) 50test_loader=DataLoader(test_dataset,shuffle=False,batch_size=batch_size)
python
1モデルの定義 2 3class generator(nn.Module): 4 def __init__(self,input_size,input_dim=10,output_dim=3):##画像の形状は、(バッチサイズ、クラス数、画像サイズ) 5 super().__init__() 6 self.fc=nn.Sequential( 7 nn.Linear(input_dim,1024), 8 nn.BatchNorm1d(1024), 9 nn.ReLU(), 10 nn.Linear(1024,128*(input_size//4)**2), 11 nn.BatchNorm1d(128*(input_size//4)**2), 12 nn.ReLU(), 13 ) 14 self.conv=nn.Sequential( 15 nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1), 16 nn.BatchNorm2d(64), 17 nn.ReLU(), 18 nn.ConvTranspose2d(64,output_dim,kernel_size=4,stride=2,padding=1), 19 nn.Tanh() 20 ) 21 22 def forward(self,x): 23 x=self.fc(x) 24 x=x.view(-1,128,input_size//4,input_size//4) 25 x=self.conv(x) 26 return x 27 28class discriminator(nn.Module): 29 def __init__(self,input_size=32,input_dim=3,output_dim=1,sig=False): 30 super().__init__() 31 self.sig=sig 32 self.conv=nn.Sequential( 33 nn.Conv2d(input_dim,64,kernel_size=4,stride=2,padding=1), 34 nn.LeakyReLU(0.2), 35 nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1), 36 nn.BatchNorm2d(128), 37 nn.LeakyReLU(0.2), 38 ) 39 self.fc=nn.Sequential( 40 nn.Linear(128*(input_size//4)**2,1024), 41 nn.BatchNorm1d(1024), 42 nn.LeakyReLU(0.2), 43 nn.Linear(1024,output_dim) 44 ) 45 46 def forward(self,x): 47 x=self.conv(x) 48 x=x.view(-1,128*(input_size//4)**2) 49 x=self.fc(x) 50 if(self.sig): 51 x=torch.sigmoid(x) 52 return x 53 54gen=generator(input_size=input_size,input_dim=10,output_dim=1) 55dis=discriminator(input_size=input_size,input_dim=1,output_dim=1,sig=True) 56gen.apply(initialize) 57dis.apply(initialize) 58gen=gen.to(device) 59dis=dis.to(device) 60gen_optim=optim.Adam(gen.parameters(),lr=0.01) 61dis_optim=optim.Adam(dis.parameters(),lr=0.01) 62lossfc=nn.BCELoss() 63
python
1学習とテスト 2 3for epoch in range(epochs): 4 gen.train() 5 dis.train() 6 for x,t in train_loader: 7 t=F.one_hot(t,num_classes=10).float() 8 nowbatch=np.size(t,0) 9 t=t.to(device) 10 pic_real=x.to(device) 11 real_label=torch.ones(nowbatch,1).to(device) 12 fake_label=torch.zeros(nowbatch,1).to(device) 13 14 gen_optim.zero_grad() 15 pic_fakeone=gen(t) 16 pred_fakeone=dis(pic_fakeone) 17 pred_real=dis(pic_real) 18 gen_loss=lossfc(pred_fakeone,real_label) 19 gen_loss.backward() 20 gen_optim.step() 21 22 dis_optim.zero_grad() 23 pic_faketwo=gen(t).detach() 24 pred_faketwo=dis(pic_faketwo) 25 dis_loss=lossfc(pred_real,real_label)+lossfc(pred_faketwo,fake_label) 26 dis_loss.backward() 27 dis_optim.step() 28 29 gen.eval() 30 dis.eval() 31 num=0 32 tr=0 33 gen_lossline=np.array([]) 34 dis_lossline=np.array([]) 35 for x,t in valid_loader: 36 t=F.one_hot(t,num_classes=10).float() 37 t=t.to(device) 38 nowbatch=np.size(t,0) 39 pic_real=x.to(device) 40 real_label=torch.ones(nowbatch,1).to(device) 41 fake_label=torch.zeros(nowbatch,1).to(device) 42 pic_fakeone=gen(t) 43 pred_fakeone=dis(pic_fakeone) 44 pred_real=dis(pic_real) 45 gen_loss=lossfc(pred_fakeone,real_label) 46 pic_faketwo=gen(t).detach() 47 pred_faketwo=dis(pic_faketwo) 48 dis_loss=lossfc(pred_real,real_label)+lossfc(pred_faketwo,fake_label) 49 gen_lossline=np.append(gen_lossline,gen_loss.item()) 50 gen_lossline=np.append(dis_lossline,dis_loss.item()) 51 num+=np.size(t,0) 52 sirasu=pred_real-real_label 53 ama=pred_faketwo-fake_label 54 tr+=torch.sum((sirasu<0.5).float()).item() 55 tr+=torch.sum((ama<0.5).float()).item() 56 print(str(epoch+1)+"回目の正答率は"+str((tr/num/2))) 57 58 59testdata=test_loader.__iter__() 60image_num=int(input()) 61image=0 62label=0 63 64for ro1 in range(image_num): 65 image,label=next(testdata) 66 67# Display the first real image from the batch 68plt.imshow(image[0].permute(1, 2, 0).cpu().numpy()) # Select first image, permute, convert to numpy for display 69plt.axis('off') 70plt.show() 71 72# Generate an image using the first label from the batch and display it 73gen_input_label = F.one_hot(label[0], num_classes=10).float().to(device) # One-hot encode the first label and move to device 74generated_image = gen(gen_input_label.unsqueeze(0)).cpu().detach()[0].permute(1, 2, 0) # Fixed: Take the first image from the batch output before permuting 75plt.imshow(generated_image.numpy()) 76plt.axis('off') 77plt.show()
試したこと・調べたこと
- teratailやGoogle等で検索した
- ソースコードを自分なりに変更した
- 知人に聞いた
- その他
上記の詳細・結果
Heの初期化などでパラメータの最適化も行いましたが、全く改善されませんでした
補足
特になし
