Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_any_scheduler_checkpoints(self):
lr_scheduler = ExponentialLR(gamma=0.01)
checkpointer = LRSchedulerCheckpoint(lr_scheduler, self.checkpoint_filename, period=1)
self._test_checkpointer(checkpointer, lr_scheduler)
def test_reduce_lr_checkpoints(self):
reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3)
checkpointer = LRSchedulerCheckpoint(reduce_lr, self.checkpoint_filename, period=1)
self._test_checkpointer(checkpointer, reduce_lr)
def test_reduce_lr_on_plateau_integration(self):
train_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3)
checkpointer = LRSchedulerCheckpoint(reduce_lr, self.checkpoint_filename, period=1)
self.model.fit_generator(train_gen,
valid_gen,
epochs=OptimizerCheckpointTest.epochs,
steps_per_epoch=5,
callbacks=[checkpointer])
def test_any_scheduler_integration(self):
train_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
lr_scheduler = ExponentialLR(gamma=0.01)
checkpointer = LRSchedulerCheckpoint(lr_scheduler, self.checkpoint_filename, period=1)
self.model.fit_generator(train_gen,
valid_gen,
epochs=OptimizerCheckpointTest.epochs,
steps_per_epoch=5,
callbacks=[checkpointer])
def _init_lr_scheduler_callbacks(self, lr_schedulers):
callbacks = []
if self.logging:
for i, lr_scheduler in enumerate(lr_schedulers):
filename = self.lr_scheduler_filename % i
tmp_filename = self.lr_scheduler_tmp_filename % i
callbacks += [
LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
]
else:
callbacks += lr_schedulers
callbacks += [BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)]
return callbacks