How to use the spektral.layers.GraphAttention function in spektral

To help you get started, we’ve selected a few spektral examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github danielegrattarola / spektral / tests / benchmarks / citation / citation.py View on Github external
'fltr': lambda A: localpooling_filter(A),
        'sparse': True
    },
    {
        'layer': ARMAConv,
        'n_layers': 1,
        'kwargs': {
            'T': neighbourhood,
            'K': 1,
            'recurrent': True
        },
        'fltr': lambda A: rescale_laplacian(normalized_laplacian(A), lmax=2),
        'sparse': True
    },
    {
        'layer': GraphAttention,
        'n_layers': neighbourhood,
        'kwargs': {},
        'fltr': lambda A: A,
        'sparse': False
    },
    {
        'layer': GraphSageConv,
        'n_layers': neighbourhood,
        'kwargs': {},
        'fltr': lambda A: A,
        'sparse': True
    },
    {
        'layer': APPNP,
        'n_layers': 1,
        'kwargs': {
github danielegrattarola / spektral / tests / benchmarks / node_classification / node_classification.py View on Github external
'sparse': True
    },
    {
        'layer': ARMAConv,
        'n_layers': 1,
        'kwargs': {
            'T': neighbourhood,
            'K': 1,
            'recurrent': True,
            'dropout_rate': dropout_rate
        },
        'fltr': lambda A: rescale_laplacian(normalized_laplacian(A), lmax=2),
        'sparse': True
    },
    {
        'layer': GraphAttention,
        'n_layers': neighbourhood,
        'kwargs': {
            'dropout_rate': dropout_rate
        },
        'fltr': lambda A: A,
        'sparse': False
    },
    {
        'layer': GraphSageConv,
        'n_layers': neighbourhood,
        'kwargs': {},
        'fltr': lambda A: A,
        'sparse': True
    },
    {
        'layer': APPNP,
github danielegrattarola / spektral / tests / test_layers / test_convolutional.py View on Github external
LAYER_K_: ChebConv,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'relu'}
    },
    {
        LAYER_K_: GraphSageConv,
        MODES_K_: [SINGLE],
        KWARGS_K_: {'channels': 8, 'activation': 'relu'}
    },
    {
        LAYER_K_: EdgeConditionedConv,
        MODES_K_: [SINGLE, BATCH],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'edges': True}
    },
    {
        LAYER_K_: GraphAttention,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'attn_heads': 2, 'concat_heads': False, 'activation': 'relu'}
    },
    {
        LAYER_K_: GraphConvSkip,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'relu'}
    },
    {
        LAYER_K_: ARMAConv,
        MODES_K_: [SINGLE, BATCH, MIXED],
        KWARGS_K_: {'channels': 8, 'activation': 'relu', 'order': 2, 'iterations': 2, 'share_weights': True}
    },
    {
        LAYER_K_: APPNP,
        MODES_K_: [SINGLE, BATCH, MIXED],
github danielegrattarola / spektral / examples / classification_delaunay.py View on Github external
l2_reg = 5e-4            # Regularization rate for l2
learning_rate = 1e-3     # Learning rate for Adam
epochs = 20000           # Number of training epochs
batch_size = 32          # Batch size
es_patience = 200        # Patience fot early stopping

# Train/test split
A_train, A_test, \
x_train, x_test, \
y_train, y_test = train_test_split(A, X, y, test_size=0.1)

# Model definition
X_in = Input(shape=(N, F))
A_in = Input((N, N))

gc1 = GraphAttention(32, activation='relu', kernel_regularizer=l2(l2_reg))([X_in, A_in])
gc2 = GraphAttention(32, activation='relu', kernel_regularizer=l2(l2_reg))([gc1, A_in])
pool = GlobalAttentionPool(128)(gc2)

output = Dense(n_classes, activation='softmax')(pool)

# Build model
model = Model(inputs=[X_in, A_in], outputs=output)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
model.summary()

# Train model
model.fit([x_train, A_train],
          y_train,
          batch_size=batch_size,
          validation_split=0.1,
github danielegrattarola / spektral / examples / node_classification_gat.py View on Github external
n_classes = y.shape[1]  # Number of classes
dropout_rate = 0.25     # Dropout rate applied to the input of GAT layers
l2_reg = 5e-4           # Regularization rate for l2
learning_rate = 1e-2    # Learning rate for SGD
epochs = 20000          # Number of training epochs
es_patience = 200       # Patience fot early stopping

# Preprocessing operations
A = add_eye(A).toarray()  # Add self-loops

# Model definition
X_in = Input(shape=(F, ))
A_in = Input(shape=(N, ))

dropout_1 = Dropout(dropout_rate)(X_in)
graph_attention_1 = GraphAttention(gat_channels,
                                   attn_heads=n_attn_heads,
                                   attn_heads_reduction='concat',
                                   dropout_rate=dropout_rate,
                                   activation='elu',
                                   kernel_regularizer=l2(l2_reg),
                                   attn_kernel_regularizer=l2(l2_reg))([dropout_1, A_in])
dropout_2 = Dropout(dropout_rate)(graph_attention_1)
graph_attention_2 = GraphAttention(n_classes,
                                   attn_heads=1,
                                   attn_heads_reduction='average',
                                   dropout_rate=dropout_rate,
                                   activation='softmax',
                                   kernel_regularizer=l2(l2_reg),
                                   attn_kernel_regularizer=l2(l2_reg))([dropout_2, A_in])

# Build model
github danielegrattarola / spektral / docs / autogen.py View on Github external
# 2) Document all its methods: [classA, (classB, "*")]
# 3) Choose which methods to document (methods listed as strings):
# [classA, (classB, ["method1", "method2", ...]), ...]
# 4) Choose which methods to document (methods listed as qualified names):
# [classA, (classB, [module.classB.method1, module.classB.method2, ...]), ...]

PAGES = [
    {
        'page': 'layers/convolution.md',
        'classes': [
            layers.GraphConv,
            layers.ChebConv,
            layers.GraphSageConv,
            layers.ARMAConv,
            layers.EdgeConditionedConv,
            layers.GraphAttention,
            layers.GraphConvSkip,
            layers.APPNP,
            layers.GINConv
        ]
    },
    {
        'page': 'layers/base.md',
        'functions': [],
        'methods': [],
        'classes': [
            layers.InnerProduct,
            layers.MinkowskiProduct
        ]
    },
    {
        'page': 'layers/pooling.md',
github danielegrattarola / spektral / examples / node_classification_gat.py View on Github external
A = add_eye(A).toarray()  # Add self-loops

# Model definition
X_in = Input(shape=(F, ))
A_in = Input(shape=(N, ))

dropout_1 = Dropout(dropout_rate)(X_in)
graph_attention_1 = GraphAttention(gat_channels,
                                   attn_heads=n_attn_heads,
                                   attn_heads_reduction='concat',
                                   dropout_rate=dropout_rate,
                                   activation='elu',
                                   kernel_regularizer=l2(l2_reg),
                                   attn_kernel_regularizer=l2(l2_reg))([dropout_1, A_in])
dropout_2 = Dropout(dropout_rate)(graph_attention_1)
graph_attention_2 = GraphAttention(n_classes,
                                   attn_heads=1,
                                   attn_heads_reduction='average',
                                   dropout_rate=dropout_rate,
                                   activation='softmax',
                                   kernel_regularizer=l2(l2_reg),
                                   attn_kernel_regularizer=l2(l2_reg))([dropout_2, A_in])

# Build model
model = Model(inputs=[X_in, A_in], outputs=graph_attention_2)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer,
              loss='categorical_crossentropy',
              weighted_metrics=['acc'])
model.summary()

# Callbacks