Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def assert_serializable(x: transform.Transformation):
t = fqname_for(x.__class__)
y = load_json(dump_json(x))
z = load_code(dump_code(x))
assert dump_json(x) == dump_json(
y
), f"Code serialization for transformer {t} does not work"
assert dump_code(x) == dump_code(
z
), f"JSON serialization for transformer {t} does not work"
)
compare_vals(len(bar02.x_list), len(bar02.x_list), len(bar03.x_list))
compare_vals(len(bar02.x_dict), len(bar02.x_dict), len(bar03.x_dict))
compare_vals(
len(bar02.input_fields),
len(bar02.input_fields),
len(bar03.input_fields),
)
compare_vals(bar02.x_list, bar02.x_list, bar03.x_list)
compare_vals(bar02.x_dict, bar02.x_dict, bar03.x_dict)
compare_vals(bar02.input_fields, bar02.input_fields, bar03.input_fields)
baz01 = Baz(a="0", b="9", c=Complex(x="1", y="2"), d="42")
baz02 = load_json(dump_json(baz01))
assert type(baz01) == type(baz02)
assert baz01 == baz02
for i in range(5)
]
x_dict = {
i: Foo(
b=random.uniform(0, B),
a=str(random.randint(0, A)),
c=Complex(
x=str(random.uniform(0, C)), y=str(random.uniform(0, C))
),
)
for i in range(6)
}
bar01 = Bar(x_list, input_fields=fields, x_dict=x_dict)
bar02 = load_code(dump_code(bar01))
bar03 = load_json(dump_json(bar02))
def compare_tpes(x, y, z, tpe):
assert tpe == type(x) == type(y) == type(z)
def compare_vals(x, y, z):
assert x == y == z
compare_tpes(bar02.x_list, bar02.x_list, bar03.x_list, tpe=list)
compare_tpes(bar02.x_dict, bar02.x_dict, bar03.x_dict, tpe=dict)
compare_tpes(
bar02.input_fields, bar02.input_fields, bar03.input_fields, tpe=list
)
compare_vals(len(bar02.x_list), len(bar02.x_list), len(bar03.x_list))
compare_vals(len(bar02.x_dict), len(bar02.x_dict), len(bar03.x_dict))
compare_vals(
Serializes a representable Gluon block.
Parameters
----------
rb
The block to export.
model_dir
The path where the model will be saved.
model_name
The name identifying the model.
epoch
The epoch number, which together with the `model_name` identifies the
model parameters.
"""
with (model_dir / f"{model_name}-network.json").open("w") as fp:
print(dump_json(rb), file=fp)
rb.save_parameters(str(model_dir / f"{model_name}-{epoch:04}.params"))
print(dump_json(self.input_transform), file=fp)
# FIXME: also needs to serialize the output_transform
# serialize all remaining constructor parameters
with (path / "parameters.json").open("w") as fp:
parameters = dict(
batch_size=self.batch_size,
prediction_length=self.prediction_length,
freq=self.freq,
ctx=self.ctx,
dtype=self.dtype,
forecast_generator=self.forecast_generator,
input_names=self.input_names,
)
print(dump_json(parameters), file=fp)
def _upload_estimator(self, locations, estimator):
logger.info("Uploading estimator config to s3.")
serialized = serde.dump_json(estimator)
with self._s3fs.open(locations.estimator_path, "w") as estimator_file:
estimator_file.write(serialized)