How to use the spektral.layers.GlobalAttentionPool function in spektral

To help you get started, we’ve selected a few spektral examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github danielegrattarola / spektral / tests / test_layers / test_global_pooling.py View on Github external
def test_global_attention_pool():
    F_ = 10
    assert F_ != F
    _test_single_mode(GlobalAttentionPool, channels=F_)
    _test_batch_mode(GlobalAttentionPool, channels=F_)
    _test_graph_mode(GlobalAttentionPool, channels=F_)
github danielegrattarola / spektral / tests / test_layers / test_global_pooling.py View on Github external
def test_global_attention_pool():
    F_ = 10
    assert F_ != F
    _test_single_mode(GlobalAttentionPool, channels=F_)
    _test_batch_mode(GlobalAttentionPool, channels=F_)
    _test_graph_mode(GlobalAttentionPool, channels=F_)
github danielegrattarola / spektral / examples / classification_delaunay.py View on Github external
epochs = 20000           # Number of training epochs
batch_size = 32          # Batch size
es_patience = 200        # Patience fot early stopping

# Train/test split
A_train, A_test, \
x_train, x_test, \
y_train, y_test = train_test_split(A, X, y, test_size=0.1)

# Model definition
X_in = Input(shape=(N, F))
A_in = Input((N, N))

gc1 = GraphAttention(32, activation='relu', kernel_regularizer=l2(l2_reg))([X_in, A_in])
gc2 = GraphAttention(32, activation='relu', kernel_regularizer=l2(l2_reg))([gc1, A_in])
pool = GlobalAttentionPool(128)(gc2)

output = Dense(n_classes, activation='softmax')(pool)

# Build model
model = Model(inputs=[X_in, A_in], outputs=output)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['acc'])
model.summary()

# Train model
model.fit([x_train, A_train],
          y_train,
          batch_size=batch_size,
          validation_split=0.1,
          epochs=epochs,
          callbacks=[
github danielegrattarola / spektral / docs / autogen.py View on Github external
layers.MinkowskiProduct
        ]
    },
    {
        'page': 'layers/pooling.md',
        'functions': [],
        'methods': [],
        'classes': [
            layers.TopKPool,
            layers.MinCutPool,
            layers.DiffPool,
            layers.SAGPool,
            layers.GlobalSumPool,
            layers.GlobalAvgPool,
            layers.GlobalMaxPool,
            layers.GlobalAttentionPool,
            layers.GlobalAttnSumPool
        ]
    },
    {
        'page': 'datasets/citation.md',
        'functions': [
            datasets.citation.load_data
        ],
        'methods': [],
        'classes': []
    },
    {
        'page': 'datasets/graphsage.md',
        'functions': [
            datasets.graphsage.load_data
        ],