Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""Symmetric (Hermitian) eigendecomposition."""
a_shape = c.GetShape(a)
dtype = a_shape.element_type()
dims = a_shape.dimensions()
assert len(dims) >= 2
m, n = dims[-2:]
assert m == n
batch_dims = tuple(dims[:-2])
num_bd = len(batch_dims)
batch = _prod(batch_dims)
layout = (num_bd, num_bd + 1) + tuple(range(num_bd - 1, -1, -1))
if n <= 32:
kernel = b"cusolver_syevj"
lwork, opaque = cusolver_kernels.build_syevj_descriptor(
np.dtype(dtype), lower, batch, n)
else:
kernel = b"cusolver_syevd"
lwork, opaque = cusolver_kernels.build_syevd_descriptor(
np.dtype(dtype), lower, batch, n)
eigvals_type = _real_type(dtype)
out = c.CustomCall(
kernel,
operands=(a,),
shape_with_layout=_Shape.tuple_shape((
_Shape.array_shape(dtype, dims, layout),
_Shape.array_shape(
np.dtype(eigvals_type), batch_dims + (n,),
tuple(range(num_bd, -1, -1))),
_Shape.array_shape(