Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def check_linop_adjoint(A, dtype=np.float, device=sp.cpu_device):
device = sp.Device(device)
x = sp.randn(A.ishape, dtype=dtype, device=device)
y = sp.randn(A.oshape, dtype=dtype, device=device)
xp = device.xp
with device:
lhs = xp.vdot(A * x, y)
rhs = xp.vdot(x, A.H * y)
xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5)
def check_linop_adjoint(A, dtype=np.float, device=sp.cpu_device):
device = sp.Device(device)
x = sp.randn(A.ishape, dtype=dtype, device=device)
y = sp.randn(A.oshape, dtype=dtype, device=device)
xp = device.xp
with device:
lhs = xp.vdot(A * x, y)
rhs = xp.vdot(x, A.H * y)
xp.testing.assert_allclose(lhs, rhs, atol=1e-5, rtol=1e-5)
def test_noncart_sense_model_batch(self):
img_shape = [16, 16]
mps_shape = [8, 16, 16]
img = sp.randn(img_shape, dtype=np.complex)
mps = sp.randn(mps_shape, dtype=np.complex)
y, x = np.mgrid[:16, :16]
coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1)
coord = coord.astype(np.float)
A = linop.Sense(mps, coord=coord, coil_batch_size=1)
check_linop_adjoint(A, dtype=np.complex)
npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]).ravel(),
(A * img).ravel(), atol=0.1, rtol=0.1)
def test_noncart_sense_model(self):
img_shape = [16, 16]
mps_shape = [8, 16, 16]
img = sp.randn(img_shape, dtype=np.complex)
mps = sp.randn(mps_shape, dtype=np.complex)
y, x = np.mgrid[:16, :16]
coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1)
coord = coord.astype(np.float)
A = linop.Sense(mps, coord=coord)
check_linop_adjoint(A, dtype=np.complex)
npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]).ravel(),
(A * img).ravel(), atol=0.1, rtol=0.1)
def test_sense_model_batch(self):
img_shape = [16, 16]
mps_shape = [8, 16, 16]
img = sp.randn(img_shape, dtype=np.complex)
mps = sp.randn(mps_shape, dtype=np.complex)
mask = np.zeros(img_shape, dtype=np.complex)
mask[::2, ::2] = 1.0
A = linop.Sense(mps, coil_batch_size=1)
check_linop_adjoint(A, dtype=np.complex)
npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]),
A * img)
def test_noncart_sense_model(self):
img_shape = [16, 16]
mps_shape = [8, 16, 16]
img = sp.randn(img_shape, dtype=np.complex)
mps = sp.randn(mps_shape, dtype=np.complex)
y, x = np.mgrid[:16, :16]
coord = np.stack([np.ravel(y - 8), np.ravel(x - 8)], axis=1)
coord = coord.astype(np.float)
A = linop.Sense(mps, coord=coord)
check_linop_adjoint(A, dtype=np.complex)
npt.assert_allclose(sp.fft(img * mps, axes=[-1, -2]).ravel(),
(A * img).ravel(), atol=0.1, rtol=0.1)
def test_sense_model_with_comm(self):
img_shape = [16, 16]
mps_shape = [8, 16, 16]
comm = sp.Communicator()
img = sp.randn(img_shape, dtype=np.complex)
mps = sp.randn(mps_shape, dtype=np.complex)
comm.allreduce(img)
comm.allreduce(mps)
ksp = sp.fft(img * mps, axes=[-1, -2])
A = linop.Sense(mps[comm.rank::comm.size], comm=comm)
npt.assert_allclose(A.H(ksp[comm.rank::comm.size]), np.sum(
sp.ifft(ksp, axes=[-1, -2]) * mps.conjugate(), 0))
def _get_vars(self):
self.t_idx = sp.ShuffledNumbers(self.num_batches)
xp = self.device.xp
with self.device:
self.y_t = xp.empty((self.batch_size, ) + self.y.shape[1:], dtype=self.dtype)
self.L = sp.randn(self.L_shape, dtype=self.dtype, device=self.device)
if self.multi_channel:
self.L /= xp.sum(xp.abs(self.L)**2, axis=(0, ) + tuple(range(-self.data_ndim, 0)), keepdims=True)**0.5
else:
self.L /= xp.sum(xp.abs(self.L)**2, axis=tuple(range(-self.data_ndim, 0)), keepdims=True)**0.5
self.L_old = xp.empty(self.L_shape, dtype=self.dtype)
self.R = ConvSparseCoefficients(self.y, self.L, lamda=self.lamda,
multi_channel=self.multi_channel,
mode=self.mode, max_iter=self.max_inner_iter,
max_power_iter=self.max_power_iter)