Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_task_runner_skips_upstream_check_for_parent_mapped_task_but_not_children(
executor,
):
add = AddTask(trigger=prefect.triggers.all_failed)
ex = Edge(SuccessTask(), add, key="x")
ey = Edge(ListTask(), add, key="y", mapped=True)
runner = TaskRunner(add)
with executor.start():
res = runner.run(
upstream_states={ex: Success(result=1), ey: Success(result=[1, 2, 3])},
executor=executor,
)
res.map_states = executor.wait(res.map_states)
assert isinstance(res, Mapped)
assert all([isinstance(s, TriggerFailed) for s in res.map_states])
def test_get_inputs_from_upstream_with_non_key_edges(self):
inputs = TaskRunner(task=Task()).get_task_inputs(
state=Pending(),
upstream_states={
Edge(1, 2, key="x"): Success(result=1),
Edge(1, 2): Success(result=2),
},
)
assert inputs == {"x": Result(1)}
def test_get_inputs_from_upstream_reads_secret_results(self):
secret_handler = SecretResultHandler(prefect.tasks.secrets.Secret(name="foo"))
result = SafeResult("1", result_handler=JSONResultHandler())
state = Success(result=result)
with prefect.context(secrets=dict(foo=42)):
inputs = TaskRunner(task=Task()).get_task_inputs(
state=Pending(),
upstream_states={
Edge(Task(result_handler=secret_handler), 2, key="x"): state
},
)
res = Result(value=42, result_handler=secret_handler)
res.safe_value = result
assert inputs == {"x": res}
def test_viz_reflects_multiple_mapping_if_flow_state_provided(self):
ipython = MagicMock(
get_ipython=lambda: MagicMock(config=dict(IPKernelApp=True))
)
add = AddTask(name="a_nice_task")
list_task = Task(name="a_list_task")
map_state1 = Mapped(map_states=[Success(), TriggerFailed()])
map_state2 = Mapped(map_states=[Success(), Failed()])
with patch.dict("sys.modules", IPython=ipython):
with Flow(name="test") as f:
first_res = add.map(x=list_task, y=8)
with pytest.warns(
UserWarning
): # making a copy of a task with dependencies
res = first_res.map(x=first_res, y=9)
graph = f.visualize(
flow_state=Success(
result={
res: map_state1,
list_task: Success(),
first_res: map_state2,
}
)
def test_viz_reflects_mapping_if_flow_state_provided(self):
ipython = MagicMock(
get_ipython=lambda: MagicMock(config=dict(IPKernelApp=True))
)
add = AddTask(name="a_nice_task")
list_task = Task(name="a_list_task")
map_state = Mapped(map_states=[Success(), Failed()])
with patch.dict("sys.modules", IPython=ipython):
with Flow() as f:
res = add.map(x=list_task, y=8)
graph = f.visualize(
flow_state=Success(result={res: map_state, list_task: Success()})
)
# one colored node for each mapped result
assert 'label="a_nice_task <map>" color="#00800080"' in graph.source
assert 'label="a_nice_task <map>" color="#FF000080"' in graph.source
assert 'label=a_list_task color="#00800080"' in graph.source
assert 'label=8 color="#00000080"' in graph.source
# two edges for each input to add()
for var in ["x", "y"]:
for index in [0, 1]:
assert "{0} [label={1} style=dashed]".format(index, var) in graph.source
</map></map>
def test_flow_runner_makes_copy_of_task_results_dict():
"""
Ensure the flow runner copies the task_results dict rather than modifying it inplace
"""
flow = Flow(name="test")
t1, t2 = Task(), Task()
flow.add_edge(t1, t2)
task_states = {t1: Pending()}
state = flow.run(task_states=task_states)
assert state.result[t1] == Success(result=None)
assert task_states == {t1: Pending()}
def test_all_failed_fail(self):
task = Task(trigger=prefect.triggers.all_failed)
state = Pending()
with pytest.raises(ENDRUN) as exc:
TaskRunner(task).check_task_trigger(
state=state, upstream_states={1: Success(), 2: Failed()}
)
assert isinstance(exc.value.state, TriggerFailed)
assert 'Trigger was "all_failed"' in str(exc.value.state)
def test_task_runner_handles_outputs_prior_to_setting_state(self, client):
@prefect.task(
cache_for=datetime.timedelta(days=1), result_handler=JSONResultHandler()
)
def add(x, y):
return x + y
result = Result(1, result_handler=JSONResultHandler())
assert result.safe_value is NoResult
x_state, y_state = Success(result=result), Success(result=result)
upstream_states = {
Edge(Task(), Task(), key="x"): x_state,
Edge(Task(), Task(), key="y"): y_state,
}
res = CloudTaskRunner(task=add).run(upstream_states=upstream_states)
assert result.safe_value != NoResult # proves was handled
## assertions
assert client.get_task_run_info.call_count == 0 # never called
assert (
client.set_task_run_state.call_count == 3
) # Pending -> Running -> Successful -> Cached
states = [call[1]["state"] for call in client.set_task_run_state.call_args_list]
upstream_states = {} # type: Dict[Edge, Union[State, Iterable]]
# -- process each edge to the task
for edge in self.flow.edges_to(task):
upstream_states[edge] = task_states.get(
edge.upstream_task, Pending(message="Task state not available.")
)
# augment edges with upstream constants
for key, val in self.flow.constants[task].items():
edge = Edge(
upstream_task=prefect.tasks.core.constants.Constant(val),
downstream_task=task,
key=key,
)
upstream_states[edge] = Success(
"Auto-generated constant value",
result=Result(val, result_handler=ConstantResultHandler(val)),
)
# -- run the task
with prefect.context(task_full_name=task.name, task_tags=task.tags):
task_states[task] = executor.submit(
self.run_task,
task=task,
state=task_state,
upstream_states=upstream_states,
context=dict(prefect.context, **task_contexts.get(task, {})),
task_runner_state_handlers=task_runner_state_handlers,
executor=executor,
)
# check if any key task failed
elif any(s.is_failed() for s in key_states):
self.logger.info("Flow run FAILED: some reference tasks failed.")
state = Failed(message="Some reference tasks failed.", result=return_states)
# check if all reference tasks succeeded
elif all(s.is_successful() for s in key_states):
self.logger.info("Flow run SUCCESS: all reference tasks succeeded")
state = Success(
message="All reference tasks succeeded.", result=return_states
)
# check for any unanticipated state that is finished but neither success nor failed
else:
self.logger.info("Flow run SUCCESS: no reference tasks failed")
state = Success(message="No reference tasks failed.", result=return_states)
return state