Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
)
set_random_seeds(seed=20190706, cuda=cuda)
n_classes = 4
n_chans = int(train_set.X.shape[1])
input_time_length = train_set.X.shape[2]
if model == "shallow":
model = ShallowFBCSPNet(
n_chans,
n_classes,
input_time_length=input_time_length,
final_conv_length="auto",
).create_network()
elif model == "deep":
model = Deep4Net(
n_chans,
n_classes,
input_time_length=input_time_length,
final_conv_length="auto",
).create_network()
if cuda:
model.cuda()
log.info("Model: \n{:s}".format(str(model)))
optimizer = optim.Adam(model.parameters())
iterator = BalancedBatchSizeIterator(batch_size=batch_size)
stop_criterion = Or(
[
MaxEpochs(max_epochs),
train_set = create_signal_target_from_raw_mne(train_cnt, marker_def, ival)
test_set = create_signal_target_from_raw_mne(test_cnt, marker_def, ival)
train_set, valid_set = split_into_two_sets(
train_set, first_set_fraction=1-valid_set_fraction)
set_random_seeds(seed=20190706, cuda=cuda)
n_classes = 4
n_chans = int(train_set.X.shape[1])
if model == 'shallow':
model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
final_conv_length=30).create_network()
elif model == 'deep':
model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
final_conv_length=2).create_network()
to_dense_prediction_model(model)
if cuda:
model.cuda()
log.info("Model: \n{:s}".format(str(model)))
dummy_input = np_to_var(train_set.X[:1, :, :, None])
if cuda:
dummy_input = dummy_input.cuda()
out = model(dummy_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
optimizer = optim.Adam(model.parameters())
def __init__(self, in_chans, n_classes, input_time_length):
super(HybridNetModule, self).__init__()
deep_model = Deep4Net(
in_chans,
n_classes,
n_filters_time=20,
n_filters_spat=30,
n_filters_2=40,
n_filters_3=50,
n_filters_4=60,
input_time_length=input_time_length,
final_conv_length=2,
).create_network()
shallow_model = ShallowFBCSPNet(
in_chans,
n_classes,
input_time_length=input_time_length,
n_filters_time=30,
n_filters_spat=40,