前提・実現したいこと
Google ColaboratoryでStyleGANを実装しようとしています。(初学者です)
generatorをロードするのurlを変更することで他の画像も生成できると記載があった為、ImageNetデータセットで学習する予定です。(imagenetを用いたstyleganがstyleganだけのものがなく、stylegan+clipなどしかなかったため)
generatorのロードの部分でAttributeError: module 'config' has no attribute 'cache_dir'というエラーがでてしまいます(このエラーについてググってみましたが、どうすれば解決するのかよくわかりませんでした)。
以下の記事を参考にして進めていましたが、わかりませんでした。
https://teratail.com/questions/295390
https://qiita.com/pacifinapacific/items/1d6cca0ff4060e12d336
http://cedro3.com/ai/stylegan/
https://qiita.com/Phoeboooo/items/12d21916de56d125f0be
発生している問題・エラーメッセージ
Python
1AttributeError Traceback (most recent call last) 2<ipython-input-9-c2a955d99cdc> in <module> 3 46 4 47 if __name__ == "__main__": 5---> 48 main() 6 7<ipython-input-9-c2a955d99cdc> in main() 8 14 # Load pre-trained network. 9 15 url = 'https://drive.google.com/file/d/1k_H6S-ePszz73lVCZrFRneaV6cbqerTm/view?usp=sharing' # karras2019stylegan-ffhq-1024x1024.pkl 10---> 16 with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 11 17 _G, _D, Gs = pickle.load(f) 12 18 # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. 13 14AttributeError: module 'config' has no attribute 'cache_dir'
該当のソースコード
Python
1 2#git clone でStyleGANのコードを使えるようにする 3!git clone https://github.com/NVlabs/stylegan.git 4!pip install https://github.com/podgorskiy/dnnlib/releases/download/0.0.1/dnnlib-0.0.1-py3-none-any.whl 5 6#ディレクトリ移動 7!cd stylegan 8!pip install tensorflow==1.15.0 9!pip install tensorflow-gpu==1.15.0 10import os 11import pickle 12import numpy as np 13import PIL.Image 14import dnnlib 15import dnnlib.tflib as tflib 16!pip install config 17import config 18 19def main(): 20 # Initialize TensorFlow. 21 tflib.init_tf() 22 23 # Load pre-trained network. 24 url = 'https://drive.google.com/file/d/1k_H6S-ePszz73lVCZrFRneaV6cbqerTm/view?usp=sharing' # karras2019stylegan-ffhq-1024x1024.pkl 25 with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f: 26 _G, _D, Gs = pickle.load(f) 27 # _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run. 28 # _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run. 29 # Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot. 30 31 # Print network details. 32 Gs.print_layers() 33 34 # Pick latent vector. 35 rnd = np.random.RandomState(10) # seed = 10 36 latents0 = rnd.randn(1, Gs.input_shape[1]) 37 latents1 = rnd.randn(1, Gs.input_shape[1]) 38 latents2 = rnd.randn(1, Gs.input_shape[1]) 39 latents3 = rnd.randn(1, Gs.input_shape[1]) 40 latents4 = rnd.randn(1, Gs.input_shape[1]) 41 latents5 = rnd.randn(1, Gs.input_shape[1]) 42 latents6 = rnd.randn(1, Gs.input_shape[1]) 43 44 num_split = 39 # 2つのベクトルを39分割 45 for i in range(40): 46 latents = latents6+(latents0-latents6)*i/num_split 47 # Generate image. 48 fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) 49 images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) 50 51 # Save image. 52 os.makedirs(config.result_dir, exist_ok=True) 53 png_filename = os.path.join(config.result_dir, 'photo'+'{0:04d}'.format(i)+'.png') 54 PIL.Image.fromarray(images[0], 'RGB').save(png_filename) 55 56if __name__ == "__main__": 57 main()
補足情報(FW/ツールのバージョンなど)
参考にしたサイトのように公式のコードそのままでも実行することはできませんでした。

回答1件
あなたの回答
tips
プレビュー