Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_MDM_predict():
"""Test prediction of MDM"""
covset = generate_cov(100, 3)
labels = np.array([0, 1]).repeat(50)
mdm = MDM(metric='riemann')
mdm.fit(covset, labels)
mdm.predict(covset)
# test fit_predict
mdm = MDM(metric='riemann')
mdm.fit_predict(covset, labels)
# test transform
mdm.transform(covset)
# predict proba
mdm.predict_proba(covset)
# test n_jobs
mdm = MDM(metric='riemann', n_jobs=2)
mdm.fit(covset, labels)
mdm.predict(covset)
mdm = MDM(metric='riemann')
mdm.fit(covset, labels)
mdm.predict(covset)
# test fit_predict
mdm = MDM(metric='riemann')
mdm.fit_predict(covset, labels)
# test transform
mdm.transform(covset)
# predict proba
mdm.predict_proba(covset)
# test n_jobs
mdm = MDM(metric='riemann', n_jobs=2)
mdm.fit(covset, labels)
mdm.predict(covset)
lr = LogisticRegression()
csp = CSP(n_components=4, reg='ledoit_wolf', log=True)
clf = Pipeline([('CSP', csp), ('LogisticRegression', lr)])
scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=1)
# Printing the results
class_balance = np.mean(labels == labels[0])
class_balance = max(class_balance, 1. - class_balance)
print("CSP + LDA Classification accuracy: %f / Chance level: %f" %
(np.mean(scores), class_balance))
###############################################################################
# Display MDM centroid
mdm = MDM()
mdm.fit(cov_data_train, labels)
fig, axes = plt.subplots(1, 2, figsize=[8, 4])
ch_names = [ch.replace('.', '') for ch in epochs.ch_names]
df = pd.DataFrame(data=mdm.covmeans_[0], index=ch_names, columns=ch_names)
g = sns.heatmap(
df, ax=axes[0], square=True, cbar=False, xticklabels=2, yticklabels=2)
g.set_title('Mean covariance - hands')
df = pd.DataFrame(data=mdm.covmeans_[1], index=ch_names, columns=ch_names)
g = sns.heatmap(
df, ax=axes[1], square=True, cbar=False, xticklabels=2, yticklabels=2)
plt.xticks(rotation='vertical')
plt.yticks(rotation='horizontal')
g.set_title('Mean covariance - feets')
labels = epochs.events[:, -1]
evoked = epochs.average()
###############################################################################
# Decoding with Xdawn + MDM
n_components = 3 # pick some components
# Define a monte-carlo cross-validation generator (reduce variance):
cv = KFold(n_splits=10, random_state=42)
pr = np.zeros(len(labels))
epochs_data = epochs.get_data()
print('Multiclass classification with XDAWN + MDM')
clf = make_pipeline(XdawnCovariances(n_components), MDM())
for train_idx, test_idx in cv.split(epochs_data):
y_train, y_test = labels[train_idx], labels[test_idx]
clf.fit(epochs_data[train_idx], y_train)
pr[test_idx] = clf.predict(epochs_data[test_idx])
print(classification_report(labels, pr))
###############################################################################
# plot the spatial patterns
xd = XdawnCovariances(n_components)
xd.fit(epochs_data, labels)
evoked.data = xd.Xd_.patterns_.T
evoked.times = np.arange(evoked.data.shape[0])
def _fit_single(X, y=None, n_clusters=2, init='random', random_state=None,
metric='riemann', max_iter=100, tol=1e-4, n_jobs=1):
"""helper to fit a single run of centroid."""
# init random state if provided
mdm = MDM(metric=metric, n_jobs=n_jobs)
squared_nomrs = [numpy.linalg.norm(x, ord='fro')**2 for x in X]
mdm.covmeans_ = _init_centroids(X, n_clusters, init,
random_state=random_state,
x_squared_norms=squared_nomrs)
if y is not None:
mdm.classes_ = numpy.unique(y)
else:
mdm.classes_ = numpy.arange(n_clusters)
labels = mdm.predict(X)
k = 0
while True:
old_labels = labels.copy()
mdm.fit(X, old_labels)
dist = mdm._predict_distances(X)
labels = mdm.classes_[dist.argmin(axis=1)]
def fit(self, X, y=None):
"""Fit the potato from covariance matrices.
Parameters
----------
X : ndarray, shape (n_trials, n_channels, n_channels)
ndarray of SPD matrices.
y : ndarray | None (default None)
Not used, here for compatibility with sklearn API.
Returns
-------
self : Potato instance
The Potato instance.
"""
self._mdm = MDM(metric=self.metric)
if y is not None:
if len(y) != len(X):
raise ValueError('y must be the same lenght of X')
classes = numpy.int32(numpy.unique(y))
if len(classes) > 2:
raise ValueError('number of classes must be maximum 2')
if self.pos_label not in classes:
raise ValueError('y must contain a positive class')
y_old = numpy.int32(numpy.array(y) == self.pos_label)
else:
y_old = numpy.ones(len(X))
def __init__(self, mean_metric='riemann', dist_metric='riemann', n_jobs=1, name="MDM"):
super(MDM, self).__init__(name)
self.mean_metric = mean_metric
self.dist_metric = dist_metric
self.n_jobs = n_jobs
metric = {'mean': self.mean_metric, 'distance': self.dist_metric}
self.model = riemann_MDM(metric=metric, n_jobs=self.n_jobs)
def __init_transform(self, X):
"""Init tr"""
self.mdm = MDM(metric=self.metric, n_jobs=self.n_jobs)
if self.mode == 'ftest':
self.global_mean = mean_covariance(X, metric=self.mdm.metric_mean)
elif self.mode == 'pairwise':
X = pairwise_distance(X, metric=self.mdm.metric_dist)**2
return X
def transform(self, X):
"""transform."""
mdm = MDM(metric=self.metric, n_jobs=self.km.n_jobs)
mdm.covmeans_ = self.covmeans_
return mdm._predict_distances(X)
def __init__(self,
n_perms=100,
model=MDM(),
cv=3,
scoring=None,
n_jobs=1,
random_state=42):
"""Init."""
self.n_perms = n_perms
self.model = model
self.cv = cv
self.scoring = scoring
self.n_jobs = n_jobs
self.random_state = random_state