Google colaboratoryでstyleganを動かそうと思っていましたが、下記のエラーが発生しました。
エラーメッセージが無いのでなぜエラー出ているのかがわかりません。
どうすれば解決できるのか知りたいです。お願いします。
発生している問題・エラーメッセージ
AssertionError Traceback (most recent call last) <ipython-input-9-e014ef4e5215> in <module>() 35 36 if __name__ == "__main__": ---> 37 main() 1 frames <ipython-input-9-e014ef4e5215> in main() 14 # ダウンロードしたpickelファイルを指定 15 with open("network-snapshot-000300.pkl", "rb") as f: ---> 16 _, _, Gs = pickle.load(f) 17 18 # Print network details. /content/stylegan/dnnlib/tflib/network.py in __setstate__(self, state) 277 278 # Set basic fields. --> 279 assert state["version"] in [2, 3] 280 self.name = state["name"] 281 self.static_kwargs = util.EasyDict(state["static_kwargs"]) AssertionError:
該当のソースコード
%tensorflow_version 1.x !git clone https://github.com/NVlabs/stylegan.git %cd stylegan import os import pickle import numpy as np import PIL.Image import dnnlib import dnnlib.tflib as tflib import config from datetime import datetime def main(): # Initialize TensorFlow. tflib.init_tf() # ダウンロードしたpickelファイルを指定 with open("network-snapshot-000300.pkl", "rb") as f: _, _, Gs = pickle.load(f) # Print network details. Gs.print_layers() # Pick latent vector. # 潜在変数を作成 rnd = np.random.RandomState(5) latents = rnd.randn(1, Gs.input_shape[1]) # Generate image. fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True) images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt) # Save image. os.makedirs(config.result_dir, exist_ok=True) filename = datetime.now().strftime('%Y%m%d%H%M%S') + '.png' png_filename = os.path.join(config.result_dir, filename) PIL.Image.fromarray(images[0], 'RGB').save(png_filename) if __name__ == "__main__": main()
あなたの回答
tips
プレビュー