How to use the pyriemann.classification.MDM function in pyriemann

To help you get started, we’ve selected a few pyriemann examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github alexandrebarachant / pyRiemann / tests / test_classification.py View on Github external
def test_MDM_predict():
    """Test prediction of MDM"""
    covset = generate_cov(100, 3)
    labels = np.array([0, 1]).repeat(50)
    mdm = MDM(metric='riemann')
    mdm.fit(covset, labels)
    mdm.predict(covset)

    # test fit_predict
    mdm = MDM(metric='riemann')
    mdm.fit_predict(covset, labels)

    # test transform
    mdm.transform(covset)

    # predict proba
    mdm.predict_proba(covset)

    # test n_jobs
    mdm = MDM(metric='riemann', n_jobs=2)
    mdm.fit(covset, labels)
    mdm.predict(covset)
github alexandrebarachant / pyRiemann / tests / test_classification.py View on Github external
mdm = MDM(metric='riemann')
    mdm.fit(covset, labels)
    mdm.predict(covset)

    # test fit_predict
    mdm = MDM(metric='riemann')
    mdm.fit_predict(covset, labels)

    # test transform
    mdm.transform(covset)

    # predict proba
    mdm.predict_proba(covset)

    # test n_jobs
    mdm = MDM(metric='riemann', n_jobs=2)
    mdm.fit(covset, labels)
    mdm.predict(covset)
github alexandrebarachant / pyRiemann / examples / motor-imagery / plot_single.py View on Github external
lr = LogisticRegression()
csp = CSP(n_components=4, reg='ledoit_wolf', log=True)

clf = Pipeline([('CSP', csp), ('LogisticRegression', lr)])
scores = cross_val_score(clf, epochs_data_train, labels, cv=cv, n_jobs=1)

# Printing the results
class_balance = np.mean(labels == labels[0])
class_balance = max(class_balance, 1. - class_balance)
print("CSP + LDA Classification accuracy: %f / Chance level: %f" %
      (np.mean(scores), class_balance))

###############################################################################
# Display MDM centroid

mdm = MDM()
mdm.fit(cov_data_train, labels)

fig, axes = plt.subplots(1, 2, figsize=[8, 4])
ch_names = [ch.replace('.', '') for ch in epochs.ch_names]

df = pd.DataFrame(data=mdm.covmeans_[0], index=ch_names, columns=ch_names)
g = sns.heatmap(
    df, ax=axes[0], square=True, cbar=False, xticklabels=2, yticklabels=2)
g.set_title('Mean covariance - hands')

df = pd.DataFrame(data=mdm.covmeans_[1], index=ch_names, columns=ch_names)
g = sns.heatmap(
    df, ax=axes[1], square=True, cbar=False, xticklabels=2, yticklabels=2)
plt.xticks(rotation='vertical')
plt.yticks(rotation='horizontal')
g.set_title('Mean covariance - feets')
github alexandrebarachant / pyRiemann / examples / ERP / plot_classify_MEG_mdm.py View on Github external
labels = epochs.events[:, -1]
evoked = epochs.average()

###############################################################################
# Decoding with Xdawn + MDM

n_components = 3  # pick some components

# Define a monte-carlo cross-validation generator (reduce variance):
cv = KFold(n_splits=10, random_state=42)
pr = np.zeros(len(labels))
epochs_data = epochs.get_data()

print('Multiclass classification with XDAWN + MDM')

clf = make_pipeline(XdawnCovariances(n_components), MDM())

for train_idx, test_idx in cv.split(epochs_data):
    y_train, y_test = labels[train_idx], labels[test_idx]

    clf.fit(epochs_data[train_idx], y_train)
    pr[test_idx] = clf.predict(epochs_data[test_idx])

print(classification_report(labels, pr))

###############################################################################
# plot the spatial patterns
xd = XdawnCovariances(n_components)
xd.fit(epochs_data, labels)

evoked.data = xd.Xd_.patterns_.T
evoked.times = np.arange(evoked.data.shape[0])
github alexandrebarachant / pyRiemann / pyriemann / clustering.py View on Github external
def _fit_single(X, y=None, n_clusters=2, init='random', random_state=None,
                metric='riemann', max_iter=100, tol=1e-4, n_jobs=1):
    """helper to fit a single run of centroid."""
    # init random state if provided
    mdm = MDM(metric=metric, n_jobs=n_jobs)
    squared_nomrs = [numpy.linalg.norm(x, ord='fro')**2 for x in X]
    mdm.covmeans_ = _init_centroids(X, n_clusters, init,
                                    random_state=random_state,
                                    x_squared_norms=squared_nomrs)
    if y is not None:
        mdm.classes_ = numpy.unique(y)
    else:
        mdm.classes_ = numpy.arange(n_clusters)

    labels = mdm.predict(X)
    k = 0
    while True:
        old_labels = labels.copy()
        mdm.fit(X, old_labels)
        dist = mdm._predict_distances(X)
        labels = mdm.classes_[dist.argmin(axis=1)]
github alexandrebarachant / pyRiemann / pyriemann / clustering.py View on Github external
def fit(self, X, y=None):
        """Fit the potato from covariance matrices.

        Parameters
        ----------
        X : ndarray, shape (n_trials, n_channels, n_channels)
            ndarray of SPD matrices.
        y : ndarray | None (default None)
            Not used, here for compatibility with sklearn API.

        Returns
        -------
        self : Potato instance
            The Potato instance.
        """
        self._mdm = MDM(metric=self.metric)

        if y is not None:
            if len(y) != len(X):
                raise ValueError('y must be the same lenght of X')

            classes = numpy.int32(numpy.unique(y))

            if len(classes) > 2:
                raise ValueError('number of classes must be maximum 2')

            if self.pos_label not in classes:
                raise ValueError('y must contain a positive class')

            y_old = numpy.int32(numpy.array(y) == self.pos_label)
        else:
            y_old = numpy.ones(len(X))
github ZhangXiao96 / EEGAdversary / lib / Blocks.py View on Github external
def __init__(self, mean_metric='riemann', dist_metric='riemann', n_jobs=1, name="MDM"):
        super(MDM, self).__init__(name)
        self.mean_metric = mean_metric
        self.dist_metric = dist_metric
        self.n_jobs = n_jobs
        metric = {'mean': self.mean_metric, 'distance': self.dist_metric}
        self.model = riemann_MDM(metric=metric, n_jobs=self.n_jobs)
github alexandrebarachant / pyRiemann / pyriemann / stats.py View on Github external
def __init_transform(self, X):
        """Init tr"""
        self.mdm = MDM(metric=self.metric, n_jobs=self.n_jobs)
        if self.mode == 'ftest':
            self.global_mean = mean_covariance(X, metric=self.mdm.metric_mean)
        elif self.mode == 'pairwise':
            X = pairwise_distance(X, metric=self.mdm.metric_dist)**2
        return X
github alexandrebarachant / pyRiemann / pyriemann / clustering.py View on Github external
def transform(self, X):
        """transform."""
        mdm = MDM(metric=self.metric, n_jobs=self.km.n_jobs)
        mdm.covmeans_ = self.covmeans_
        return mdm._predict_distances(X)
github alexandrebarachant / pyRiemann / pyriemann / stats.py View on Github external
def __init__(self,
                 n_perms=100,
                 model=MDM(),
                 cv=3,
                 scoring=None,
                 n_jobs=1,
                 random_state=42):
        """Init."""
        self.n_perms = n_perms
        self.model = model
        self.cv = cv
        self.scoring = scoring
        self.n_jobs = n_jobs
        self.random_state = random_state