Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _get_params(self):
self.device = sp.Device(self.device)
self.dtype = self.y.dtype
self.data_ndim = self.y.ndim - self.multi_channel - 1
if self.checkpoint_path is not None:
self.checkpoint_path = pathlib.Path(self.checkpoint_path)
self.checkpoint_path.mkdir(parents=True, exist_ok=True)
self.batch_size = min(len(self.y), self.batch_size)
self.num_batches = len(self.y) // self.batch_size
self.L_shape = [self.num_filters] + [self.filt_width] * self.data_ndim
if self.multi_channel:
self.L_shape = [self.y.shape[1]] + self.L_shape
if self.mode == 'full':
self.R_t_shape = [self.batch_size, self.num_filters] + [i - self.filt_width + 1
for i in self.y.shape[-self.data_ndim:]]
def check_linop_adjoint(A, dtype=np.float, device=sp.cpu_device):
device = sp.Device(device)
x = sp.randn(A.ishape, dtype=dtype, device=device)
y = sp.randn(A.oshape, dtype=dtype, device=device)
xp = device.xp
with device:
lhs = xp.vdot(A * x, y)
rhs = xp.vdot(x, A.H * y)
xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5)
def check_linop_adjoint(A, dtype=np.float, device=sp.cpu_device):
device = sp.Device(device)
x = sp.randn(A.ishape, dtype=dtype, device=device)
y = sp.randn(A.oshape, dtype=dtype, device=device)
xp = device.xp
with device:
lhs = xp.vdot(A * x, y)
rhs = xp.vdot(x, A.H * y)
xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5)
Martin Uecker, Peng Lai, Mark J. Murphy, Patrick Virtue, Michael Elad,
John M. Pauly, Shreyas S. Vasanawala, and Michael Lustig
ESPIRIT - An Eigenvalue Approach to Autocalibrating Parallel MRI:
Where SENSE meets GRAPPA.
Magnetic Resonance in Medicine, 71:990-1001 (2014)
"""
img_ndim = ksp.ndim - 1
num_coils = len(ksp)
with sp.get_device(ksp):
# Get calibration region
calib_shape = [num_coils] + [calib_width] * img_ndim
calib = sp.resize(ksp, calib_shape)
calib = sp.to_device(calib, device)
device = sp.Device(device)
xp = device.xp
with device:
# Get calibration matrix
kernel_shape = [num_coils] + [kernel_width] * img_ndim
kernel_strides = [1] * (img_ndim + 1)
mat = sp.array_to_blocks(calib, kernel_shape, kernel_strides)
mat = mat.reshape([-1, sp.prod(kernel_shape)])
# Perform SVD on calibration matrix
_, S, VH = xp.linalg.svd(mat, full_matrices=False)
VH = VH[S > thresh * S.max(), :]
# Get kernels
num_kernels = len(VH)
kernels = VH.reshape([num_kernels] + kernel_shape)
img_shape = ksp.shape[1:]
Args:
mps (array): sensitivity maps of shape [num_coils] + image shape.
weights (array): k-space weights.
coord (array): k-space coordinates of shape [...] + [ndim].
lamda (float): regularization.
Returns:
array: k-space preconditioner of same shape as k-space.
"""
dtype = mps.dtype
if weights is not None:
weights = sp.to_device(weights, device)
device = sp.Device(device)
xp = device.xp
mps_shape = list(mps.shape)
img_shape = mps_shape[1:]
img2_shape = [i * 2 for i in img_shape]
ndim = len(img_shape)
scale = sp.prod(img2_shape)**1.5 / sp.prod(img_shape)
with device:
if coord is None:
idx = (slice(None, None, 2), ) * ndim
ones = xp.zeros(img2_shape, dtype=dtype)
if weights is None:
ones[idx] = 1
else:
def _get_params(self):
self.device = sp.Device(self.device)
self.dtype = self.y.dtype
self.num_data = len(self.y)
self.filt_width = self.L.shape[-1]
self.num_filters = self.L.shape[self.multi_channel]
self.data_ndim = self.y.ndim - self.multi_channel - 1
if self.mode == 'full':
self.R_shape = [self.num_data, self.num_filters] + [i - self.filt_width + 1
for i in self.y.shape[-self.data_ndim:]]
else:
self.R_shape = [self.num_data, self.num_filters] + [i + self.filt_width - 1
for i in self.y.shape[-self.data_ndim:]]
def __init__(self, y,
mps_ker_width=16, ksp_calib_width=24,
lamda=0, device=sp.cpu_device, comm=None,
weights=None, coord=None, max_iter=10,
max_inner_iter=10, normalize=True, show_pbar=True):
self.y = y
self.mps_ker_width = mps_ker_width
self.ksp_calib_width = ksp_calib_width
self.lamda = lamda
self.weights = weights
self.coord = coord
self.max_iter = max_iter
self.max_inner_iter = max_inner_iter
self.normalize = normalize
self.device = sp.Device(device)
self.comm = comm
self.dtype = y.dtype
self.num_coils = len(y)
if comm is not None:
show_pbar = show_pbar and comm.rank == 0
self._get_data()
self._get_vars()
self._get_alg()
super().__init__(self.alg, show_pbar=show_pbar)
def use_device(self, device):
self.device = sp.Device(device)