How to use the poutyne.framework.callbacks.callbacks.Callback function in Poutyne

To help you get started, we’ve selected a few Poutyne examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github GRAAL-Research / poutyne / poutyne / framework / callbacks / clip_grad.py View on Github external
from torch.nn.utils import clip_grad_norm_, clip_grad_value_

from .callbacks import Callback


class ClipNorm(Callback):
    """
    Uses PyTorch's :func:`~torch.nn.utils.clip_grad_norm_()`
    method to clip gradient.

    See:
        :func:`torch.nn.utils.clip_grad_norm_()`

    """

    def __init__(self, parameters, max_norm, *, norm_type=2):
        super().__init__()
        self.parameters = list(parameters)
        self.max_norm = max_norm
        self.norm_type = norm_type

    def on_backward_end(self, batch_number):
github GRAAL-Research / poutyne / poutyne / framework / callbacks / periodic.py View on Github external
portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""
import os
import warnings
import tempfile

from .callbacks import Callback


class PeriodicSaveCallback(Callback):
    """
    The source code of this class is under the MIT License and was copied from the Keras project,
    and has been modified.

    Write a file after every epoch. `filename` can contain named formatting options, which will be
    filled the value of `epoch` and keys in `logs` (passed in `on_epoch_end`). For example: if
    `filename` is `weights.{epoch:02d}-{val_loss:.2f}.txt`, then `save_file()` will be called with a
    file descriptor for a file with the epoch number and the validation loss in the filename.

    By default, the file is written atomically to the specified filename so that the training can
    be killed and restarted later using the same filename for periodic file saving. To do so, a
    temporary file is created using the system's `tmp` directory and then is moved to the final
    destination after the checkpoint is made. Sometimes, this move is not possible on some systems.
    To address this problem, it is possible to specify the destination of the temporary file using
    the ``temporary_filename`` argument.
github GRAAL-Research / poutyne / poutyne / framework / callbacks / lr_scheduler.py View on Github external
"""
Poutyne's callbacks for learning rate schedulers are just wrappers around `PyTorch's learning rate
schedulers `_ and thus have
the same arguments except for the optimizer that has to be omitted.
"""
import sys
import inspect
import torch.optim.lr_scheduler
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler

from .callbacks import Callback


class _PyTorchLRSchedulerWrapper(Callback):
    """
    Default class for the LR scheduling callback. Proposes default comportment for the scheduler
    loading and saving as well as for the epoch end handling.
    """

    def __init__(self, torch_lr_scheduler, *args, **kwargs):
        super().__init__()
        if len(args) > 0 and isinstance(args[0], Optimizer):
            raise ValueError("In the LR scheduler callbacks, the optimizer is "
                             "automatically passed to the PyTorch's LR scheduler. "
                             "You must remove it from the arguments.")
        self.args = args
        self.kwargs = kwargs
        self.scheduler = None
        self.state_to_load = None
        self.torch_lr_scheduler = torch_lr_scheduler
github GRAAL-Research / poutyne / poutyne / framework / callbacks / clip_grad.py View on Github external
See:
        :func:`torch.nn.utils.clip_grad_norm_()`

    """

    def __init__(self, parameters, max_norm, *, norm_type=2):
        super().__init__()
        self.parameters = list(parameters)
        self.max_norm = max_norm
        self.norm_type = norm_type

    def on_backward_end(self, batch_number):
        clip_grad_norm_(self.parameters, self.max_norm, norm_type=self.norm_type)


class ClipValue(Callback):
    """
    Uses PyTorch's :func:`~torch.nn.utils.clip_grad_value_()`
    method to clip gradient.

    See:
        :func:`torch.nn.utils.clip_grad_value_()`

    """

    def __init__(self, parameters, clip_value):
        super().__init__()
        self.parameters = list(parameters)
        self.clip_value = clip_value

    def on_backward_end(self, batch_number):
        clip_grad_value_(self.parameters, self.clip_value)
github GRAAL-Research / poutyne / poutyne / framework / callbacks / earlystopping.py View on Github external
The above copyright notice and this permission notice shall be included in all copies or substantial
portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES
OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""

import numpy as np
from .callbacks import Callback


class EarlyStopping(Callback):
    """
    The source code of this class is under the MIT License and was copied from the Keras project,
    and has been modified.

    Stop training when a monitored quantity has stopped improving.

    Args:
        monitor (int): Quantity to be monitored.
        min_delta (float): Minimum change in the monitored quantity to qualify as an improvement,
            i.e. an absolute change of less than min_delta, will count as no improvement.
            (Default value = 0)
        patience (int): Number of epochs with no improvement after which training will be stopped.
            (Default value = 0)
        verbose (bool): Whether to print when early stopping is done.
            (Default value = False)
        mode (str): One of {'min', 'max'}. In `min` mode, training will stop when the quantity
github GRAAL-Research / poutyne / poutyne / framework / callbacks / delay.py View on Github external
from .callbacks import Callback, CallbackList


class DelayCallback(Callback):
    """
    Delays one or many callbacks for a certain number of epochs or number of batches. If both
    ``epoch_delay`` and ``batch_delay`` are provided, the biggest has precedence.

    Args:
        callbacks (Callback, List[Callback]): A callback or a list of callbacks to delay.
        epoch_delay (int, optional): Number of epochs to delay.
        batch_delay (int, optional): Number of batches to delay. The number of batches can span many
            epochs. When the batch delay expires (i.e. there are more than `batch_delay` done), the
            :func:`~poutyne.framework.callbacks.Callback.on_epoch_begin()` method is called on
            the callback(s) before the :func:`~poutyne.framework.callbacks.Callback.on_batch_begin()` method.
    """

    def __init__(self, callbacks, *, epoch_delay=None, batch_delay=None):
        super().__init__()
        if isinstance(callbacks, CallbackList):
github GRAAL-Research / poutyne / poutyne / framework / callbacks / logger.py View on Github external
import csv
from .callbacks import Callback


class Logger(Callback):
    def __init__(self, *, batch_granularity=False):
        super().__init__()
        self.batch_granularity = batch_granularity
        self.epoch = 0

    def on_train_begin(self, logs):
        metrics = ['loss'] + self.model.metrics_names

        if self.batch_granularity:
            self.fieldnames = ['epoch', 'batch', 'size', 'time', 'lr']
        else:
            self.fieldnames = ['epoch', 'time', 'lr']
        self.fieldnames += metrics
        self.fieldnames += ['val_' + metric for metric in metrics]
        self._on_train_begin_write(logs)
github GRAAL-Research / poutyne / poutyne / framework / callbacks / progress.py View on Github external
import sys

import itertools

from .callbacks import Callback


class ProgressionCallback(Callback):
    def on_train_begin(self, logs):
        self.metrics = ['loss'] + self.model.metrics_names
        self.epochs = self.params['epochs']
        self.steps = self.params['steps']

    def on_train_end(self, logs):
        pass

    def on_epoch_begin(self, epoch_number, logs):
        self.step_times_sum = 0.
        self.epoch_number = epoch_number
        sys.stdout.write("\rEpoch %d/%d" % (self.epoch_number, self.epochs))
        sys.stdout.flush()

    def on_epoch_end(self, epoch_number, logs):
        epoch_total_time = logs['time']
github GRAAL-Research / poutyne / poutyne / framework / callbacks / progress.py View on Github external
import sys

import itertools

from .callbacks import Callback


class ProgressionCallback(Callback):
    def on_train_begin(self, logs):
        self.metrics = ['loss'] + self.model.batch_metrics_names + self.model.epoch_metrics_names
        self.epochs = self.params['epochs']
        self.steps = self.params['steps']

    def on_train_end(self, logs):
        pass

    def on_epoch_begin(self, epoch, logs):
        self.step_times_sum = 0.
        self.epoch = epoch
        sys.stdout.write("\rEpoch %d/%d" % (self.epoch, self.epochs))
        sys.stdout.flush()

    def on_epoch_end(self, epoch, logs):
        epoch_total_time = logs['time']
github GRAAL-Research / poutyne / poutyne / framework / callbacks / logger.py View on Github external
import csv
from .callbacks import Callback


class Logger(Callback):
    def __init__(self, *, batch_granularity=False):
        super().__init__()
        self.batch_granularity = batch_granularity
        self.epoch = 0

    def on_train_begin(self, logs):
        metrics = ['loss'] + self.model.batch_metrics_names

        if self.batch_granularity:
            self.fieldnames = ['epoch', 'batch', 'size', 'time', 'lr']
        else:
            self.fieldnames = ['epoch', 'time', 'lr']
        self.fieldnames += metrics
        self.fieldnames += ['val_' + metric for metric in metrics]
        self._on_train_begin_write(logs)