Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
num_entity_vectors=10,
vector_dim=100,
sagemaker_session=sagemaker_session,
)
record_set = prepare_record_set_from_local_files(
data_path, ipinsights.data_location, num_records, FEATURE_DIM, sagemaker_session
)
ipinsights.fit(records=record_set, job_name=job_name)
with timeout_and_delete_endpoint_by_name(job_name, sagemaker_session):
model = IPInsightsModel(
ipinsights.model_data, role="SageMakerRole", sagemaker_session=sagemaker_session
)
predictor = model.deploy(1, cpu_instance_type, endpoint_name=job_name)
assert isinstance(predictor, RealTimePredictor)
predict_input = [["user_1", "1.1.1.1"]]
result = predictor.predict(predict_input)
assert len(result["predictions"]) == 1
assert 0 > result["predictions"][0]["dot_product"] > -1 # We expect ~ -0.22
def test_predict_tensor_request_csv(sagemaker_session):
data = [6.4, 3.2, 0.5, 1.5]
tensor_proto = tf.make_tensor_proto(
values=np.asarray(data), shape=[1, len(data)], dtype=tf.float32
)
predictor = RealTimePredictor(
serializer=tf_csv_serializer,
deserializer=tf_json_deserializer,
sagemaker_session=sagemaker_session,
endpoint=ENDPOINT,
)
mock_response(
json.dumps(CLASSIFICATION_RESPONSE).encode("utf-8"), sagemaker_session, JSON_CONTENT_TYPE
)
result = predictor.predict(tensor_proto)
sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
Accept=JSON_CONTENT_TYPE,
Body="6.4,3.2,0.5,1.5",
ContentType=CSV_CONTENT_TYPE,
def test_delete_endpoint_only():
sagemaker_session = empty_sagemaker_session()
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
predictor.delete_endpoint(delete_endpoint_config=False)
sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
sagemaker_session.delete_endpoint_config.assert_not_called()
def test_predict_call_with_headers_and_csv():
sagemaker_session = ret_csv_sagemaker_session()
predictor = RealTimePredictor(
ENDPOINT, sagemaker_session, accept=CSV_CONTENT_TYPE, serializer=csv_serializer
)
data = [1, 2]
result = predictor.predict(data)
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
expected_request_args = {
"Accept": CSV_CONTENT_TYPE,
"Body": "1,2",
"ContentType": CSV_CONTENT_TYPE,
"EndpointName": ENDPOINT,
}
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_args
def test_predict_call_pass_through():
sagemaker_session = empty_sagemaker_session()
predictor = RealTimePredictor(ENDPOINT, sagemaker_session)
data = "untouched"
result = predictor.predict(data)
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
expected_request_args = {"Body": data, "EndpointName": ENDPOINT}
call_args, kwargs = sagemaker_session.sagemaker_runtime_client.invoke_endpoint.call_args
assert kwargs == expected_request_args
assert result == RETURN_VALUE
model_data=sparkml_model_data,
env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
sagemaker_session=sagemaker_session,
)
xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
xgb_model = Model(
model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session
)
model = PipelineModel(
models=[sparkml_model, xgb_model],
role="SageMakerRole",
sagemaker_session=sagemaker_session,
name=endpoint_name,
)
model.deploy(1, cpu_instance_type, endpoint_name=endpoint_name)
predictor = RealTimePredictor(
endpoint=endpoint_name,
sagemaker_session=sagemaker_session,
serializer=json_serializer,
content_type=CONTENT_TYPE_CSV,
accept=CONTENT_TYPE_CSV,
)
with open(VALID_DATA_PATH, "r") as f:
valid_data = f.read()
assert predictor.predict(valid_data) == "0.714013934135"
with open(INVALID_DATA_PATH, "r") as f:
invalid_data = f.read()
assert predictor.predict(invalid_data) is None
model.delete_model()
def test_delete_model_fail():
sagemaker_session = empty_sagemaker_session()
sagemaker_session.sagemaker_client.delete_model = Mock(
side_effect=Exception("Could not find model.")
)
expected_error_message = "One or more models cannot be deleted, please retry."
predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
with pytest.raises(Exception) as exception:
predictor.delete_model()
assert expected_error_message in str(exception.val)
def test_predict_call_with_headers_and_json():
sagemaker_session = json_sagemaker_session()
predictor = RealTimePredictor(
ENDPOINT,
sagemaker_session,
content_type="not/json",
accept="also/not-json",
serializer=json_serializer,
)
data = [1, 2]
result = predictor.predict(data)
assert sagemaker_session.sagemaker_runtime_client.invoke_endpoint.called
expected_request_args = {
"Accept": "also/not-json",
"Body": json.dumps(data),
"ContentType": "not/json",
from __future__ import absolute_import
import logging
from pkg_resources import parse_version
import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.mxnet.defaults import MXNET_VERSION
from sagemaker.predictor import RealTimePredictor, json_serializer, json_deserializer
logger = logging.getLogger("sagemaker")
class MXNetPredictor(RealTimePredictor):
"""A RealTimePredictor for inference against MXNet Endpoints.
This is able to serialize Python lists, dictionaries, and numpy arrays to
multidimensional tensors for MXNet inference.
"""
def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``MXNetPredictor``.
Args:
endpoint_name (str): The name of the endpoint to perform inference
on.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.
"""Placeholder docstring"""
from __future__ import absolute_import
import logging
import sagemaker
from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning
from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME
from sagemaker.predictor import RealTimePredictor
from sagemaker.tensorflow.defaults import TF_VERSION
from sagemaker.tensorflow.predictor import tf_json_serializer, tf_json_deserializer
logger = logging.getLogger("sagemaker")
class TensorFlowPredictor(RealTimePredictor):
"""A ``RealTimePredictor`` for inference against TensorFlow endpoint.
This is able to serialize Python lists, dictionaries, and numpy arrays to
multidimensional tensors for inference
"""
def __init__(self, endpoint_name, sagemaker_session=None):
"""Initialize an ``TensorFlowPredictor``.
Args:
endpoint_name (str): The name of the endpoint to perform inference
on.
sagemaker_session (sagemaker.session.Session): Session object which
manages interactions with Amazon SageMaker APIs and any other
AWS services needed. If not specified, the estimator creates one
using the default AWS configuration chain.