Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
return ds, skorch.dataset.Dataset(corpus.valid[:200], y=None)
### Perform training
net = skorch.NeuralNetClassifier(
module=wavetorch.WaveCell,
# Training configuration
max_epochs=cfg['training']['N_epochs'],
batch_size=cfg['training']['batch_size'],
lr=cfg['training']['lr'],
# train_split=skorch.dataset.CVSplit(cfg['training']['N_folds'], stratified=True, random_state=cfg['seed']),
optimizer=torch.optim.Adam,
criterion=torch.nn.CrossEntropyLoss,
callbacks=[
ClipDesignRegion,
skorch.callbacks.EpochScoring('accuracy', lower_is_better=False, on_train=True, name='train_acc'),
skorch.callbacks.Checkpoint(monitor=None, fn_prefix='1234_', dirname='test', f_params="params_{last_epoch[epoch]}.pt", f_optimizer='optimizer.pt', f_history='history.json')
],
callbacks__print_log__keys_ignored=None,
train_split=None,
# These al get passed as options to WaveCell
module__Nx=cfg['geom']['Nx'],
module__Ny=cfg['geom']['Ny'],
module__h=cfg['geom']['h'],
module__dt=cfg['geom']['dt'],
module__init=cfg['geom']['init'],
module__c0=cfg['geom']['c0'],
module__c1=cfg['geom']['c1'],
module__sigma=cfg['geom']['pml']['max'],
module__N=cfg['geom']['pml']['N'],
module__p=cfg['geom']['pml']['p'],
y_test,
batch_size,
device,
lr,
max_epochs,
):
torch.manual_seed(0)
net = NeuralNetClassifier(
ClassifierModule,
batch_size=batch_size,
optimizer=torch.optim.Adadelta,
lr=lr,
device=device,
max_epochs=max_epochs,
callbacks=[
('tr_acc', EpochScoring(
'accuracy',
lower_is_better=False,
on_train=True,
name='train_acc',
)),
],
)
net.fit(X_train, y_train)
y_pred = net.predict(X_test)
score = accuracy_score(y_test, y_pred)
return score
def _default_callbacks(self):
return [
('epoch_timer', EpochTimer()),
('train_loss', BatchScoring(
train_loss_score,
name='train_loss',
on_train=True,
target_extractor=noop,
)),
('valid_loss', BatchScoring(
valid_loss_score,
name='valid_loss',
target_extractor=noop,
)),
('valid_acc', EpochScoring(
'accuracy',
name='valid_acc',
lower_is_better=False,
)),
('print_log', PrintLog()),
]