Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from braindecode.models.util import to_dense_prediction_model
# Set if you want to use GPU
# You can also use torch.cuda.is_available() to determine if cuda is available on your machine.
cuda = False
set_random_seeds(seed=20170629, cuda=cuda)
# This will determine how many crops are processed in parallel
input_time_length = 450
n_classes = 2
in_chans = train_set.X.shape[1]
# final_conv_length determines the size of the receptive field of the ConvNet
model = ShallowFBCSPNet(in_chans=in_chans, n_classes=n_classes,
input_time_length=input_time_length,
final_conv_length=12).create_network()
to_dense_prediction_model(model)
if cuda:
model.cuda()
from torch import optim
optimizer = optim.Adam(model.parameters())
from braindecode.torch_ext.util import np_to_var
# determine output size
test_input = np_to_var(
np.ones((2, in_chans, input_time_length, 1), dtype=np.float32))
if cuda:
test_input = test_input.cuda()
out = model(test_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
print("{:d} predictions per input/trial".format(n_preds_per_input))
train_set, valid_set = split_into_two_sets(
train_set, first_set_fraction=1-valid_set_fraction)
set_random_seeds(seed=20190706, cuda=cuda)
n_classes = 4
n_chans = int(train_set.X.shape[1])
if model == 'shallow':
model = ShallowFBCSPNet(n_chans, n_classes, input_time_length=input_time_length,
final_conv_length=30).create_network()
elif model == 'deep':
model = Deep4Net(n_chans, n_classes, input_time_length=input_time_length,
final_conv_length=2).create_network()
to_dense_prediction_model(model)
if cuda:
model.cuda()
log.info("Model: \n{:s}".format(str(model)))
dummy_input = np_to_var(train_set.X[:1, :, :, None])
if cuda:
dummy_input = dummy_input.cuda()
out = model(dummy_input)
n_preds_per_input = out.cpu().data.numpy().shape[2]
optimizer = optim.Adam(model.parameters())
iterator = CropsFromTrialsIterator(batch_size=batch_size,
input_time_length=input_time_length,
n_preds_per_input=n_preds_per_input)
-------
"""
self.loss = loss
self._ensure_network_exists()
if cropped:
model_already_dense = np.any(
[
hasattr(m, "dilation")
and (m.dilation != 1)
and (m.dilation) != (1, 1)
for m in self.network.modules()
]
)
if not model_already_dense:
to_dense_prediction_model(self.network)
else:
log.info("Seems model was already converted to dense model...")
if not hasattr(optimizer, "step"):
optimizer_class = find_optimizer(optimizer)
optimizer = optimizer_class(self.network.parameters())
self.optimizer = optimizer
self.extra_monitors = extra_monitors
# Already setting it here, so multiple calls to fit
# will lead to different batches being drawn
self.seed_rng = RandomState(iterator_seed)
self.cropped = cropped
self.compiled = True
for name, module in shallow_model.named_children():
if name == "conv_classifier":
new_conv_layer = nn.Conv2d(
module.in_channels,
40,
kernel_size=module.kernel_size,
stride=module.stride,
)
reduced_shallow_model.add_module(
"shallow_final_conv", new_conv_layer
)
break
reduced_shallow_model.add_module(name, module)
to_dense_prediction_model(reduced_deep_model)
to_dense_prediction_model(reduced_shallow_model)
self.reduced_deep_model = reduced_deep_model
self.reduced_shallow_model = reduced_shallow_model
self.final_conv = nn.Conv2d(
100, n_classes, kernel_size=(1, 1), stride=1
)
reduced_shallow_model = nn.Sequential()
for name, module in shallow_model.named_children():
if name == "conv_classifier":
new_conv_layer = nn.Conv2d(
module.in_channels,
40,
kernel_size=module.kernel_size,
stride=module.stride,
)
reduced_shallow_model.add_module(
"shallow_final_conv", new_conv_layer
)
break
reduced_shallow_model.add_module(name, module)
to_dense_prediction_model(reduced_deep_model)
to_dense_prediction_model(reduced_shallow_model)
self.reduced_deep_model = reduced_deep_model
self.reduced_shallow_model = reduced_shallow_model
self.final_conv = nn.Conv2d(
100, n_classes, kernel_size=(1, 1), stride=1
)