Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_smoke(self, device):
quaternion = torch.zeros(4).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert angle_axis.shape == (3,)
def test_small_angle(self, device):
theta = 1e-2
quaternion = torch.tensor([np.cos(theta / 2), np.sin(theta / 2), 0., 0.]).to(device)
expected = torch.tensor([theta, 0., 0.]).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert_allclose(angle_axis, expected)
def test_unit_quaternion(self, device):
quaternion = torch.tensor([1., 0., 0., 0.]).to(device)
expected = torch.tensor([0., 0., 0.]).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert_allclose(angle_axis, expected)
@pytest.mark.parametrize("batch_size", (1, 3, 8))
def test_smoke_batch(self, device, batch_size):
quaternion = torch.zeros(batch_size, 4).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert angle_axis.shape == (batch_size, 3)
def test_gradcheck(self, device):
eps = 1e-12
quaternion = torch.tensor([1., 0., 0., 0.]).to(device) + eps
quaternion = tensor_to_gradcheck_var(quaternion)
# evaluate function gradient
assert gradcheck(kornia.quaternion_to_angle_axis, (quaternion,),
raise_exception=True)
def test_smoke(self, device):
quaternion = torch.zeros(4).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert angle_axis.shape == (3,)
def test_z_rotation(self, device):
quaternion = torch.tensor([np.sqrt(3) / 2, 0., 0., 0.5]).to(device)
expected = torch.tensor([0., 0., kornia.pi / 3]).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert_allclose(angle_axis, expected)
def test_y_rotation(self, device):
quaternion = torch.tensor([0., 0., 1., 0.]).to(device)
expected = torch.tensor([0., kornia.pi, 0.]).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert_allclose(angle_axis, expected)
def test_y_rotation(self, device):
quaternion = torch.tensor([0., 0., 1., 0.]).to(device)
expected = torch.tensor([0., kornia.pi, 0.]).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert_allclose(angle_axis, expected)
def test_z_rotation(self, device):
quaternion = torch.tensor([np.sqrt(3) / 2, 0., 0., 0.5]).to(device)
expected = torch.tensor([0., 0., kornia.pi / 3]).to(device)
angle_axis = kornia.quaternion_to_angle_axis(quaternion)
assert_allclose(angle_axis, expected)