Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_hive_task_query_generation():
with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
context = _common_engine.EngineContext(
execution_id=WorkflowExecutionIdentifier(
project='unit_test',
domain='unit_test',
name='unit_test'
),
execution_date=_datetime.utcnow(),
stats=None, # TODO: A mock stats object that we can read later.
logging=_logging, # TODO: A mock logging object that we can read later.
tmp_dir=user_working_directory
)
references = {
name: _task_output.OutputReference(_type_helpers.get_sdk_type_from_literal_type(variable.type))
for name, variable in _six.iteritems(two_queries.interface.outputs)
}
def test_hive_task_dynamic_job_spec_generation():
with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
context = _common_engine.EngineContext(
execution_id=WorkflowExecutionIdentifier(
project='unit_test',
domain='unit_test',
name='unit_test'
),
execution_date=_datetime.utcnow(),
stats=None, # TODO: A mock stats object that we can read later.
logging=_logging, # TODO: A mock logging object that we can read later.
tmp_dir=user_working_directory
)
dj_spec = two_queries._produce_dynamic_job_spec(context, _literals.LiteralMap(literals={}))
# Bindings
assert len(dj_spec.outputs[0].binding.collection.bindings) == 2
assert isinstance(dj_spec.outputs[0].binding.collection.bindings[0].scalar.schema, Schema)
def test_single_step_entrypoint_in_proc():
with _TemporaryConfiguration(os.path.join(os.path.dirname(__file__), 'fake.config'),
internal_overrides={
'project': 'test',
'domain': 'development'
}):
with _utils.AutoDeletingTempDir("in") as input_dir:
literal_map = _type_helpers.pack_python_std_map_to_literal_map(
{'a': 9}, _type_map_from_variable_map(_task_defs.add_one.interface.inputs))
input_file = os.path.join(input_dir.name, "inputs.pb")
_utils.write_proto_to_file(literal_map.to_flyte_idl(), input_file)
with _utils.AutoDeletingTempDir("out") as output_dir:
_execute_task(
_task_defs.add_one.task_module,
_task_defs.add_one.task_function_name,
input_file,
output_dir.name,
False
)
p = _utils.load_proto_from_file(
_literals_pb2.LiteralMap,
def test_create_at_known_location():
with _test_utils.LocalTestFileSystem():
with _utils.AutoDeletingTempDir('test') as wd:
b = _schema_impl.Schema.create_at_known_location(wd.name, schema_type=_schema_impl.SchemaType())
assert b.local_path is None
assert b.remote_location == wd.name + "/"
assert b.mode == 'wb'
with b as w:
w.write(_pd.DataFrame.from_dict({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]}))
df = _pd.read_parquet(_os.path.join(wd.name, "000000"))
assert list(df['a']) == [1, 2, 3, 4]
assert list(df['b']) == [5, 6, 7, 8]
def _execute_user_code(self, inputs):
"""
:param flytekit.models.literals.LiteralMap inputs:
:rtype: dict[Text,flytekit.models.common.FlyteIdlEntity]
"""
with _common_utils.AutoDeletingTempDir("user_dir") as user_working_directory:
return self.sdk_task.execute(
_common_engine.EngineContext(
execution_id=WorkflowExecutionIdentifier(
project='unit_test',
domain='unit_test',
name='unit_test'
),
execution_date=_datetime.utcnow(),
stats=MockStats(),
logging=_logging, # TODO: A mock logging object that we can read later.
tmp_dir=user_working_directory
),
inputs
)
def execute(self, inputs, context=None):
"""
Just execute the task and write the outputs to where they belong
:param flytekit.models.literals.LiteralMap inputs:
:param dict[Text, Text] context:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
"""
with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir:
with _common_utils.AutoDeletingTempDir("task_dir") as task_dir:
with _data_proxy.LocalWorkingDirectoryContext(task_dir):
with _data_proxy.RemoteDataContext():
output_file_dict = dict()
# This sets the logging level for user code and is the only place an sdk setting gets
# used at runtime. Optionally, Propeller can set an internal config setting which
# takes precedence.
log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
_logging.getLogger().setLevel(log_level)
try:
output_file_dict = self.sdk_task.execute(
_common_engine.EngineContext(
execution_id=_identifier.WorkflowExecutionIdentifier(
project=_internal_config.EXECUTION_PROJECT.get(),
def execute(self, inputs, context=None):
"""
Just execute the task and write the outputs to where they belong
:param flytekit.models.literals.LiteralMap inputs:
:param dict[Text, Text] context:
:rtype: dict[Text, flytekit.models.common.FlyteIdlEntity]
"""
with _common_utils.AutoDeletingTempDir("engine_dir") as temp_dir:
with _common_utils.AutoDeletingTempDir("task_dir") as task_dir:
with _data_proxy.LocalWorkingDirectoryContext(task_dir):
with _data_proxy.RemoteDataContext():
output_file_dict = dict()
# This sets the logging level for user code and is the only place an sdk setting gets
# used at runtime. Optionally, Propeller can set an internal config setting which
# takes precedence.
log_level = _internal_config.LOGGING_LEVEL.get() or _sdk_config.LOGGING_LEVEL.get()
_logging.getLogger().setLevel(log_level)
try:
output_file_dict = self.sdk_task.execute(
_common_engine.EngineContext(
execution_id=_identifier.WorkflowExecutionIdentifier(
project=_internal_config.EXECUTION_PROJECT.get(),
domain=_internal_config.EXECUTION_DOMAIN.get(),
def inject_inputs(variable_map_bytes, input_bytes, working_directory):
"""
This method forwards necessary context into the notebook Kernel. Ideally, this code shouldn't be duplicating what
is in the underlying engine, but for now...
:param bytes variable_map_bytes:
:param bytes input_bytes:
:param Text working_directory:
:rtype: dict[Text,Any]
"""
if not _os.path.exists(working_directory):
tmpdir = _utils.AutoDeletingTempDir("nb_made_")
tmpdir.__enter__()
working_directory = tmpdir.name
_data_proxy.LocalWorkingDirectoryContext(working_directory).__enter__()
_data_proxy.RemoteDataContext()
lm_pb2 = _literals_pb2.LiteralMap()
lm_pb2.ParseFromString(input_bytes)
vm_pb2 = _interface_pb2.VariableMap()
vm_pb2.ParseFromString(variable_map_bytes)
# TODO: Inject vargs and wf_params
return _type_helpers.unpack_literal_map_to_sdk_python_std(
_literals.LiteralMap.from_flyte_idl(lm_pb2),
{
k: _type_helpers.get_sdk_type_from_literal_type(v.type)
def get_inputs(self):
"""
:rtype: flytekit.models.literals.LiteralMap
"""
client = _FlyteClientManager(_platform_config.URL.get(), insecure=_platform_config.INSECURE.get()).client
url_blob = client.get_execution_data(self.sdk_workflow_execution.id)
if url_blob.inputs.bytes > 0:
with _common_utils.AutoDeletingTempDir() as t:
tmp_name = _os.path.join(t.name, "inputs.pb")
_data_proxy.Data.get_data(url_blob.inputs.url, tmp_name)
return _literals.LiteralMap.from_flyte_idl(
_common_utils.load_proto_from_file(_literals_pb2.LiteralMap, tmp_name)
)
return _literals.LiteralMap({})