Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def assert_serializable(x: transform.Transformation):
t = fqname_for(x.__class__)
y = load_json(dump_json(x))
z = load_code(dump_code(x))
assert dump_json(x) == dump_json(
y
), f"Code serialization for transformer {t} does not work"
assert dump_code(x) == dump_code(
z
), f"JSON serialization for transformer {t} does not work"
def run_train_and_test(
env: TrainEnv, forecaster_type: Type[Union[Estimator, Predictor]]
) -> None:
check_gpu_support()
forecaster_fq_name = fqname_for(forecaster_type)
forecaster_version = forecaster_type.__version__
logger.info(f"Using gluonts v{gluonts.__version__}")
logger.info(f"Using forecaster {forecaster_fq_name} v{forecaster_version}")
forecaster = forecaster_type.from_hyperparameters(**env.hyperparameters)
logger.info(
f"The forecaster can be reconstructed with the following expression: "
f"{dump_code(forecaster)}"
)
logger.info(
"Using the following data channels: "
f"{', '.join(name for name in ['train', 'validation', 'test'] if name in env.datasets)}"
)
def encode_pydantic_model(v: BaseModel) -> Any:
"""
Specializes :func:`encode` for invocations where ``v`` is an instance of
the :class:`~BaseModel` class.
"""
return {
"__kind__": kind_inst,
"class": fqname_for(v.__class__),
"kwargs": encode(v.__dict__),
}
def make_gunicorn_app(
env: Optional[ServeEnv],
forecaster_type: Optional[Type[Union[Estimator, Predictor]]],
settings: Settings,
) -> Application:
check_gpu_support()
if forecaster_type is not None:
logger.info(f"Using dynamic predictor factory")
ctor = forecaster_type.from_hyperparameters
forecaster_fq_name = fqname_for(forecaster_type)
forecaster_version = forecaster_type.__version__
def predictor_factory(request) -> Predictor:
return ctor(**request["configuration"])
else:
logger.info(f"Using static predictor factory")
assert env is not None
predictor = Predictor.deserialize(env.path.model)
forecaster_fq_name = fqname_for(type(predictor))
forecaster_version = predictor.__version__
def predictor_factory(request) -> Predictor:
return predictor
def encode_path(v: PurePath) -> Any:
"""
Specializes :func:`encode` for invocations where ``v`` is an instance of
the :class:`~PurePath` class.
"""
return {
"__kind__": kind_inst,
"class": fqname_for(v.__class__),
"args": encode([str(v)]),
}
def from_hyperparameters(cls, **hyperparameters) -> "GluonEstimator":
Model = getattr(cls.__init__, "Model", None)
if not Model:
raise AttributeError(
f"Cannot find attribute Model attached to the "
f"{fqname_for(cls)}. Most probably you have forgotten to mark "
f"the class constructor as @validated()."
)
try:
trainer = from_hyperparameters(Trainer, **hyperparameters)
return cls(
**Model(**{**hyperparameters, "trainer": trainer}).__dict__
)
except ValidationError as e:
raise GluonTSHyperparametersError from e
}
if isinstance(v, (list, set, tuple)):
return list(map(encode, v))
if isinstance(v, dict):
return {k: encode(v) for k, v in v.items()}
if isinstance(v, type):
return {"__kind__": kind_type, "class": fqname_for(v)}
if hasattr(v, "__getnewargs_ex__"):
args, kwargs = v.__getnewargs_ex__() # mypy: ignore
return {
"__kind__": kind_inst,
"class": fqname_for(v.__class__),
"args": encode(args),
"kwargs": encode(kwargs),
}
raise RuntimeError(bad_type_msg.format(fqname_for(v.__class__)))
if isinstance(v, dict):
return {k: encode(v) for k, v in v.items()}
if isinstance(v, type):
return {"__kind__": kind_type, "class": fqname_for(v)}
if hasattr(v, "__getnewargs_ex__"):
args, kwargs = v.__getnewargs_ex__() # mypy: ignore
return {
"__kind__": kind_inst,
"class": fqname_for(v.__class__),
"args": encode(args),
"kwargs": encode(kwargs),
}
raise RuntimeError(bad_type_msg.format(fqname_for(v.__class__)))
def encode_np_dtype(v: np.dtype) -> Any:
"""
Specializes :func:`encode` for invocations where ``v`` is an instance of
the :class:`~mxnet.Context` class.
"""
return {
"__kind__": kind_inst,
"class": fqname_for(v.__class__),
"args": encode([v.name]),
}
ctor = forecaster_type.from_hyperparameters
forecaster_fq_name = fqname_for(forecaster_type)
forecaster_version = forecaster_type.__version__
def predictor_factory(request) -> Predictor:
return ctor(**request["configuration"])
else:
logger.info(f"Using static predictor factory")
assert env is not None
predictor = Predictor.deserialize(env.path.model)
forecaster_fq_name = fqname_for(type(predictor))
forecaster_version = predictor.__version__
def predictor_factory(request) -> Predictor:
return predictor
logger.info(f"Using gluonts v{gluonts.__version__}")
logger.info(f"Using forecaster {forecaster_fq_name} v{forecaster_version}")
execution_params = {
"MaxConcurrentTransforms": settings.number_of_workers,
"BatchStrategy": settings.sagemaker_batch_strategy,
"MaxPayloadInMB": settings.sagemaker_max_payload_in_mb,
}
flask_app = make_flask_app(predictor_factory, execution_params)