質問するログイン新規登録

Q&A

1回答

207閲覧

GANでの画像生成について

lake

総合スコア3

PyTorch

PyTorchは、オープンソースのPython向けの機械学習ライブラリ。Facebookの人工知能研究グループが開発を主導しています。強力なGPUサポートを備えたテンソル計算、テープベースの自動微分による柔軟なニューラルネットワークの記述が可能です。

AI(人工知能)

AI(人工知能)とは、言語の理解や推論、問題解決などの知的行動を人間に代わってコンピューターに行わせる技術のことです。

生成AI

学習データを基にテキスト、画像、コードなどの新しいコンテンツを自律的に生成するAI。従来のデータの分類や予測を行うAIとは異なり、0から1を生み出す創造的なアウトプットが可能な点が特徴です。

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

1クリップ

投稿2026/03/29 15:21

0

1

実現したいこと

MNIST画像を生成・識別するGANのモデルを作りたい

発生している問題・分からないこと

簡易的なGANのモデルを作成しているのですが、generatorによって生成した画像が、次のようなmnist画像にかすりもしないような画像しか生成されません。どうすればmnist画像が生成されるようになるのでしょうか。イメージ説明

上がデータローダーから取り出したMNIST,下が生成画像です。

エラーメッセージ

error

1エラーが表示されてはいませんが、出力画像が異なっています

該当のソースコード

python

1前準備 2 3import PIL 4PIL.PILLOW_VERSION = PIL.__version__ 5 6import torch, time, os, pickle 7import numpy as np 8import torch.nn as nn 9import torch.optim as optim 10from torch.utils.data import DataLoader 11from torch.nn import Parameter 12from torchvision import datasets, transforms 13import matplotlib.pyplot as plt 14from torch.autograd import grad 15from IPython import display 16import pylab as pl 17import time 18import torch.nn.functional as F 19 20%matplotlib inline 21 22import os 23os.environ['CUDA_VISIBLE_DEVICES '] = '0' 24 25input_size=32 26seed=42 27batch_size=128 28epochs=20 29device=torch.device("cuda" if torch.cuda.is_available() else "cpu") 30 31def initialize(m): 32 if(type(m)==nn.Linear or type(m)==nn.ConvTranspose2d or type(m)==nn.Conv2d): 33 nn.init.kaiming_normal_(m.weight) 34 #m.bias.data.fill_(0.0) 35 36 37transform=transforms.Compose( [transforms.Resize(input_size), 38 transforms.ToTensor(), 39 transforms.Normalize(mean=(0.5,),std=(0.5,)), 40 ]) 41 42 43pr_dataset=datasets.MNIST("data/mnist",train=True,download=True,transform=transform) 44test_dataset=datasets.MNIST("data/mnist",train=False,download=True,transform=transform) 45 46train_dataset,valid_dataset=torch.utils.data.random_split(pr_dataset,[int(len(pr_dataset)*0.8),len(pr_dataset)-int(len(pr_dataset)*0.8)],generator=torch.Generator().manual_seed(seed)) 47 48train_loader=DataLoader(train_dataset,shuffle=True,batch_size=batch_size) 49valid_loader=DataLoader(valid_dataset,shuffle=False,batch_size=batch_size) 50test_loader=DataLoader(test_dataset,shuffle=False,batch_size=batch_size)

python

1モデルの定義 2 3class generator(nn.Module): 4 def __init__(self,input_size,input_dim=10,output_dim=3):##画像の形状は、(バッチサイズ、クラス数、画像サイズ) 5 super().__init__() 6 self.fc=nn.Sequential( 7 nn.Linear(input_dim,1024), 8 nn.BatchNorm1d(1024), 9 nn.ReLU(), 10 nn.Linear(1024,128*(input_size//4)**2), 11 nn.BatchNorm1d(128*(input_size//4)**2), 12 nn.ReLU(), 13 ) 14 self.conv=nn.Sequential( 15 nn.ConvTranspose2d(128,64,kernel_size=4,stride=2,padding=1), 16 nn.BatchNorm2d(64), 17 nn.ReLU(), 18 nn.ConvTranspose2d(64,output_dim,kernel_size=4,stride=2,padding=1), 19 nn.Tanh() 20 ) 21 22 def forward(self,x): 23 x=self.fc(x) 24 x=x.view(-1,128,input_size//4,input_size//4) 25 x=self.conv(x) 26 return x 27 28class discriminator(nn.Module): 29 def __init__(self,input_size=32,input_dim=3,output_dim=1,sig=False): 30 super().__init__() 31 self.sig=sig 32 self.conv=nn.Sequential( 33 nn.Conv2d(input_dim,64,kernel_size=4,stride=2,padding=1), 34 nn.LeakyReLU(0.2), 35 nn.Conv2d(64,128,kernel_size=4,stride=2,padding=1), 36 nn.BatchNorm2d(128), 37 nn.LeakyReLU(0.2), 38 ) 39 self.fc=nn.Sequential( 40 nn.Linear(128*(input_size//4)**2,1024), 41 nn.BatchNorm1d(1024), 42 nn.LeakyReLU(0.2), 43 nn.Linear(1024,output_dim) 44 ) 45 46 def forward(self,x): 47 x=self.conv(x) 48 x=x.view(-1,128*(input_size//4)**2) 49 x=self.fc(x) 50 if(self.sig): 51 x=torch.sigmoid(x) 52 return x 53 54gen=generator(input_size=input_size,input_dim=10,output_dim=1) 55dis=discriminator(input_size=input_size,input_dim=1,output_dim=1,sig=True) 56gen.apply(initialize) 57dis.apply(initialize) 58gen=gen.to(device) 59dis=dis.to(device) 60gen_optim=optim.Adam(gen.parameters(),lr=0.01) 61dis_optim=optim.Adam(dis.parameters(),lr=0.01) 62lossfc=nn.BCELoss() 63

python

1学習とテスト 2 3for epoch in range(epochs): 4 gen.train() 5 dis.train() 6 for x,t in train_loader: 7 t=F.one_hot(t,num_classes=10).float() 8 nowbatch=np.size(t,0) 9 t=t.to(device) 10 pic_real=x.to(device) 11 real_label=torch.ones(nowbatch,1).to(device) 12 fake_label=torch.zeros(nowbatch,1).to(device) 13 14 gen_optim.zero_grad() 15 pic_fakeone=gen(t) 16 pred_fakeone=dis(pic_fakeone) 17 pred_real=dis(pic_real) 18 gen_loss=lossfc(pred_fakeone,real_label) 19 gen_loss.backward() 20 gen_optim.step() 21 22 dis_optim.zero_grad() 23 pic_faketwo=gen(t).detach() 24 pred_faketwo=dis(pic_faketwo) 25 dis_loss=lossfc(pred_real,real_label)+lossfc(pred_faketwo,fake_label) 26 dis_loss.backward() 27 dis_optim.step() 28 29 gen.eval() 30 dis.eval() 31 num=0 32 tr=0 33 gen_lossline=np.array([]) 34 dis_lossline=np.array([]) 35 for x,t in valid_loader: 36 t=F.one_hot(t,num_classes=10).float() 37 t=t.to(device) 38 nowbatch=np.size(t,0) 39 pic_real=x.to(device) 40 real_label=torch.ones(nowbatch,1).to(device) 41 fake_label=torch.zeros(nowbatch,1).to(device) 42 pic_fakeone=gen(t) 43 pred_fakeone=dis(pic_fakeone) 44 pred_real=dis(pic_real) 45 gen_loss=lossfc(pred_fakeone,real_label) 46 pic_faketwo=gen(t).detach() 47 pred_faketwo=dis(pic_faketwo) 48 dis_loss=lossfc(pred_real,real_label)+lossfc(pred_faketwo,fake_label) 49 gen_lossline=np.append(gen_lossline,gen_loss.item()) 50 gen_lossline=np.append(dis_lossline,dis_loss.item()) 51 num+=np.size(t,0) 52 sirasu=pred_real-real_label 53 ama=pred_faketwo-fake_label 54 tr+=torch.sum((sirasu<0.5).float()).item() 55 tr+=torch.sum((ama<0.5).float()).item() 56 print(str(epoch+1)+"回目の正答率は"+str((tr/num/2))) 57 58 59testdata=test_loader.__iter__() 60image_num=int(input()) 61image=0 62label=0 63 64for ro1 in range(image_num): 65 image,label=next(testdata) 66 67# Display the first real image from the batch 68plt.imshow(image[0].permute(1, 2, 0).cpu().numpy()) # Select first image, permute, convert to numpy for display 69plt.axis('off') 70plt.show() 71 72# Generate an image using the first label from the batch and display it 73gen_input_label = F.one_hot(label[0], num_classes=10).float().to(device) # One-hot encode the first label and move to device 74generated_image = gen(gen_input_label.unsqueeze(0)).cpu().detach()[0].permute(1, 2, 0) # Fixed: Take the first image from the batch output before permuting 75plt.imshow(generated_image.numpy()) 76plt.axis('off') 77plt.show()

試したこと・調べたこと

  • teratailやGoogle等で検索した
  • ソースコードを自分なりに変更した
  • 知人に聞いた
  • その他
上記の詳細・結果

Heの初期化などでパラメータの最適化も行いましたが、全く改善されませんでした

補足

特になし

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

meg_

2026/03/31 13:41 編集

> 簡易的なGANのモデルを作成しているのですが どういう生成モデルを作成したつもりでしょうか?コードを見たところよく分からなかったので教えてください。 数字画像をランダムではなく意図した数字を生成したいのでしょうか?そのために生成モデルにラベルを入力してますか?
meg_

2026/04/01 14:09 編集

【追記】 訂正します。 モデルへの入力データの次元が少なすぎる、ほとんど固定されていることが問題の可能性があるかと思います。 【追記2】 モデルパラメーター初期化処理の変更だけでもノイズではないものを生成できたので回答に書きました。さらに上記【追記】への対策をすることでもっとクリアな画像の生成ができるようになるかと思います。
guest

回答1

0

生成AIに聞いて初期化部分とオプティマイザーを下記に変えたところ数字(らしきもの)を生成できました。

def initialize(m): classname = m.__class__.__name__ if classname.find('Conv') != -1 or classname.find('Linear') != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) # GANの定番 elif classname.find('BatchNorm') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0)
gen_optim=optim.Adam(gen.parameters(),lr=0.0002, betas=(0.5, 0.999)) dis_optim=optim.Adam(dis.parameters(),lr=0.0002, betas=(0.5, 0.999))
1回目の正答率は1.0 2回目の正答率は0.9070833333333334 3回目の正答率は0.9509583333333333 4回目の正答率は0.9006666666666666 5回目の正答率は0.8587083333333333 6回目の正答率は0.79975 7回目の正答率は0.9030416666666666 8回目の正答率は1.0 9回目の正答率は1.0 10回目の正答率は1.0 11回目の正答率は1.0 12回目の正答率は0.59225 13回目の正答率は0.691625 14回目の正答率は0.802125 15回目の正答率は0.64175 16回目の正答率は0.9070833333333334 17回目の正答率は0.6455833333333333 18回目の正答率は0.8482083333333333 19回目の正答率は0.6017916666666666 20回目の正答率は0.6985833333333333

イメージ説明

AIが言うには下記理由で改善するそうです。

DCGANなどで推奨される「平均0、標準偏差0.02」の正規分布に変えます。
重みの初期化を 「平均0、標準偏差0.02の正規分布」 に変えたことで、ネットワークの各層の出力が極端に偏らなくなり、GeneratorとDiscriminatorが「お互いに学習のヒント(勾配)を出し合える状態」になった証拠です。

通常の分類タスクなどで使われるは、GANにとっては高すぎます。
GANの損失関数は非凸で複雑なため、過去の勾配を引きずりすぎると、更新方向が大きく振動して学習が不安定になります。
ベータを0.5に設定することで、この振動(Oscillation)を抑え、学習を安定させる効果があります。

投稿2026/04/01 14:04

編集2026/04/01 14:23
meg_

総合スコア11104

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだベストアンサーが選ばれていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.29%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問