Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_batch_unbatch():
t1 = tree1()
t2 = tree2()
bg = dgl.batch([t1, t2])
assert bg.number_of_nodes() == 10
assert bg.number_of_edges() == 8
assert bg.batch_size == 2
assert bg.batch_num_nodes == [5, 5]
assert bg.batch_num_edges == [4, 4]
tt1, tt2 = dgl.unbatch(bg)
assert F.allclose(t1.ndata['h'], tt1.ndata['h'])
assert F.allclose(t1.edata['h'], tt1.edata['h'])
assert F.allclose(t2.ndata['h'], tt2.ndata['h'])
assert F.allclose(t2.edata['h'], tt2.edata['h'])
def test_set2set():
g = dgl.DGLGraph(nx.path_graph(10))
ctx = F.ctx()
s2s = nn.Set2Set(5, 3, 3) # hidden size 5, 3 iters, 3 layers
s2s.initialize(ctx=ctx)
print(s2s)
# test#1: basic
h0 = F.randn((g.number_of_nodes(), 5))
h1 = s2s(g, h0)
assert h1.shape[0] == 10 and h1.ndim == 1
# test#2: batched graph
bg = dgl.batch([g, g, g])
h0 = F.randn((bg.number_of_nodes(), 5))
h1 = s2s(bg, h0)
assert h1.shape[0] == 3 and h1.shape[1] == 10 and h1.ndim == 2
feat0 = F.randn((g0.number_of_nodes(), 10))
g0.ndata['x'] = feat0
# to test the case where k > number of nodes.
dgl.topk_nodes(g0, 'x', 20, idx=-1)
# test correctness
val, indices = dgl.topk_nodes(g0, 'x', 5, idx=-1)
ground_truth = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 9, 10), 0, True)[:5], (5,))
assert F.allclose(ground_truth, indices)
g0.ndata.pop('x')
# test#2: batched graph
g1 = dgl.DGLGraph(nx.path_graph(12))
feat1 = F.randn((g1.number_of_nodes(), 10))
bg = dgl.batch([g0, g1])
bg.ndata['x'] = F.cat([feat0, feat1], 0)
# to test the case where k > number of nodes.
dgl.topk_nodes(bg, 'x', 16, idx=1)
# test correctness
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=False, idx=0)
ground_truth_0 = F.reshape(
F.argsort(F.slice_axis(feat0, -1, 0, 1), 0, False)[:6], (6,))
ground_truth_1 = F.reshape(
F.argsort(F.slice_axis(feat1, -1, 0, 1), 0, False)[:6], (6,))
ground_truth = F.stack([ground_truth_0, ground_truth_1], 0)
assert F.allclose(ground_truth, indices)
# test idx=None
val, indices = dgl.topk_nodes(bg, 'x', 6, descending=True)
assert F.allclose(val, F.stack([F.topk(feat0, 6, 0), F.topk(feat1, 6, 0)], 0))
def test_batch_unbatch1():
t1 = tree1()
t2 = tree2()
b1 = dgl.batch([t1, t2])
b2 = dgl.batch([t2, b1])
assert b2.number_of_nodes() == 15
assert b2.number_of_edges() == 12
assert b2.batch_size == 3
assert b2.batch_num_nodes == [5, 5, 5]
assert b2.batch_num_edges == [4, 4, 4]
s1, s2, s3 = dgl.unbatch(b2)
assert F.allclose(t2.ndata['h'], s1.ndata['h'])
assert F.allclose(t2.edata['h'], s1.edata['h'])
assert F.allclose(t1.ndata['h'], s2.ndata['h'])
assert F.allclose(t1.edata['h'], s2.edata['h'])
assert F.allclose(t2.ndata['h'], s3.ndata['h'])
assert F.allclose(t2.edata['h'], s3.edata['h'])
# Test batching readonly graphs
t1.readonly()
def batch(self, samples):
src_samples = [x[0] for x in samples]
enc_trees = [x[1] for x in samples]
dec_trees = [x[2] for x in samples]
src_batch = pad_sequence([torch.tensor(x) for x in src_samples], batch_first=True)
enc_tree_batch = dgl.batch(enc_trees)
dec_tree_batch = dgl.batch(dec_trees)
return src_batch, enc_tree_batch, dec_tree_batch
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return data.SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].as_in_context(ctx),
wordid=batch_trees.ndata['x'].as_in_context(ctx),
label=batch_trees.ndata['y'].as_in_context(ctx))
return batcher_dev
def collate_fn(batch):
graphs, pmpds, labels = zip(*batch)
batched_graphs = dgl.batch(graphs)
batched_pmpds = sp.block_diag(pmpds)
batched_labels = np.concatenate(labels, axis=0)
return batched_graphs, batched_pmpds, batched_labels
def batcher_dev(batch):
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees,
mask=batch_trees.ndata['mask'].to(device),
wordid=batch_trees.ndata['x'].to(device),
label=batch_trees.ndata['y'].to(device))
return batcher_dev
###############################################################################
# The learning curve of a run is presented below.
plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()
###############################################################################
# The trained model is evaluated on the test set created. To deploy
# the tutorial, restrict the running time to get a higher
# accuracy (:math:`80` % ~ :math:`90` %) than the ones printed below.
model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
(test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
(test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))