How to use the torchkbnufft.nufft.sparse_interp_mat.precomp_sparse_mats function in torchkbnufft

To help you get started, we’ve selected a few torchkbnufft examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbnufft_ob = AdjKbNufft(
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        y.requires_grad = True
        x = adjkbnufft_ob.forward(y, ktraj, interp_mats)

        ((x ** 2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = kbnufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad-y_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        y.requires_grad = True
        x = adjsensenufft_ob.forward(y, ktraj, interp_mats)

        ((x ** 2) / 2).sum().backward()
        y_grad = y.grad.clone().detach()

        y_hat = sensenufft_ob.forward(x.clone().detach(), ktraj, interp_mats)

        assert torch.norm(y_grad-y_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbnufft_ob = AdjKbNufft(
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbnufft_ob.forward(x, ktraj, interp_mats)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbnufft_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad-x_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_sparse_adjoints.py View on Github external
for device in device_list:
        x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbnufft_ob = KbNufft(
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbnufft_ob = AdjKbNufft(
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x_forw = kbnufft_ob(x, ktraj, interp_mats)
        y_back = adjkbnufft_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
github mmuckley / torchkbnufft / tests / test_sparse_adjoints.py View on Github external
x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x_forw = kbinterp_ob(x, ktraj, interp_mats)
        y_back = adjkbinterp_ob(y, ktraj, interp_mats)

        inprod1 = inner_product(y, x_forw, dim=2)
        inprod2 = inner_product(y_back, x, dim=2)

        assert torch.norm(inprod1 - inprod2) < norm_tol
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj, interp_mats)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad-x_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        kbinterp_ob = KbInterpForw(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjkbinterp_ob = KbInterpBack(
            im_size=im_size,
            grid_size=grid_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, kbinterp_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = kbinterp_ob.forward(x, ktraj, interp_mats)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjkbinterp_ob.forward(y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad-x_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
x = x.detach().to(dtype=dtype, device=device)
        y = y.detach().to(dtype=dtype, device=device)
        ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = sensenufft_ob.forward(x, ktraj, interp_mats)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjsensenufft_ob.forward(
            y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad-x_hat) < norm_tol
github mmuckley / torchkbnufft / tests / test_pytorch_sparse_grad_adj_matching.py View on Github external
ktraj = ktraj.detach().to(dtype=dtype, device=device)

        sensenufft_ob = MriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints,
            coilpack=True
        ).to(dtype=dtype, device=device)
        adjsensenufft_ob = AdjMriSenseNufft(
            smap=smap,
            im_size=im_size,
            numpoints=numpoints,
            coilpack=True
        ).to(dtype=dtype, device=device)

        real_mat, imag_mat = precomp_sparse_mats(ktraj, sensenufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }

        x.requires_grad = True
        y = sensenufft_ob.forward(x, ktraj, interp_mats)

        ((y ** 2) / 2).sum().backward()
        x_grad = x.grad.clone().detach()

        x_hat = adjsensenufft_ob.forward(
            y.clone().detach(), ktraj, interp_mats)

        assert torch.norm(x_grad-x_hat) < norm_tol
github mmuckley / torchkbnufft / profile_torchkbnufft.py View on Github external
dtype=dtype, device=device)
    adjkbsense_ob = AdjMriSenseNufft(
        smap=smap, im_size=im_size).to(dtype=dtype, device=device)

    adjkbnufft_ob = AdjKbNufft(im_size=im_size).to(dtype=dtype, device=device)

    # precompute toeplitz kernel if using toeplitz
    if use_toep:
        print('using toeplitz for forward/backward')
        kern = calc_toep_kernel(adjkbsense_ob, ktraj)
        toep_ob = ToepSenseNufft(smap=smap).to(dtype=dtype, device=device)

    # precompute the sparse interpolation matrices
    if sparse_mats_flag:
        print('using sparse interpolation matrices')
        real_mat, imag_mat = precomp_sparse_mats(ktraj, adjkbnufft_ob)
        interp_mats = {
            'real_interp_mats': real_mat,
            'imag_interp_mats': imag_mat
        }
    else:
        print('not using sparse interpolation matrices')
        interp_mats = None

    if use_toep:
        # warm-up computation
        for _ in range(num_nuffts):
            x = toep_ob(image.to(device=device), kern.to(
                device=device)).to(cpudevice)
        # run the speed tests
        if device == torch.device('cuda'):
            torch.cuda.reset_max_memory_allocated()