Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def sagemaker_session(region_name=DEFAULT_REGION): # type: (str) -> sagemaker.Session
return sagemaker.Session(boto3.Session(region_name=region_name))
def test_get_caller_identity_arn_from_describe_notebook_instance(boto_session):
sess = Session(boto_session)
expected_role = "arn:aws:iam::369233609183:role/service-role/SageMakerRole-20171129T072388"
sess.sagemaker_client.describe_notebook_instance.return_value = {"RoleArn": expected_role}
actual = sess.get_caller_identity_arn()
assert actual == expected_role
sess.sagemaker_client.describe_notebook_instance.assert_called_once_with(
NotebookInstanceName="SageMakerInstance"
)
def test_process(boto_session):
session = Session(boto_session)
process_request_args = {
"inputs": [
{
"InputName": "input-1",
"S3Input": {
"S3Uri": "mocked_s3_uri_from_upload_data",
"LocalPath": "/container/path/",
"S3DataType": "Archive",
"S3InputMode": "File",
"S3DownloadMode": "Continuous",
"S3DataDistributionType": "FullyReplicated",
"S3CompressionType": "None",
},
},
{
def test_delete_model(boto_session):
sess = Session(boto_session)
model_name = "my_model"
sess.delete_model(model_name)
boto_session.client().delete_model.assert_called_with(ModelName=model_name)
def test_get_caller_identity_arn_from_an_user(boto_session):
sess = Session(boto_session)
arn = "arn:aws:iam::369233609183:user/mia"
sess.boto_session.client("sts", endpoint_url=STS_ENDPOINT).get_caller_identity.return_value = {
"Arn": arn
}
sess.boto_session.client("iam").get_role.return_value = {"Role": {"Arn": arn}}
actual = sess.get_caller_identity_arn()
assert actual == "arn:aws:iam::369233609183:user/mia"
def sagemaker_session():
boto_mock = Mock(name="boto_session")
ims = sagemaker.Session(boto_session=boto_mock)
ims.default_bucket = Mock(name="default_bucket", return_value=BUCKET_NAME)
return ims
def sagemaker_session():
boto_mock = Mock(name="boto_session", region_name=REGION)
ims = sagemaker.Session(sagemaker_client=Mock(name="sagemaker_client"), boto_session=boto_mock)
ims.sagemaker_client.describe_training_job = Mock(
name="describe_training_job", return_value=TRAINING_JOB_RESPONSE
)
ims.endpoint_from_model_data = Mock(
"endpoint_from_model_data", return_value=ENDPOINT_FROM_MODEL_RETURNED_NAME
)
return ims
def sagemaker_session_complete():
boto_mock = Mock(name="boto_session")
boto_mock.client("logs").describe_log_streams.return_value = DEFAULT_LOG_STREAMS
boto_mock.client("logs").get_log_events.side_effect = DEFAULT_LOG_EVENTS
ims = sagemaker.Session(boto_session=boto_mock, sagemaker_client=Mock())
ims.sagemaker_client.describe_training_job.return_value = COMPLETED_DESCRIBE_JOB_RESULT
ims.sagemaker_client.describe_transform_job.return_value = (
COMPLETED_DESCRIBE_TRANSFORM_JOB_RESULT
)
return ims
def deploy_endpoint(session, client, endpoint_name, setting, pytorch):
sagemaker_session = sagemaker.Session(
boto_session=session,
sagemaker_client=client)
conf = yaml.load(open(setting))
model_args = conf['model']
model_args['sagemaker_session'] = sagemaker_session
model_args['name'] = endpoint_name + '-model-' + dt.now().strftime('%y%m%d%H%M')
if pytorch:
model = PyTorchModel(**model_args)
else:
model = ChainerModel(**model_args)
deploy_args = conf['deploy']
deploy_args['endpoint_name'] = endpoint_name
model.deploy(**deploy_args)
def train(source_dir, data_path='doodle/data', training_steps=20000, evaluation_steps=2000,
train_instance_type='local', train_instance_count=1, run_tensorboard_locally=True,
uid=None, role=None, bucket=None, profile_name=None):
assert os.path.exists(source_dir)
boto_session = boto3.Session(profile_name=profile_name)
session = sagemaker.Session(boto_session=boto_session)
role = role if role is not None else sagemaker.get_execution_role()
bucket = bucket if bucket is not None else session.default_bucket()
uid = uid if uid is not None else uuid4()
logger.debug(session.get_caller_identity_arn())
role = session.expand_role(role)
params = {
'train_tfrecord_file': 'train.tfr',
'test_tfrecord_file' : 'test.tfr',
'samples_per_epoch' : 700000,
'save_summary_steps' : 100,
}
output_path = 's3://{}/doodle/model/{}/export'.format(bucket, uid)
checkpoint_path = 's3://{}/doodle/model/{}/ckpt' .format(bucket, uid)
code_location = 's3://{}/doodle/model/{}/source'.format(bucket, uid)