Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
else:
x = x * torch.prod(grid_size)
# scaling coefficient multiply
while len(scaling_coef.shape) < len(x.shape):
scaling_coef = scaling_coef.unsqueeze(0)
# try to broadcast multiply - batch over coil if not enough memory
raise_error = False
try:
x = conj_complex_mult(x, scaling_coef, dim=2)
except RuntimeError as e:
if 'out of memory' in str(e) and not raise_error:
torch.cuda.empty_cache()
for coilind in range(x.shape[1]):
x[:, coilind, ...] = conj_complex_mult(
x[:, coilind:coilind + 1, ...], scaling_coef, dim=2)
raise_error = True
else:
raise e
except BaseException:
raise e
return x
# indexing locations
gridind = (kofflist + Jval.unsqueeze(1)).to(dtype)
distind = torch.round(
(tm - gridind) * L.unsqueeze(1)).to(dtype=int_type)
gridind = gridind.to(int_type)
arr_ind = torch.zeros((M,), dtype=int_type, device=device)
coef = torch.stack((
torch.ones(M, dtype=dtype, device=device),
torch.zeros(M, dtype=dtype, device=device)
))
for d in range(ndims): # spatial dimension
if conjcoef:
coef = conj_complex_mult(
coef,
table[d][:, distind[d, :] + centers[d]],
dim=0
)
else:
coef = complex_mult(
coef,
table[d][:, distind[d, :] + centers[d]],
dim=0
)
arr_ind = arr_ind + torch.remainder(gridind[d, :], dims[d]).view(-1) * \
torch.prod(dims[d + 1:])
return coef, arr_ind
Returns:
tensor: The images after forward and adjoint NUFFT of size
(1, 2) + im_size.
"""
# multiply sensitivities
x = complex_mult(x, smap, dim=1)
# Toeplitz NUFFT
x = fft_filter(
x.unsqueeze(0),
kern.unsqueeze(0),
norm=norm
).squeeze(0)
# conjugate sum
x = torch.sum(conj_complex_mult(x, smap, dim=1), dim=0, keepdim=True)
return x
x = x[tuple(map(slice, crop_starts, crop_ends))]
# scaling
if norm == 'ortho':
x = x * torch.sqrt(torch.prod(grid_size))
else:
x = x * torch.prod(grid_size)
# scaling coefficient multiply
while len(scaling_coef.shape) < len(x.shape):
scaling_coef = scaling_coef.unsqueeze(0)
# try to broadcast multiply - batch over coil if not enough memory
raise_error = False
try:
x = conj_complex_mult(x, scaling_coef, dim=2)
except RuntimeError as e:
if 'out of memory' in str(e) and not raise_error:
torch.cuda.empty_cache()
for coilind in range(x.shape[1]):
x[:, coilind, ...] = conj_complex_mult(
x[:, coilind:coilind + 1, ...], scaling_coef, dim=2)
raise_error = True
else:
raise e
except BaseException:
raise e
return x
interpob (dictionary): A NUFFT interpolation object.
interp_mats (dictionary, default=None): A dictionary of sparse
interpolation matrices. If not None, the NUFFT operation will use
the matrices for interpolation.
Returns:
tensor: The images after adjoint NUFFT of size (nbatch, ncoil, 2) +
im_size.
"""
# adjoint nufft
x = AdjKbNufftFunction.apply(y, om, interpob, interp_mats)
# conjugate sum
x = list(x)
for i in range(len(x)):
x[i] = torch.sum(conj_complex_mult(
x[i], smap[i], dim=1), dim=0, keepdim=True)
if isinstance(smap, torch.Tensor):
x = torch.stack(x)
return x