How to use the skorch.dataset.unpack_data function in skorch

To help you get started, we’ve selected a few skorch examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github skorch-dev / skorch / skorch / net.py View on Github external
The device to store each inference result on.
          This defaults to CPU memory since there is genereally
          more memory available there. For performance reasons
          this might be changed to a specific CUDA device,
          e.g. 'cuda:0'.

        Yields
        ------
        yp : torch tensor
          Result from a forward call on an individual batch.

        """
        dataset = self.get_dataset(X)
        iterator = self.get_iterator(dataset, training=training)
        for data in iterator:
            Xi = unpack_data(data)[0]
            yp = self.evaluation_step(Xi, training=training)
            if isinstance(yp, tuple):
                yield tuple(n.to(device) for n in yp)
            else:
                yield yp.to(device)
github skorch-dev / skorch / skorch / net.py View on Github external
yi_res = yi if not y_train_is_ph else None
                self.notify('on_batch_begin', X=Xi, y=yi_res, training=True)
                step = self.train_step(Xi, yi, **fit_params)
                train_batch_count += 1
                self.history.record_batch('train_loss', step['loss'].item())
                self.history.record_batch('train_batch_size', get_len(Xi))
                self.notify('on_batch_end', X=Xi, y=yi_res, training=True, **step)
            self.history.record("train_batch_count", train_batch_count)

            if dataset_valid is None:
                self.notify('on_epoch_end', **on_epoch_kwargs)
                continue

            valid_batch_count = 0
            for data in self.get_iterator(dataset_valid, training=False):
                Xi, yi = unpack_data(data)
                yi_res = yi if not y_valid_is_ph else None
                self.notify('on_batch_begin', X=Xi, y=yi_res, training=False)
                step = self.validation_step(Xi, yi, **fit_params)
                valid_batch_count += 1
                self.history.record_batch('valid_loss', step['loss'].item())
                self.history.record_batch('valid_batch_size', get_len(Xi))
                self.notify('on_batch_end', X=Xi, y=yi_res, training=False, **step)
            self.history.record("valid_batch_count", valid_batch_count)

            self.notify('on_epoch_end', **on_epoch_kwargs)
        return self