Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
class CoraDataset(Planetoid):
def __init__(self):
dataset = "Cora"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(CoraDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("citeseer")
class CiteSeerDataset(Planetoid):
def __init__(self):
dataset = "CiteSeer"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(CiteSeerDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("pubmed")
class PubMedDataset(Planetoid):
def __init__(self):
dataset = "PubMed"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(PubMedDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("reddit")
class RedditDataset(Reddit):
def __init__(self):
dataset = "Reddit"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(RedditDataset, self).__init__(path, T.TargetIndegree())
data = read_gatne_data(self.raw_dir)
torch.save(data, self.processed_paths[0])
def __repr__(self):
return "{}()".format(self.name)
@register_dataset("amazon")
class AmazonDataset(GatneDataset):
def __init__(self):
dataset = "amazon"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(AmazonDataset, self).__init__(path, dataset)
@register_dataset("twitter")
class TwitterDataset(GatneDataset):
def __init__(self):
dataset = "twitter"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(TwitterDataset, self).__init__(path, dataset)
@register_dataset("youtube")
class YouTubeDataset(GatneDataset):
def __init__(self):
dataset = "youtube"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(YouTubeDataset, self).__init__(path, dataset)
dataset, filename = "blogcatalog", "blogcatalog"
url = "http://leitang.net/code/social-dimension/data/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(BlogcatalogDataset, self).__init__(path, filename, url)
@register_dataset("flickr")
class FlickrDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "flickr", "flickr"
url = "http://leitang.net/code/social-dimension/data/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(FlickrDataset, self).__init__(path, filename, url)
@register_dataset("wikipedia")
class WikipediaDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "wikipedia", "POS"
url = "http://snap.stanford.edu/node2vec/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(WikipediaDataset, self).__init__(path, filename, url)
@register_dataset("ppi")
class PPIDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "ppi", "Homo_sapiens"
url = "http://snap.stanford.edu/node2vec/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(PPIDataset, self).__init__(path, filename, url)
return ["data.pt"]
def get(self, idx):
assert idx == 0
return self.data
def download(self):
for name in self.raw_file_names:
download_url("{}/{}".format(self.url, name), self.raw_dir)
def process(self):
data = read_edgelist_label_data(self.raw_dir, self.name)
torch.save(data, self.processed_paths[0])
@register_dataset("dblp")
class DBLP(EdgelistLabel):
def __init__(self):
dataset = "dblp"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(DBLP, self).__init__(path, dataset)
data, slices = self.collate([data])
torch.save((data, slices), self.processed_paths[0])
def __repr__(self):
return '{}()'.format(self.name)
@register_dataset('cora')
class CoraDataset(Planetoid):
def __init__(self):
dataset = 'Cora'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
super(CoraDataset, self).__init__(path, dataset, T.NormalizeFeatures())
@register_dataset('citeseer')
class CiteSeerDataset(Planetoid):
def __init__(self):
dataset = 'CiteSeer'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
super(CiteSeerDataset, self).__init__(path, dataset, T.NormalizeFeatures())
@register_dataset('pubmed')
class PubMedDataset(Planetoid):
def __init__(self):
dataset = 'PubMed'
path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', dataset)
super(PubMedDataset, self).__init__(path, dataset, T.NormalizeFeatures())
smat = scio.loadmat(path)
adj_matrix, group = smat["network"], smat["group"]
y = tf.convert_to_tensor(group.todense(), np.float32)
row_ind, col_ind = adj_matrix.nonzero()
edge_index = tf.stack([tf.convert_to_tensor(row_ind), tf.convert_to_tensor(col_ind)])
edge_attr = tf.convert_to_tensor(adj_matrix[row_ind, col_ind])
data = Data(edge_index=edge_index, edge_attr=edge_attr, x=None, y=y)
with open(self.processed_paths[0], 'wb') as output:
pickle.dump(data, output, pickle.HIGHEST_PROTOCOL)
@register_dataset("blogcatalog")
class BlogcatalogDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "blogcatalog", "blogcatalog"
url = "http://leitang.net/code/social-dimension/data/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(BlogcatalogDataset, self).__init__(path, filename, url)
@register_dataset("flickr")
class FlickrDataset(MatlabMatrix):
def __init__(self):
dataset, filename = "flickr", "flickr"
url = "http://leitang.net/code/social-dimension/data/"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(FlickrDataset, self).__init__(path, filename, url)
class CiteSeerDataset(Planetoid):
def __init__(self):
dataset = "CiteSeer"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(CiteSeerDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("pubmed")
class PubMedDataset(Planetoid):
def __init__(self):
dataset = "PubMed"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(PubMedDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("reddit")
class RedditDataset(Reddit):
def __init__(self):
dataset = "Reddit"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(RedditDataset, self).__init__(path, T.TargetIndegree())
import os.path as osp
import torch
import torch_geometric.transforms as T
from torch_geometric.datasets import Planetoid, Reddit
from . import register_dataset
@register_dataset("cora")
class CoraDataset(Planetoid):
def __init__(self):
dataset = "Cora"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(CoraDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("citeseer")
class CiteSeerDataset(Planetoid):
def __init__(self):
dataset = "CiteSeer"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(CiteSeerDataset, self).__init__(path, dataset, T.TargetIndegree())
@register_dataset("pubmed")
adj = sp.load_npz(osp.join(self.raw_dir, "reddit_graph.npz"))
row = torch.from_numpy(adj.row).to(torch.long)
col = torch.from_numpy(adj.col).to(torch.long)
edge_index = torch.stack([row, col], dim=0)
edge_index, _ = coalesce(edge_index, None, x.size(0), x.size(0))
data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = split == 1
data.val_mask = split == 2
data.test_mask = split == 3
torch.save(self.collate([data]), self.processed_paths[0])
@register_dataset("reddit")
class RedditDataset(Reddit):
def __init__(self):
dataset = "Reddit"
path = osp.join(osp.dirname(osp.realpath(__file__)), "..", "data", dataset)
super(RedditDataset, self).__init__(path, T.NormalizeFeatures())
class AmazonDataset(GatneDataset):
def __init__(self):
dataset = "amazon"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(AmazonDataset, self).__init__(path, dataset)
@register_dataset("twitter")
class TwitterDataset(GatneDataset):
def __init__(self):
dataset = "twitter"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(TwitterDataset, self).__init__(path, dataset)
@register_dataset("youtube")
class YouTubeDataset(GatneDataset):
def __init__(self):
dataset = "youtube"
path = osp.join(osp.dirname(osp.realpath(__file__)), "../..", "data", dataset)
super(YouTubeDataset, self).__init__(path, dataset)