Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_sdk_launch_plan_node():
@_tasks.inputs(a=_types.Types.Integer)
@_tasks.outputs(b=_types.Types.Integer)
@_tasks.python_task()
def testy_test(wf_params, a, b):
pass
@_workflow.workflow_class
class test_workflow(object):
a = _workflow.Input(_types.Types.Integer)
test = testy_test(a=1)
b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)
lp = test_workflow.create_launch_plan()
lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'name', 'version')
n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp)
assert n.launchplan_ref.project == 'project'
assert n.launchplan_ref.domain == 'domain'
assert n.launchplan_ref.name == 'name'
assert n.launchplan_ref.version == 'version'
# Test floating ID
lp._id = _identifier.Identifier(
_identifier.ResourceType.TASK,
'new_project',
'new_domain',
'new_name',
'new_version'
)
assert n.launchplan_ref.project == 'new_project'
assert n.launchplan_ref.domain == 'new_domain'
def test_sdk_launch_plan_node():
@_tasks.inputs(a=_types.Types.Integer)
@_tasks.outputs(b=_types.Types.Integer)
@_tasks.python_task()
def testy_test(wf_params, a, b):
pass
@_workflow.workflow_class
class test_workflow(object):
a = _workflow.Input(_types.Types.Integer)
test = testy_test(a=1)
b = _workflow.Output(test.outputs.b, sdk_type=_types.Types.Integer)
lp = test_workflow.create_launch_plan()
lp._id = _identifier.Identifier(_identifier.ResourceType.TASK, 'project', 'domain', 'name', 'version')
n = _component_nodes.SdkWorkflowNode(sdk_launch_plan=lp)
assert n.launchplan_ref.project == 'project'
assert n.launchplan_ref.domain == 'domain'
assert n.launchplan_ref.name == 'name'
assert n.launchplan_ref.version == 'version'
# Test floating ID
lp._id = _identifier.Identifier(
_identifier.ResourceType.TASK,
'new_project',
def test_dynamic_launch_plan_yielding_of_constant_workflow():
outputs = lp_yield_empty_wf.unit_test()
# TODO: Currently, Flytekit will not return early and not do anything if there are any workflow nodes detected
# in the output of a dynamic task.
dj_spec = outputs[_sdk_constants.FUTURES_FILE_NAME]
assert len(dj_spec.nodes) == 1
assert len(dj_spec.outputs) == 1
assert dj_spec.outputs[0].var == "out"
assert len(outputs.keys()) == 1
def test_dynamic_launch_plan_yielding():
outputs = lp_yield_task.unit_test(num=10)
# TODO: Currently, Flytekit will not return early and not do anything if there are any workflow nodes detected
# in the output of a dynamic task.
dj_spec = outputs[_sdk_constants.FUTURES_FILE_NAME]
assert dj_spec.min_successes == 1
launch_plan_node = dj_spec.nodes[0]
node_id = launch_plan_node.id
assert "models-test-dynamic-wfs-id-lp" in node_id
assert node_id.endswith("-0")
# Assert that the output of the dynamic job spec is bound to the single node in the spec, the workflow node
# containing the launch plan
assert dj_spec.outputs[0].var == "out"
assert dj_spec.outputs[0].binding.promise.node_id == node_id
assert dj_spec.outputs[0].binding.promise.var == "task_output"
None,
None,
None,
None,
None,
None,
None,
None,
None,
False,
None,
{},
None,
)
t.add_inputs({'value_in': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")})
t.add_outputs({'value_out': interface.Variable(primitives.Integer.to_flyte_literal_type(), "")})
out = t.unit_test(value_in=1)
assert out['value_out'] == 2
with _pytest.raises(_user_exceptions.FlyteAssertion) as e:
t()
assert "value_in" in str(e.value)
assert "INTEGER" in str(e.value)
def test_fetch_latest(mock_get_engine):
admin_task = _task_models.Task(
_identifier.Identifier(_identifier.ResourceType.TASK, "p1", "d1", "n1", "v1"),
_MagicMock(),
)
mock_engine = _MagicMock()
mock_engine.fetch_latest_task = _MagicMock(
return_value=admin_task
)
mock_get_engine.return_value = mock_engine
task = _task.SdkTask.fetch_latest("p1", "d1", "n1")
assert task.id == admin_task.id
def test_launch_plan_node():
workflow_to_test = _workflow.workflow(
{},
inputs={
'required_input': _workflow.Input(_types.Types.Integer),
'default_input': _workflow.Input(_types.Types.Integer, default=5)
},
outputs={
'out': _workflow.Output([1, 2, 3], sdk_type=[_types.Types.Integer])
}
)
lp = workflow_to_test.create_launch_plan()
# Test that required input isn't set
with _pytest.raises(_user_exceptions.FlyteAssertion):
lp()
# Test that positional args are rejected
with _pytest.raises(_user_exceptions.FlyteAssertion):
lp(1, 2)
# Test that type checking works
with _pytest.raises(_user_exceptions.FlyteTypeException):
lp(required_input='abc', default_input=1)
# Test that bad arg name is detected
with _pytest.raises(_user_exceptions.FlyteAssertion):
lp(required_input=1, bad_arg=1)
# Test default input is accounted for
n = lp(required_input=10)
{},
{}
)
)
task_node = _workflow.TaskNode(task.id)
node = _workflow.Node(
id='my_node',
metadata=node_metadata,
inputs=[b0],
upstream_node_ids=[],
output_aliases=[],
task_node=task_node)
template = _workflow.WorkflowTemplate(
id=_identifier.Identifier(_identifier.ResourceType.WORKFLOW, "project", "domain", "name", "version"),
metadata=_workflow.WorkflowMetadata(),
interface=typed_interface,
nodes=[node],
outputs=[b1, b2],
)
obj = _workflow_closure.WorkflowClosure(workflow=template, tasks=[task])
assert len(obj.tasks) == 1
obj2 = _workflow_closure.WorkflowClosure.from_flyte_idl(obj.to_flyte_idl())
assert obj == obj2
)
task_metadata = _task.TaskMetadata(
True,
_task.RuntimeMetadata(_task.RuntimeMetadata.RuntimeType.FLYTE_SDK, "1.0.0", "python"),
timedelta(days=1),
_literals.RetryStrategy(3),
"0.1.1b0",
"This is deprecated!"
)
cpu_resource = _task.Resources.ResourceEntry(_task.Resources.ResourceName.CPU, "1")
resources = _task.Resources(requests=[cpu_resource], limits=[cpu_resource])
task = _task.TaskTemplate(
_identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version"),
"python",
task_metadata,
typed_interface,
{'a': 1, 'b': {'c': 2, 'd': 3}},
container=_task.Container(
"my_image",
["this", "is", "a", "cmd"],
["this", "is", "an", "arg"],
resources,
{},
{}
)
)
task_node = _workflow.TaskNode(task.id)
node = _workflow.Node(
def test_serialize():
workflow_to_test = _workflow.workflow(
{},
inputs={
'required_input': _workflow.Input(_types.Types.Integer),
'default_input': _workflow.Input(_types.Types.Integer, default=5)
}
)
workflow_to_test._id = _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v")
lp = workflow_to_test.create_launch_plan(
fixed_inputs={'required_input': 5},
role='iam_role',
)
with _configuration.TemporaryConfiguration(
_os.path.join(_os.path.dirname(_os.path.realpath(__file__)), '../../common/configs/local.config'),
internal_overrides={
'image': 'myflyteimage:v123',
'project': 'myflyteproject',
'domain': 'development'
}
):
s = lp.serialize()
assert s.workflow_id == _identifier.Identifier(_identifier.ResourceType.WORKFLOW, "p", "d", "n", "v").to_flyte_idl()
assert s.auth_role.assumable_iam_role == 'iam_role'