Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
State,
Submitted,
Success,
TimedOut,
TriggerFailed,
_MetaState,
)
from prefect.serialization.result_handlers import ResultHandlerSchema
from prefect.serialization.state import StateSchema
all_states = sorted(
set(
cls
for cls in prefect.engine.state.__dict__.values()
if isinstance(cls, type)
and issubclass(cls, prefect.engine.state.State)
and not cls is _MetaState
),
key=lambda c: c.__name__,
)
cached_input_states = sorted(
set(cls for cls in all_states if hasattr(cls(), "cached_inputs")),
key=lambda c: c.__name__,
)
@pytest.mark.parametrize("cls", all_states)
def test_create_state_with_no_args(cls):
state = cls()
assert state.message is None
assert state.result == NoResult
def test_upstream_skipped_states_are_handled_properly(executor):
@task
def skip_task():
pass
@task
def add(x):
return x + 1
with Flow(name="test") as f:
res = add.map(skip_task)
s = f.run(
executor=executor, task_states={skip_task: prefect.engine.state.Skipped()}
)
m = s.result[res]
assert s.is_successful()
assert m.is_skipped()
def test_deserialize_retrying(self, version_0_4_0):
state = s.state.StateSchema().load(version_0_4_0["states"]["retrying"])
assert isinstance(state, prefect.engine.state.Retrying)
run_count = fields.Int(allow_none=True)
class RunningSchema(BaseStateSchema):
class Meta:
object_class = state.Running
class FinishedSchema(BaseStateSchema):
class Meta:
object_class = state.Finished
class SuccessSchema(FinishedSchema):
class Meta:
object_class = state.Success
class CachedSchema(SuccessSchema):
class Meta:
object_class = state.Cached
cached_inputs = ResultSerializerField(allow_none=True)
cached_parameters = JSONCompatible(allow_none=True)
cached_result_expiration = fields.DateTime(allow_none=True)
class MappedSchema(SuccessSchema):
class Meta:
exclude = ["result", "map_states"]
object_class = state.Mapped
def create_object(self, data: dict, **kwargs: Any) -> "prefect.engine.state.State":
result_obj = data.pop("_result", result.NoResult)
data["result"] = result_obj
data.pop("context", None)
base_obj = super().create_object(data)
return base_obj
class ClientFailedSchema(MetaStateSchema):
class Meta:
object_class = state.ClientFailed
class SubmittedSchema(MetaStateSchema):
class Meta:
object_class = state.Submitted
class QueuedSchema(MetaStateSchema):
class Meta:
object_class = state.Queued
start_time = fields.DateTime(allow_none=True)
class ScheduledSchema(PendingSchema):
class Meta:
object_class = state.Scheduled
start_time = fields.DateTime(allow_none=True)
"""
data = self._graphql(
"""
query ($taskRunId: String!){
task_runs(filter: {id: $taskRunId}) {
state {
state
result
}
}
}
""",
taskRunId=task_run_id,
)
state = data.task_runs[0].get("state", {})
return prefect.engine.state.TaskState(
state=state.get("state", None), result=state.get("result", None)
)
class FailedSchema(FinishedSchema):
class Meta:
object_class = state.Failed
class TimedOutSchema(FinishedSchema):
class Meta:
object_class = state.TimedOut
cached_inputs = ResultSerializerField(allow_none=True)
class TriggerFailedSchema(FailedSchema):
class Meta:
object_class = state.TriggerFailed
class SkippedSchema(SuccessSchema):
class Meta:
object_class = state.Skipped
exclude_fields = ["cached"]
class PausedSchema(PendingSchema):
class Meta:
object_class = state.Paused
class StateSchema(OneOfSchema):
"""
Field that chooses between several nested schemas
(or any sequence type).
Example:
```python
flatten_seq([1, 2, [3, 4], 5, [6, [7]]])
>>> [1, 2, 3, 4, 5, 6, 7]
```
Args:
- seq (Iterable): the sequence to flatten
Returns:
- generator: a generator that yields the flattened sequence
"""
for item in seq:
if isinstance(item, collections.Iterable) and not isinstance(
item, (str, bytes, prefect.engine.state.State)
):
yield from flatten_seq(item)
else:
yield item
class ResumeSchema(ScheduledSchema):
class Meta:
object_class = state.Resume
class RetryingSchema(ScheduledSchema):
class Meta:
object_class = state.Retrying
run_count = fields.Int(allow_none=True)
class RunningSchema(BaseStateSchema):
class Meta:
object_class = state.Running
class FinishedSchema(BaseStateSchema):
class Meta:
object_class = state.Finished
class SuccessSchema(FinishedSchema):
class Meta:
object_class = state.Success
class CachedSchema(SuccessSchema):
class Meta:
object_class = state.Cached
"""
Helper function for ensuring only safe values are serialized.
Note that it is up to the user to actively store a Result's value in a
safe way prior to serialization (if they want the result to be avaiable post-serialization).
"""
if context.get("attr") == "_result":
return obj._result.safe_value # type: ignore
value = context.get("value", result.NoResult)
if value is None:
return value
return value.safe_value
class BaseStateSchema(ObjectSchema):
class Meta:
object_class = state.State
context = fields.Dict(key=fields.Str(), values=JSONCompatible(), allow_none=True)
message = fields.String(allow_none=True)
_result = Nested(StateResultSchema, allow_none=False, value_selection_fn=get_safe)
@post_load
def create_object(self, data: dict, **kwargs: Any) -> state.State:
result_obj = data.pop("_result", result.NoResult)
data["result"] = result_obj
base_obj = super().create_object(data)
return base_obj
class PendingSchema(BaseStateSchema):
class Meta:
object_class = state.Pending