How to use the sagemaker.estimator.Estimator function in sagemaker

To help you get started, we’ve selected a few sagemaker examples, based on popular ways it is used in public projects.

Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.

github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_generic_to_fit_with_sagemaker_metrics_enabled(sagemaker_session):
    e = Estimator(
        IMAGE_NAME,
        ROLE,
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        output_path=OUTPUT_PATH,
        sagemaker_session=sagemaker_session,
        enable_sagemaker_metrics=True,
    )

    e.fit()

    sagemaker_session.train.assert_called_once()
    args = sagemaker_session.train.call_args[1]
    assert args["enable_sagemaker_metrics"]
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_generic_to_deploy_network_isolation(sagemaker_session):
    e = Estimator(
        IMAGE_NAME,
        ROLE,
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        output_path=OUTPUT_PATH,
        enable_network_isolation=True,
        sagemaker_session=sagemaker_session,
    )

    e.fit()
    e.deploy(INSTANCE_COUNT, INSTANCE_TYPE)

    sagemaker_session.create_model.assert_called_once()
    _, kwargs = sagemaker_session.create_model.call_args
    assert kwargs["enable_network_isolation"]
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_generic_create_model_vpc_config_override(sagemaker_session):
    vpc_config_a = {"Subnets": ["foo"], "SecurityGroupIds": ["bar"]}
    vpc_config_b = {"Subnets": ["foo", "bar"], "SecurityGroupIds": ["baz"]}

    e = Estimator(
        IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session
    )
    e.fit({"train": "s3://bucket/training-prefix"})
    assert e.get_vpc_config() is None
    assert e.create_model().vpc_config is None
    assert e.create_model(vpc_config_override=vpc_config_a).vpc_config == vpc_config_a
    assert e.create_model(vpc_config_override=None).vpc_config is None

    e.subnets = vpc_config_a["Subnets"]
    e.security_group_ids = vpc_config_a["SecurityGroupIds"]
    assert e.get_vpc_config() == vpc_config_a
    assert e.create_model().vpc_config == vpc_config_a
    assert e.create_model(vpc_config_override=vpc_config_b).vpc_config == vpc_config_b
    assert e.create_model(vpc_config_override=None).vpc_config is None

    with pytest.raises(ValueError):
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_generic_to_fit_with_network_isolation(sagemaker_session):
    e = Estimator(
        IMAGE_NAME,
        ROLE,
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        output_path=OUTPUT_PATH,
        sagemaker_session=sagemaker_session,
        enable_network_isolation=True,
    )

    e.fit()

    sagemaker_session.train.assert_called_once()
    args = sagemaker_session.train.call_args[1]
    assert args["enable_network_isolation"]
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_generic_deploy_accelerator_type(sagemaker_session):
    e = Estimator(
        IMAGE_NAME, ROLE, INSTANCE_COUNT, INSTANCE_TYPE, sagemaker_session=sagemaker_session
    )
    e.fit({"train": "s3://bucket/training-prefix"})
    e.deploy(INSTANCE_COUNT, INSTANCE_TYPE, ACCELERATOR_TYPE)

    args = e.sagemaker_session.endpoint_from_production_variants.call_args[1]
    print(args)
    assert args["name"].startswith(IMAGE_NAME)
    assert args["production_variants"][0]["AcceleratorType"] == ACCELERATOR_TYPE
    assert args["production_variants"][0]["InitialInstanceCount"] == INSTANCE_COUNT
    assert args["production_variants"][0]["InstanceType"] == INSTANCE_TYPE
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_fit_deploy_tags_in_estimator(sagemaker_session):
    tags = [{"Key": "TagtestKey", "Value": "TagtestValue"}]
    estimator = Estimator(
        IMAGE_NAME,
        ROLE,
        INSTANCE_COUNT,
        INSTANCE_TYPE,
        tags=tags,
        sagemaker_session=sagemaker_session,
    )

    estimator.fit()

    estimator.deploy(INSTANCE_COUNT, INSTANCE_TYPE)

    variant = [
        {
            "InstanceType": "c4.4xlarge",
            "VariantName": "AllTraffic",
github aws / aws-step-functions-data-science-sdk-python / tests / unit / test_pipeline.py View on Github external
def linear_learner_estimator():
    s3_output_location = 's3://sagemaker/models'
    sagemaker_session = MagicMock()
    sagemaker_session.boto_region_name = 'us-east-1'

    ll_estimator = sagemaker.estimator.Estimator(
        LINEAR_LEARNER_IMAGE,
        SAGEMAKER_EXECUTION_ROLE, 
        train_instance_count=1, 
        train_instance_type='ml.c4.xlarge',
        train_volume_size=20,
        train_max_run=3600,
        input_mode='File',
        output_path=s3_output_location,
        sagemaker_session=sagemaker_session
    )

    ll_estimator.debugger_hook_config = DebuggerHookConfig(
        s3_output_path='s3://sagemaker/models/debug'
    )

    ll_estimator.set_hyperparameters(feature_dim=10, predictor_type='regressor', mini_batch_size=32)
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_attach_without_hyperparameters(sagemaker_session):
    returned_job_description = RETURNED_JOB_DESCRIPTION.copy()
    del returned_job_description["HyperParameters"]

    mock_describe_training_job = Mock(
        name="describe_training_job", return_value=returned_job_description
    )
    sagemaker_session.sagemaker_client.describe_training_job = mock_describe_training_job

    estimator = Estimator.attach(training_job_name="job", sagemaker_session=sagemaker_session)

    assert estimator.hyperparameters() == {}
github aws / sagemaker-python-sdk / tests / unit / test_estimator.py View on Github external
def test_estimator_transformer_creation_with_optional_params(sagemaker_session):
    base_name = "foo"
    estimator = Estimator(
        image_name=IMAGE_NAME,
        role=ROLE,
        train_instance_count=INSTANCE_COUNT,
        train_instance_type=INSTANCE_TYPE,
        sagemaker_session=sagemaker_session,
        base_job_name=base_name,
    )
    estimator.latest_training_job = _TrainingJob(sagemaker_session, JOB_NAME)

    strategy = "MultiRecord"
    assemble_with = "Line"
    kms_key = "key"
    accept = "text/csv"
    max_concurrent_transforms = 1
    max_payload = 6
    env = {"FOO": "BAR"}
github awslabs / datawig-sagemaker / sagemaker / client.py View on Github external
def train_model(session, data_location, hyperparameters):
    account = session.boto_session.client('sts').get_caller_identity()['Account']
    region = session.boto_session.region_name
    image = '{}.dkr.ecr.{}.amazonaws.com/{}:latest'.format(account, region, ALGORITHM_NAME)

    estimator = sagemaker.estimator.Estimator(image_name = image,
                                            role = ROLE,
                                            train_instance_count = 1,
                                            train_instance_type = 'ml.c4.2xlarge',
                                            output_path = "s3://{}/output".format(S3_BUCKET),
                                            sagemaker_session = session,
                                            hyperparameters = hyperparameters)

    print("Starting train procedure ...")
    estimator.fit(data_location)
    print("Training done.")

    return estimator