Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
num_entity_vectors=10,
vector_dim=100,
sagemaker_session=sagemaker_session,
)
record_set = prepare_record_set_from_local_files(
data_path, ipinsights.data_location, num_records, FEATURE_DIM, sagemaker_session
)
ipinsights.fit(records=record_set, job_name=job_name)
with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
model = IPInsightsModel(
ipinsights.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
assert isinstance(predictor, RealTimePredictor)
predict_input = [["user_1", "1.1.1.1"]]
result = predictor.predict(predict_input)
assert len(result["predictions"]) == 1
assert 0 > result["predictions"][0]["dot_product"] > -1 # We expect ~ -0.22def test_predict_tensor_request_csv(sagemaker_session):
data = [6.4, 3.2, 0.5, 1.5]
tensor_proto = tf.make_tensor_proto(
values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32
)
predictor = RealTimePredictor(
serializer=tf_csv_serializer,
deserializer=tf_json_deserializer,
sagemaker_session=sagemaker_session,
endpoint=ENDPOINT,
)
mock_response(
json.dumps(CLASSIFICATION_RESPONSE).encode("utf-8"), sagemaker_session, JSON_CONTENT_TYPE
)
result = predictor.predict(tensor_proto)
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
Accept=JSON_CONTENT_TYPE,
Body="6.4,3.2,0.5,1.5",
ContentType=CSV_CONTENT_TYPE,def test_delete_endpoint_only():
sagemaker_session = empty_sagemaker_session()
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
predictor.delete_endpoint(delete_endpoint_config=False)
sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
sagemaker_session.delete_endpoint_config.assert_not_called()def test_predict_call_with_headers_and_csv():
sagemaker_session = ret_csv_sagemaker_session()
predictor = RealTimePredictor(
ENDPOINT, sagemaker_session, accept=CSV_CONTENT_TYPE, serializer=csv_serializer
)
data = [1, 2]
result = predictor.predict(data)
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
expected_request_args = {
"Accept": CSV_CONTENT_TYPE,
"Body": "1,2",
"ContentType": CSV_CONTENT_TYPE,
"EndpointName": ENDPOINT,
}
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_argsdef test_predict_call_pass_through():
sagemaker_session = empty_sagemaker_session()
predictor = RealTimePredictor(ENDPOINT, sagemaker_session)
data = "untouched"
result = predictor.predict(data)
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
expected_request_args = {"Body": data, "EndpointName": ENDPOINT}
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_args
assert result == RETURN_VALUEmodel_data=sparkml_model_data,
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
sagemaker_session=sagemaker_session,
)
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
xgb_model = Model(
model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session
)
model = PipelineModel(
models=[sparkml_model, xgb_model],
role="SageMakerRole",
sagemaker_session=sagemaker_session,
name=endpoint_name,
)
model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor = RealTimePredictor(
endpoint=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=json_serializer,
content_type=CONTENT_TYPE_CSV,
accept=CONTENT_TYPE_CSV,
)
with open(VALID_DATA_PATH, "r") as f:
valid_data = f.read()
assert predictor.predict(valid_data) == "0.714013934135"
with open(INVALID_DATA_PATH, "r") as f:
invalid_data = f.read()
assert predictor.predict(invalid_data) is None
model.delete_model()def test_delete_model_fail():
sagemaker_session = empty_sagemaker_session()
sagemaker_session.sagemaker_client.delete_model = Mock(
side_effect=Exception("Could not find model.")
)
expected_error_message = "One or more models cannot be deleted, please retry."
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
with pytest.raises(Exception) as exception:
predictor.delete_model()
assert expected_error_message in str(exception.val)def test_predict_call_with_headers_and_json():
sagemaker_session = json_sagemaker_session()
predictor = RealTimePredictor(
ENDPOINT,
sagemaker_session,
content_type="not/json",
accept="also/not-json",
serializer=json_serializer,
)
data = [1, 2]
result = predictor.predict(data)
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
expected_request_args = {
"Accept": "also/not-json",
"Body": json.dumps(data),
"ContentType": "not/json",from __future__ import absolute_import
import logging
from pkg_resources import parse_version
import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
logger = logging.getLogger("sagemaker")
class MXNetPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against MXNet Endpoints.
This is able to serialize Python lists, dictionaries, and numpy arrays to
multidimensional tensors for MXNet inference.
"""
def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``MXNetPredictor``.
Args:
endpoint_name (str): The name of the endpoint to perform inference
on.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain."""Placeholder docstring"""
from __future__ import absolute_import
import logging
import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
logger = logging.getLogger("sagemaker")
class TensorFlowPredictor(RealTimePredictor):
"""A ``RealTimePredictor`` for inference against TensorFlow endpoint.
This is able to serialize Python lists, dictionaries, and numpy arrays to
multidimensional tensors for inference
"""
def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``TensorFlowPredictor``.
Args:
endpoint_name (str): The name of the endpoint to perform inference
on.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.