Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_call_fit(base_fit, sagemaker_session):
pca = PCA(base_job_name="pca", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
pca.fit(data, MINI_BATCH_SIZE)
base_fit.assert_called_once()
assert len(base_fit.call_args[0]) == 2
assert base_fit.call_args[0][0] == data
assert base_fit.call_args[0][1] == MINI_BATCH_SIZE
def test_model_image(sagemaker_session):
randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
randomcutforest.fit(data, MINI_BATCH_SIZE)
model = randomcutforest.create_model()
assert model.image == registry(REGION, "randomcutforest") + "/randomcutforest:1"
def test_predictor_type(sagemaker_session):
randomcutforest = RandomCutForest(sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
randomcutforest.fit(data, MINI_BATCH_SIZE)
model = randomcutforest.create_model()
predictor = model.deploy(1, TRAIN_INSTANCE_TYPE)
assert isinstance(predictor, RandomCutForestPredictor)
def test_call_fit_none_mini_batch_size(sagemaker_session):
ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
ntm.fit(data)
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
kmeans = KMeans(base_job_name="kmeans", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
kmeans._prepare_for_training(data)
assert kmeans.mini_batch_size == 5000
def test_prepare_for_training_wrong_value_upper_mini_batch_size(sagemaker_session):
ntm = NTM(base_job_name="ntm", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
with pytest.raises(ValueError):
ntm._prepare_for_training(data, 10001)
def test_prepare_for_training_wrong_type_mini_batch_size(sagemaker_session):
ipinsights = IPInsights(
base_job_name="ipinsights", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS
)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
with pytest.raises((TypeError, ValueError)):
ipinsights._prepare_for_training(data, "some")
def test_prepare_for_training_no_mini_batch_size(sagemaker_session):
randomcutforest = RandomCutForest(
base_job_name="randomcutforest", sagemaker_session=sagemaker_session, **ALL_REQ_ARGS
)
data = RecordSet(
"s3://{}/{}".format(BUCKET_NAME, PREFIX),
num_records=1,
feature_dim=FEATURE_DIM,
channel="train",
)
randomcutforest._prepare_for_training(data)
assert randomcutforest.mini_batch_size == MINI_BATCH_SIZE
def test_format_inputs_to_input_config_list_not_all_records():
records = RecordSet(s3_data=BUCKET_NAME, num_records=1, feature_dim=1)
inputs = [records, "mock"]
with pytest.raises(ValueError) as ex:
_Job._format_inputs_to_input_config(inputs)
assert "List compatible only with RecordSets or FileSystemRecordSets." in str(ex)
* (sagemaker.amazon.amazon_estimator.RecordSet) - A collection of
Amazon :class:~`Record` objects serialized and stored in S3. For
use with an estimator for an Amazon algorithm.
* (list[sagemaker.amazon.amazon_estimator.RecordSet]) - A list of
:class:~`sagemaker.amazon.amazon_estimator.RecordSet` objects,
where each instance is a different channel of training data.
mini_batch_size:
"""
if isinstance(inputs, list):
for record in inputs:
if isinstance(record, amazon_estimator.RecordSet) and record.channel == "train":
estimator.feature_dim = record.feature_dim
break
elif isinstance(inputs, amazon_estimator.RecordSet):
estimator.feature_dim = inputs.feature_dim
else:
raise TypeError("Training data must be represented in RecordSet or list of RecordSets")
estimator.mini_batch_size = mini_batch_size