Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
from .callbacks import Callback
class ClipNorm(Callback):
"""
Uses PyTorch's :func:`~torch.nn.utils.clip_grad_norm_()`
method to clip gradient.
See:
:func:`torch.nn.utils.clip_grad_norm_()`
"""
def __init__(self, parameters, max_norm, *, norm_type=2):
super().__init__()
self.parameters = list(parameters)
self.max_norm = max_norm
self.norm_type = norm_type
def on_backward_end(self, batch_number):
portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
import os
import warnings
import tempfile
from .callbacks import Callback
class PeriodicSaveCallback(Callback):
"""
The source code of this class is under the MIT License and was copied from the Keras project,
and has been modified.
Write a file after every epoch. `filename` can contain named formatting options, which will be
filled the value of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
`filename` is `weights.{epoch:02d}-{val_loss:.2f}.txt`, then `save_file()` will be called with a
file descriptor for a file with the epoch number and the validation loss in the filename.
By default, the file is written atomically to the specified filename so that the training can
be killed and restarted later using the same filename for periodic file saving. To do so, a
temporary file is created using the system's `tmp` directory and then is moved to the final
destination after the checkpoint is made. Sometimes, this move is not possible on some systems.
To address this problem, it is possible to specify the destination of the temporary file using
the ``temporary_filename`` argument.
"""
Poutyne's callbacks for learning rate schedulers are just wrappers around `PyTorch's learning rate
schedulers `_ and thus have
the same arguments except for the optimizer that has to be omitted.
"""
import sys
import inspect
import torch.optim.lr_scheduler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from .callbacks import Callback
class _PyTorchLRSchedulerWrapper(Callback):
"""
Default class for the LR scheduling callback. Proposes default comportment for the scheduler
loading and saving as well as for the epoch end handling.
"""
def __init__(self, torch_lr_scheduler, *args, **kwargs):
super().__init__()
if len(args) > 0 and isinstance(args[0], Optimizer):
raise ValueError("In the LR scheduler callbacks, the optimizer is "
"automatically passed to the PyTorch's LR scheduler. "
"You must remove it from the arguments.")
self.args = args
self.kwargs = kwargs
self.scheduler = None
self.state_to_load = None
self.torch_lr_scheduler = torch_lr_scheduler
See:
:func:`torch.nn.utils.clip_grad_norm_()`
"""
def __init__(self, parameters, max_norm, *, norm_type=2):
super().__init__()
self.parameters = list(parameters)
self.max_norm = max_norm
self.norm_type = norm_type
def on_backward_end(self, batch_number):
clip_grad_norm_(self.parameters, self.max_norm, norm_type=self.norm_type)
class ClipValue(Callback):
"""
Uses PyTorch's :func:`~torch.nn.utils.clip_grad_value_()`
method to clip gradient.
See:
:func:`torch.nn.utils.clip_grad_value_()`
"""
def __init__(self, parameters, clip_value):
super().__init__()
self.parameters = list(parameters)
self.clip_value = clip_value
def on_backward_end(self, batch_number):
clip_grad_value_(self.parameters, self.clip_value)
The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
import numpy as np
from .callbacks import Callback
class EarlyStopping(Callback):
"""
The source code of this class is under the MIT License and was copied from the Keras project,
and has been modified.
Stop training when a monitored quantity has stopped improving.
Args:
monitor (int): Quantity to be monitored.
min_delta (float): Minimum change in the monitored quantity to qualify as an improvement,
i.e. an absolute change of less than min_delta, will count as no improvement.
(Default value = 0)
patience (int): Number of epochs with no improvement after which training will be stopped.
(Default value = 0)
verbose (bool): Whether to print when early stopping is done.
(Default value = False)
mode (str): One of {'min', 'max'}. In `min` mode, training will stop when the quantity
from .callbacks import Callback, CallbackList
class DelayCallback(Callback):
"""
Delays one or many callbacks for a certain number of epochs or number of batches. If both
``epoch_delay`` and ``batch_delay`` are provided, the biggest has precedence.
Args:
callbacks (Callback, List[Callback]): A callback or a list of callbacks to delay.
epoch_delay (int, optional): Number of epochs to delay.
batch_delay (int, optional): Number of batches to delay. The number of batches can span many
epochs. When the batch delay expires (i.e. there are more than `batch_delay` done), the
:func:`~poutyne.framework.callbacks.Callback.on_epoch_begin()` method is called on
the callback(s) before the :func:`~poutyne.framework.callbacks.Callback.on_batch_begin()` method.
"""
def __init__(self, callbacks, *, epoch_delay=None, batch_delay=None):
super().__init__()
if isinstance(callbacks, CallbackList):
import csv
from .callbacks import Callback
class Logger(Callback):
def __init__(self, *, batch_granularity=False):
super().__init__()
self.batch_granularity = batch_granularity
self.epoch = 0
def on_train_begin(self, logs):
metrics = ['loss'] + self.model.metrics_names
if self.batch_granularity:
self.fieldnames = ['epoch', 'batch', 'size', 'time', 'lr']
else:
self.fieldnames = ['epoch', 'time', 'lr']
self.fieldnames += metrics
self.fieldnames += ['val_' + metric for metric in metrics]
self._on_train_begin_write(logs)
import sys
import itertools
from .callbacks import Callback
class ProgressionCallback(Callback):
def on_train_begin(self, logs):
self.metrics = ['loss'] + self.model.metrics_names
self.epochs = self.params['epochs']
self.steps = self.params['steps']
def on_train_end(self, logs):
pass
def on_epoch_begin(self, epoch_number, logs):
self.step_times_sum = 0.
self.epoch_number = epoch_number
sys.stdout.write("\rEpoch %d/%d" % (self.epoch_number, self.epochs))
sys.stdout.flush()
def on_epoch_end(self, epoch_number, logs):
epoch_total_time = logs['time']
import sys
import itertools
from .callbacks import Callback
class ProgressionCallback(Callback):
def on_train_begin(self, logs):
self.metrics = ['loss'] + self.model.batch_metrics_names + self.model.epoch_metrics_names
self.epochs = self.params['epochs']
self.steps = self.params['steps']
def on_train_end(self, logs):
pass
def on_epoch_begin(self, epoch, logs):
self.step_times_sum = 0.
self.epoch = epoch
sys.stdout.write("\rEpoch %d/%d" % (self.epoch, self.epochs))
sys.stdout.flush()
def on_epoch_end(self, epoch, logs):
epoch_total_time = logs['time']
import csv
from .callbacks import Callback
class Logger(Callback):
def __init__(self, *, batch_granularity=False):
super().__init__()
self.batch_granularity = batch_granularity
self.epoch = 0
def on_train_begin(self, logs):
metrics = ['loss'] + self.model.batch_metrics_names
if self.batch_granularity:
self.fieldnames = ['epoch', 'batch', 'size', 'time', 'lr']
else:
self.fieldnames = ['epoch', 'time', 'lr']
self.fieldnames += metrics
self.fieldnames += ['val_' + metric for metric in metrics]
self._on_train_begin_write(logs)