Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_temporary_filename_arg_with_differing_checkpoint_filename(self):
epochs = 10
tmp_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.tmp.ckpt')
checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint_{epoch}.ckpt')
train_gen = some_data_generator(ModelCheckpointTest.batch_size)
valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
checkpointer = ModelCheckpoint(checkpoint_filename,
monitor='val_loss',
verbose=True,
period=1,
temporary_filename=tmp_filename)
self.model.fit_generator(train_gen, valid_gen, epochs=epochs, steps_per_epoch=5, callbacks=[checkpointer])
self.assertFalse(os.path.isfile(tmp_filename))
for i in range(1, epochs + 1):
self.assertTrue(os.path.isfile(checkpoint_filename.format(epoch=i)))
def test_periodic_with_period_of_2(self):
checkpointer = ModelCheckpoint(self.checkpoint_filename,
monitor='val_loss',
verbose=True,
period=2,
save_best_only=False)
val_losses = [1] * 10
has_checkpoints = [False, True] * 5
self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
def test_restore_best_without_save_best_only(self):
with self.assertRaises(ValueError):
ModelCheckpoint(self.checkpoint_filename,
monitor='val_loss',
verbose=True,
save_best_only=False,
restore_best=True)
with self.assertRaises(ValueError):
ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, restore_best=True)
def test_integration(self):
train_gen = some_data_generator(ModelCheckpointTest.batch_size)
valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=True)
self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
def test_periodic_with_period_of_1(self):
checkpointer = ModelCheckpoint(self.checkpoint_filename,
monitor='val_loss',
verbose=True,
period=1,
save_best_only=False)
val_losses = [1] * 10
has_checkpoints = [True] * 10
self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
def test_save_best_only(self):
checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=True)
val_losses = [10, 3, 8, 5, 2]
has_checkpoints = [True, True, False, False, True]
self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
def test_non_atomic_write(self):
checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.ckpt')
train_gen = some_data_generator(ModelCheckpointTest.batch_size)
valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
checkpointer = ModelCheckpoint(checkpoint_filename,
monitor='val_loss',
verbose=True,
period=1,
atomic_write=False)
self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
self.assertTrue(os.path.isfile(checkpoint_filename))
def test_temporary_filename_arg(self):
tmp_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.tmp.ckpt')
checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.ckpt')
train_gen = some_data_generator(ModelCheckpointTest.batch_size)
valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
checkpointer = ModelCheckpoint(checkpoint_filename,
monitor='val_loss',
verbose=True,
period=1,
temporary_filename=tmp_filename)
self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
self.assertFalse(os.path.isfile(tmp_filename))
self.assertTrue(os.path.isfile(checkpoint_filename))
def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch):
callbacks = []
best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
monitor=self.monitor_metric,
mode=self.monitor_mode,
save_best_only=not save_every_epoch,
restore_best=not save_every_epoch,
verbose=not save_every_epoch,
temporary_filename=self.best_checkpoint_tmp_filename)
callbacks.append(best_checkpoint)
if save_every_epoch:
best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)
callbacks.append(best_restore)
if initial_epoch > 1:
# We set the current best metric score in the ModelCheckpoint so that
# it does not save checkpoint it would not have saved if the
# optimization was not stopped.