Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def edge_index_from_dict(graph_dict, num_nodes=None):
row, col = [], []
for key, value in graph_dict.items():
row += repeat(key, len(value))
col += value
edge_index = torch.stack([torch.tensor(row), torch.tensor(col)], dim=0)
# NOTE: There are duplicated edges and self loops in the datasets. Other
# implementations do not remove them!
edge_index, _ = remove_self_loops(edge_index)
edge_index, _ = coalesce(edge_index, None, num_nodes, num_nodes)
return edge_index
def norm(edge_index, num_nodes, edge_weight, gcn=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones(
(edge_index.size(1), ), dtype=dtype, device=edge_index.device)
edge_weight = edge_weight.view(-1)
assert edge_weight.size(0) == edge_index.size(1)
edge_index, _ = remove_self_loops(edge_index)
edge_index = add_self_loops(edge_index, num_nodes)
loop_weight = torch.full(
(num_nodes, ),
1 if gcn else 0,
dtype=edge_weight.dtype,
device=edge_weight.device)
edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-1)
# deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight
def __call__(self, data):
edge_index, edge_attr = data.edge_index, data.edge_attr
n = data.num_nodes
fill = 1e16
value = edge_index.new_full(
(edge_index.size(1), ), fill, dtype=torch.float)
index, value = spspmm(edge_index, value, edge_index, value, n, n, n)
index, value = remove_self_loops(index, value)
edge_index = torch.cat([edge_index, index], dim=1)
if edge_attr is None:
data.edge_index, _ = coalesce(edge_index, None, n, n)
else:
value = value.view(-1, *[1 for _ in range(edge_attr.dim() - 1)])
value = value.expand(-1, *list(edge_attr.size())[1:])
edge_attr = torch.cat([edge_attr, value], dim=0)
data.edge_index, edge_attr = coalesce(
edge_index, edge_attr, n, n, op='min', fill_value=fill)
edge_attr[edge_attr >= fill] = 0
data.edge_attr = edge_attr
return data
def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones(
(edge_index.size(1), ), dtype=dtype, device=edge_index.device)
edge_weight = edge_weight.view(-1)
assert edge_weight.size(0) == edge_index.size(1)
edge_index, _ = remove_self_loops(edge_index)
edge_index = add_self_loops(edge_index, num_nodes)
loop_weight = torch.full(
(num_nodes, ),
1 if not improved else 2,
dtype=edge_weight.dtype,
device=edge_weight.device)
edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def forward(self, x, edge_index):
""""""
edge_index, _ = remove_self_loops(edge_index)
edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
return self.propagate(edge_index, x=x, num_nodes=x.size(0))
def __call__(self, data):
pos = data.pos
assert not pos.is_cuda
tree = scipy.spatial.cKDTree(pos)
indices = tree.query_ball_tree(tree, self.r)
row, col = [], []
for i, neighbors in enumerate(indices):
row += repeat(i, len(neighbors))
col += neighbors
edge_index = torch.tensor([row, col])
edge_index, _ = remove_self_loops(edge_index)
data.edge_index = edge_index
return data
y = np.load(osp.join(self.raw_dir, '{}_labels.npy').format(split))
y = torch.from_numpy(y).to(torch.float)
data_list = []
path = osp.join(self.raw_dir, '{}_graph_id.npy').format(split)
idx = torch.from_numpy(np.load(path)).to(torch.long)
idx = idx - idx.min()
for i in range(idx.max().item() + 1):
mask = idx == i
G_s = G.subgraph(mask.nonzero().view(-1).tolist())
edge_index = torch.tensor(list(G_s.edges)).t().contiguous()
edge_index = edge_index - edge_index.min()
edge_index, _ = remove_self_loops(edge_index)
data = Data(edge_index=edge_index, x=x[mask], y=y[mask])
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
data_list.append(data)
torch.save(self.collate(data_list), self.processed_paths[s])