以下のサイトを参考にしてDGLライブラリを用いて、グラフ向け深層学習を実装しようとしたところ、以下のようなエラーが出ました。
https://qiita.com/K-1/items/62cd7fe04b80868ded8e
コードはここからダウンロードしました。
https://gist.github.com/k1ochiai/cd0279ca79dd74e91a2b5e1187928adb
利用環境は、
anaconda
python version3.6.12
UI: Jupyter Notebook
バックエンド: PyTorch
ライブラリ:DGL(Deep Graph Library), networkx
#コードの内容
python
1import dgl 2import torch 3 4def build_sample_graph(): 5 g = dgl.DGLGraph() 6 g.add_nodes(5) 7 edge_list = [(0, 1), (0, 2), (0, 3), (0, 4), (1, 2), (2, 3)] 8 src, dst = tuple(zip(*edge_list)) 9 g.add_edges(src, dst) 10 g.add_edges(dst, src) 11 12 return g 13 14G = build_sample_graph() 15print('We have %d nodes.' % G.number_of_nodes()) 16print('We have %d edges.' % G.number_of_edges()) 17 18import networkx as nx 19nx_G = G.to_networkx().to_undirected() 20pos = nx.kamada_kawai_layout(nx_G) 21nx.draw(nx_G, pos,with_labels=True, node_color=[[.7, .7, .7]]) 22 23import torch 24import torch.nn as nn 25import torch.nn.functional as F 26 27def gcn_message(edges): 28 print("gcn_message:",edges.src['h']) 29 return {'msg' : edges.src['h']} 30 31def gcn_reduce(nodes): 32 print("gcn_reduce:",nodes.mailbox['msg']) 33 return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)} 34 35class GCNLayer(nn.Module): 36 def __init__(self, in_feats, out_feats): 37 super(GCNLayer, self).__init__() 38 self.linear = nn.Linear(in_feats, out_feats) 39 40 def forward(self, g, inputs): 41 #フォワードパスを定義 42 g.ndata['h'] = inputs 43 g.send(g.edges(), gcn_message) 44 g.recv(g.nodes(), gcn_reduce) 45 h = g.ndata.pop('h') 46 return self.linear(h) 47 48class GCN(nn.Module): 49 def __init__(self, in_feats, num_classes): 50 super(GCN, self).__init__() 51 self.gcn1 = GCNLayer(in_feats, num_classes) 52 53 def forward(self, g, inputs): 54 h = self.gcn1(g, inputs) 55 h = torch.relu(h) 56 57 return h 58 59net = GCN(5, 2) 60inputs = torch.eye(5) 61labeled_nodes = torch.tensor([0, 1, 2, 3, 4]) 62labels = torch.tensor([0, 0, 1, 1, 0]) 63 64optimizer = torch.optim.Adam(net.parameters(), lr=0.01) 65all_logits = [] 66G.set_n_initializer(dgl.init.zero_initializer) 67 68for epoch in range(3): 69 logits = net(G, inputs) 70 all_logits.append(logits.detach()) 71 logp = F.log_softmax(logits, 1) 72 loss = F.nll_loss(logp[labeled_nodes], labels) 73 74 optimizer.zero_grad() 75 loss.backward() 76 optimizer.step() 77 78 print('Epoch %d | Loss: %.4f' % (epoch, loss.item()))
#バグの発生箇所(DGLerror)
python
1--------------------------------------------------------------------------- 2DGLError Traceback (most recent call last) 3<ipython-input-20-adea207aebaa> in <module> 4 4 5 5 for epoch in range(3): 6----> 6 logits = net(G, inputs) 7 7 all_logits.append(logits.detach()) 8 8 logp = F.log_softmax(logits, 1) 9 10~/anaconda/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 11 720 result = self._slow_forward(*input, **kwargs) 12 721 else: 13--> 722 result = self.forward(*input, **kwargs) 14 723 for hook in itertools.chain( 15 724 _global_forward_hooks.values(), 16 17<ipython-input-18-70c26494cadc> in forward(self, g, inputs) 18 5 19 6 def forward(self, g, inputs): 20----> 7 h = self.gcn1(g, inputs) 21 8 h = torch.relu(h) 22 9 23 24~/anaconda/envs/pytorch/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 25 720 result = self._slow_forward(*input, **kwargs) 26 721 else: 27--> 722 result = self.forward(*input, **kwargs) 28 723 for hook in itertools.chain( 29 724 _global_forward_hooks.values(), 30 31<ipython-input-17-a96f9714ed3a> in forward(self, g, inputs) 32 18 def forward(self, g, inputs): 33 19 g.ndata['h'] = inputs 34---> 20 g.send(g.edges(), gcn_message) 35 21 g.recv(g.nodes(), gcn_reduce) 36 22 h = g.ndata.pop('h') 37 38~/anaconda/envs/pytorch/lib/python3.6/site-packages/dgl/heterograph.py in send(self, edges, message_func, etype) 39 5556 DEPRECATE: please use send_and_recv, update_all. 40 5557 """ 41-> 5558 raise DGLError('DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges\n' 42 5559 ' API to compute messages as edge data. Then use DGLGraph.send_and_recv\n' 43 5560 ' and set the message function as dgl.function.copy_e to conduct message\n' 44 45DGLError: DGLGraph.send is deprecated. As a replacement, use DGLGraph.apply_edges 46 API to compute messages as edge data. Then use DGLGraph.send_and_recv 47 and set the message function as dgl.function.copy_e to conduct message 48 aggregation.
と上記のように、DGLerrorが出ました。
DGLGraph.sendではなく、DGLGraph.apply_edgesとDGLGraph.send_and_recvを用いればいいと書かれてありますが、以上のコードをどのように書き換えれば良いのかがわかりません。ご教授願えませんでしょうか?
あなたの回答
tips
プレビュー