Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
import sys
import os
import numpy as np
import scipy.sparse as sp
from tqdm import tqdm
from ogb.nodeproppred import PygNodePropPredDataset
"""
Run this script to convert the graph from the open graph benchmark format
to the GraphSAINT format.
Right now, ogbn-products and ogbn-arxiv can be converted by this script.
"""
dataset = PygNodePropPredDataset(name=sys.argv[1])
split_idx = dataset.get_idx_split()
train_idx, valid_idx, test_idx = split_idx['train'], split_idx['valid'], split_idx['test']
graph = dataset[0]
num_node = graph.y.shape[0]
# import pdb; pdb.set_trace()
save_dir = './data/'+sys.argv[1]+'/'
try:
os.mkdir(save_dir)
except OSError as error:
print(error)
# feats.npy
feats = graph.x.numpy()
np.save(save_dir+'feats.npy',feats)
Hanqing Zeng (zengh@usc.edu); Hongkuan Zhou (hongkuaz@usc.edu)
"""
from graphsaint.globals import *
from graphsaint.pytorch_version.models import GraphSAINT
from graphsaint.pytorch_version.minibatch import Minibatch
from graphsaint.utils import *
from graphsaint.metric import *
from graphsaint.pytorch_version.utils import *
from ogb.nodeproppred import Evaluator
import torch
import time
evaluator=Evaluator(name='ogbn-products')
def evaluate_full_batch(model, minibatch, mode='val'):
"""
Full batch evaluation: for validation and test sets only.
When calculating the F1 score, we will mask the relevant root nodes
(e.g., those belonging to the val / test sets).
"""
loss,preds,labels = model.eval_step(*minibatch.one_batch(mode=mode))
if mode == 'val':
node_target = [minibatch.node_val]
elif mode == 'test':
node_target = [minibatch.node_test]
else:
assert mode == 'valtest'
node_target = [minibatch.node_val, minibatch.node_test]
labels = labels.argmax(dim=-1, keepdim=True)