Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# 17 edges
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = F.randn((10, D))
ecol = F.randn((17, D))
if grad:
ncol = F.attach_grad(ncol)
ecol = F.attach_grad(ecol)
g.ndata['h'] = ncol
g.edata['w'] = ecol
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
return g
7, 9, 13
0, 8, 14
8, 9, 15
9, 0, 16
'''
g = dgl.graph([(0,1), (1,9), (0,2), (2,9), (0,3), (3,9), (0,4), (4,9),
(0,5), (5,9), (0,6), (6,9), (0,7), (7,9), (0,8), (8,9), (9,0)])
ncol = F.randn((10, D))
ecol = F.randn((17, D))
if grad:
ncol = F.attach_grad(ncol)
ecol = F.attach_grad(ecol)
g.ndata['h'] = ncol
g.edata['w'] = ecol
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
return g
return F.copy_to(F.astype(F.randn(shape), dtype), ctx)
g.set_n_initializer(_init)
g.set_e_initializer(_init)
def _message(edges):
return {'m' : edges.src['h1'] + edges.dst['h2'] + edges.data['h1'] +
edges.data['h2']}
def _reduce(nodes):
return {'h' : F.sum(nodes.mailbox['m'], 1)}
def _apply(nodes):
return {'h' : nodes.data['h']}
g.register_message_func(_message)
g.register_reduce_func(_reduce)
g.register_apply_node_func(_apply)
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
# add nodes and edges
g.add_nodes(N)
g.ndata.update({'h1': F.randn((N, D)),
'h2': F.randn((N, D))})
g.add_nodes(3)
g.add_edge(0, 1)
g.add_edge(1, 0)
g.edata.update({'h1': F.randn((2, D)),
'h2': F.randn((2, D))})
g.send()
expected = F.copy_to(F.ones((g.number_of_edges(),), dtype=F.int64), F.cpu())
assert F.array_equal(g._get_msg_index().tousertensor(), expected)
# add more edges
# create a graph where 0 is the source and 9 is the sink
# 17 edges
for i in range(1, 9):
g.add_edge(0, i)
g.add_edge(i, 9)
# add a back flow from 9 to 0
g.add_edge(9, 0)
ncol = F.randn((10, D))
ecol = F.randn((17, D))
if grad:
ncol = F.attach_grad(ncol)
ecol = F.attach_grad(ecol)
g.ndata['h'] = ncol
g.edata['w'] = ecol
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
return g
def foo(g):
with g.local_scope():
g.ndata['hh'] = F.ones((g.number_of_nodes(), 3))
g.edata['ww'] = F.ones((g.number_of_edges(), 4))
with g.local_scope():
g.ndata['hhh'] = F.ones((g.number_of_nodes(), 3))
g.edata['www'] = F.ones((g.number_of_edges(), 4))
assert 'hhh' not in g.ndata
assert 'www' not in g.edata
foo(g)
assert 'hh' not in g.ndata
assert 'ww' not in g.edata
# test initializer1
g = dgl.graph([(0,1), (1,1)])
g.set_n_initializer(dgl.init.zero_initializer)
def foo(g):
with g.local_scope():
g.nodes[0].data['h'] = F.ones((1, 1))
assert F.allclose(g.ndata['h'], F.tensor([[1.], [0.]]))
foo(g)
# test initializer2
def foo_e_initializer(shape, dtype, ctx, id_range):
return F.ones(shape)
g.set_e_initializer(foo_e_initializer, field='h')
def foo(g):
with g.local_scope():
g.edges[0, 1].data['h'] = F.ones((1, 1))
assert F.allclose(g.edata['h'], F.ones((2, 1)))
g.edges[0, 1].data['w'] = F.ones((1, 1))
assert F.allclose(g.edata['w'], F.tensor([[1.], [0.]]))
foo(g)
def forward(self, g):
if g is not None:
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
self.g = g
# 1. Build node features
if isinstance(self.embed_nodes, nn.Embedding):
node_features = self.embed_nodes(self.g.ndata[GNN_NODE_LABELS_KEY])
elif isinstance(self.embed_nodes, torch.Tensor):
node_features = self.embed_nodes[self.g.ndata[GNN_NODE_LABELS_KEY]]
else:
node_features = torch.zeros(self.g.number_of_nodes(), self.node_dim)
node_features = node_features.cuda() if self.is_cuda else node_features
# 2. Build edge features
if isinstance(self.embed_edges, nn.Embedding):
edge_features = self.embed_edges(self.g.edata[GNN_EDGE_LABELS_KEY])
elif isinstance(self.embed_edges, torch.Tensor):
edge_features = self.embed_edges[self.g.edata[GNN_EDGE_LABELS_KEY]]
def _init_graph(self, enc_tree, dec_tree):
enc_tree.register_message_func(self.enc_cell.message_func)
enc_tree.register_reduce_func(self.enc_cell.reduce_func)
enc_tree.register_apply_node_func(self.enc_cell.apply_node_func)
enc_tree.set_n_initializer(dgl.init.zero_initializer)
dec_tree.register_message_func(self.dec_cell.message_func)
dec_tree.register_reduce_func(self.dec_cell.reduce_func)
dec_tree.register_apply_node_func(self.dec_cell.apply_node_func)
dec_tree.set_n_initializer(dgl.init.zero_initializer)
# the code here is ugly and we should replace it with DGL bipartite graph
# in the future.
n_inter = max(n_enc, n_dec)
row_inter, col_inter = map(np.concatenate, (row_inter, col_inter))
coo_inter = coo_matrix((np.zeros_like(row_inter), (row_inter, col_inter)), shape=(n_inter, n_inter))
g_inter = dgl.DGLGraph(coo_inter, readonly=True)
# process readout ids
readout_ids = th.cat(readout_ids)
# initialize graph
g_enc.set_n_initializer(dgl.init.zero_initializer)
g_enc.set_e_initializer(dgl.init.zero_initializer)
g_dec.set_n_initializer(dgl.init.zero_initializer)
g_dec.set_e_initializer(dgl.init.zero_initializer)
g_inter.set_n_initializer(dgl.init.zero_initializer)
g_inter.set_n_initializer(dgl.init.zero_initializer)
data['enc'] = th.cat(data['enc'])
data['dec'] = th.cat(data['dec'])
labels = th.cat(labels)
# assign enc graph feature
g_enc.edata['etype'] = etypes['enc']
g_enc.ndata['pos'] = pos_arr['enc']
g_enc.nodes[leaf_ids['enc']].data['x'] = data['enc']
# assign dec graph feature
g_dec.edata['etype'] = etypes['dec']
g_dec.ndata['pos'] = pos_arr['dec']
g_dec.nodes[leaf_ids['dec']].data['x'] = data['dec']
return Batch(g_enc=g_enc, g_dec=g_dec, g_inter=g_inter, readout_ids=readout_ids, leaf_ids=leaf_ids, y=labels)
row.append(src)
col.append(dst)
etypes.append(th.from_numpy(etype))
# update shift
v_shift += g.number_of_nodes
e_shift += g.number_of_edges
n = v_shift
root_ids = th.tensor(root_ids)
leaf_ids = th.cat(leaf_ids)
pos_arr = th.cat(pos_arr)
etypes = th.cat(etypes)
row, col = map(np.concatenate, (row, col))
coo = coo_matrix((np.zeros_like(row), (row, col)), shape=(n, n))
g = dgl.DGLGraph(coo, readonly=True)
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
data = th.cat(data)
labels = th.cat(labels)
g.edata['etype'] = etypes
g.ndata['pos'] = pos_arr
g.nodes[leaf_ids].data['x'] = data
return Batch(g=g, readout_ids=root_ids, leaf_ids=leaf_ids, y=labels)