Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _test_checkpointer(self, checkpointer, lr_scheduler):
scheduler_states = {}
generator = some_data_generator(OptimizerCheckpointTest.batch_size)
checkpointer.set_params({'epochs': OptimizerCheckpointTest.epochs, 'steps': 1})
checkpointer.set_model(self.model)
checkpointer.on_train_begin({})
for epoch in range(1, OptimizerCheckpointTest.epochs + 1):
checkpointer.on_epoch_begin(epoch, {})
checkpointer.on_batch_begin(1, {})
loss = self._update_model(generator)
checkpointer.on_batch_end(1, {'batch': 1, 'size': OptimizerCheckpointTest.batch_size, 'loss': loss})
checkpointer.on_epoch_end(epoch, {'epoch': epoch, 'loss': loss, 'val_loss': 1})
filename = self.checkpoint_filename.format(epoch=epoch)
self.assertTrue(os.path.isfile(filename))
scheduler_states[epoch] = torch_to_numpy(lr_scheduler.scheduler.state_dict(), copy=True)
checkpointer.on_train_end({})
self._test_checkpoint(scheduler_states, lr_scheduler)
def _test_checkpoint(self, scheduler_states, lr_scheduler):
for epoch, epoch_scheduler_state in scheduler_states.items():
filename = self.checkpoint_filename.format(epoch=epoch)
lr_scheduler.load_state(filename)
saved_scheduler_state = torch_to_numpy(lr_scheduler.scheduler.state_dict())
self.assertEqual(epoch_scheduler_state, saved_scheduler_state)
best_epoch_weights = None
checkpointer.set_params({'epochs': len(val_losses), 'steps': 1})
checkpointer.set_model(self.model)
checkpointer.on_train_begin({})
for epoch, val_loss in enumerate(val_losses, 1):
checkpointer.on_epoch_begin(epoch, {})
checkpointer.on_batch_begin(1, {})
loss = self._update_model(generator)
checkpointer.on_batch_end(1, {'batch': 1, 'size': BestModelRestoreTest.batch_size, 'loss': loss})
checkpointer.on_epoch_end(epoch, {'epoch': epoch, 'loss': loss, 'val_loss': val_loss})
if epoch == best_epoch:
best_epoch_weights = torch_to_numpy(self.model.get_weight_copies())
checkpointer.on_train_end({})
final_weights = torch_to_numpy(self.model.get_weight_copies())
self.assertEqual(best_epoch_weights, final_weights)
def _test_restore_with_val_losses(self, checkpointer, val_losses, best_epoch):
generator = some_data_generator(BestModelRestoreTest.batch_size)
best_epoch_weights = None
checkpointer.set_params({'epochs': len(val_losses), 'steps': 1})
checkpointer.set_model(self.model)
checkpointer.on_train_begin({})
for epoch, val_loss in enumerate(val_losses, 1):
checkpointer.on_epoch_begin(epoch, {})
checkpointer.on_batch_begin(1, {})
loss = self._update_model(generator)
checkpointer.on_batch_end(1, {'batch': 1, 'size': BestModelRestoreTest.batch_size, 'loss': loss})
checkpointer.on_epoch_end(epoch, {'epoch': epoch, 'loss': loss, 'val_loss': val_loss})
if epoch == best_epoch:
best_epoch_weights = torch_to_numpy(self.model.get_weight_copies())
checkpointer.on_train_end({})
final_weights = torch_to_numpy(self.model.get_weight_copies())
self.assertEqual(best_epoch_weights, final_weights)
def _compute_loss_and_metrics(self, x, y, return_loss_tensor=False, return_pred=False):
x, y = self._process_input(x, y)
x = x if isinstance(x, (list, tuple)) else (x, )
pred_y = self.model(*x)
loss = self.loss_function(pred_y, y)
if not return_loss_tensor:
loss = float(loss)
with torch.no_grad():
metrics = self._compute_metrics(pred_y, y)
for epoch_metric in self.epoch_metrics:
epoch_metric(pred_y, y)
pred_y = torch_to_numpy(pred_y) if return_pred else None
return loss, metrics, pred_y
def _validate(self, step_iterator, return_pred=False, return_ground_truth=False):
pred_list = None
true_list = None
if return_pred:
pred_list = []
if return_ground_truth:
true_list = []
with self._set_training_mode(False):
for step, (x, y) in step_iterator:
step.loss, step.metrics, pred_y = self._compute_loss_and_metrics(x, y, return_pred=return_pred)
if return_pred:
pred_list.append(pred_y)
if return_ground_truth:
true_list.append(torch_to_numpy(y))
step.size = self._get_batch_size(x, y)
return step_iterator.loss, step_iterator.metrics, pred_list, true_list
samples. See the :func:`fit_generator()` method for details on the types of generators
supported. This should only yield input data ``x`` and not the target ``y``.
steps (int, optional): Number of iterations done on ``generator``.
(Defaults the number of steps needed to see the entire dataset)
Returns:
List of the predictions of each batch with tensors converted into Numpy arrays.
"""
if steps is None and hasattr(generator, '__len__'):
steps = len(generator)
pred_y = []
with self._set_training_mode(False):
for _, x in _get_step_iterator(steps, generator):
x = self._process_input(x)
x = x if isinstance(x, (tuple, list)) else (x, )
pred_y.append(torch_to_numpy(self.model(*x)))
return pred_y