Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_random_brightness(self):
torch.manual_seed(42)
f = ColorJitter(brightness=0.2)
input = torch.tensor([[[[0.1, 0.2, 0.3],
[0.6, 0.5, 0.4],
[0.7, 0.8, 1.]]]]) # 1 x 1 x 3 x 3
input = input.repeat(2, 3, 1, 1) # 2 x 3 x 3
expected = torch.tensor([[[[0.2529, 0.3529, 0.4529],
[0.7529, 0.6529, 0.5529],
[0.8529, 0.9529, 1.0000]],
[[0.2529, 0.3529, 0.4529],
[0.7529, 0.6529, 0.5529],
[0.8529, 0.9529, 1.0000]],
[[0.2529, 0.3529, 0.4529],
[0.7529, 0.6529, 0.5529],
def test_sequential(self):
f = nn.Sequential(
ColorJitter(return_transform=True),
ColorJitter(return_transform=True),
)
input = torch.rand(3, 5, 5) # 3 x 5 x 5
expected = input
expected_transform = torch.eye(3).unsqueeze(0) # 3 x 3
assert_allclose(f(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_allclose(f(input)[1], expected_transform)
def test_color_jitter_batch(self):
f = nn.Sequential(
ColorJitter(return_transform=True),
ColorJitter(return_transform=True),
)
input = torch.rand(2, 3, 5, 5) # 2 x 3 x 5 x 5
expected = input
expected_transform = torch.eye(3).unsqueeze(0).expand((2, 3, 3)) # 2 x 3 x 3
assert_allclose(f(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_allclose(f(input)[0], expected)
assert_allclose(f(input)[1], expected_transform)
def test_color_jitter_batch(self):
f = nn.Sequential(
ColorJitter(return_transform=True),
ColorJitter(return_transform=True),
)
input = torch.rand(2, 3, 5, 5) # 2 x 3 x 5 x 5
expected = input
expected_transform = torch.eye(3).unsqueeze(0).expand((2, 3, 3)) # 2 x 3 x 3
assert_allclose(f(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_allclose(f(input)[0], expected)
assert_allclose(f(input)[1], expected_transform)
def test_random_saturation_tuple(self):
torch.manual_seed(42)
f = ColorJitter(saturation=(0.8, 1.2))
input = torch.tensor([[[[0.1, 0.2, 0.3],
[0.6, 0.5, 0.4],
[0.7, 0.8, 1.]],
[[1.0, 0.5, 0.6],
[0.6, 0.3, 0.2],
[0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7],
[0.9, 0.3, 0.2],
[0.8, 0.4, .5]]]]) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = torch.tensor([[[[1.8763e-01, 2.5842e-01, 3.3895e-01],
[6.2921e-01, 5.0000e-01, 4.0000e-01],
def test_color_jitter_batch(self):
f = ColorJitter()
f1 = ColorJitter(return_transform=True)
input = torch.rand(2, 3, 5, 5) # 2 x 3 x 5 x 5
expected = input
expected_transform = torch.eye(3).unsqueeze(0).expand((2, 3, 3)) # 2 x 3 x 3
assert_allclose(f(input), expected, atol=1e-4, rtol=1e-5)
assert_allclose(f1(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_allclose(f1(input)[1], expected_transform)
def test_random_hue(self):
torch.manual_seed(42)
f = ColorJitter(hue=0.2)
input = torch.tensor([[[[0.1, 0.2, 0.3],
[0.6, 0.5, 0.4],
[0.7, 0.8, 1.]],
[[1.0, 0.5, 0.6],
[0.6, 0.3, 0.2],
[0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7],
[0.9, 0.3, 0.2],
[0.8, 0.4, .5]]]]) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = torch.tensor([[[[0.1000, 0.2000, 0.3000],
[0.6000, 0.5000, 0.4000],
def test_color_jitter(self):
f = ColorJitter()
f1 = ColorJitter(return_transform=True)
input = torch.rand(3, 5, 5) # 3 x 5 x 5
expected = input
expected_transform = torch.eye(3).unsqueeze(0) # 3 x 3
assert_allclose(f(input), expected, atol=1e-4, rtol=1e-5)
assert_allclose(f1(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_allclose(f1(input)[1], expected_transform)
def test_random_hue_tensor(self):
torch.manual_seed(42)
f = ColorJitter(hue=torch.tensor([-0.2, 0.2]))
input = torch.tensor([[[[0.1, 0.2, 0.3],
[0.6, 0.5, 0.4],
[0.7, 0.8, 1.]],
[[1.0, 0.5, 0.6],
[0.6, 0.3, 0.2],
[0.8, 0.1, 0.2]],
[[0.6, 0.8, 0.7],
[0.9, 0.3, 0.2],
[0.8, 0.4, .5]]]]) # 1 x 1 x 3 x 3
input = input.repeat(2, 1, 1, 1) # 2 x 3 x 3
expected = torch.tensor([[[[0.1000, 0.2000, 0.3000],
[0.6000, 0.5000, 0.4000],
def test_sequential(self):
f = nn.Sequential(
ColorJitter(return_transform=True),
ColorJitter(return_transform=True),
)
input = torch.rand(3, 5, 5) # 3 x 5 x 5
expected = input
expected_transform = torch.eye(3).unsqueeze(0) # 3 x 3
assert_allclose(f(input)[0], expected, atol=1e-4, rtol=1e-5)
assert_allclose(f(input)[1], expected_transform)