Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
serialize_fn_list = [lambda x: x, lambda x: load_json(dump_json(x))]
def assert_serializable(x: transform.Transformation):
t = fqname_for(x.__class__)
y = load_json(dump_json(x))
z = load_code(dump_code(x))
assert dump_json(x) == dump_json(
y
), f"Code serialization for transformer {t} does not work"
assert dump_code(x) == dump_code(
z
), f"JSON serialization for transformer {t} does not work"
serialize_fn_list = [lambda x: x, lambda x: load_json(dump_json(x))]
def test_json_serialization(e) -> None:
expected, actual = e, serde.load_json(serde.dump_json(e))
assert check_equality(expected, actual)
serialize_fn_list = [lambda x: x, lambda x: load_json(dump_json(x))]
lambda x: serde.load_json(serde.dump_json(x)),
lambda x: serde.load_binary(serde.dump_binary(x)),
serialize_fn_list = [lambda x: x, lambda x: load_json(dump_json(x))]
def deserialize(
cls, path: Path, ctx: Optional[mx.Context] = None
) -> "SymbolBlockPredictor":
ctx = ctx if ctx is not None else get_mxnet_context()
with mx.Context(ctx):
# deserialize constructor parameters
with (path / "parameters.json").open("r") as fp:
parameters = load_json(fp.read())
parameters["ctx"] = ctx
# deserialize transformation chain
with (path / "input_transform.json").open("r") as fp:
transform = load_json(fp.read())
# deserialize prediction network
num_inputs = len(parameters["input_names"])
prediction_net = import_symb_block(
num_inputs, path, "prediction_net"
)
return SymbolBlockPredictor(
input_transform=transform,
prediction_net=prediction_net,
**parameters,
)
----------
model_dir
The path where the model is saved.
model_name
The name identifying the model.
epoch
The epoch number, which together with the `model_name` identifies the
model parameters.
Returns
-------
mx.gluon.HybridBlock:
The deserialized block.
"""
with (model_dir / f"{model_name}-network.json").open("r") as fp:
rb = cast(mx.gluon.HybridBlock, load_json(fp.read()))
rb.load_parameters(
str(model_dir / f"{model_name}-{epoch:04}.params"),
ctx=mx.current_context(),
allow_missing=False,
ignore_extra=False,
)
return rb
def deserialize(
cls, path: Path, ctx: Optional[mx.Context] = None
) -> "SymbolBlockPredictor":
ctx = ctx if ctx is not None else get_mxnet_context()
with mx.Context(ctx):
# deserialize constructor parameters
with (path / "parameters.json").open("r") as fp:
parameters = load_json(fp.read())
parameters["ctx"] = ctx
# deserialize transformation chain
with (path / "input_transform.json").open("r") as fp:
transform = load_json(fp.read())
# deserialize prediction network
num_inputs = len(parameters["input_names"])
prediction_net = import_symb_block(
num_inputs, path, "prediction_net"
)
return SymbolBlockPredictor(
input_transform=transform,
prediction_net=prediction_net,