Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def norm(edge_index, num_nodes, edge_weight, gcn=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones(
(edge_index.size(1), ), dtype=dtype, device=edge_index.device)
edge_weight = edge_weight.view(-1)
assert edge_weight.size(0) == edge_index.size(1)
edge_index, _ = remove_self_loops(edge_index)
edge_index = add_self_loops(edge_index, num_nodes)
loop_weight = torch.full(
(num_nodes, ),
1 if gcn else 0,
dtype=edge_weight.dtype,
device=edge_weight.device)
edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-1)
# deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight
def forward(self, x, edge_index):
""""""
edge_index, _ = remove_self_loops(edge_index)
edge_index = add_self_loops(edge_index, num_nodes=x.size(0))
x = torch.mm(x, self.weight).view(-1, self.heads, self.out_channels)
return self.propagate(edge_index, x=x, num_nodes=x.size(0))
def norm(edge_index, num_nodes, edge_weight, improved=False, dtype=None):
if edge_weight is None:
edge_weight = torch.ones(
(edge_index.size(1), ), dtype=dtype, device=edge_index.device)
edge_weight = edge_weight.view(-1)
assert edge_weight.size(0) == edge_index.size(1)
edge_index, _ = remove_self_loops(edge_index)
edge_index = add_self_loops(edge_index, num_nodes)
loop_weight = torch.full(
(num_nodes, ),
1 if not improved else 2,
dtype=edge_weight.dtype,
device=edge_weight.device)
edge_weight = torch.cat([edge_weight, loop_weight], dim=0)
row, col = edge_index
deg = scatter_add(edge_weight, row, dim=0, dim_size=num_nodes)
deg_inv_sqrt = deg.pow(-0.5)
deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
return edge_index, deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
def __call__(self, data):
N = data.num_nodes
edge_index = data.edge_index
assert data.edge_attr is None
edge_index = add_self_loops(edge_index, num_nodes=N)
edge_index, _ = coalesce(edge_index, None, N, N)
data.edge_index = edge_index
return data