How to use the geomstats.backend.einsum function in geomstats

To help you get started, we’ve selected a few geomstats examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github geomstats / geomstats / geomstats / special_orthogonal_group.py View on Github external
def compose(self, point_1, point_2, point_type=None):
        """
        Compose two elements of SO(n).
        """
        if point_type is None:
            point_type = self.default_point_type

        point_1 = self.regularize(point_1, point_type=point_type)
        point_2 = self.regularize(point_2, point_type=point_type)

        if point_type == 'vector':
            point_1 = self.matrix_from_rotation_vector(point_1)
            point_2 = self.matrix_from_rotation_vector(point_2)

        point_prod = gs.einsum('ijk,ikl->ijl', point_1, point_2)

        if point_type == 'vector':
            point_prod = self.rotation_vector_from_matrix(point_prod)

        point_prod = self.regularize(
            point_prod, point_type=point_type)
        return point_prod
github geomstats / geomstats / geomstats / geometry / spd_matrices.py View on Github external
def _aux_inner_product(tangent_vec_a, tangent_vec_b, inv_base_point):
        """Compute the inner-product (auxiliary).

        Parameters
        ----------
        tangent_vec_a : array-like, shape=[..., n, n]
        tangent_vec_b : array-like, shape=[..., n, n]
        inv_base_point : array-like, shape=[..., n, n]

        Returns
        -------
        inner_product : array-like, shape=[..., n, n]
        """
        aux_a = gs.einsum(
            '...ij,...jk->...ik', inv_base_point, tangent_vec_a)
        aux_b = gs.einsum(
            '...ij,...jk->...ik', inv_base_point, tangent_vec_b)
        prod = gs.einsum(
            '...ij,...jk->...ik', aux_a, aux_b)
        inner_product = gs.trace(prod, axis1=-2, axis2=-1)
        return inner_product
github geomstats / geomstats / geomstats / special_orthogonal_group.py View on Github external
mask_else = ~mask_0
            mask_else_float = gs.cast(mask_else, gs.float32)

            # This avoids division by 0.
            angle += mask_0_float * 1.

            coef_1 += mask_else_float * (gs.sin(angle) / angle)
            coef_2 += mask_else_float * (
                (1 - gs.cos(angle)) / (angle ** 2))

            term_1 = gs.zeros((n_rot_vecs,) + (self.n,) * 2)
            term_2 = gs.zeros_like(term_1)

            coef_1 = gs.squeeze(coef_1, axis=0)
            term_1 = (gs.eye(self.dimension)
                      + gs.einsum('n,njk->njk', coef_1, skew_rot_vec))

            term_2 = (coef_2
                      + gs.einsum('nij,njk->nik', skew_rot_vec, skew_rot_vec))
            #for i in range(n_rot_vecs):
            #    term_1[i] = (gs.eye(self.dimension)
            #                 + coef_1[i] * skew_rot_vec[i])
            #    term_2[i] = (coef_2[i]
            #                 * gs.matmul(skew_rot_vec[i], skew_rot_vec[i]))
            rot_mat = term_1 + term_2

            rot_mat = self.projection(rot_mat)

        else:
            skew_mat = self.skew_matrix_from_vector(rot_vec)
            rot_mat = self.embedding_manifold.group_exp_from_identity(skew_mat)
github geomstats / geomstats / geomstats / geometry / riemannian_metric.py View on Github external
def while_loop_body(iteration, mean, variance, sq_dist):
                print('pass while loop body')
                print('mean', mean)
                print('points', points)
                logs = self.log(point=points, base_point=mean)
                print('logs', logs, logs.shape)

                tangent_mean = gs.einsum('nk,nj->j', weights, logs)

                print('tangent mean', tangent_mean)

                tangent_mean /= sum_weights

                mean_next = self.exp(
                    tangent_vec=tangent_mean,
                    base_point=mean)

                print('Next mean', mean_next)

                sq_dist = self.squared_dist(mean_next, mean)
                sq_dists_between_iterates.append(sq_dist)

                variance = self.variance(points=points,
                                         weights=weights,
github geomstats / geomstats / geomstats / geometry / spd_matrices.py View on Github external
inv_sqrt_base_point : array-like, shape=[.., n, n]

        Returns
        -------
        log : array-like, shape=[..., n, n]
        """
        point_near_id = gs.einsum(
            '...ij,...jk->...ik', inv_sqrt_base_point, point)
        point_near_id = gs.einsum(
            '...ij,...jk->...ik', point_near_id, inv_sqrt_base_point)
        point_near_id = GeneralLinear.to_symmetric(point_near_id)
        log_at_id = SPDMatrices.logm(point_near_id)

        log = gs.einsum(
            '...ij,...jk->...ik', sqrt_base_point, log_at_id)
        log = gs.einsum(
            '...ij,...jk->...ik', log, sqrt_base_point)
        return log
github geomstats / geomstats / examples / loss_and_gradient_se3.py View on Github external
differential_scalar_t = gs.transpose(differential_scalar, axes=(1, 0))

        upper_left_block = gs.hstack(
            (differential_scalar_t, differential_vec[0]))
        upper_right_block = gs.zeros((3, 3))
        lower_right_block = gs.eye(3)
        lower_left_block = gs.zeros((3, 4))

        top = gs.hstack((upper_left_block, upper_right_block))
        bottom = gs.hstack((lower_left_block, lower_right_block))

        differential = gs.vstack((top, bottom))
        differential = gs.expand_dims(differential, axis=0)

        grad = gs.einsum('ni,nij->ni', grad, differential)

    grad = gs.squeeze(grad, axis=0)
    return grad
github geomstats / geomstats / geomstats / geometry / hypersphere.py View on Github external
Tangent vector at base point to be transported.
        tangent_vec_b : array-like, shape=[..., dim + 1]
            Tangent vector at base point, along which the parallel transport
            is computed.
        base_point : array-like, shape=[..., dim + 1]
            Point on the hypersphere.

        Returns
        -------
        transported_tangent_vec: array-like, shape=[..., dim + 1]
            Transported tangent vector at exp_(base_point)(tangent_vec_b).
        """
        theta = gs.linalg.norm(tangent_vec_b, axis=-1)
        normalized_b = gs.einsum('..., ...i->...i', 1 / theta, tangent_vec_b)
        pb = gs.einsum('...i,...i->...', tangent_vec_a, normalized_b)
        p_orth = tangent_vec_a - gs.einsum('..., ...i->...i', pb, normalized_b)
        transported = \
            - gs.einsum('..., ...i->...i', gs.sin(theta) * pb, base_point)\
            + gs.einsum('..., ...i->...i', gs.cos(theta) * pb, normalized_b)\
            + p_orth
        return transported
github geomstats / geomstats / geomstats / geometry / hyperbolic_space.py View on Github external
+ COSH_TAYLOR_COEFFS[4] * norm_tangent_vec ** 4
                      + COSH_TAYLOR_COEFFS[6] * norm_tangent_vec ** 6
                      + COSH_TAYLOR_COEFFS[8] * norm_tangent_vec ** 8)
            coef_2 += mask_0_float * (
                      1. + SINH_TAYLOR_COEFFS[3] * norm_tangent_vec ** 2
                      + SINH_TAYLOR_COEFFS[5] * norm_tangent_vec ** 4
                      + SINH_TAYLOR_COEFFS[7] * norm_tangent_vec ** 6
                      + SINH_TAYLOR_COEFFS[9] * norm_tangent_vec ** 8)
            # This avoids dividing by 0.
            norm_tangent_vec += mask_0_float * 1.0
            coef_1 += mask_else_float * (gs.cosh(norm_tangent_vec))
            coef_2 += mask_else_float * (
                (gs.sinh(norm_tangent_vec) / (norm_tangent_vec)))

            exp = (gs.einsum('ni,nj->nj', coef_1, base_point)
                   + gs.einsum('ni,nj->nj', coef_2, tangent_vec))

            hyperbolic_space = HyperbolicSpace(dimension=self.dimension)
            exp = hyperbolic_space.regularize(exp)
            return exp

        elif self.point_type == 'ball':

            norm_base_point = base_point.norm(2,
                                              -1, keepdim=True).expand_as(
                                                base_point)

            lambda_base_point = 1 / (1 - norm_base_point ** 2)

            norm_tangent_vector = tangent_vec.norm(2,
                                                   -1, keepdim=True).expand_as(
                                                    tangent_vec)
github geomstats / geomstats / geomstats / geometry / special_orthogonal.py View on Github external
mask_else_float = gs.cast(mask_else, gs.float32) + self.epsilon

        regularized_vec = gs.zeros_like(tangent_vec)
        regularized_vec += gs.einsum(
            '...,...i->...i', mask_0_float, tangent_vec)

        tangent_vec_canonical_norm += mask_0_float

        coef = gs.zeros_like(tangent_vec_metric_norm)
        coef += mask_else_float * (
            tangent_vec_metric_norm
            / tangent_vec_canonical_norm)

        coef_tangent_vec = gs.einsum(
            '...,...i->...i', coef, tangent_vec)
        regularized_vec += gs.einsum(
            '...,...i->...i',
            mask_else_float,
            self.regularize(coef_tangent_vec))

        coef += mask_0_float
        regularized_vec = gs.einsum(
            '...,...i->...i', 1. / coef, regularized_vec)
        regularized_vec = gs.einsum(
            '...,...i->...i', mask_else_float, regularized_vec)
        return regularized_vec
github geomstats / geomstats / geomstats / geometry / invariant_metric.py View on Github external
Parameters
        ----------
        point : array-like, shape=[..., dim]
            Point in the group.

        Returns
        -------
        log : array-like, shape=[..., dim]
            Tangent vector at the identity equal to the Riemannian logarithm
            of point at the identity.
        """
        point = self.group.regularize(point)
        inner_prod_mat = self.inner_product_mat_at_identity
        inv_inner_prod_mat = GeneralLinear.inverse(inner_prod_mat)
        sqrt_inv_inner_prod_mat = gs.linalg.sqrtm(inv_inner_prod_mat)
        log = gs.einsum('...i,...ij->...j', point, sqrt_inv_inner_prod_mat)
        log = self.group.regularize_tangent_vec_at_identity(
            tangent_vec=log, metric=self)
        return log