Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _test_batch_delay(self, epoch_delay, batch_in_epoch_delay):
batch_delay = epoch_delay * DelayCallbackTest.steps_per_epoch + batch_in_epoch_delay
delay_callback = DelayCallback(self.mock_callback, batch_delay=batch_delay)
train_generator = some_data_generator(DelayCallbackTest.batch_size)
valid_generator = some_data_generator(DelayCallbackTest.batch_size)
self.model.fit_generator(train_generator,
valid_generator,
epochs=DelayCallbackTest.epochs,
steps_per_epoch=DelayCallbackTest.steps_per_epoch,
validation_steps=DelayCallbackTest.steps_per_epoch,
callbacks=[delay_callback])
params = {'epochs': DelayCallbackTest.epochs, 'steps': DelayCallbackTest.steps_per_epoch}
call_list = []
call_list.append(call.on_train_begin({}))
for epoch in range(epoch_delay + 1, DelayCallbackTest.epochs + 1):
call_list.append(call.on_epoch_begin(epoch, {}))
start_step = batch_in_epoch_delay + 1 if epoch == epoch_delay + 1 else 1
for step in range(start_step, params['steps'] + 1):
def setUp(self):
torch.manual_seed(42)
self.pytorch_module = nn.Linear(1, 1)
self.loss_function = nn.MSELoss()
self.optimizer = torch.optim.SGD(self.pytorch_module.parameters(), lr=1e-3)
self.model = Model(self.pytorch_module, self.optimizer, self.loss_function)
self.mock_callback = MagicMock()
self.delay_callback = DelayCallback(self.mock_callback)
self.train_dict = {'loss': ANY, 'time': ANY}
self.log_dict = {'loss': ANY, 'val_loss': ANY, 'time': ANY}
def test_epoch_delay(self):
epoch_delay = 4
delay_callback = DelayCallback(self.mock_callback, epoch_delay=epoch_delay)
train_generator = some_data_generator(DelayCallbackTest.batch_size)
valid_generator = some_data_generator(DelayCallbackTest.batch_size)
self.model.fit_generator(train_generator,
valid_generator,
epochs=DelayCallbackTest.epochs,
steps_per_epoch=DelayCallbackTest.steps_per_epoch,
validation_steps=DelayCallbackTest.steps_per_epoch,
callbacks=[delay_callback])
params = {'epochs': DelayCallbackTest.epochs, 'steps': DelayCallbackTest.steps_per_epoch}
call_list = []
call_list.append(call.on_train_begin({}))
for epoch in range(epoch_delay + 1, DelayCallbackTest.epochs + 1):
call_list.append(call.on_epoch_begin(epoch, {}))
for step in range(1, params['steps'] + 1):
call_list.append(call.on_batch_begin(step, {}))