Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_tf_deploy_model_server_workers_unset(sagemaker_session):
tf = _build_tf(sagemaker_session)
tf.fit(inputs=s3_input("s3://mybucket/train"))
tf.deploy(initial_instance_count=1, instance_type="ml.c2.2xlarge")
assert (
MODEL_SERVER_WORKERS_PARAM_NAME.upper()
not in sagemaker_session.method_calls[3][1][2]["Environment"]
)
def test_fit_verify_job_name(strftime, sagemaker_session):
fw = DummyFramework(
entry_point=SCRIPT_PATH,
role="DummyRole",
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
enable_cloudwatch_metrics=True,
tags=TAGS,
encrypt_inter_container_traffic=True,
)
fw.fit(inputs=s3_input("s3://mybucket/train"))
_, _, train_kwargs = sagemaker_session.train.mock_calls[0]
assert train_kwargs["hyperparameters"]["sagemaker_enable_cloudwatch_metrics"]
assert train_kwargs["image"] == IMAGE_NAME
assert train_kwargs["input_mode"] == "File"
assert train_kwargs["tags"] == TAGS
assert train_kwargs["job_name"] == JOB_NAME
assert train_kwargs["encrypt_inter_container_traffic"] is True
assert fw.latest_training_job.name == JOB_NAME
def test_format_input_s3_input():
input_dict = _Job._format_inputs_to_input_config(
s3_input(
"s3://foo/bar",
distribution="ShardedByS3Key",
compression="gzip",
content_type="whizz",
record_wrapping="bang",
)
)
assert input_dict == [
{
"CompressionType": "gzip",
"ChannelName": "training",
"ContentType": "whizz",
"DataSource": {
"S3DataSource": {
"S3DataType": "S3Prefix",
"S3DataDistributionType": "ShardedByS3Key",
def test_load_config(estimator):
inputs = s3_input(BUCKET_NAME)
config = _Job._load_config(inputs, estimator)
assert config["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] == BUCKET_NAME
assert config["role"] == ROLE
assert config["output_config"]["S3OutputPath"] == S3_OUTPUT_PATH
assert "KmsKeyId" not in config["output_config"]
assert config["resource_config"]["InstanceCount"] == INSTANCE_COUNT
assert config["resource_config"]["InstanceType"] == INSTANCE_TYPE
assert config["resource_config"]["VolumeSizeInGB"] == VOLUME_SIZE
assert config["stop_condition"]["MaxRuntimeInSeconds"] == MAX_RUNTIME
def test_dict_of_mixed_input_types():
input_list = _Job._format_inputs_to_input_config(
{"a": "s3://foo/bar", "b": s3_input("s3://whizz/bang")}
)
expected = [
{
"ChannelName": "a",
"DataSource": {
"S3DataSource": {
"S3DataDistributionType": "FullyReplicated",
"S3DataType": "S3Prefix",
"S3Uri": "s3://foo/bar",
}
},
},
{
"ChannelName": "b",
"DataSource": {
uri_input,
content_type=content_type,
input_mode=input_mode,
compression=compression,
target_attribute_name=target_attribute_name,
)
return s3_input_result
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
return file_input(uri_input)
if isinstance(uri_input, str) and validate_uri:
raise ValueError(
'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
'"file://"'.format(uri_input)
)
if isinstance(uri_input, str):
s3_input_result = s3_input(
uri_input,
content_type=content_type,
input_mode=input_mode,
compression=compression,
target_attribute_name=target_attribute_name,
)
return s3_input_result
if isinstance(uri_input, (s3_input, file_input, FileSystemInput)):
return uri_input
raise ValueError(
"Cannot format input {}. Expecting one of str, s3_input, file_input or "
"FileSystemInput".format(uri_input)
)
inputs,
estimator,
static_hyperparameters,
metric_definitions,
estimator_name=None,
objective_type=None,
objective_metric_name=None,
parameter_ranges=None,
):
"""Prepare training config for one estimator"""
training_config = _Job._load_config(inputs, estimator)
training_config["input_mode"] = estimator.input_mode
training_config["metric_definitions"] = metric_definitions
if isinstance(inputs, s3_input):
if "InputMode" in inputs.config:
logging.debug(
"Selecting s3_input's input_mode (%s) for TrainingInputMode.",
inputs.config["InputMode"],
)
training_config["input_mode"] = inputs.config["InputMode"]
if isinstance(estimator, sagemaker.algorithm.AlgorithmEstimator):
training_config["algorithm_arn"] = estimator.algorithm_arn
else:
training_config["image"] = estimator.train_image()
training_config["enable_network_isolation"] = estimator.enable_network_isolation()
training_config[
"encrypt_inter_container_traffic"
] = estimator.encrypt_inter_container_traffic
"""
if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("s3://"):
return s3_input(
model_uri,
input_mode="File",
distribution="FullyReplicated",
content_type="application/x-sagemaker-model",
)
if isinstance(model_uri, string_types) and validate_uri and model_uri.startswith("file://"):
return file_input(model_uri)
if isinstance(model_uri, string_types) and validate_uri:
raise ValueError(
'Model URI must be a valid S3 or FILE URI: must start with "s3://" or ' '"file://'
)
if isinstance(model_uri, string_types):
return s3_input(
model_uri,
input_mode="File",
distribution="FullyReplicated",
content_type="application/x-sagemaker-model",
)
raise ValueError("Cannot format model URI {}. Expecting str".format(model_uri))
def records_s3_input(self):
"""Return a s3_input to represent the training data"""
return s3_input(self.s3_data, distribution="ShardedByS3Key", s3_data_type=self.s3_data_type)
content_type=None,
input_mode=None,
compression=None,
target_attribute_name=None,
):
"""
Args:
uri_input:
validate_uri:
content_type:
input_mode:
compression:
target_attribute_name:
"""
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("s3://"):
s3_input_result = s3_input(
uri_input,
content_type=content_type,
input_mode=input_mode,
compression=compression,
target_attribute_name=target_attribute_name,
)
return s3_input_result
if isinstance(uri_input, str) and validate_uri and uri_input.startswith("file://"):
return file_input(uri_input)
if isinstance(uri_input, str) and validate_uri:
raise ValueError(
'URI input {} must be a valid S3 or FILE URI: must start with "s3://" or '
'"file://"'.format(uri_input)
)
if isinstance(uri_input, str):
s3_input_result = s3_input(