Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_jit(self):
batch_size, channels, height, width = 2, 3, 64, 64
img = torch.ones(batch_size, channels, height, width)
gray = color.RgbToGrayscale()
gray_traced = torch.jit.trace(color.RgbToGrayscale(), img)
assert_allclose(gray(img), gray_traced(img))
def test_jit(self):
batch_size, channels, height, width = 2, 3, 64, 64
img = torch.ones(batch_size, channels, height, width)
gray = color.RgbToGrayscale()
gray_traced = torch.jit.trace(color.RgbToGrayscale(), img)
assert_allclose(gray(img), gray_traced(img))
def test_rgb_to_grayscale(self):
channels, height, width = 3, 4, 5
img = torch.ones(channels, height, width)
assert color.RgbToGrayscale()(img).shape == (1, height, width)
def test_rgb_to_grayscale_batch(self):
batch_size, channels, height, width = 2, 3, 4, 5
img = torch.ones(batch_size, channels, height, width)
assert color.RgbToGrayscale()(img).shape == \
(batch_size, 1, height, width)