Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""Contains skorch-specific exceptions and warnings."""
class SkorchException(BaseException):
"""Base skorch exception."""
class NotInitializedError(SkorchException):
"""Module is not initialized, please call the ``.initialize``
method or train the model by calling ``.fit(...)``.
"""
class SkorchWarning(UserWarning):
"""Base skorch warning."""
class DeviceWarning(SkorchWarning):
"""A problem with a device (e.g. CUDA) was detected."""
def on_epoch_end(self, net, **kwargs):
if self.monitor is None:
do_checkpoint = True
elif callable(self.monitor):
do_checkpoint = self.monitor(net)
else:
try:
do_checkpoint = net.history[-1, self.monitor]
except KeyError as e:
raise SkorchException(
"Monitor value '{}' cannot be found in history. "
"Make sure you have validation data if you use "
"validation scores for checkpointing.".format(e.args[0]))
if do_checkpoint:
target = self.target
if isinstance(self.target, str):
target = self.target.format(
net=net,
last_epoch=net.history[-1],
last_batch=net.history[-1, 'batches', -1],
)
if net.verbose > 0:
print("Checkpoint! Saving model to {}.".format(target))
net.save_params(target)
conjunction with dirname.
"""
if not self.dirname:
return
def _is_truthy_and_not_str(f):
return f and not isinstance(f, str)
if (
_is_truthy_and_not_str(self.f_optimizer) or
_is_truthy_and_not_str(self.f_params) or
_is_truthy_and_not_str(self.f_history) or
_is_truthy_and_not_str(self.f_pickle)
):
raise SkorchException(
'dirname can only be used when f_* are strings')
def on_epoch_end(self, net, **kwargs):
if "{}_best".format(self.monitor) in net.history[-1]:
warnings.warn(
"Checkpoint monitor parameter is set to '{0}' and the history "
"contains '{0}_best'. Perhaps you meant to set the parameter "
"to '{0}_best'".format(self.monitor), UserWarning)
if self.monitor is None:
do_checkpoint = True
elif callable(self.monitor):
do_checkpoint = self.monitor(net)
else:
try:
do_checkpoint = net.history[-1, self.monitor]
except KeyError as e:
raise SkorchException(
"Monitor value '{}' cannot be found in history. "
"Make sure you have validation data if you use "
"validation scores for checkpointing.".format(e.args[0]))
if self.event_name is not None:
net.history.record(self.event_name, bool(do_checkpoint))
if do_checkpoint:
self.save_model(net)
self._sink("A checkpoint was triggered in epoch {}.".format(
len(net.history) + 1
), net.verbose)
def on_epoch_end(self, net, **kwargs):
if self.monitor is None:
do_checkpoint = True
elif callable(self.monitor):
do_checkpoint = self.monitor(net)
else:
try:
do_checkpoint = net.history[-1, self.monitor]
except KeyError as e:
raise SkorchException(
"Monitor value '{}' cannot be found in history. "
"Make sure you have validation data if you use "
"validation scores for checkpointing.".format(e.args[0]))
if do_checkpoint:
target = self.target
if isinstance(self.target, str):
target = self.target.format(
net=net,
last_epoch=net.history[-1],
last_batch=net.history[-1, 'batches', -1],
)
if net.verbose > 0:
print("Checkpoint! Saving model to {}.".format(target))
net.save_params(target)