質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

Q&A

0回答

1166閲覧

コードによってGPUが全く使われない -keras

0604hana1111

総合スコア0

Keras

Kerasは、TheanoやTensorFlow/CNTK対応のラッパーライブラリです。DeepLearningの数学的部分を短いコードでネットワークとして表現することが可能。DeepLearningの最新手法を迅速に試すことができます。

0グッド

1クリップ

投稿2020/11/13 13:06

前提・実現したいこと

現在、kerasでプログラムを書く上で、GPUを使用できる環境を構築しようとしています。
GPUを認識させることはおそらく出来ているのですが、プログラムによってGPUが使われたり、全く使われなかったりしております。

具体的には、下記の1つ目のコードを実行中はGPUが全く使われません。
CPUの使用率は20%程度で、学習にかかる時間はおよそ20分です。

2つ目のコードの方では、GPU使用率は15%前後で、CPUの使用率も30%を超えます。
学習にかかる時間は20秒程度でした。

なぜ1つ目のコードではGPUが使われないのでしょうか?
学習モデル自体が全く違うので、使用率も変化するのかなと思ったのですが、1つ目のコードであまりにもGPUが使われないので質問させていただきました。
初学者で、もしかしたらあたりまえのことを質問してしまっているかもしれませんが、何卒よろしくお願いいたします。

該当のソースコード1

keras

1%matplotlib inline 2import numpy as np 3import matplotlib.pyplot as plt 4 5from keras.datasets import mnist 6from keras.layers import Dense, Flatten, Reshape 7from keras.layers.advanced_activations import LeakyReLU 8from keras.models import Sequential 9from keras.optimizers import Adam 10 11 12#mnistの形状[28, 28, 1]を定義 13img_rows = 28 14img_cols = 28 15channels = 1 16img_shape = (img_rows, img_cols, channels) 17#generatorが画像を生成するために入力させてあげるノイズの次元 18z_dim = 100 19 20#generator(生成器)の定義するための関数 21def build_generator(img_shape, z_dim): 22 model = Sequential() 23 model.add(Dense(128, input_dim=z_dim)) 24 model.add(LeakyReLU(alpha=0.01)) 25 model.add(Dense(28*28*1, activation='tanh')) 26 model.add(Reshape(img_shape)) 27 return model 28 29#discriminator(識別器)の定義するための関数 30def build_discriminatior(img_shape): 31 model = Sequential() 32 model.add(Flatten(input_shape=img_shape)) 33 model.add(Dense(128)) 34 model.add(LeakyReLU(alpha=0.01)) 35 model.add(Dense(1, activation='sigmoid')) 36 return model 37 38#Ganのモデル定義する(生成器と識別器をつなげてあげる)ための関数 39def build_gan(generator, discriminator): 40 model = Sequential() 41 model.add(generator) 42 model.add(discriminator) 43 return model 44 45 46#実際関数を呼び出してにGANのモデルをコンパイルしてあげる 47discriminator = build_discriminatior(img_shape) 48discriminator.compile(loss='binary_crossentropy', optimizer=Adam(), metrics=['accuracy']) 49generator = build_generator(img_shape, z_dim) 50 51 52 53#識別器の学習機能をオフにしてあげる。こうすることで、識別器と生成者を別々に学習させてあげられる 54discriminator.trainable = False 55 56gan = build_gan(generator, discriminator) 57gan.compile(loss='binary_crossentropy', optimizer=Adam()) 58 59losses = [] 60accuracies = [] 61iteration_checkpoint = [] 62 63x_train =[] 64(x_train, _), (_, _) = mnist.load_data() 65print(x_train.shape[0]) 66idx = np.random.randint(0, x_train.shape[0], 128) 67imgs = x_train[idx] 68print(imgs.shape) 69#学習させてあげるための関数。イテレーション数、バッチサイズ、 何イテレーションで画像を生成して可視化するかを引数にとる 70def train(iterations, batch_size, sample_interval): 71 (x_train, _), (_, _) = mnist.load_data() 72 73 x_train = x_train / 127.5 - 1 74 x_train = np.expand_dims(x_train, axis=3) 75 76 real = np.ones((batch_size, 1)) 77 fake = np.zeros((batch_size, 1)) 78 79 for iteration in range(iterations): 80 81 idx = np.random.randint(0, x_train.shape[0], batch_size) 82 imgs = x_train[idx] 83 z = np.random.normal(0, 1, (batch_size, 100)) 84 gen_imgs = generator.predict(z) 85 86 d_loss_real = discriminator.train_on_batch(imgs, real) 87 d_loss_fake = discriminator.train_on_batch(gen_imgs, fake) 88 d_loss, acc = 0.5 * np.add(d_loss_real, d_loss_fake) 89 90 z = np.random.normal(0, 1, (batch_size, 100)) 91 gen_imgs = generator.predict(z) 92 93 g_loss = gan.train_on_batch(z, real) 94#sample_intervalごとに損失値と精度、チェックポイントを保存 95 if (iteration+1) % sample_interval == 0: 96 losses.append((d_loss, g_loss)) 97 accuracies.append(acc) 98 iteration_checkpoint.append(iteration+1) 99#画像を生成 100 sample_images(generator) 101 102#サンプルとして画像を生成するための関数 103def sample_images(generator, image_grid_rows =4, image_grid_colmuns=4): 104 z = np.random.normal(0, 1, (image_grid_rows*image_grid_colmuns, z_dim)) 105 gen_images = generator.predict(z) 106 107 gen_images = 0.5 * gen_images + 0.5 108 109 fig, axs = plt.subplots(image_grid_rows, image_grid_colmuns, figsize=(4,4), sharex=True, sharey=True) 110 111 cnt = 0 112 for i in range(image_grid_rows): 113 for j in range(image_grid_colmuns): 114 axs[i, j].imshow(gen_images[cnt, :, :, 0], cmap='gray') 115 axs[i, j].axis('off') 116 cnt += 1 117 118 119 120iterations = 20000 121batch_size = 128 122sample_interval = 1000 123 124train(iterations,batch_size,sample_interval)

イメージ説明

該当のソースコード2

keras

1import keras 2from keras.datasets import mnist 3from keras.models import Sequential 4from keras.layers import Dense, Dropout 5from keras.optimizers import RMSprop 6 7batch_size = 128 8num_classes = 10 9epochs = 20 10 11# モデルの生成関数 12def createModel(): 13 model = Sequential() 14 model.add(Dense(512, activation='relu', input_shape=(784,))) 15 model.add(Dropout(0.2)) 16 model.add(Dense(512, activation='relu')) 17 model.add(Dropout(0.2)) 18 model.add(Dense(10, activation='softmax')) 19 20 model.compile(loss='categorical_crossentropy', 21 optimizer=RMSprop(), 22 metrics=['accuracy']) 23 return model 24 25# Mnistデータのロード 26(x_train, y_train), (x_test, y_test) = mnist.load_data() 27 28x_train = x_train.reshape(60000, 784) # 2次元配列を1次元に変換(訓練データ) 29x_test = x_test.reshape(10000, 784) # 2次元配列を1次元に変換(テストデータ) 30x_train = x_train.astype('float32') # int型をfloat32型に変換 31x_test = x_test.astype('float32') # int型をfloat32型に変換 32x_train /= 255 # [0-255]の値を[0.0-1.0]に変換 33x_test /= 255 34 35# 正解ラベルのOne hot vector化 36y_train = keras.utils.to_categorical(y_train, num_classes) 37y_test = keras.utils.to_categorical(y_test, num_classes) 38 39 40# モデルの定義 41model = createModel() 42 43# 学習の実行 44history = model.fit(x_train, y_train, # 画像とラベルデータ 45 batch_size=batch_size, 46 epochs=epochs, # エポック数の指定 47 validation_data=(x_test, y_test)) 48 49# モデル構成の確認 50model.summary() 51 52score = model.evaluate(x_test, y_test, verbose=0) 53print('Test loss:', score[0]) 54print('Test accuracy:', score[1])

イメージ説明

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問