質問をすることでしか得られない、回答やアドバイスがある。

15分調べてもわからないことは、質問しよう!

新規登録して質問してみよう
ただいま回答率
85.48%
Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

Q&A

0回答

634閲覧

Tnesorflow2のSubclassing記法で定義したグラフにダミーのデータを入れてTensorboardにグラフを表示する方法

hiretamago

総合スコア0

Python

Pythonは、コードの読みやすさが特徴的なプログラミング言語の1つです。 強い型付け、動的型付けに対応しており、後方互換性がないバージョン2系とバージョン3系が使用されています。 商用製品の開発にも無料で使用でき、OSだけでなく仮想環境にも対応。Unicodeによる文字列操作をサポートしているため、日本語処理も標準で可能です。

0グッド

0クリップ

投稿2020/07/02 08:35

前提・実現したいこと

■ 前提

言語:python 3.6
ライブラリ:Tensorflow 2.2.0

(参考)
開発環境:AWSのEMRのZeppelin

terminal

1$ cat /etc/system-release 2Amazon Linux AMI release 2018.03

■ 実現したいこと

VAEのグラフをdefine-by-runの記述が可能なTensorflow2で書き、ダミーのデータを流して仮のグラフを取得し、tensorboardで定義したグラフを確認したいと思っています。
しかし、描画が上手くできないため質問させていただきました。

Tensorflow2のAPIは大きく分けて

  • Sequential(積層型)モデル: コンパクトで簡単な書き方
  • Functional(関数型)API: 複雑なモデルも定義できる柔軟な書き方
  • Subclassing(サブクラス化)モデル: 難易度は少し上がるが、フルカスタマイズが可能

の3つですが、柔軟かつ自由にグラフを書きたいので、Subclassingの記法を採用しています。

発生している問題・エラーメッセージ

EncoderとDecoderのグラフ表示は上手く出来ているのですが、
reparameterizationTrick部分のグラフ表示が上手く出来ず、Tensorboardとterminalで次のエラーが出てしまいます。

Tensorboardでのエラー

Graph visualization failed. Error: Malformed GraphDef. This can sometimes be caused by a bad network connection or difficulty reconciling multiple GraphDefs; for the latter case, please refer to https://github.com/tensorflow/tensorboard/issues/1929.

terminalでのエラー

ValueError: Cannot combine GraphDefs because nodes share a name but contents are different: Identity

該当のソースコード

reparameterizationTrickのグラフを作るコードだけを載せています。

python

1import os 2# tensorflowはCPUのみを使う 3os.environ["CUDA_VISIBLE_DEVICES"]="-1" 4import tensorflow as tf 5tf.random.set_seed(0) 6import tensorboard 7import pathlib as pl 8import re 9import datetime 10import numpy as np 11np.random.seed(0) 12 13from PIL import Image 14import matplotlib.pyplot as plt 15

python

1class NeuralNetworkReparameterizationTrick(tf.keras.Model): 2 def __init__(self, *args, **kwargs): 3 super().__init__(*args, **kwargs) 4 5 self._dim_inputs_mean = None # batch_sizeを無視した入力の次元(タプル形式) 6 self._dim_inputs_logvar = None # batch_sizeを無視した入力の次元(タプル形式) 7 8 @property 9 def dim_inputs_mean(self): 10 return self._dim_inputs_mean 11 12 @dim_inputs_mean.setter 13 def dim_inputs_mean(self, dim_inputs_mean): 14 """ 15 dim_inputs_mean -> (m, ) : m次元のデータ 16 dim_inputs_mean -> (m, n, ) : m行n列のデータ 17 """ 18 if type(dim_inputs_mean)==tuple: 19 self._dim_inputs_mean = dim_inputs_mean 20 else: 21 print("batch_sizeを無視したデータのサイズをタプル形式で指定してください。") 22 exit() 23 24 @property 25 def dim_inputs_logvar(self): 26 return self._dim_inputs_logvar 27 28 @dim_inputs_logvar.setter 29 def dim_inputs_logvar(self, dim_inputs_logvar): 30 """ 31 dim_inputs_logvar -> (m, ) : m次元のデータ 32 dim_inputs_logvar -> (m, n, ) : m行n列のデータ 33 """ 34 if type(dim_inputs_logvar)==tuple: 35 self._dim_inputs_logvar = dim_inputs_logvar 36 else: 37 print("batch_sizeを無視したデータのサイズをタプル形式で指定してください。") 38 exit() 39 40 def call(self, mean, logvar): 41 std = tf.exp(logvar*0.5, name="std") 42 eps = tf.random.normal((std.shape[1],), name="eps") 43 z = mean + eps * std 44 return z 45 def get_symbolic_model(self, dim_inputs_mean=(50,), dim_inputs_logvar=(50,), name_inputs_mean=None, name_inputs_logvar=None, name_model=None): 46 """ 47 Imperative API で作成したモデルをベースに 48 Symbolic API ( Functional API ) のモデルとして再作成した 49 “仮のモデル”を生成する。 50 """ 51 self.dim_inputs_mean = dim_inputs_mean 52 self.dim_inputs_logvar = dim_inputs_logvar 53 if name_inputs_mean is not None: 54 inputs_mean = tf.keras.layers.Input(shape=self.dim_inputs_mean, name=name_inputs_mean) 55 else: 56 inputs_mean = tf.keras.layers.Input(shape=self.dim_inputs_mean) 57 if name_inputs_logvar is not None: 58 inputs_logvar = tf.keras.layers.Input(shape=self.dim_inputs_logvar, name=name_inputs_logvar) 59 else: 60 inputs_logvar = tf.keras.layers.Input(shape=self.dim_inputs_logvar) 61 if name_model is not None: 62 outputs = tf.keras.Model(inputs=[inputs_mean, inputs_logvar], outputs=self.call(mean=inputs_mean, logvar=inputs_logvar), name=name_model) 63 else: 64 outputs = tf.keras.Model(inputs=[inputs_mean, inputs_logvar], outputs=self.call(mean=inputs_mean, logvar=inputs_logvar)) 65 outputs.dim_inputs_mean = self.dim_inputs_mean 66 outputs.dim_inputs_logvar = self.dim_inputs_logvar 67 return outputs 68 69 def output_tensorboard_structure(self, dim_inputs_mean=(50,), dim_inputs_logvar=(50,), name_inputs_mean=None, name_inputs_logvar=None, name_model=None, profiler_outdir="/tmp/sample_tensorboard/sample_logs"): 70 """ 71 Imperative API で作成したモデルをベースに 72 Symbolic API ( Functional API ) のモデルとして再作成した 73 “仮のモデル”を生成する。 74 """ 75 76 s_model = self.get_symbolic_model(dim_inputs_mean=model.reparameterizationTrick.dim_inputs_mean, dim_inputs_logvar=model.reparameterizationTrick.dim_inputs_logvar, name_inputs_mean=name_inputs_mean, name_inputs_logvar=name_inputs_logvar, name_model=name_model) 77 @tf.function 78 def traceme(x_mean, x_logvar): 79 return s_model(inputs=[x_mean, x_logvar]) 80 81 writer = tf.summary.create_file_writer(logdir=profiler_outdir) 82 tf.summary.trace_on(graph=True, profiler=True) 83 # Forward pass 84 dim1_plus_dim_inputs_mean = tuple([1]+list(self.dim_inputs_mean)) 85 dim1_plus_dim_inputs_logvar = tuple([1]+list(self.dim_inputs_logvar)) 86 sample_data_mean = tf.zeros(dim1_plus_dim_inputs_mean) 87 sample_data_logvar = tf.ones(dim1_plus_dim_inputs_logvar) 88 89 traceme(sample_data_mean, sample_data_logvar) 90 with writer.as_default(): 91 tf.summary.trace_export(name="my_func_trace", step=0, profiler_outdir=profiler_outdir) 92 tf.summary.trace_off() 93 return None 94 95 96class NeuralNetwork(tf.keras.Model): 97 98 def __init__(self, *args, **kwargs): 99 """ 100 グラフの要素を定義する。 101 ただし、入力は定義しない。 102 """ 103 super().__init__(*args, **kwargs) 104 self.reparameterizationTrick = NeuralNetworkReparameterizationTrick(name="ReparameterizationTrick") 105 106 def call(self, inputs): 107 """ 108 __init__で定義された要素を繋ぎ合わせて、 109 グラフのフォワードパスを定義する。 110 """ 111 mean, logvar = inputs 112 x = self.reparameterizationtrick(mean, logvar) 113 outputs =x 114 return outputs 115 116# モデルの生成 117model = NeuralNetwork() 118####################################################################################### 119# 一部のグラフの描画(reparameterizationTrick) 120####################################################################################### 121# 仮のモデルの次元を指定 122model.reparameterizationTrick.dim_inputs_mean = (20,) 123model.reparameterizationTrick.dim_inputs_logvar = (20,) 124# 仮のモデルを取得 125name_inputs_mean="inputs_mean" 126name_inputs_logvar="inputs_logvar" 127name_model="symbolic_model_reparameterizationTrick" 128s_model = model.reparameterizationTrick.get_symbolic_model(dim_inputs_mean=model.reparameterizationTrick.dim_inputs_mean, dim_inputs_logvar=model.reparameterizationTrick.dim_inputs_logvar, name_inputs_mean=name_inputs_mean, name_inputs_logvar=name_inputs_logvar, name_model=name_model) 129# グラフのsummaryを表示 130s_model.summary() 131# グラフの構造のPNGを保存 132to_file_reparameterizationTrick = '/tmp/'+name_model+'.png' 133tf.keras.utils.plot_model(model=s_model, show_shapes=True, show_layer_names=True, to_file=to_file_reparameterizationTrick, rankdir='TB', dpi=300) 134# グラフの構造をTensorboardで閲覧出来るようにする 135stamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S.%f") 136logdir = 'test_tensorboard_logs/func/'+name_model+'_%s' % stamp 137profiler_outdir = "/tmp/" + logdir 138path = pl.Path(profiler_outdir).parent.as_posix() 139model.reparameterizationTrick.output_tensorboard_structure(dim_inputs_mean=model.reparameterizationTrick.dim_inputs_mean, dim_inputs_logvar=model.reparameterizationTrick.dim_inputs_logvar, name_inputs_mean=name_inputs_mean, name_inputs_logvar=name_inputs_logvar, name_model=name_model, profiler_outdir=profiler_outdir) 140 141print("次のコマンドを実行してからアクセスすること。") 142print("tensorboard --logdir {}".format(path)) 143#######################################################################################

試したこと

EncoderとreparameterizationTrickのグラフのモデルのsummaryを比較をしたところ、
入力層を除く各層の入力と出力のデータ形式に違いがあることは確認しました。
Encoder:タプル
reparameterizationTrick:タプル内包リスト
データ形式の違いがバグ解消のヒントなのでは?と考えています。

気になる質問をクリップする

クリップした質問は、後からいつでもMYページで確認できます。

またクリップした質問に回答があった際、通知やメールを受け取ることができます。

バッドをするには、ログインかつ

こちらの条件を満たす必要があります。

guest

あなたの回答

tips

太字

斜体

打ち消し線

見出し

引用テキストの挿入

コードの挿入

リンクの挿入

リストの挿入

番号リストの挿入

表の挿入

水平線の挿入

プレビュー

まだ回答がついていません

会員登録して回答してみよう

アカウントをお持ちの方は

15分調べてもわからないことは
teratailで質問しよう!

ただいまの回答率
85.48%

質問をまとめることで
思考を整理して素早く解決

テンプレート機能で
簡単に質問をまとめる

質問する

関連した質問