Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_integration(self):
train_gen = some_data_generator(20)
valid_gen = some_data_generator(20)
model_restore = BestModelRestore(monitor='val_loss', verbose=True)
self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[model_restore])
def test_save_best_only_with_max(self):
model_restore = BestModelRestore(monitor='val_loss', mode='max')
val_losses = [3, 2, 8, 5, 4]
best_epoch = 3
self._test_restore_with_val_losses(model_restore, val_losses, best_epoch)
def test_basic_restore(self):
model_restore = BestModelRestore(monitor='val_loss')
val_losses = [3, 2, 8, 5, 4]
best_epoch = 2
self._test_restore_with_val_losses(model_restore, val_losses, best_epoch)
def _init_lr_scheduler_callbacks(self, lr_schedulers):
callbacks = []
if self.logging:
for i, lr_scheduler in enumerate(lr_schedulers):
filename = self.lr_scheduler_filename % i
tmp_filename = self.lr_scheduler_tmp_filename % i
callbacks += [
LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
]
else:
callbacks += lr_schedulers
callbacks += [BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)]
return callbacks
def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch):
callbacks = []
best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
monitor=self.monitor_metric,
mode=self.monitor_mode,
save_best_only=not save_every_epoch,
restore_best=not save_every_epoch,
verbose=not save_every_epoch,
temporary_filename=self.best_checkpoint_tmp_filename)
callbacks.append(best_checkpoint)
if save_every_epoch:
best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)
callbacks.append(best_restore)
if initial_epoch > 1:
# We set the current best metric score in the ModelCheckpoint so that
# it does not save checkpoint it would not have saved if the
# optimization was not stopped.
best_epoch_stats = self.get_best_epoch_stats()
best_epoch = best_epoch_stats['epoch'].item()
best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
if not save_every_epoch:
best_checkpoint.best_filename = best_filename
best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()
else:
best_restore.best_weights = torch.load(best_filename, map_location='cpu')
best_restore.current_best = best_epoch_stats[self.monitor_metric].item()