今現在、グラフィカルモデルを作成しており、ネットワーク図の描写の部分で困っています。
偏相関を計算して、それをもとにネットワークグラフを描写。描写したグラフをもとに、データ解析をしていこうと思っています。共分散などの計算過程は概ね成功していると思うのですが、最後のグラフ表示でうまくいっていません。
動作イメージは、csvファイルを読み込んでそれを元に計算していく感じです。読み込むcsvファイルは、A行と1列目がヘッダーとラベルになっています。
出ているエラーコードは、次の二つです。片方のエラーを修正したら、もう片方が出てしまうのでいつまでたっても修正できません。valid_labelsとvalid_nodes,posの修正を何度か行っているのですが、改善が見られません。
出てしまうエラーは、
Traceback (most recent call last):
File "ーーー", line 130, in <module>
draw_network_graph(G, pos, labels_dict, corr_mat)
File "ーーー", line 29, in draw_network_graph
nx.draw(G, pos=pos, node_size=500, alpha=0.8)
File "ーーー", line 121, in draw
draw_networkx(G, pos=pos, ax=ax, **kwds)
File "ーーー", line 303, in draw_networkx
draw_networkx_nodes(G, pos, **node_kwds)
File "ーーー", line 427, in draw_networkx_nodes
raise nx.NetworkXError(f"Node {err} has no position.") from err
networkx.exception.NetworkXError: Node 5 has no position.
もう一つのエラーが、
Traceback (most recent call last):
File "ーー", line 127, in <module>
pos = {label: pos[i] if label in pos else (0.0, 0.0) for i, label in enumerate(valid_labels)}
File "ーー", line 127, in <dictcomp>
pos = {label: pos[i] if label in pos else (0.0, 0.0) for i, label in enumerate(valid_labels)}
KeyError: 1
です。
グラフ描写をしているプログラムは、次のようになっています。
ネットワーク図を描画する関数
def draw_network_graph(G, pos, labels_dict, corr_mat):
# 相関係数のラベルを作成する
edge_labels = {(i, j): f"{corr_mat[i, j]:.2f}" for i, j in G.edges()}
# ノードを描画する nx.draw(G, pos=pos, node_size=500, alpha=0.8) # ノードのラベルを描画する nx.draw_networkx_labels(G, pos=pos, labels=labels_dict, font_size=10, font_family='Yu Gothic') # エッジのラベルを描画する nx.draw_networkx_edge_labels(G, pos=pos, edge_labels=edge_labels, font_size=10) # 描画領域を非表示にする plt.axis('off')
ネットワーク図の描画に必要なデータを準備する
G = nx.from_numpy_array(cov_mat)
pos = nx.spring_layout(G)
valid_nodes = set(range(len(labels)))
posのキーをラベルに変更する
pos = {labels[i]: pos[i] for i in range(len(labels)) if i in pos and labels[i] != ''}
0をキーとして値を設定する
if 0 not in pos:
pos[0] = (0.0, 0.0)
ノードとラベルを描画する
valid_labels = [label for label in labels if label != '']
valid_nodes = [i for i, label in enumerate(labels) if label != '']
labels_dict = {i: label for i, label in enumerate(valid_labels)}
pos辞書にすべてのノードの位置情報を追加する
pos = {labels[i]: pos[i] for i in range(len(labels)) if i in pos and labels[i] != ''}
draw_network_graph(G, pos, labels_dict, corr_mat)
plt.show()
諸事情で全文ではなく、グラフ描写のところだけですが、修正箇所を教えていただきたいです。
よろしくお願いいたします
バッドをするには、ログインかつ
こちらの条件を満たす必要があります。
2023/05/19 13:23