How to use the spektral.datasets.qm9 function in spektral

To help you get started, we’ve selected a few spektral examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github danielegrattarola / spektral / tests / test_datasets.py View on Github external
def test_qm9():
    adj, nf, ef, labels = qm9.load_data('numpy', amount=1000)
    correctly_padded(adj, nf, ef)
    assert adj.shape[0] == labels.shape[0]

    # Test that it doesn't crash
    qm9.load_data('networkx', amount=1000)
    qm9.load_data('sdf', amount=1000)
github danielegrattarola / spektral / tests / test_datasets.py View on Github external
def test_qm9():
    adj, nf, ef, labels = qm9.load_data('numpy', amount=1000)
    correctly_padded(adj, nf, ef)
    assert adj.shape[0] == labels.shape[0]

    # Test that it doesn't crash
    qm9.load_data('networkx', amount=1000)
    qm9.load_data('sdf', amount=1000)
github danielegrattarola / spektral / examples / regression_molecules.py View on Github external
import matplotlib.pyplot as plt
import numpy as np
from keras.callbacks import EarlyStopping
from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split

from spektral.datasets import qm9
from spektral.layers import EdgeConditionedConv, GlobalAvgPool
from spektral.utils import label_to_one_hot

# Load data
A, X, E, y = qm9.load_data(return_type='numpy',
                           nf_keys='atomic_num',
                           ef_keys='type',
                           self_loops=True,
                           amount=1000)  # Set to None to train on whole dataset
y = y[['cv']].values  # Heat capacity at 298.15K

# Preprocessing
uniq_X = np.unique(X)
uniq_X = uniq_X[uniq_X != 0]
X = label_to_one_hot(X, uniq_X)
uniq_E = np.unique(E)
uniq_E = uniq_E[uniq_E != 0]
E = label_to_one_hot(E, uniq_E)

# Parameters
N = X.shape[-2]           # Number of nodes in the graphs
github danielegrattarola / spektral / examples / regression_molecules_disjoint.py View on Github external
import tensorflow as tf
from keras.layers import Input, Dense
from keras.models import Model
from keras.optimizers import Adam
from sklearn.model_selection import train_test_split

from spektral.datasets import qm9
from spektral.layers import GlobalAvgPool, EdgeConditionedConv
from spektral.utils import Batch, batch_iterator
from spektral.utils import label_to_one_hot

np.random.seed(0)
SW_KEY = 'dense_1_sample_weights:0'  # Keras automatically creates a placeholder for sample weights, which must be fed

# Load data
A, X, E, y = qm9.load_data(return_type='numpy',
                           nf_keys='atomic_num',
                           ef_keys='type',
                           self_loops=True,
                           auto_pad=False,
                           amount=1000)  # Set to None to train on whole dataset
y = y[['cv']].values  # Heat capacity at 298.15K

# Preprocessing
uniq_X = np.unique([v for x in X for v in np.unique(x)])
X = [label_to_one_hot(x, uniq_X) for x in X]
uniq_E = np.unique([v for e in E for v in np.unique(e)])
uniq_E = uniq_E[uniq_E != 0]
E = [label_to_one_hot(e, uniq_E) for e in E]

# Parameters
F = X[0].shape[-1]    # Dimension of node features
github danielegrattarola / spektral / docs / autogen.py View on Github external
],
        'methods': [],
        'classes': []
    },
    {
        'page': 'datasets/delaunay.md',
        'functions': [
            datasets.delaunay.generate_data
        ],
        'methods': [],
        'classes': []
    },
    {
        'page': 'datasets/qm9.md',
        'functions': [
            datasets.qm9.load_data
        ],
        'methods': [],
        'classes': []
    },
{
        'page': 'datasets/mnist.md',
        'functions': [
            datasets.mnist.load_data
        ],
        'methods': [],
        'classes': []
    },
    {
        'page': 'brain.md',
        'functions': [
            brain.get_fc_graphs