How to use the spektral.layers.GlobalAvgPool function in spektral

To help you get started, we’ve selected a few spektral examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github danielegrattarola / spektral / examples / classification_BDGC_disjoint.py View on Github external
# Block 2
gc2 = GraphConvSkip(n_channels,
                    activation=activ,
                    kernel_regularizer=l2(GNN_l2))([X_1, A_1])
X_2, A_2, I_2, M_2 = MinCutPool(k=int(average_N // 4),
                                h=mincut_H,
                                activation=activ,
                                kernel_regularizer=l2(pool_l2))([gc2, A_1, I_1])

# Block 3
X_3 = GraphConvSkip(n_channels,
                    activation=activ,
                    kernel_regularizer=l2(GNN_l2))([X_2, A_2])

# Output block
avgpool = GlobalAvgPool()([X_3, I_2])
output = Dense(n_out, activation='softmax')(avgpool)

# Build model
model = Model([X_in, A_in, I_in], output)
model.compile(optimizer='adam', loss='categorical_crossentropy', target_tensors=[target])
model.summary()

# Training setup
sess = K.get_session()
loss = model.total_loss
acc = K.mean(categorical_accuracy(target, model.output))
opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_step = opt.minimize(loss)

# Initialize all variables
init_op = tf.global_variables_initializer()
github danielegrattarola / spektral / tests / test_layers / test_global_pooling.py View on Github external
def test_global_avg_pool():
    _test_single_mode(GlobalAvgPool)
    _test_batch_mode(GlobalAvgPool)
    _test_graph_mode(GlobalAvgPool)
github danielegrattarola / spektral / tests / test_layers / test_global_pooling.py View on Github external
def test_global_avg_pool():
    _test_single_mode(GlobalAvgPool)
    _test_batch_mode(GlobalAvgPool)
    _test_graph_mode(GlobalAvgPool)
github danielegrattarola / spektral / examples / regression_molecules.py View on Github external
es_patience = 5           # Patience fot early stopping

# Train/test split
A_train, A_test, \
X_train, X_test, \
E_train, E_test, \
y_train, y_test = train_test_split(A, X, E, y, test_size=0.1)

# Model definition
X_in = Input(shape=(N, F))
A_in = Input(shape=(N, N))
E_in = Input(shape=(N, N, S))

gc1 = EdgeConditionedConv(32, activation='relu')([X_in, A_in, E_in])
gc2 = EdgeConditionedConv(32, activation='relu')([gc1, A_in, E_in])
pool = GlobalAvgPool()(gc2)
output = Dense(n_out)(pool)

# Build model
model = Model(inputs=[X_in, A_in, E_in], outputs=output)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer, loss='mse')
model.summary()

# Train model
model.fit([X_train, A_train, E_train],
          y_train,
          batch_size=batch_size,
          validation_split=0.1,
          epochs=epochs,
          callbacks=[
              EarlyStopping(patience=es_patience,  restore_best_weights=True)
github danielegrattarola / spektral / examples / regression_molecules_disjoint.py View on Github external
# Train/test split
A_train, A_test, \
X_train, X_test, \
E_train, E_test, \
y_train, y_test = train_test_split(A, X, E, y, test_size=0.1)

# Model definition
X_in = Input(batch_shape=(None, F))
A_in = Input(batch_shape=(None, None))
E_in = Input(batch_shape=(None, None, S))
I_in = Input(batch_shape=(None, ), dtype='int64')
target = Input(tensor=tf.placeholder(tf.float32, shape=(None, n_out), name='target'))

gc1 = EdgeConditionedConv(32, activation='relu')([X_in, A_in, E_in])
gc2 = EdgeConditionedConv(32, activation='relu')([gc1, A_in, E_in])
pool = GlobalAvgPool()([gc2, I_in])
output = Dense(n_out)(pool)

# Build model
model = Model(inputs=[X_in, A_in, E_in, I_in], outputs=output)
optimizer = Adam(lr=learning_rate)
model.compile(optimizer=optimizer, loss='mse', target_tensors=target)
model.summary()

# Training setup
sess = K.get_session()
loss = model.total_loss
opt = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_step = opt.minimize(loss)
init_op = tf.global_variables_initializer()
sess.run(init_op)
github danielegrattarola / spektral / docs / autogen.py View on Github external
'classes': [
            layers.InnerProduct,
            layers.MinkowskiProduct
        ]
    },
    {
        'page': 'layers/pooling.md',
        'functions': [],
        'methods': [],
        'classes': [
            layers.TopKPool,
            layers.MinCutPool,
            layers.DiffPool,
            layers.SAGPool,
            layers.GlobalSumPool,
            layers.GlobalAvgPool,
            layers.GlobalMaxPool,
            layers.GlobalAttentionPool,
            layers.GlobalAttnSumPool
        ]
    },
    {
        'page': 'datasets/citation.md',
        'functions': [
            datasets.citation.load_data
        ],
        'methods': [],
        'classes': []
    },
    {
        'page': 'datasets/graphsage.md',
        'functions': [