Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _get_batches_per_epoch_phase(self, net, X, training):
if X is None:
return 0
batch_size = self._get_batch_size(net, training)
return int(np.ceil(get_len(X) / batch_size))
def _get_batches_per_epoch_phase(self, net, dataset, training):
if dataset is None:
return 0
batch_size = self._get_batch_size(net, training)
return int(np.ceil(get_len(dataset) / batch_size))
self.notify('on_batch_begin', X=Xi, y=yi_res, training=True)
step = self.train_step(Xi, yi, **fit_params)
self.history.record_batch('train_loss', step['loss'].item())
self.history.record_batch('train_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
if dataset_valid is None:
self.notify('on_epoch_end', **on_epoch_kwargs)
continue
for Xi, yi in self.get_iterator(dataset_valid, training=False):
yi_res = yi if not y_valid_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
step = self.validation_step(Xi, yi, **fit_params)
self.history.record_batch('valid_loss', step['loss'].item())
self.history.record_batch('valid_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
self.notify('on_epoch_end', **on_epoch_kwargs)
return self
self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
self.history.record("train_batch_count", train_batch_count)
if dataset_valid is None:
self.notify('on_epoch_end', **on_epoch_kwargs)
continue
valid_batch_count = 0
for data in self.get_iterator(dataset_valid, training=False):
Xi, yi = unpack_data(data)
yi_res = yi if not y_valid_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
step = self.validation_step(Xi, yi, **fit_params)
valid_batch_count += 1
self.history.record_batch('valid_loss', step['loss'].item())
self.history.record_batch('valid_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
self.history.record("valid_batch_count", valid_batch_count)
self.notify('on_epoch_end', **on_epoch_kwargs)
return self