Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
eog=False,
exclude='bads')
# Extract trials, only using EEG channels
epoched = mne.Epochs(raw, events, dict(hands=2, feet=3), tmin=1, tmax=4.1,
proj=False, picks=eeg_channel_inds,
baseline=None, preload=True)
import numpy as np
from braindecode.datautil.signal_target import SignalAndTarget
# Convert data from volt to millivolt
# Pytorch expects float32 for input and int64 for labels.
X = (epoched.get_data() * 1e6).astype(np.float32)
y = (epoched.events[:, 2] - 2).astype(np.int64) # 2,3 -> 0,1
train_set = SignalAndTarget(X[:60], y=y[:60])
test_set = SignalAndTarget(X[60:], y=y[60:])
from braindecode.models.shallow_fbcsp import ShallowFBCSPNet
from torch import nn
from braindecode.torch_ext.util import set_random_seeds
from braindecode.models.util import to_dense_prediction_model
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)
# This will determine how many crops are processed in parallel
input_time_length = 450
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
Returns
-------
reduced_set: :class:`.SignalAndTarget`
Dataset with only examples selected.
"""
# probably not necessary
indices = np.array(indices)
if hasattr(dataset.X, "ndim"):
# numpy array
new_X = np.array(dataset.X)[indices]
else:
# list
new_X = [dataset.X[i] for i in indices]
new_y = np.asarray(dataset.y)[indices]
return SignalAndTarget(new_X, new_y)
def concatenate_two_sets(set_a, set_b):
"""
Concatenate two sets together.
Parameters
----------
set_a, set_b: :class:`.SignalAndTarget`
Returns
-------
concatenated_set: :class:`.SignalAndTarget`
"""
new_X = concatenate_np_array_or_add_lists(set_a.X, set_b.X)
new_y = concatenate_np_array_or_add_lists(set_a.y, set_b.y)
return SignalAndTarget(new_X, new_y)
"lowest class"
)
else:
if np.max(np.sum(this_y, axis=1)) > 1:
log.warning(
"Have multiple active classes and will convert to "
"lowest class"
)
this_new_y = np.argmax(this_y, axis=1)
this_new_y[np.sum(this_y, axis=1) == 0] = -1
new_y.append(this_new_y)
y = new_y
if one_label_per_trial:
y = np.array(y, dtype=np.int64)
return SignalAndTarget(X, y)
label 0 or 1, e.g. 0.5.
individual_crops: bool
Returns
-------
outs_per_trial: 2darray or list of 2darrays
Network outputs for each trial, optionally for each crop within trial.
"""
if individual_crops:
assert self.cropped, "Cropped labels only for cropped decoding"
X = _ensure_float32(X)
all_preds = []
with th.no_grad():
dummy_y = np.ones(len(X), dtype=np.int64)
for b_X, _ in self.iterator.get_batches(
SignalAndTarget(X, dummy_y), False
):
b_X_var = np_to_var(b_X)
if self.cuda:
b_X_var = b_X_var.cuda()
all_preds.append(var_to_np(self.network(b_X_var)))
if self.cropped:
outs_per_trial = compute_preds_per_trial_from_crops(
all_preds, self.iterator.input_time_length, X
)
if not individual_crops:
outs_per_trial = np.array(
[np.mean(o, axis=1) for o in outs_per_trial]
)
else:
outs_per_trial = np.concatenate(all_preds)
return outs_per_trial
if optimizer.__class__.__name__ == "AdamW":
schedule_weight_decay = True
optimizer = ScheduledOptimizer(
scheduler,
self.optimizer,
schedule_weight_decay=schedule_weight_decay,
)
loss_function = self.loss
if self.cropped:
loss_function = lambda outputs, targets: self.loss(
th.mean(outputs, dim=2), targets
)
if validation_data is not None:
valid_X = _ensure_float32(validation_data[0])
valid_y = validation_data[1]
valid_set = SignalAndTarget(valid_X, valid_y)
else:
valid_set = None
test_set = None
self.monitors = [LossMonitor()]
if self.cropped:
self.monitors.append(CroppedTrialMisclassMonitor(input_time_length))
else:
self.monitors.append(MisclassMonitor())
if self.extra_monitors is not None:
self.monitors.extend(self.extra_monitors)
self.monitors.append(RuntimeMonitor())
exp = Experiment(
self.network,
train_set,
valid_set,
test_set,