Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
The device to store each inference result on.
This defaults to CPU memory since there is genereally
more memory available there. For performance reasons
this might be changed to a specific CUDA device,
e.g. 'cuda:0'.
Yields
------
yp : torch tensor
Result from a forward call on an individual batch.
"""
dataset = self.get_dataset(X)
iterator = self.get_iterator(dataset, training=training)
for data in iterator:
Xi = unpack_data(data)[0]
yp = self.evaluation_step(Xi, training=training)
if isinstance(yp, tuple):
yield tuple(n.to(device) for n in yp)
else:
yield yp.to(device)
yi_res = yi if not y_train_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=True)
step = self.train_step(Xi, yi, **fit_params)
train_batch_count += 1
self.history.record_batch('train_loss', step['loss'].item())
self.history.record_batch('train_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
self.history.record("train_batch_count", train_batch_count)
if dataset_valid is None:
self.notify('on_epoch_end', **on_epoch_kwargs)
continue
valid_batch_count = 0
for data in self.get_iterator(dataset_valid, training=False):
Xi, yi = unpack_data(data)
yi_res = yi if not y_valid_is_ph else None
self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
step = self.validation_step(Xi, yi, **fit_params)
valid_batch_count += 1
self.history.record_batch('valid_loss', step['loss'].item())
self.history.record_batch('valid_batch_size', get_len(Xi))
self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
self.history.record("valid_batch_count", valid_batch_count)
self.notify('on_epoch_end', **on_epoch_kwargs)
return self