Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_backward():
g = create_test_heterograph()
x = F.randn((3, 5))
F.attach_grad(x)
g.nodes['user'].data['h'] = x
with F.record_grad():
g.multi_update_all(
{'plays' : (fn.copy_u('h', 'm'), fn.sum('m', 'y')),
'wishes': (fn.copy_u('h', 'm'), fn.sum('m', 'y'))},
'sum')
y = g.nodes['game'].data['y']
F.backward(y, F.ones(y.shape))
print(F.grad(x))
assert F.array_equal(F.grad(x), F.tensor([[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.],
[2., 2., 2., 2., 2.]]))
def message_func_edge(edges):
if len(edges.src[fld].shape) == 1:
return {'m' : edges.src[fld] * edges.data['e1']}
else:
return {'m' : edges.src[fld] * edges.data['e2']}
def reduce_func(nodes):
return {fld : mx.nd.sum(nodes.mailbox['m'], axis=1)}
def apply_func(nodes):
return {fld : 2 * nodes.data[fld]}
g = simple_graph()
# update all
v1 = g.ndata[fld]
g.update_all(fn.copy_src(src=fld, out='m'), fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.set_n_repr({fld : v1})
g.update_all(message_func, reduce_func, apply_func)
v3 = g.ndata[fld]
assert np.allclose(v2.asnumpy(), v3.asnumpy(), rtol=1e-05, atol=1e-05)
# update all with edge weights
v1 = g.ndata[fld]
g.update_all(fn.src_mul_edge(src=fld, edge='e1', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v2 = g.ndata[fld]
g.set_n_repr({fld : v1})
g.update_all(fn.src_mul_edge(src=fld, edge='e2', out='m'),
fn.sum(msg='m', out=fld), apply_func)
v3 = g.ndata[fld].squeeze()
g.set_n_repr({fld : v1})
g.update_all(message_func_edge, reduce_func, apply_func)
skip_start = (0 == self.n_layers-1)
if skip_start:
h = mx.nd.concat(h, self.activation(h))
else:
h = self.activation(h)
for i, layer in enumerate(self.layers):
new_history = h.copy().detach()
history_str = 'h_{}'.format(i)
history = nf.layers[i].data[history_str]
h = h - history
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
h = nf.layers[i+1].data.pop('activation')
# update history
if i < nf.num_layers-1:
nf.layers[i].data[history_str] = new_history
return h
def graphsage_cv_train(g, ctx, args, n_classes, train_nid, test_nid, n_test_samples, distributed):
n0_feats = g.nodes[0].data['features']
num_nodes = g.number_of_nodes()
in_feats = n0_feats.shape[1]
g_ctx = n0_feats.context
norm = mx.nd.expand_dims(1./g.in_degrees().astype('float32'), 1)
g.set_n_repr({'norm': norm.as_in_context(g_ctx)})
degs = g.in_degrees().astype('float32').asnumpy()
degs[degs > args.num_neighbors] = args.num_neighbors
g.set_n_repr({'subg_norm': mx.nd.expand_dims(mx.nd.array(1./degs, ctx=g_ctx), 1)})
n_layers = args.n_layers
g.update_all(fn.copy_src(src='features', out='m'),
fn.sum(msg='m', out='preprocess'),
lambda node : {'preprocess': node.data['preprocess'] * node.data['norm']})
for i in range(n_layers):
g.init_ndata('h_{}'.format(i), (num_nodes, args.n_hidden), 'float32')
g.init_ndata('agg_h_{}'.format(i), (num_nodes, args.n_hidden), 'float32')
model = GraphSAGETrain(in_feats,
args.n_hidden,
n_classes,
n_layers,
args.dropout,
prefix='GraphSAGE')
model.initialize(ctx=ctx)
loss_fcn = gluon.loss.SoftmaxCELoss()
from .nnutils import GRUUpdate, cuda
from dgl import batch, bfs_edges_generator
import dgl.function as DGLF
import numpy as np
MAX_NB = 8
def level_order(forest, roots):
edges = bfs_edges_generator(forest, roots)
_, leaves = forest.find_edges(edges[-1])
edges_back = bfs_edges_generator(forest, roots, reverse=True)
yield from reversed(edges_back)
yield from edges
enc_tree_msg = [DGLF.copy_src(src='m', out='m'), DGLF.copy_src(src='rm', out='rm')]
enc_tree_reduce = [DGLF.sum(msg='m', out='s'), DGLF.sum(msg='rm', out='accum_rm')]
enc_tree_gather_msg = DGLF.copy_edge(edge='m', out='m')
enc_tree_gather_reduce = DGLF.sum(msg='m', out='m')
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
self.hidden_size = hidden_size
self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, nodes):
x = nodes.data['x']
m = nodes.data['m']
return {
'h': torch.relu(self.W(torch.cat([x, m], 1))),
}
skip_start = (0 == self.n_layers-1)
if skip_start:
h = torch.cat((h, self.activation(h)), dim=1)
else:
h = self.activation(h)
for i, layer in enumerate(self.layers):
new_history = h.clone().detach()
history_str = 'h_{}'.format(i)
history = nf.layers[i].data[history_str]
h = h - history
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
lambda node : {'h': node.mailbox['m'].mean(dim=1)},
layer)
h = nf.layers[i+1].data.pop('activation')
# update history
if i < nf.num_layers-1:
nf.layers[i].data[history_str] = new_history
return h
def forward(self, nf):
nf.layers[0].data['activation'] = nf.layers[0].data['features']
for i, layer in enumerate(self.layers):
h = nf.layers[i].data.pop('activation')
nf.layers[i].data['h'] = h
nf.block_compute(i,
fn.copy_src(src='h', out='m'),
fn.sum(msg='m', out='h'),
layer)
return nf.layers[-1].data.pop('activation')
#
# GCN implementation with DGL
# ``````````````````````````````````````````
# We first define the message and reduce function as usual. Since the
# aggregation on a node :math:`u` only involves summing over the neighbors'
# representations :math:`h_v`, we can simply use builtin functions:
import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')
###############################################################################
# We then define the node UDF for ``apply_nodes``, which is a fully-connected layer:
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
if self.activation is not None:
h = self.activation(h)
return {'h' : h}
final_vec = torch.mv(torch.t(latent_rel_vecs), norm_r_att_scores)
else:
final_vec = latent_rel_vecs.mean(dim=0).to(self.device) # mean pooling
final_vecs.append(torch.cat((final_vec, s_vec), dim=0))
logits = self.hidden2output(torch.stack(final_vecs))
if not ana_mode:
return logits
else:
return logits, path_att_scores, qa_pair_att_scores
gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h' : h}
class GraphConvLayer(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GraphConvLayer, self).__init__()
def apply_func(nodes):
u = nodes.data['u'] if self.edge_gate else nodes.data['h']
if self.bias:
u = u + self.bias
if self.activation is not None:
u = self.activation(u)
if self.edge_gate:
a = torch.sigmoid(self.linear(torch.cat([u, nodes.data['h']], 1)))
u = u * a + nodes.data['h'] * (1 - a)
return {'h': u}
if self.edge_gate:
g.update_all(message_func, fn.sum(msg='msg', out='u'), apply_func)
else:
g.update_all(message_func, fn.sum(msg='msg', out='h'), apply_func)