Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
from flytekit.common.tasks import sdk_runnable as _sdk_runnable
from flytekit import configuration as _configuration
from flytekit.models import types as _type_models, task as _task_models
from flytekit.models.core import identifier as _identifier
import datetime as _datetime
import os as _os
@inputs(in1=Types.Integer)
@outputs(out1=Types.String)
@python_task
def default_task(wf_params, in1, out1):
pass
default_task._id = _identifier.Identifier(_identifier.ResourceType.TASK, "project", "domain", "name", "version")
def test_default_python_task():
assert isinstance(default_task, _sdk_runnable.SdkRunnableTask)
assert default_task.interface.inputs['in1'].description == ''
assert default_task.interface.inputs['in1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.INTEGER)
assert default_task.interface.outputs['out1'].description == ''
assert default_task.interface.outputs['out1'].type == \
_type_models.LiteralType(simple=_type_models.SimpleType.STRING)
assert default_task.type == _common_constants.SdkTaskType.PYTHON_TASK
assert default_task.task_function_name == 'default_task'
assert default_task.task_module == __name__
assert default_task.metadata.timeout == _datetime.timedelta(seconds=0)
assert default_task.metadata.deprecated_error_message == ''
assert default_task.metadata.discoverable is False
def test_task_execution_identifier():
empty_id = identifier.Identifier(identifier.ResourceType.UNSPECIFIED, "", "", "", "")
not_empty_id = identifier.Identifier(identifier.ResourceType.UNSPECIFIED, "", "", "", "version")
assert empty_id.is_empty
assert not not_empty_id.is_empty
algorithm_version="0.72",
metric_definitions=[MetricDefinition(name="Minimize", regex="validation:error")]
),
)
train_task_exec = simple_training_job_task(
train='s3://my-bucket/training.csv',
validation='s3://my-bucket/validation.csv',
static_hyperparameters=example_hyperparams,
stopping_condition=StoppingCondition(
max_runtime_in_seconds=43200,
max_wait_time_in_seconds=43200,
).to_flyte_idl(),
)
train_task_exec._id = _identifier.Identifier(
_identifier.ResourceType.TASK, "my_project", "my_domain", "my_name", "my_version")
def test_simple_training_job_task():
assert isinstance(simple_training_job_task, SdkSimpleTrainingJobTask)
assert isinstance(simple_training_job_task, _sdk_task.SdkTask)
assert simple_training_job_task.interface.inputs['train'].description == ''
assert simple_training_job_task.interface.inputs['train'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['validation'].description == ''
assert simple_training_job_task.interface.inputs['validation'].type == \
_sdk_types.Types.MultiPartCSV.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['static_hyperparameters'].description == ''
assert simple_training_job_task.interface.inputs['static_hyperparameters'].type == \
_sdk_types.Types.Generic.to_flyte_literal_type()
assert simple_training_job_task.interface.inputs['stopping_condition'].type == \
def test_get_node_execution_inputs(mock_client_factory, execution_data_locations):
mock_client = MagicMock()
mock_client.get_node_execution_data = MagicMock(
return_value=_execution_models.NodeExecutionGetDataResponse(
execution_data_locations[0],
execution_data_locations[1]
)
)
mock_client_factory.return_value = mock_client
m = MagicMock()
type(m).id = PropertyMock(
return_value=identifier.NodeExecutionIdentifier(
"node-a",
identifier.WorkflowExecutionIdentifier(
"project",
"domain",
"name",
)
)
)
inputs = engine.FlyteNodeExecution(m).get_inputs()
assert len(inputs.literals) == 1
assert inputs.literals['a'].scalar.primitive.integer == 1
mock_client.get_node_execution_data.assert_called_once_with(
identifier.NodeExecutionIdentifier(
"node-a",
identifier.WorkflowExecutionIdentifier(
deprecated
)
for discoverable, runtime_metadata, timeout, retry_strategy, discovery_version, deprecated in product(
[True, False],
LIST_OF_RUNTIME_METADATA,
[timedelta(days=i) for i in range(3)],
LIST_OF_RETRY_POLICIES,
["1.0"],
["deprecated"]
)
]
LIST_OF_TASK_TEMPLATES = [
task.TaskTemplate(
identifier.Identifier(identifier.ResourceType.TASK, "project", "domain", "name", "version"),
"python",
task_metadata,
interfaces,
{'a': 1, 'b': [1, 2, 3], 'c': 'abc', 'd': {'x': 1, 'y': 2, 'z': 3}},
container=task.Container(
"my_image",
["this", "is", "a", "cmd"],
["this", "is", "an", "arg"],
resources,
{'a': 'b'},
{'d': 'e'}
)
)
for task_metadata, interfaces, resources in product(
LIST_OF_TASK_METADATA,
LIST_OF_INTERFACES,
'python'
),
timeout or _datetime.timedelta(seconds=0),
_literal_models.RetryStrategy(retries),
interruptible,
cache_version,
deprecated
)
interface = get_interface_from_task_info(fn.__annotations__, outputs or [])
task_instance = PythonTask(fn, interface, metadata, outputs, task_obj)
# TODO: One of the things I want to make sure to do is better naming support. At this point, we should already
# be able to determine the name of the task right? Can anyone think of situations where we can't?
# Where does the current instance tracker come into play?
task_instance.id = _identifier_model.Identifier(_identifier_model.ResourceType.TASK, "proj", "dom", "blah", "1")
return task_instance
def update_task_meta(description, host, insecure, project, domain, name):
"""
Updates a task entity under the scope specified by {project, domain, name} across versions.
"""
_welcome_message()
client = _friendly_client.SynchronousFlyteClient(host, insecure=insecure)
client.update_named_entity(
_core_identifier.ResourceType.TASK,
_named_entity.NamedEntityIdentifier(project, domain, name),
_named_entity.NamedEntityMetadata(description, _named_entity.NamedEntityState.ACTIVE))
_click.echo("Successfully updated task")
def from_flyte_idl(cls, p):
"""
:param flyteidl.admin.execution_pb2.ExecutionSpec p:
:return: ExecutionSpec
"""
return cls(
launch_plan=_identifier.Identifier.from_flyte_idl(p.launch_plan),
metadata=ExecutionMetadata.from_flyte_idl(p.metadata),
notifications=NotificationList.from_flyte_idl(p.notifications) if p.HasField("notifications") else None,
disable_all=p.disable_all if p.HasField("disable_all") else None,
labels=_common_models.Labels.from_flyte_idl(p.labels),
annotations=_common_models.Annotations.from_flyte_idl(p.annotations),
)
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.core.workflow_pb2.WorkflowNode pb2_object:
:rtype: WorkflowNode
"""
if pb2_object.HasField('launchplan_ref'):
return cls(launchplan_ref=_identifier.Identifier.from_flyte_idl(pb2_object.launchplan_ref))
else:
return cls(sub_workflow_ref=_identifier.Identifier.from_flyte_idl(pb2_object.sub_workflow_ref))
with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir:
with _common_utils.AutoDeletingTempDir("task_dir") as task_dir:
with _data_proxy.LocalWorkingDirectoryContext(task_dir):
with _data_proxy.RemoteDataContext():
output_file_dict = dict()
# This sets the logging level for user code and is the only place an sdk setting gets
# used at runtime. Optionally, Propeller can set an internal config setting which
# takes precedence.
log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
_logging.getLogger().setLevel(log_level)
try:
output_file_dict = self.sdk_task.execute(
_common_engine.EngineContext(
execution_id=_identifier.WorkflowExecutionIdentifier(
project=_internal_config.EXECUTION_PROJECT.get(),
domain=_internal_config.EXECUTION_DOMAIN.get(),
name=_internal_config.EXECUTION_NAME.get()
),
execution_date=_datetime.utcnow(),
stats=_get_stats(
# Stats metric path will be:
# registration_project.registration_domain.app.module.task_name.user_stats
# and it will be tagged with execution-level values for project/domain/wf/lp
"{}.{}.{}.user_stats".format(
_internal_config.TASK_PROJECT.get() or _internal_config.PROJECT.get(),
_internal_config.TASK_DOMAIN.get() or _internal_config.DOMAIN.get(),
_internal_config.TASK_NAME.get() or _internal_config.NAME.get()
),
tags={
'exec_project': _internal_config.EXECUTION_PROJECT.get(),