Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_smoke(self, device):
angle_axis = torch.zeros(3)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert quaternion.shape == (4,)
def test_small_angle(self, device):
theta = 1e-2
angle_axis = torch.tensor([theta, 0., 0.]).to(device)
expected = torch.tensor([np.cos(theta / 2), np.sin(theta / 2), 0., 0.]).to(device)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert_allclose(quaternion, expected)
def test_zero_angle(self, device):
angle_axis = torch.tensor([0., 0., 0.]).to(device)
expected = torch.tensor([1., 0., 0., 0.]).to(device)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert_allclose(quaternion, expected)
def test_small_angle(self, device):
theta = 1e-2
angle_axis = torch.tensor([theta, 0., 0.]).to(device)
expected = torch.tensor([np.cos(theta / 2), np.sin(theta / 2), 0., 0.]).to(device)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert_allclose(quaternion, expected)
@pytest.mark.parametrize("batch_size", (1, 3, 8))
def test_smoke_batch(self, device, batch_size):
angle_axis = torch.zeros(batch_size, 3).to(device)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert quaternion.shape == (batch_size, 4)
def test_x_rotation(self, device):
half_sqrt2 = 0.5 * np.sqrt(2)
angle_axis = torch.tensor([kornia.pi / 2, 0., 0.]).to(device)
expected = torch.tensor([half_sqrt2, half_sqrt2, 0., 0.]).to(device)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert_allclose(quaternion, expected)
def test_x_rotation(self, device):
half_sqrt2 = 0.5 * np.sqrt(2)
angle_axis = torch.tensor([kornia.pi / 2, 0., 0.]).to(device)
expected = torch.tensor([half_sqrt2, half_sqrt2, 0., 0.]).to(device)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert_allclose(quaternion, expected)
def test_smoke_batch(self, device, batch_size):
angle_axis = torch.zeros(batch_size, 3).to(device)
quaternion = kornia.angle_axis_to_quaternion(angle_axis)
assert quaternion.shape == (batch_size, 4)