Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_hive_task_query_generation():
with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
context = _common_engine.EngineContext(
execution_id=WorkflowExecutionIdentifier(
project='unit_test',
domain='unit_test',
name='unit_test'
),
execution_date=_datetime.utcnow(),
stats=None, # TODO: A mock stats object that we can read later.
logging=_logging, # TODO: A mock logging object that we can read later.
tmp_dir=user_working_directory
)
references = {
name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type))
for name, variable in _six.iteritems(two_queries.interface.outputs)
}
qubole_hive_jobs = two_queries._generate_plugin_objects(context, references)
assert len(qubole_hive_jobs) == 2
# deprecated, collection is only here for backwards compatibility
assert len(qubole_hive_jobs[0].query_collection.queries) == 1
assert len(qubole_hive_jobs[1].query_collection.queries) == 1
# The output references should now have the same fake S3 path as the formatted queries
assert references['hive_results'].value[0].uri != ''
assert references['hive_results'].value[1].uri != ''
assert references['hive_results'].value[0].uri in qubole_hive_jobs[0].query.query
assert references['hive_results'].value[1].uri in qubole_hive_jobs[1].query.query
from flytekit.sdk import types
import six
GOOD_INPUTS = {
'a': types.Types.Integer,
'name': types.Types.String,
}
GOOD_OUTPUTS = {
'x': types.Types.Integer,
}
GOOD_NOTEBOOK = sdk_runnable.RunnableNotebookTask(
notebook_path="notebooks/good.ipynb",
inputs={
k: interface.Variable(
helpers.python_std_to_sdk_type(v).to_flyte_literal_type(),
''
)
for k, v in six.iteritems(GOOD_INPUTS)
},
outputs={
k: interface.Variable(
helpers.python_std_to_sdk_type(v).to_flyte_literal_type(),
''
)
for k, v in six.iteritems(GOOD_OUTPUTS)
},
task_type=constants.SdkTaskType.PYTHON_TASK,
)
def test_good_notebook():
# Create inputs, just inputs. Outputs need to come later.
# interface = get_interface_from_task_info(fn.__annotations__, outputs or [])
# inputs_map = interface.inputs
inputs_map = get_variable_map(inputs)
# Create promises out of all the inputs. Check for defaults in the function definition.
default_inputs = get_default_args(fn)
input_parameters = []
for input_name, input_variable_obj in inputs_map.items():
# _interface_models.Parameter(var=input_variable_obj, default=None, required=required)
# This is a bit annoying. I'd like to work directly with the Parameter model like above , but for now
# it's easier to use the promise.Input wrapper
# This is also annoying... I already have the literal type, but I have to go back to the SDK type (invoking
# the type engine)... in the constructor, it again turns it back to the literal type when creating the
# Parameter model.
sdk_type = _type_helpers.get_sdk_type_from_literal_type(input_variable_obj.type)
logger.debug(f"Converting literal type {input_variable_obj.type} to sdk type {sdk_type}")
arg_map = {'default': default_inputs[input_name]} if input_name in default_inputs else {}
input_parameters.append(_WorkflowInput(name=input_name, type=sdk_type, **arg_map))
# Fill in call args later - for now this only works for workflows with no inputs
workflow_outputs = fn()
# Iterate through the workflow outputs and collect two things
# 1. Get the outputs and use them to construct the old Output objects
# outputs can be like 5, or 'hi'
# or promise.NodeOutputs (let's just focus on this one first for POC)
# or Input objects from above in the case of a passthrough value.
# 2. Iterate through the outputs and collect all the nodes.
workflow_output_objs = []
all_nodes = []
def __init__(self, value, sdk_type=None, help=None):
"""
:param T value:
:param flytekit.common.types.base_sdk_types.FlyteSdkType sdk_type: If specified, the value provided must
match this type exactly. If not provided, the SDK will attempt to infer the type. It is recommended
this value be provided as the SDK might not always be able to infer the correct type.
"""
super(Output, self).__init__(
'',
value,
sdk_type=_type_helpers.python_std_to_sdk_type(sdk_type) if sdk_type else None,
help=help
)
def inputs(self):
"""
Returns the inputs to the execution in the standard Python format as dictated by the type engine.
:rtype: dict[Text, T]
"""
if self._inputs is None:
self._inputs = _type_helpers.unpack_literal_map_to_sdk_python_std(
_engine_loader.get_engine().get_node_execution(self).get_inputs()
)
return self._inputs
def __init__(self, type_map, node):
"""
:param dict[Text, flytekit.models.interface.Variable] type_map:
:param SdkNode node:
"""
super(ParameterMapper, self).__init__()
for key, var in _six.iteritems(type_map):
self[key] = self._return_mapping_object(node, _type_helpers.get_sdk_type_from_literal_type(var.type), key)
self._initialized = True
def execute(self, context, inputs):
"""
:param flytekit.engines.common.EngineContext context:
:param flytekit.models.literals.LiteralMap inputs:
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
:returns: This function must return a dictionary mapping 'filenames' to Flyte Interface Entities. These
entities will be used by the engine to pass data from node to node, populate metadata, etc. etc.. Each
engine will have different behavior. For instance, the Flyte engine will upload the entities to a remote
working directory (with the names provided), which will in turn allow Flyte Propeller to push along the
workflow. Where as local engine will merely feed the outputs directly into the next node.
"""
inputs_dict = _type_helpers.unpack_literal_map_to_sdk_python_std(inputs, {
k: _type_helpers.get_sdk_type_from_literal_type(v.type) for k, v in _six.iteritems(self.interface.inputs)
})
outputs_dict = {
name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type))
for name, variable in _six.iteritems(self.interface.outputs)
}
inputs_dict.update(outputs_dict)
with GlobalSparkContext():
_exception_scopes.user_entry_point(self.task_function)(
_sdk_runnable.ExecutionParameters(
execution_date=context.execution_date,
execution_id=context.execution_id,
stats=context.stats,
logging=context.logging,
"""
if not _os.path.exists(working_directory):
tmpdir = _utils.AutoDeletingTempDir("nb_made_")
tmpdir.__enter__()
working_directory = tmpdir.name
_data_proxy.LocalWorkingDirectoryContext(working_directory).__enter__()
_data_proxy.RemoteDataContext()
lm_pb2 = _literals_pb2.LiteralMap()
lm_pb2.ParseFromString(input_bytes)
vm_pb2 = _interface_pb2.VariableMap()
vm_pb2.ParseFromString(variable_map_bytes)
# TODO: Inject vargs and wf_params
return _type_helpers.unpack_literal_map_to_sdk_python_std(
_literals.LiteralMap.from_flyte_idl(lm_pb2),
{
k: _type_helpers.get_sdk_type_from_literal_type(v.type)
for k, v in _six.iteritems(_interface.VariableMap.from_flyte_idl(vm_pb2).variables)
}
def local_execute(self, **input_map):
"""
:param dict[Text, T] input_map: Python Std input from users. We will cast these to the appropriate Flyte
literals.
:rtype: dict[Text, T]
:returns: The output produced by this task in Python standard format.
"""
return _engine_loader.get_engine('local').get_task(self).execute(
_type_helpers.pack_python_std_map_to_literal_map(input_map, {
k: _type_helpers.get_sdk_type_from_literal_type(v.type)
for k, v in _six.iteritems(self.interface.inputs)
})
# should be set in one of three places,
# 1) When the object is registered (in the code above)
# 2) By the dynamic task code after this runnable object has already been __call__'ed. The SdkNode produced
# maintains a link to this object and will set the ID according to the configuration variables present.
# 3) When SdkLaunchPlan.fetch() is run
super(SdkRunnableLaunchPlan, self).__init__(
None,
_launch_plan_models.LaunchPlanMetadata(
schedule=schedule or _schedule_model.Schedule(''),
notifications=notifications or []
),
_interface_models.ParameterMap(default_inputs),
_type_helpers.pack_python_std_map_to_literal_map(
fixed_inputs,
{
k: _type_helpers.get_sdk_type_from_literal_type(var.type)
for k, var in _six.iteritems(sdk_workflow.interface.inputs) if k in fixed_inputs
}
),
labels or _common_models.Labels({}),
annotations or _common_models.Annotations({}),
auth,
)
self._interface = _interface.TypedInterface(
{k: v.var for k, v in _six.iteritems(default_inputs)},
sdk_workflow.interface.outputs
)
self._upstream_entities = {sdk_workflow}
self._sdk_workflow = sdk_workflow