Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_name_from_base_short(sagemaker_short_timestamp):
sagemaker.utils.name_from_base(NAME, short=True)
assert sagemaker_short_timestamp.called_once
def test_name_from_base(sagemaker_timestamp):
sagemaker.utils.name_from_base(NAME, short=False)
assert sagemaker_timestamp.called_once
job_name (str): Name of the training job to be created. If not
specified, one is generated, using the base name given to the
constructor if applicable.
"""
if job_name is not None:
self._current_job_name = job_name
else:
# honor supplied base_job_name or generate it
if self.base_job_name:
base_name = self.base_job_name
elif isinstance(self, sagemaker.algorithm.AlgorithmEstimator):
base_name = self.algorithm_arn.split("/")[-1] # pylint: disable=no-member
else:
base_name = base_name_from_image(self.train_image())
self._current_job_name = name_from_base(base_name)
# if output_path was specified we use it otherwise initialize here.
# For Local Mode with local_code=True we don't need an explicit output_path
if self.output_path is None:
local_code = get_config_value("local.local_code", self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
self.output_path = ""
else:
self.output_path = "s3://{}/".format(self.sagemaker_session.default_bucket())
# Prepare rules and debugger configs for training.
if self.rules and self.debugger_hook_config is None:
self.debugger_hook_config = DebuggerHookConfig(s3_output_path=self.output_path)
# If an object was provided without an S3 URI is not provided, default it for the customer.
if self.debugger_hook_config and not self.debugger_hook_config.s3_output_path:
self.debugger_hook_config.s3_output_path = self.output_path
Returns:
dict: The container information of this framework model.
"""
deploy_image = model.image
if not deploy_image:
region_name = model.sagemaker_session.boto_session.region_name
deploy_image = fw_utils.create_image_uri(
region_name,
model.__framework_name__,
instance_type,
model.framework_version,
model.py_version,
)
base_name = utils.base_name_from_image(deploy_image)
model.name = model.name or utils.name_from_base(base_name)
bucket = model.bucket or model.sagemaker_session._default_bucket
script = os.path.basename(model.entry_point)
key = "{}/source/sourcedir.tar.gz".format(model.name)
if model.source_dir and model.source_dir.lower().startswith("s3://"):
code_dir = model.source_dir
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
else:
code_dir = "s3://{}/{}".format(bucket, key)
model.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
s3_operations["S3Upload"] = [
{"Path": model.source_dir or script, "Bucket": bucket, "Key": key, "Tar": True}
]
deploy_env = dict(model.env)
def _prepare_job_name_for_tuning(self, job_name=None):
"""Set current job name before starting tuning"""
if job_name is not None:
self._current_job_name = job_name
else:
base_name = self.base_tuning_job_name
if base_name is None:
estimator = (
self.estimator or self.estimator_dict[sorted(self.estimator_dict.keys())[0]]
)
base_name = base_name_from_image(estimator.train_image())
self._current_job_name = name_from_base(
base_name, max_length=self.TUNING_JOB_NAME_MAX_LENGTH, short=True
)
specified, one is generated using the base name given to the
constructor, if applicable.
Returns:
str: The supplied or generated job name.
"""
if job_name is not None:
return job_name
if self.base_job_name:
base_name = self.base_job_name
else:
base_name = _SUGGESTION_JOB_BASE_NAME
return name_from_base(base=base_name)
"""
if isinstance(estimator, sagemaker.amazon.amazon_estimator.AmazonAlgorithmEstimatorBase):
estimator.prepare_workflow_for_training(
records=inputs, mini_batch_size=mini_batch_size, job_name=job_name
)
else:
estimator.prepare_workflow_for_training(job_name=job_name)
default_bucket = estimator.sagemaker_session.default_bucket()
s3_operations = {}
if job_name is not None:
estimator._current_job_name = job_name
else:
base_name = estimator.base_job_name or utils.base_name_from_image(estimator.train_image())
estimator._current_job_name = utils.name_from_base(base_name)
if estimator.output_path is None:
estimator.output_path = "s3://{}/".format(default_bucket)
if isinstance(estimator, sagemaker.estimator.Framework):
prepare_framework(estimator, s3_operations)
elif isinstance(estimator, amazon_estimator.AmazonAlgorithmEstimatorBase):
prepare_amazon_algorithm_estimator(estimator, inputs, mini_batch_size)
job_config = job._Job._load_config(inputs, estimator, expand_role=False, validate_uri=False)
train_config = {
"AlgorithmSpecification": {
"TrainingImage": estimator.train_image(),
"TrainingInputMode": estimator.input_mode,
},
def update_data_capture_config(self, data_capture_config):
"""Updates the DataCaptureConfig for the Predictor's associated Amazon SageMaker Endpoint
with the provided DataCaptureConfig.
Args:
data_capture_config (sagemaker.model_monitor.DataCaptureConfig): The
DataCaptureConfig to update the predictor's endpoint to use.
"""
endpoint_desc = self.sagemaker_session.sagemaker_client.describe_endpoint(
EndpointName=self.endpoint
)
new_config_name = name_from_base(base=self.endpoint)
data_capture_config_dict = None
if data_capture_config is not None:
data_capture_config_dict = data_capture_config._to_request_dict()
self.sagemaker_session.create_endpoint_config_from_existing(
existing_config_name=endpoint_desc["EndpointConfigName"],
new_config_name=new_config_name,
new_data_capture_config_dict=data_capture_config_dict,
)
self.sagemaker_session.update_endpoint(
endpoint_name=self.endpoint, endpoint_config_name=new_config_name
)