Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_noncontiguous(self, device):
batch_size = 3
inp = torch.rand(3, 5, 5).expand(batch_size, -1, -1, -1).to(device)
kernel = torch.ones(1, 2, 2).to(device)
actual = kornia.filter2D(inp, kernel)
expected = actual
assert_allclose(actual, actual)
input = torch.tensor([[[
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 5., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
]]]).expand(2, 2, -1, -1).to(device)
expected = torch.tensor([[[
[0., 0., 0., 0., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 0., 0., 0., 0.],
]]]).to(device)
actual = kornia.filter2D(input, kernel)
assert_allclose(actual, expected)
def test_smoke(self, device):
kernel = torch.rand(1, 3, 3).to(device)
input = torch.ones(1, 1, 7, 8).to(device)
assert kornia.filter2D(input, kernel).shape == input.shape
input = torch.tensor([[[
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 5., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
]]]).to(device)
expected = torch.tensor([[[
[0., 0., 0., 0., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 5., 5., 5., 0.],
[0., 0., 0., 0., 0.],
]]]).to(device)
actual = kornia.filter2D(input, kernel)
assert_allclose(actual, expected)
def forward(self, input: torch.Tensor): # type: ignore
return kornia.filter2D(input, self.kernel, self.border_type)
ksize = tuple(ksize)
if ksize in GAUS_KERNELS:
kernel = GAUS_KERNELS[ksize]
else:
_ks = ksize
ksize = np.array(ksize).astype(float)
sigma = 0.3 * ((ksize - 1.0) * 0.5 - 1.0) + 0.8
kernel = (
torch.from_numpy(_gaussian_kernel2d(ksize, sigma))
.to(img.device, non_blocking=True)
.to(dtype, non_blocking=True)
)
GAUS_KERNELS[_ks] = kernel
img = K.filter2D(img, kernel)
if dtype == torch.uint8:
img = torch.clamp(torch.round(img), 0, 255)
return img.to(dtype, non_blocking=True)
raise ValueError("img1 and img2 shapes must be the same. Got: {}".format(img1.shape, img2.shape))
if not img1.device == img2.device:
raise ValueError("img1 and img2 must be in the same device. Got: {}".format(img1.device, img2.device))
if not img1.dtype == img2.dtype:
raise ValueError("img1 and img2 must be in the same dtype. Got: {}".format(img1.dtype, img2.dtype))
# prepare kernel
b, c, h, w = img1.shape
tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype)
tmp_kernel = torch.unsqueeze(tmp_kernel, dim=0)
# compute local mean per channel
mu1: torch.Tensor = filter2D(img1, tmp_kernel)
mu2: torch.Tensor = filter2D(img2, tmp_kernel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
# compute local sigma per channel
sigma1_sq = filter2D(img1 * img1, tmp_kernel) - mu1_sq
sigma2_sq = filter2D(img2 * img2, tmp_kernel) - mu2_sq
sigma12 = filter2D(img1 * img2, tmp_kernel) - mu1_mu2
ssim_map = ((2.0 * mu1_mu2 + self.C1) * (2.0 * sigma12 + self.C2)) / (
(mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2)
)
loss = torch.clamp(-ssim_map + 1.0, min=0, max=1) / 2.0
if not img1.shape == img2.shape:
raise ValueError("img1 and img2 shapes must be the same. Got: {}".format(img1.shape, img2.shape))
if not img1.device == img2.device:
raise ValueError("img1 and img2 must be in the same device. Got: {}".format(img1.device, img2.device))
if not img1.dtype == img2.dtype:
raise ValueError("img1 and img2 must be in the same dtype. Got: {}".format(img1.dtype, img2.dtype))
# prepare kernel
b, c, h, w = img1.shape
tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype)
tmp_kernel = torch.unsqueeze(tmp_kernel, dim=0)
# compute local mean per channel
mu1: torch.Tensor = filter2D(img1, tmp_kernel)
mu2: torch.Tensor = filter2D(img2, tmp_kernel)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
# compute local sigma per channel
sigma1_sq = filter2D(img1 * img1, tmp_kernel) - mu1_sq
sigma2_sq = filter2D(img2 * img2, tmp_kernel) - mu2_sq
sigma12 = filter2D(img1 * img2, tmp_kernel) - mu1_mu2
ssim_map = ((2.0 * mu1_mu2 + self.C1) * (2.0 * sigma12 + self.C2)) / (
(mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2)
)
loss = torch.clamp(-ssim_map + 1.0, min=0, max=1) / 2.0