Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_task_runner_raises_endrun_with_correct_state_if_client_cant_receive_state_updates(
monkeypatch
):
task = Task(name="test")
get_task_run_info = MagicMock(side_effect=SyntaxError)
set_task_run_state = MagicMock()
client = MagicMock(
get_task_run_info=get_task_run_info, set_task_run_state=set_task_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.task_runner.Client", MagicMock(return_value=client)
)
## an ENDRUN will cause the TaskRunner to return the most recently computed state
state = Pending(message="unique message", result=42)
res = CloudTaskRunner(task=task).run(state=state, context={"map_index": 1})
assert get_task_run_info.called
assert res is state
def test_set_task_run_state(patch_post):
response = {"data": {"setTaskRunStates": {"states": [{"status": "SUCCESS"}]}}}
post = patch_post(response)
state = Pending()
with set_temporary_config(
{"cloud.graphql": "http://my-cloud.foo", "cloud.auth_token": "secret_token"}
):
client = Client()
result = client.set_task_run_state(task_run_id="76-salt", version=0, state=state)
assert result is state
def test_flow_runner_prioritizes_kwarg_states_over_db_states(monkeypatch, state):
flow = prefect.Flow(name="test")
db_state = state("already", result=10)
get_flow_run_info = MagicMock(return_value=MagicMock(state=db_state))
set_flow_run_state = MagicMock()
client = MagicMock(
get_flow_run_info=get_flow_run_info, set_flow_run_state=set_flow_run_state
)
monkeypatch.setattr(
"prefect.engine.cloud.flow_runner.Client", MagicMock(return_value=client)
)
res = CloudFlowRunner(flow=flow).run(state=Pending("let's do this"))
## assertions
assert get_flow_run_info.call_count == 1 # one time to pull latest state
assert set_flow_run_state.call_count == 2 # Pending -> Running -> Success
states = [call[1]["state"] for call in set_flow_run_state.call_args_list]
assert states == [Running(), Success(result={})]
def test_with_two_finished(self):
state = Pending()
new_state = TaskRunner(Task()).check_upstream_finished(
state=state, upstream_states={1: Success(), 2: Failed()}
)
assert new_state is state
def test_get_empty_inputs(self):
inputs = TaskRunner(task=Task()).get_task_inputs(
state=Pending(), upstream_states={}
)
assert inputs == {}
cached_inputs=complex_result,
result=res3,
cached_parameters={"x": 1, "y": {"z": 2}},
cached_result_expiration=utc_dt,
)
cached_state_naive = state.Cached(
cached_inputs=complex_result,
result=res3,
cached_parameters={"x": 1, "y": {"z": 2}},
cached_result_expiration=naive_dt,
)
running_tags = state.Running()
running_tags.context = dict(tags=["1", "2", "3"])
test_states = [
state.Looped(loop_count=45),
state.Pending(cached_inputs=complex_result),
state.Paused(cached_inputs=complex_result),
state.Retrying(start_time=utc_dt, run_count=3),
state.Retrying(start_time=naive_dt, run_count=3),
state.Scheduled(start_time=utc_dt),
state.Scheduled(start_time=naive_dt),
state.Resume(start_time=utc_dt),
state.Resume(start_time=naive_dt),
running_tags,
state.Submitted(state=state.Retrying(start_time=utc_dt, run_count=2)),
state.Submitted(state=state.Resume(start_time=utc_dt)),
state.Queued(state=state.Pending()),
state.Queued(state=state.Pending(), start_time=utc_dt),
state.Queued(state=state.Retrying(start_time=utc_dt, run_count=2)),
cached_state,
cached_state_naive,
state.TimedOut(cached_inputs=complex_result),
# ---------------------------------------------
# Collect results
# ---------------------------------------------
# terminal tasks determine if the flow is finished
terminal_tasks = self.flow.terminal_tasks()
# reference tasks determine flow state
reference_tasks = self.flow.reference_tasks()
# wait until all terminal tasks are finished
final_tasks = terminal_tasks.union(reference_tasks).union(return_tasks)
final_states = executor.wait(
{
t: task_states.get(t, Pending("Task not evaluated by FlowRunner."))
for t in final_tasks
}
)
# also wait for any children of Mapped tasks to finish, and add them
# to the dictionary to determine flow state
all_final_states = final_states.copy()
for t, s in list(final_states.items()):
if s.is_mapped():
s.map_states = executor.wait(s.map_states)
s.result = [ms.result for ms in s.map_states]
all_final_states[t] = s.map_states
assert isinstance(final_states, dict)
key_states = set(flatten_seq([all_final_states[t] for t in reference_tasks]))
Returns:
- State: the state of the task after running the check
Raises:
- ENDRUN: if the task is not ready to run
"""
if state.is_cached():
assert isinstance(state, Cached) # mypy assert
sanitized_inputs = {key: res.value for key, res in inputs.items()}
if self.task.cache_validator(
state, sanitized_inputs, prefect.context.get("parameters")
):
state._result = state._result.to_result(self.task.result_handler)
return state
else:
state = Pending("Cache was invalid; ready to run.")
if self.task.cache_for is not None:
candidate_states = prefect.context.caches.get(
self.task.cache_key or self.task.name, []
)
sanitized_inputs = {key: res.value for key, res in inputs.items()}
for candidate in candidate_states:
if self.task.cache_validator(
candidate, sanitized_inputs, prefect.context.get("parameters")
):
candidate._result = candidate._result.to_result(
self.task.result_handler
)
return candidate
if self.task.cache_for is not None:
If the provided state is a meta state, the state it wraps is extracted.
Args:
- state (Optional[State]): the initial state of the run
- context (dict): the context to be updated with relevant information
Returns:
- tuple: a tuple of the updated state and context objects
"""
# extract possibly nested meta states -> for example a Submitted( Queued( Retry ) )
while isinstance(state, State) and state.is_meta_state():
state = state.state # type: ignore
state = state or Pending()
return state, context