Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
e = Estimator(
IMAGE_NAME,
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
output_path=OUTPUT_PATH,
sagemaker_session=sagemaker_session,
enable_sagemaker_metrics=True,
)
e.fit()
sagemaker_session.train.assert_called_once()
args = sagemaker_session.train.call_args[1]
assert args["enable_sagemaker_metrics"]
def test_generic_to_deploy_network_isolation(sagemaker_session):
e = Estimator(
IMAGE_NAME,
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
output_path=OUTPUT_PATH,
enable_network_isolation=True,
sagemaker_session=sagemaker_session,
)
e.fit()
e.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
sagemaker_session.create_model.assert_called_once()
_, kwargs = sagemaker_session.create_model.call_args
assert kwargs["enable_network_isolation"]
def test_generic_create_model_vpc_config_override(sagemaker_session):
vpc_config_a = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]}
e = Estimator(
IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session
)
e.fit({"train": "s3://bucket/training-prefix"})
assert e.get_vpc_config() is None
assert e.create_model().vpc_config is None
assert e.create_model(vpc_config_override=vpc_config_a).vpc_config == vpc_config_a
assert e.create_model(vpc_config_override=None).vpc_config is None
e.subnets = vpc_config_a["Subnets"]
e.security_group_ids = vpc_config_a["SecurityGroupIds"]
assert e.get_vpc_config() == vpc_config_a
assert e.create_model().vpc_config == vpc_config_a
assert e.create_model(vpc_config_override=vpc_config_b).vpc_config == vpc_config_b
assert e.create_model(vpc_config_override=None).vpc_config is None
with pytest.raises(ValueError):
def test_generic_to_fit_with_network_isolation(sagemaker_session):
e = Estimator(
IMAGE_NAME,
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
output_path=OUTPUT_PATH,
sagemaker_session=sagemaker_session,
enable_network_isolation=True,
)
e.fit()
sagemaker_session.train.assert_called_once()
args = sagemaker_session.train.call_args[1]
assert args["enable_network_isolation"]
def test_generic_deploy_accelerator_type(sagemaker_session):
e = Estimator(
IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session
)
e.fit({"train": "s3://bucket/training-prefix"})
e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, ACCELERATOR_TYPE)
args = e.sagemaker_session.endpoint_from_production_variants.call_args[1]
print(args)
assert args["name"].startswith(IMAGE_NAME)
assert args["production_variants"][0]["AcceleratorType"] == ACCELERATOR_TYPE
assert args["production_variants"][0]["InitialInstanceCount"] == INSTANCE_COUNT
assert args["production_variants"][0]["InstanceType"] == INSTANCE_TYPE
def test_fit_deploy_tags_in_estimator(sagemaker_session):
tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}]
estimator = Estimator(
IMAGE_NAME,
ROLE,
INSTANCE_COUNT,
INSTANCE_TYPE,
tags=tags,
sagemaker_session=sagemaker_session,
)
estimator.fit()
estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)
variant = [
{
"InstanceType": "c4.4xlarge",
"VariantName": "AllTraffic",
def linear_learner_estimator():
s3_output_location = 's3://sagemaker/models'
sagemaker_session = MagicMock()
sagemaker_session.boto_region_name = 'us-east-1'
ll_estimator = sagemaker.estimator.Estimator(
LINEAR_LEARNER_IMAGE,
SAGEMAKER_EXECUTION_ROLE,
train_instance_count=1,
train_instance_type='ml.c4.xlarge',
train_volume_size=20,
train_max_run=3600,
input_mode='File',
output_path=s3_output_location,
sagemaker_session=sagemaker_session
)
ll_estimator.debugger_hook_config = DebuggerHookConfig(
s3_output_path='s3://sagemaker/models/debug'
)
ll_estimator.set_hyperparameters(feature_dim=10, predictor_type='regressor', mini_batch_size=32)
def test_attach_without_hyperparameters(sagemaker_session):
returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
del returned_job_description["HyperParameters"]
mock_describe_training_job = Mock(
name="describe_training_job", return_value=returned_job_description
)
sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job
estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)
assert estimator.hyperparameters() == {}
def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
base_name = "foo"
estimator = Estimator(
image_name=IMAGE_NAME,
role=ROLE,
train_instance_count=INSTANCE_COUNT,
train_instance_type=INSTANCE_TYPE,
sagemaker_session=sagemaker_session,
base_job_name=base_name,
)
estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)
strategy = "MultiRecord"
assemble_with = "Line"
kms_key = "key"
accept = "text/csv"
max_concurrent_transforms = 1
max_payload = 6
env = {"FOO": "BAR"}
def train_model(session, data_location, hyperparameters):
account = session.boto_session.client('sts').get_caller_identity()['Account']
region = session.boto_session.region_name
image = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, ALGORITHM_NAME)
estimator = sagemaker.estimator.Estimator(image_name = image,
role = ROLE,
train_instance_count = 1,
train_instance_type = 'ml.c4.2xlarge',
output_path = "s3://{}/output".format(S3_BUCKET),
sagemaker_session = session,
hyperparameters = hyperparameters)
print("Starting train procedure ...")
estimator.fit(data_location)
print("Training done.")
return estimator