Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
ob = KbInterpBack(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
ob = KbNufft(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
ob = AdjKbNufft(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
ob = MriSenseNufft(
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
ob = AdjMriSenseNufft(
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
# test 2d tuple inputs
im_sz = (256, 256)
smap = torch.randn(*((1,) + im_sz))
grid_sz = (512, 512)
n_shift = (128, 128)
numpoints = (6, 6)
table_oversamp = (2**10, 2**10)
kbwidth = (2.34, 2.34)
order = (0, 0)
norm = 'None'
ob = KbInterpForw(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
x = params_2d['x']
y = params_2d['y']
ktraj = params_2d['ktraj']
smap = params_2d['smap']
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
y.requires_grad = True
x = adjsensenufft_ob.forward(y, ktraj, interp_mats)
((x ** 2) / 2).sum().backward()
y_grad = y.grad.clone().detach()
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
ob = KbInterpBack(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order)
ob = KbNufft(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
ob = AdjKbNufft(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
ob = MriSenseNufft(
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
ob = AdjMriSenseNufft(
smap=smap, im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
table_oversamp=table_oversamp, kbwidth=kbwidth, order=order, norm=norm)
# test 3d tuple inputs
im_sz = (10, 256, 256)
smap = torch.randn(*((1,) + im_sz))
grid_sz = (10, 512, 512)
n_shift = (5, 128, 128)
numpoints = (6, 6, 6)
table_oversamp = (2**10, 2**10, 2**10)
kbwidth = (2.34, 2.34, 2.34)
order = (0, 0, 0)
norm = 'None'
ob = KbInterpForw(
im_size=im_sz, grid_size=grid_sz, n_shift=n_shift, numpoints=numpoints,
cur_table = KbNufft(im_sz, order=order, kbwidth=kbwidth).table
check_tables(base_table, cur_table)
cur_table = KbInterpBack(
im_sz, order=order, kbwidth=kbwidth).table
check_tables(base_table, cur_table)
cur_table = KbInterpForw(
im_sz, order=order, kbwidth=kbwidth).table
check_tables(base_table, cur_table)
cur_table = MriSenseNufft(
smap, im_sz, order=order, kbwidth=kbwidth).table
check_tables(base_table, cur_table)
cur_table = AdjMriSenseNufft(
smap, im_sz, order=order, kbwidth=kbwidth).table
check_tables(base_table, cur_table)
x = params_3d['x']
y = params_3d['y']
ktraj = params_3d['ktraj']
smap = params_3d['smap']
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x.requires_grad = True
y = sensenufft_ob.forward(x, ktraj, interp_mats)
((y ** 2) / 2).sum().backward()
x_grad = x.grad.clone().detach()
y = params_2d['y']
ktraj = params_2d['ktraj']
smap = params_2d['smap']
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
x_forw = sensenufft_ob(x, ktraj)
y_back = adjsensenufft_ob(y, ktraj)
inprod1 = inner_product(y, x_forw, dim=2)
inprod2 = inner_product(y_back, x, dim=2)
assert torch.norm(inprod1 - inprod2) < norm_tol
y = params_2d['y']
ktraj = params_2d['ktraj']
smap = params_2d['smap']
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x.requires_grad = True
y = sensenufft_ob.forward(x, ktraj, interp_mats)
((y ** 2) / 2).sum().backward()
y = params_2d['y']
ktraj = params_2d['ktraj']
smap = params_2d['smap']
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x_forw = sensenufft_ob(x, ktraj, interp_mats)
y_back = adjsensenufft_ob(y, ktraj, interp_mats)
inprod1 = inner_product(y, x_forw, dim=2)
x = params_2d['x']
y = params_2d['y']
ktraj = params_2d['ktraj']
smap = params_2d['smap']
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
x_forw = sensenufft_ob(x, ktraj)
y_back = adjsensenufft_ob(y, ktraj)
inprod1 = inner_product(y, x_forw, dim=2)
inprod2 = inner_product(y_back, x, dim=2)
assert torch.norm(inprod1 - inprod2) < norm_tol
num_nuffts = 5
else:
dtype = torch.float
if use_toep:
num_nuffts = 50
else:
num_nuffts = 20
cpudevice = torch.device('cpu')
image = image.to(dtype=dtype)
ktraj = ktraj.to(dtype=dtype)
smap = smap.to(dtype=dtype)
kbsense_ob = MriSenseNufft(smap=smap, im_size=im_size).to(
dtype=dtype, device=device)
adjkbsense_ob = AdjMriSenseNufft(
smap=smap, im_size=im_size).to(dtype=dtype, device=device)
adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)
# precompute toeplitz kernel if using toeplitz
if use_toep:
print('using toeplitz for forward/backward')
kern = calc_toep_kernel(adjkbsense_ob, ktraj)
toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)
# precompute the sparse interpolation matrices
if sparse_mats_flag:
print('using sparse interpolation matrices')
real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)
interp_mats = {
'real_interp_mats': real_mat,