Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
}
if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)
if update:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:updateEndpoint'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createEndpoint'
kwargs[Field.Parameters.value] = parameters
super(EndpointStep, self).__init__(state_id, **kwargs)
class TuningStep(Task):
"""
Creates a Task State to execute a SageMaker HyperParameterTuning Job.
"""
def __init__(self, state_id, tuner, job_name, data, wait_for_completion=True, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
tuner (sagemaker.tuner.HyperparameterTuner): The tuner to use in the TuningStep.
job_name (str or Placeholder): Specify a tuning job name. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
data: Information about the training data. Please refer to the ``fit()`` method of the associated estimator in the tuner, as this can take any of the following forms:
* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import absolute_import
from stepfunctions.inputs import ExecutionInput, StepInput
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
from stepfunctions.steps.utils import tags_dict_to_kv_list
from sagemaker.workflow.airflow import training_config, transform_config, model_config, tuning_config
from sagemaker.model import Model, FrameworkModel
from sagemaker.model_monitor import DataCaptureConfig
class TrainingStep(Task):
"""
Creates a Task State to execute a `SageMaker Training Job `_. The TrainingStep will also create a model by default, and the model shares the same name as the training job.
"""
def __init__(self, state_id, estimator, job_name, data=None, hyperparameters=None, mini_batch_size=None, experiment_config=None, wait_for_completion=True, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
estimator (sagemaker.estimator.EstimatorBase): The estimator for the training step. Can be a `BYO estimator, Framework estimator `_ or `Amazon built-in algorithm estimator `_.
job_name (str or Placeholder): Specify a training job name, this is required for the training job to run. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
data: Information about the training data. Please refer to the ``fit()`` method of the associated estimator, as this can take any of the following forms:
* (str) - The S3 location where training data is saved.
* (dict[str, str] or dict[str, sagemaker.session.s3_input]) - If using multiple
channels for training data, you can specify a dict mapping channel names to
heartbeat_seconds (int, optional): Positive integer specifying heartbeat timeout for the state in seconds. This value should be lower than the one specified for `timeout_seconds`. If more time than the specified heartbeat elapses between heartbeats from the task, then the interpreter fails the state with a `States.Timeout` Error Name.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
if wait_for_callback:
kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage.waitForTaskToken'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::sqs:sendMessage'
super(SqsSendMessageStep, self).__init__(state_id, **kwargs)
class EmrCreateClusterStep(Task):
"""
Creates a Task state to create and start running a cluster (job flow). See `Call Amazon EMR with Step Functions `_ for more details.
"""
def __init__(self, state_id, wait_for_completion=True, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
"""
if wait_for_completion:
else:
raise ValueError("Expected 'model' parameter to be of type 'sagemaker.model.Model', but received type '{}'".format(type(model).__name__))
if 'S3Operations' in parameters:
del parameters['S3Operations']
if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)
kwargs[Field.Parameters.value] = parameters
kwargs[Field.Resource.value] = 'arn:aws:states:::sagemaker:createModel'
super(ModelStep, self).__init__(state_id, **kwargs)
class EndpointConfigStep(Task):
"""
Creates a Task State to `create an endpoint configuration in SageMaker `_.
"""
def __init__(self, state_id, endpoint_config_name, model_name, initial_instance_count, instance_type, variant_name='AllTraffic', data_capture_config=None, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
endpoint_config_name (str or Placeholder): The name of the endpoint configuration to create. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
model_name (str or Placeholder): The name of the SageMaker model to attach to the endpoint configuration. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
initial_instance_count (int or Placeholder): The initial number of instances to run in the ``Endpoint`` created from this ``Model``.
instance_type (str or Placeholder): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'.
variant_name (str, optional): The name of the production variant.
data_capture_config (sagemaker.model_monitor.DataCaptureConfig, optional): Specifies
configuration related to Endpoint data capture for use with
to the ModelStep to save the trained model in Sagemaker.
Args:
model_name (str, optional): Specify a model name. If not provided, training job name will be used as the model name.
Returns:
sagemaker.model.Model: Sagemaker model representation of the expected trained model.
"""
model = self.estimator.create_model()
if model_name:
model.name = model_name
else:
model.name = self.job_name
model.model_data = self.output()["ModelArtifacts"]["S3ModelArtifacts"]
return model
class TransformStep(Task):
"""
Creates a Task State to execute a `SageMaker Transform Job `_.
"""
def __init__(self, state_id, transformer, job_name, model_name, data, data_type='S3Prefix', content_type=None, compression_type=None, split_type=None, experiment_config=None, wait_for_completion=True, tags=None, input_filter=None, output_filter=None, join_source=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
transformer (sagemaker.transformer.Transformer): The SageMaker transformer to use in the TransformStep.
job_name (str or Placeholder): Specify a transform job name. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
model_name (str or Placeholder): Specify a model name for the transform job to use. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
data (str): Input data location in S3.
data_type (str): What the S3 location defines (default: 'S3Prefix').
Valid values:
if isinstance(job_name, (ExecutionInput, StepInput)):
parameters['TransformJobName'] = job_name
parameters['ModelName'] = model_name
if experiment_config is not None:
parameters['ExperimentConfig'] = experiment_config
if tags:
parameters['Tags'] = tags_dict_to_kv_list(tags)
kwargs[Field.Parameters.value] = parameters
super(TransformStep, self).__init__(state_id, **kwargs)
class ModelStep(Task):
"""
Creates a Task State to `create a model in SageMaker `_.
"""
def __init__(self, state_id, model, model_name=None, instance_type=None, tags=None, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
model (sagemaker.model.Model): The SageMaker model to use in the ModelStep. If :py:class:`TrainingStep` was used to train the model and saving the model is the next step in the workflow, the output of :py:func:`TrainingStep.get_expected_model()` can be passed here.
model_name (str or Placeholder, optional): Specify a model name, this is required for creating the model. We recommend to use :py:class:`~stepfunctions.inputs.ExecutionInput` placeholder collection to pass the value dynamically in each execution.
instance_type (str, optional): The EC2 instance type to deploy this Model to. For example, 'ml.p2.xlarge'. This parameter is typically required when the estimator used is not an `Amazon built-in algorithm `_.
tags (list[dict], optional): `List to tags `_ to associate with the resource.
"""
if isinstance(model, FrameworkModel):
parameters = model_config(model=model, instance_type=instance_type, role=model.role, image=model.image)
def __init__(self, state_id, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:putItem'
super(DynamoDBPutItemStep, self).__init__(state_id, **kwargs)
class DynamoDBDeleteItemStep(Task):
"""
Creates a Task state to delete an item from DynamoDB. See `Call DynamoDB APIs with Step Functions `_ for more details.
"""
def __init__(self, state_id, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:deleteItem'
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:createCluster'
super(EmrCreateClusterStep, self).__init__(state_id, **kwargs)
class EmrTerminateClusterStep(Task):
"""
Creates a Task state to shut down a cluster (job flow). See `Call Amazon EMR with Step Functions `_ for more details.
"""
def __init__(self, state_id, wait_for_completion=True, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
"""
if wait_for_completion:
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
wait_for_completion (bool, optional): Boolean value set to `True` if the Task state should wait to complete before proceeding to the next step in the workflow. (default: True)
"""
if wait_for_completion:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep.sync'
else:
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:addStep'
super(EmrAddStepStep, self).__init__(state_id, **kwargs)
class EmrCancelStepStep(Task):
"""
Creates a Task state to cancel a pending step in a running cluster. See `Call Amazon EMR with Step Functions `_ for more details.
"""
def __init__(self, state_id, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::elasticmapreduce:cancelStep'
# You may not use this file except in compliance with the License.
# A copy of the License is located at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# or in the "license" file accompanying this file. This file is distributed
# on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.
from __future__ import absolute_import
from stepfunctions.steps.states import Task
from stepfunctions.steps.fields import Field
class DynamoDBGetItemStep(Task):
"""
Creates a Task state to get an item from DynamoDB. See `Call DynamoDB APIs with Step Functions `_ for more details.
"""
def __init__(self, state_id, **kwargs):
"""
Args:
state_id (str): State name whose length **must be** less than or equal to 128 unicode characters. State names **must be** unique within the scope of the whole state machine.
comment (str, optional): Human-readable comment or description. (default: None)
input_path (str, optional): Path applied to the state’s raw input to select some or all of it; that selection is used by the state. (default: '$')
parameters (dict, optional): The value of this field becomes the effective input for the state.
result_path (str, optional): Path specifying the raw input’s combination with or replacement by the state’s result. (default: '$')
output_path (str, optional): Path applied to the state’s output after the application of `result_path`, producing the effective output which serves as the raw input for the next state. (default: '$')
"""
kwargs[Field.Resource.value] = 'arn:aws:states:::dynamodb:getItem'
super(DynamoDBGetItemStep, self).__init__(state_id, **kwargs)