Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_input_fn_with_input_channels(docker_image, sagemaker_session, opt_ml, processor):
resource_path = os.path.join(SCRIPT_PATH, '../resources/iris')
copy_resource(resource_path, opt_ml, 'code')
copy_resource(resource_path, opt_ml, 'data', 'input/data')
s3_source_archive = fw_utils.tar_and_upload_dir(session=sagemaker_session.boto_session,
bucket=sagemaker_session.default_bucket(),
s3_key_prefix='test_job',
script='iris_train_input_fn_with_channels.py',
directory=os.path.join(resource_path, 'code'))
additional_hyperparameters = dict(training_steps=1, evaluation_steps=1)
create_config_files('iris_train_input_fn_with_channels.py', s3_source_archive.s3_prefix,
opt_ml, additional_hyperparameters)
os.makedirs(os.path.join(opt_ml, 'model'))
train(docker_image, opt_ml, processor)
assert file_exists(opt_ml, 'model/export/Servo'), 'model was not exported'
assert file_exists(opt_ml, 'model/checkpoint'), 'checkpoint was not created'
assert file_exists(opt_ml, 'output/success'), 'Success file was not created'
assert not file_exists(opt_ml, 'output/failure'), 'Failure happened'
def get_account(framework, framework_version, py_version="py3"):
if (
framework_version in ORIGINAL_FW_VERSIONS[framework]
or framework_version in SERVING_FW_VERSIONS[framework]
or is_mxnet_1_4_py2(
framework, framework_version, py_version
) # except for MXNet 1.4.1 (1.4) py2 Asimov teams owns both py2 and py3
):
return fw_utils.DEFAULT_ACCOUNT
return fw_utils.ASIMOV_DEFAULT_ACCOUNT
def create_docker_services(command, tmpdir, hosts, image, additional_volumes, additional_env_vars,
customer_script, source_dir, entrypoint, use_gpu=False):
environment = []
session = boto3.Session()
optml_dirs = set()
if command == 'train':
optml_dirs = {'output', 'input'}
elif command == 'serve':
environment.extend(DEFAULT_HOSTING_ENV)
if customer_script:
timestamp = utils.sagemaker_timestamp()
s3_script_path = fw_utils.tar_and_upload_dir(session=session,
bucket=default_bucket(session),
s3_key_prefix='test-{}'.format(timestamp),
script=customer_script,
directory=source_dir)[0]
environment.extend([
'SAGEMAKER_PROGRAM={}'.format(os.path.basename(customer_script)),
'SAGEMAKER_SUBMIT_DIRECTORY={}'.format(s3_script_path)
])
else:
raise ValueError('Unexpected command: {}'.format(command))
environment.extend(credentials_to_env(session))
environment.extend(additional_env_vars)
def test_s3_checkpoint_save_timeout(docker_image, opt_ml, sagemaker_session, processor):
resource_path = os.path.join(SCRIPT_PATH, '../resources/python_sdk')
default_bucket = sagemaker_session.default_bucket()
s3_source_archive = fw_utils.tar_and_upload_dir(session=sagemaker_session.boto_session,
bucket=default_bucket,
s3_key_prefix='test_job',
script='rand_model_emb.py',
directory=resource_path)
checkpoint_s3_path = 's3://{}/integ-s3-timeout/checkpoints-{}'.format(default_bucket,
uuid.uuid4())
hyperparameters = dict(
training_steps=3,
evaluation_steps=3,
checkpoint_path=checkpoint_s3_path
)
create_config_files('rand_model_emb.py', s3_source_archive.s3_prefix, opt_ml, hyperparameters)
train(docker_image, opt_ml, processor)
def test_wide_deep(docker_image, sagemaker_session, opt_ml, processor):
resource_path = os.path.join(SCRIPT_PATH, '../resources/wide_deep')
copy_resource(resource_path, opt_ml, 'code')
copy_resource(resource_path, opt_ml, 'data', 'input/data')
s3_source_archive = fw_utils.tar_and_upload_dir(session=sagemaker_session.boto_session,
bucket=sagemaker_session.default_bucket(),
s3_key_prefix='test_job',
script='wide_deep.py',
directory=os.path.join(resource_path, 'code'))
create_config_files('wide_deep.py', s3_source_archive.s3_prefix, opt_ml,
dict(training_steps=1, evaluation_steps=1))
os.makedirs(os.path.join(opt_ml, 'model'))
train(docker_image, opt_ml, processor)
assert file_exists(opt_ml, 'model/export/Servo'), 'model was not exported'
assert file_exists(opt_ml, 'model/checkpoint'), 'checkpoint was not created'
assert file_exists(opt_ml, 'output/success'), 'Success file was not created'
assert not file_exists(opt_ml, 'output/failure'), 'Failure happened'
def prepare_container_def(self, instance_type, accelerator_type=None):
"""
Args:
instance_type:
accelerator_type:
"""
image = self._get_image_uri(instance_type, accelerator_type)
env = self._get_container_env()
if self.entry_point:
key_prefix = sagemaker.fw_utils.model_code_key_prefix(self.key_prefix, self.name, image)
bucket = self.bucket or self.sagemaker_session.default_bucket()
model_data = "s3://" + os.path.join(bucket, key_prefix, "model.tar.gz")
sagemaker.utils.repack_model(
self.entry_point,
self.source_dir,
self.dependencies,
self.model_data,
model_data,
self.sagemaker_session,
kms_key=self.model_kms_key,
)
else:
model_data = self.model_data
def _upload_code(self, key_prefix, repack=False):
"""
Args:
key_prefix:
repack:
"""
local_code = utils.get_config_value("local.local_code", self.sagemaker_session.config)
if self.sagemaker_session.local_mode and local_code:
self.uploaded_code = None
elif not repack:
bucket = self.bucket or self.sagemaker_session.default_bucket()
self.uploaded_code = fw_utils.tar_and_upload_dir(
session=self.sagemaker_session.boto_session,
bucket=bucket,
s3_key_prefix=key_prefix,
script=self.entry_point,
directory=self.source_dir,
dependencies=self.dependencies,
)
if repack:
bucket = self.bucket or self.sagemaker_session.default_bucket()
repacked_model_data = "s3://" + os.path.join(bucket, key_prefix, "model.tar.gz")
utils.repack_model(
inference_script=self.entry_point,
source_directory=self.source_dir,
dependencies=self.dependencies,
def prepare_framework(estimator, s3_operations):
"""Prepare S3 operations (specify where to upload `source_dir` ) and
environment variables related to framework.
Args:
estimator (sagemaker.estimator.Estimator): The framework estimator to
get information from and update.
s3_operations (dict): The dict to specify s3 operations (upload
`source_dir` ).
"""
if estimator.code_location is not None:
bucket, key = fw_utils.parse_s3_url(estimator.code_location)
key = os.path.join(key, estimator._current_job_name, "source", "sourcedir.tar.gz")
elif estimator.uploaded_code is not None:
bucket, key = fw_utils.parse_s3_url(estimator.uploaded_code.s3_prefix)
else:
bucket = estimator.sagemaker_session._default_bucket
key = os.path.join(estimator._current_job_name, "source", "sourcedir.tar.gz")
script = os.path.basename(estimator.entry_point)
if estimator.source_dir and estimator.source_dir.lower().startswith("s3://"):
code_dir = estimator.source_dir
estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
else:
code_dir = "s3://{}/{}".format(bucket, key)
estimator.uploaded_code = fw_utils.UploadedCode(s3_prefix=code_dir, script_name=script)
s3_operations["S3Upload"] = [
def _validate_args(
self,
py_version,
script_mode,
framework_version,
training_steps,
evaluation_steps,
requirements_file,
checkpoint_path,
):
"""Placeholder docstring"""
if py_version == "py3" or script_mode:
if framework_version is None:
raise AttributeError(fw.EMPTY_FRAMEWORK_VERSION_ERROR)
found_args = []
if training_steps:
found_args.append("training_steps")
if evaluation_steps:
found_args.append("evaluation_steps")
if requirements_file:
found_args.append("requirements_file")
if checkpoint_path:
found_args.append("checkpoint_path")
if found_args:
raise AttributeError(
"{} are deprecated in script mode. Please do not set {}.".format(
", ".join(_FRAMEWORK_MODE_ARGS), ", ".join(found_args)
)
)
'mpi':
{
'enabled': True
}
}
**kwargs: Additional kwargs passed to the Framework constructor.
.. tip::
You can find additional parameters for initializing this class at
:class:`~sagemaker.estimator.Framework` and
:class:`~sagemaker.estimator.EstimatorBase`.
"""
if framework_version is None:
logger.warning(fw.empty_framework_version_warning(TF_VERSION, self.LATEST_VERSION))
self.framework_version = framework_version or TF_VERSION
if not py_version:
py_version = "py3" if self._only_python_3_supported() else "py2"
if "enable_sagemaker_metrics" not in kwargs:
# enable sagemaker metrics for TF v1.15 or greater:
if fw.is_version_equal_or_higher([1, 15], self.framework_version):
kwargs["enable_sagemaker_metrics"] = True
super(TensorFlow, self).__init__(image_name=image_name, **kwargs)
self.checkpoint_path = checkpoint_path
if py_version == "py2":
logger.warning("tensorflow py2 container will be deprecated soon.")