Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
kbnufft_ob = KbNufft(
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjkbnufft_ob = AdjKbNufft(
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
y.requires_grad = True
x = adjkbnufft_ob.forward(y, ktraj, interp_mats)
((x ** 2) / 2).sum().backward()
y_grad = y.grad.clone().detach()
y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)
assert torch.norm(y_grad-y_hat) < norm_tol
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
y.requires_grad = True
x = adjsensenufft_ob.forward(y, ktraj, interp_mats)
((x ** 2) / 2).sum().backward()
y_grad = y.grad.clone().detach()
y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj, interp_mats)
assert torch.norm(y_grad-y_hat) < norm_tol
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
kbnufft_ob = KbNufft(
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjkbnufft_ob = AdjKbNufft(
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x.requires_grad = True
y = kbnufft_ob.forward(x, ktraj, interp_mats)
((y ** 2) / 2).sum().backward()
x_grad = x.grad.clone().detach()
x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats)
assert torch.norm(x_grad-x_hat) < norm_tol
for device in device_list:
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
kbnufft_ob = KbNufft(
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjkbnufft_ob = AdjKbNufft(
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x_forw = kbnufft_ob(x, ktraj, interp_mats)
y_back = adjkbnufft_ob(y, ktraj, interp_mats)
inprod1 = inner_product(y, x_forw, dim=2)
inprod2 = inner_product(y_back, x, dim=2)
assert torch.norm(inprod1 - inprod2) < norm_tol
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
kbinterp_ob = KbInterpForw(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjkbinterp_ob = KbInterpBack(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x_forw = kbinterp_ob(x, ktraj, interp_mats)
y_back = adjkbinterp_ob(y, ktraj, interp_mats)
inprod1 = inner_product(y, x_forw, dim=2)
inprod2 = inner_product(y_back, x, dim=2)
assert torch.norm(inprod1 - inprod2) < norm_tol
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
kbinterp_ob = KbInterpForw(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjkbinterp_ob = KbInterpBack(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x.requires_grad = True
y = kbinterp_ob.forward(x, ktraj, interp_mats)
((y ** 2) / 2).sum().backward()
x_grad = x.grad.clone().detach()
x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats)
assert torch.norm(x_grad-x_hat) < norm_tol
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
kbinterp_ob = KbInterpForw(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjkbinterp_ob = KbInterpBack(
im_size=im_size,
grid_size=grid_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x.requires_grad = True
y = kbinterp_ob.forward(x, ktraj, interp_mats)
((y ** 2) / 2).sum().backward()
x_grad = x.grad.clone().detach()
x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats)
assert torch.norm(x_grad-x_hat) < norm_tol
x = x.detach().to(dtype=dtype, device=device)
y = y.detach().to(dtype=dtype, device=device)
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x.requires_grad = True
y = sensenufft_ob.forward(x, ktraj, interp_mats)
((y ** 2) / 2).sum().backward()
x_grad = x.grad.clone().detach()
x_hat = adjsensenufft_ob.forward(
y.clone().detach(), ktraj, interp_mats)
assert torch.norm(x_grad-x_hat) < norm_tol
ktraj = ktraj.detach().to(dtype=dtype, device=device)
sensenufft_ob = MriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
adjsensenufft_ob = AdjMriSenseNufft(
smap=smap,
im_size=im_size,
numpoints=numpoints,
coilpack=True
).to(dtype=dtype, device=device)
real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
x.requires_grad = True
y = sensenufft_ob.forward(x, ktraj, interp_mats)
((y ** 2) / 2).sum().backward()
x_grad = x.grad.clone().detach()
x_hat = adjsensenufft_ob.forward(
y.clone().detach(), ktraj, interp_mats)
assert torch.norm(x_grad-x_hat) < norm_tol
dtype=dtype, device=device)
adjkbsense_ob = AdjMriSenseNufft(
smap=smap, im_size=im_size).to(dtype=dtype, device=device)
adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)
# precompute toeplitz kernel if using toeplitz
if use_toep:
print('using toeplitz for forward/backward')
kern = calc_toep_kernel(adjkbsense_ob, ktraj)
toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)
# precompute the sparse interpolation matrices
if sparse_mats_flag:
print('using sparse interpolation matrices')
real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)
interp_mats = {
'real_interp_mats': real_mat,
'imag_interp_mats': imag_mat
}
else:
print('not using sparse interpolation matrices')
interp_mats = None
if use_toep:
# warm-up computation
for _ in range(num_nuffts):
x = toep_ob(image.to(device=device), kern.to(
device=device)).to(cpudevice)
# run the speed tests
if device == torch.device('cuda'):
torch.cuda.reset_max_memory_allocated()