Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def _test_dist_operations(sagemaker_session, image_uri, instance_type, dist_backend, train_instance_count=3):
with timeout(minutes=DEFAULT_TIMEOUT):
pytorch = PyTorch(entry_point=dist_operations_path,
role='SageMakerRole',
train_instance_count=train_instance_count,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
image_name=image_uri,
hyperparameters={'backend': dist_backend})
pytorch.sagemaker_session.default_bucket()
fake_input = pytorch.sagemaker_session.upload_data(path=dist_operations_path,
key_prefix='pytorch/distributed_operations')
job_name = utils.unique_name_from_base('test-pytorch-dist-ops')
pytorch.fit({'required_argument': fake_input}, job_name=job_name)
def test_mnist_gpu(sagemaker_session, image_uri, dist_gpu_backend):
with timeout(minutes=DEFAULT_TIMEOUT):
pytorch = PyTorch(entry_point=mnist_script,
role='SageMakerRole',
train_instance_count=2,
image_name=image_uri,
train_instance_type=MULTI_GPU_INSTANCE,
sagemaker_session=sagemaker_session,
hyperparameters={'backend': dist_gpu_backend})
training_input = sagemaker_session.upload_data(path=os.path.join(data_dir, 'training'),
key_prefix='pytorch/mnist')
job_name = utils.unique_name_from_base('test-pytorch-dist-ops')
pytorch.fit({'training': training_input}, job_name=job_name)
def _test_dgl_training(sagemaker_session, ecr_image, instance_type):
dgl = PyTorch(entry_point=DGL_SCRIPT_PATH,
role='SageMakerRole',
train_instance_count=1,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
image_name=ecr_image)
with timeout(minutes=DEFAULT_TIMEOUT):
job_name = utils.unique_name_from_base('test-pytorch-dgl-image')
dgl.fit(job_name=job_name)
def _test_mnist_distributed(sagemaker_session, image_uri, instance_type, dist_backend):
with timeout(minutes=DEFAULT_TIMEOUT):
pytorch = PyTorch(entry_point=mnist_script,
role='SageMakerRole',
train_instance_count=2,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
image_name=image_uri,
debugger_hook_config=False,
hyperparameters={'backend': dist_backend, 'epochs': 2})
training_input = pytorch.sagemaker_session.upload_data(path=training_dir,
key_prefix='pytorch/mnist')
job_name = utils.unique_name_from_base('test-pytorch-mnist')
pytorch.fit({'training': training_input}, job_name=job_name)
def _test_dist_operations(sagemaker_session, image_uri, instance_type, dist_backend, train_instance_count=3):
with timeout(minutes=DEFAULT_TIMEOUT):
pytorch = PyTorch(entry_point=dist_operations_path,
role='SageMakerRole',
train_instance_count=train_instance_count,
train_instance_type=instance_type,
sagemaker_session=sagemaker_session,
image_name=image_uri,
debugger_hook_config=False,
hyperparameters={'backend': dist_backend})
pytorch.sagemaker_session.default_bucket()
fake_input = pytorch.sagemaker_session.upload_data(path=dist_operations_path,
key_prefix='pytorch/distributed_operations')
job_name = utils.unique_name_from_base('test-pytorch-dist-ops')
pytorch.fit({'required_argument': fake_input}, job_name=job_name)
def _pytorch_estimator(
sagemaker_session,
framework_version=defaults.PYTORCH_VERSION,
train_instance_type=None,
base_job_name=None,
**kwargs
):
return PyTorch(
entry_point=SCRIPT_PATH,
framework_version=framework_version,
py_version=PYTHON_VERSION,
role=ROLE,
sagemaker_session=sagemaker_session,
train_instance_count=INSTANCE_COUNT,
train_instance_type=train_instance_type if train_instance_type else INSTANCE_TYPE,
base_job_name=base_job_name,
**kwargs
)
def test_dist_operations_path_cpu(image_uri, dist_cpu_backend, sagemaker_local_session, tmpdir):
estimator = PyTorch(entry_point=dist_operations_path,
role=ROLE,
image_name=image_uri,
train_instance_count=2,
train_instance_type='local',
sagemaker_session=sagemaker_local_session,
hyperparameters={'backend': dist_cpu_backend},
output_path='file://{}'.format(tmpdir))
_train_and_assert_success(estimator, str(tmpdir))
# print(type(tensorboard_output_config))
# print(isinstance(tensorboard_output_config, DebuggerHookConfig))
# pprint(tensorboard_output_config.__dict__)
train_data = sagemaker.session.s3_input(
s3_data=f's3://{args.s3_bucket}/{args.data_prefix}/train',
distribution='FullyReplicated',
s3_data_type='S3Prefix')
test_data = sagemaker.session.s3_input(
s3_data=f's3://{args.s3_bucket}/{args.data_prefix}/val',
distribution='FullyReplicated',
s3_data_type='S3Prefix')
job = PyTorch(entry_point='sm-entry.py',
source_dir='.',
framework_version='1.4.0',
train_instance_count=1,
train_instance_type=args.instance,
hyperparameters=params,
role=EXECUTION_ROLE_ARN,
output_path=f's3://{args.s3_bucket}/jobs/{job_name}/output-path',
base_job_name=args.main_command,
code_location=f's3://{args.s3_bucket}/jobs/{job_name}'
# ,checkpoint_s3_uri=f's3://{S3_BUCKET}/jobs/{job_name}/checkpoints'
,tensorboard_output_config=tensorboard_output_config
,train_use_spot_instances=args.use_spots
# ,train_max_wait=ONE_HOUR
)
job.fit(inputs={'train': train_data, 'test': test_data}, wait=args.wait_for_job)