How to use the braindecode.datautil.iterators.get_balanced_batches function in braindecode

To help you get started, we’ve selected a few braindecode examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github TNTLFreiburg / braindecode / braindecode / visualization / perturbation.py View on Github external
Phase perturbations are sampled for each input individually, but applied to all X of that input
    n_iterations : int
        Number of iterations of correlation computation. The higher the better
    batch_size : int
        Number of inputs that are used for one forward pass. (Concatenated for all inputs)

    Returns
    -------
    pert_corrs : numpy array
        List of length n_layers containing average perturbation correlations over iterations
        L  x  CxFrxFi (Channels,Frequencies,Filters)
    """
    rng = np.random.RandomState(seed)

    # Get batch indeces
    batch_inds = get_balanced_batches(
        n_trials=len(inputs), rng=rng, shuffle=False, batch_size=batch_size
    )
    # Calculate layer activations and reshape
    log.info("Compute original predictions...")
    orig_preds = [pred_fn(inputs[inds]) for inds in batch_inds]
    use_shape = []
    for l in range(n_layers):
        tmp = list(orig_preds[0][l].shape)
        tmp.extend([1] * (4 - len(tmp)))
        tmp[0] = len(inputs)
        use_shape.append(tmp)
    orig_preds_layers = [
        np.concatenate(
            [orig_preds[o][l] for o in range(len(orig_preds))]
        ).reshape(use_shape[l])
        for l in range(n_layers)
github TNTLFreiburg / braindecode / braindecode / datautil / splitters.py View on Github external
Index of the test fold (0-based)
    rng: `numpy.random.RandomState`, optional
        Random Generator for shuffling, None means no shuffling

    Returns
    -------
    reduced_set: :class:`.SignalAndTarget`
        Dataset with only examples selected.
    """
    n_trials = len(dataset.X)
    if n_trials < n_folds:
        raise ValueError(
            "Less Trials: {:d} than folds: {:d}".format(n_trials, n_folds)
        )
    shuffle = rng is not None
    folds = get_balanced_batches(n_trials, rng, shuffle, n_batches=n_folds)
    test_inds = folds[i_test_fold]
    all_inds = list(range(n_trials))
    train_inds = np.setdiff1d(all_inds, test_inds)
    assert np.intersect1d(train_inds, test_inds).size == 0
    assert np.array_equal(np.sort(np.union1d(train_inds, test_inds)), all_inds)

    train_set = select_examples(dataset, train_inds)
    test_set = select_examples(dataset, test_inds)
    return train_set, test_set
github TNTLFreiburg / braindecode / braindecode / datautil / splitters.py View on Github external
Index of the test fold (0-based). Validation fold will be immediately preceding fold.
    rng: `numpy.random.RandomState`, optional
        Random Generator for shuffling, None means no shuffling

    Returns
    -------
    reduced_set: :class:`.SignalAndTarget`
        Dataset with only examples selected.
    """
    n_trials = len(dataset.X)
    if n_trials < n_folds:
        raise ValueError(
            "Less Trials: {:d} than folds: {:d}".format(n_trials, n_folds)
        )
    shuffle = rng is not None
    folds = get_balanced_batches(n_trials, rng, shuffle, n_batches=n_folds)
    test_inds = folds[i_test_fold]
    valid_inds = folds[i_test_fold - 1]
    all_inds = list(range(n_trials))
    train_inds = np.setdiff1d(all_inds, np.union1d(test_inds, valid_inds))
    assert np.intersect1d(train_inds, valid_inds).size == 0
    assert np.intersect1d(train_inds, test_inds).size == 0
    assert np.intersect1d(valid_inds, test_inds).size == 0
    assert np.array_equal(
        np.sort(np.union1d(train_inds, np.union1d(valid_inds, test_inds))),
        all_inds,
    )

    train_set = select_examples(dataset, train_inds)
    valid_set = select_examples(dataset, valid_inds)
    test_set = select_examples(dataset, test_inds)
github TNTLFreiburg / braindecode / braindecode / datautil / lazy_iterators.py View on Github external
self.n_preds_per_input,
            check_preds_smaller_trial_len=self.check_preds_smaller_trial_len,
        )
        for i_trial, trial_blocks in enumerate(start_stop_blocks_per_trial):
            assert trial_blocks[0][0] == 0
            assert trial_blocks[-1][1] == i_trial_stops[i_trial]

        i_trial_start_stop_block = np.array(
            [
                (i_trial, start, stop)
                for i_trial, block in enumerate(start_stop_blocks_per_trial)
                for start, stop in block
            ]
        )

        batches = get_balanced_batches(
            n_trials=len(i_trial_start_stop_block),
            rng=self.rng,
            shuffle=shuffle,
            batch_size=self.batch_size,
        )

        return [i_trial_start_stop_block[batch_ind] for batch_ind in batches]