Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
optvarnames = tr.tensor_names(collection=CollectionKeys.OPTIMIZER_VARIABLES)
assert len(optvarnames) == 5
for optvarname in optvarnames:
assert len(tr.tensor(optvarname).steps(ModeKeys.TRAIN)) == 7
for s in tr.tensor(optvarname).steps(ModeKeys.TRAIN):
assert tr.tensor(optvarname).value(s, mode=ModeKeys.TRAIN) is not None
assert len(tr.tensor(optvarname).steps(ModeKeys.EVAL)) == 0
assert len(tr.tensor(optvarname).steps(ModeKeys.PREDICT)) == 0
assert len(tr.tensor_names(collection=CollectionKeys.LOSSES)) == 1
loss_name = tr.tensor_names(collection=CollectionKeys.LOSSES)[0]
# loss is not in predict mode (so less 2)
# add one for end of epoch
assert len(tr.tensor(loss_name).steps(ModeKeys.TRAIN)) == 8
assert len(tr.tensor(loss_name).steps(ModeKeys.EVAL)) == 4
assert len(tr.tensor(loss_name).steps(ModeKeys.PREDICT)) == 0
assert len(tr.tensor(loss_name).steps()) == 12
metricnames = tr.tensor_names(collection=CollectionKeys.METRICS)
assert len(metricnames) == 3
hook.set_mode(ModeKeys.TRAIN)
for i in range(steps):
batch_size = 32
data, target = mx.random.randn(batch_size, 1, 28, 28), mx.random.randn(batch_size)
data = data.as_in_context(mx.cpu(0))
with autograd.record():
output = net(data)
loss = softmax_cross_entropy(output, target)
loss.backward()
# update parameters
trainer.step(batch_size)
# calculate training metrics
train_loss += loss.mean().asscalar()
hook.save_scalar("mx_after_train", 1, sm_metric=False)
hook.set_mode(ModeKeys.EVAL)
for i in range(steps):
batch_size = 32
data, target = mx.random.randn(batch_size, 1, 28, 28), mx.random.randn(batch_size)
data = data.as_in_context(mx.cpu(0))
val_output = net(data)
loss = softmax_cross_entropy(val_output, target)
dummy_step_creator(
trial_dir=path, global_step=i + 40, mode="EVAL", mode_step=i, worker_name="worker_0"
)
trial = create_trial(path)
num_workers = len(trial.workers())
assert num_workers == 1
assert trial.loaded_all_steps is True
all_steps = trial.steps(show_incomplete_steps=True)
completed_steps = trial.steps()
assert all_steps == [0, 10, 20, 30, 40, 50, 60, 70]
assert completed_steps == all_steps
assert trial.has_passed_step(30) == StepState.AVAILABLE
assert trial.has_passed_step(23, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
assert trial.has_passed_step(40, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
assert trial.has_passed_step(30, mode=ModeKeys.EVAL) == StepState.AVAILABLE
assert trial.has_passed_step(23, mode=ModeKeys.EVAL) == StepState.UNAVAILABLE
assert trial.has_passed_step(80) == StepState.UNAVAILABLE
assert trial.has_passed_step(80, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
assert trial.has_passed_step(80, mode=ModeKeys.EVAL) == StepState.UNAVAILABLE
assert trial.last_index_token == os.path.join(
path, "index/000000000/000000000070_worker_0.json"
)
assert trial.last_complete_step == 70
shutil.rmtree(path, ignore_errors=True)
trial_dir=path, global_step=i + 40, mode="EVAL", mode_step=i, worker_name="worker_0"
)
trial = create_trial(path)
num_workers = len(trial.workers())
assert num_workers == 1
assert trial.loaded_all_steps is False
all_steps = trial.steps(show_incomplete_steps=True)
completed_steps = trial.steps()
assert all_steps == [0, 10, 20, 30, 40, 50, 60, 70]
assert completed_steps == all_steps
assert trial.has_passed_step(30) == StepState.AVAILABLE
assert trial.has_passed_step(23, mode=ModeKeys.TRAIN) == StepState.UNAVAILABLE
assert trial.has_passed_step(40, mode=ModeKeys.TRAIN) == StepState.NOT_YET_AVAILABLE
assert trial.has_passed_step(30, mode=ModeKeys.EVAL) == StepState.AVAILABLE
assert trial.has_passed_step(23, mode=ModeKeys.EVAL) == StepState.UNAVAILABLE
assert trial.has_passed_step(80) == StepState.NOT_YET_AVAILABLE
assert trial.has_passed_step(80, mode=ModeKeys.TRAIN) == StepState.NOT_YET_AVAILABLE
assert trial.has_passed_step(80, mode=ModeKeys.EVAL) == StepState.NOT_YET_AVAILABLE
assert trial.last_index_token == os.path.join(
path, "index/000000000/000000000070_worker_0.json"
)
assert trial.last_complete_step == 70
shutil.rmtree(path, ignore_errors=True)
for s in tr.tensor(tensornames[0]).steps(ModeKeys.TRAIN):
for w in tr.tensor(tensornames[0]).workers(s, ModeKeys.TRAIN):
assert tr.tensor(tensornames[0]).value(s, worker=w, mode=ModeKeys.TRAIN) is not None
assert (
len(tr.tensor(tensornames[0]).workers(s, ModeKeys.TRAIN))
== strategy.num_replicas_in_sync
)
for tname in tr.tensor_names(collection="losses"):
for s in tr.tensor(tname).steps(ModeKeys.EVAL):
assert len(tr.tensor(tname).workers(s, ModeKeys.EVAL)) == 1
assert tr.tensor(tname).value(s, mode=ModeKeys.EVAL) is not None
if tname != tensornames[0]:
for s in tr.tensor(tname).steps(ModeKeys.TRAIN):
assert len(tr.tensor(tname).workers(s, ModeKeys.EVAL)) == 1
assert tr.tensor(tname).value(s, mode=ModeKeys.EVAL) is not None
def test_collection_defaults_json(out_dir, monkeypatch):
pre_test_clean_up()
monkeypatch.setenv(
CONFIG_FILE_PATH_ENV_STR,
"tests/tensorflow/hooks/test_json_configs/test_collection_defaults.json",
)
hook = SessionHook.create_from_json_file()
# Check save_intervals for each mode
assert hook.save_config.get_save_config(ModeKeys.TRAIN).save_interval == 2
assert hook.save_config.get_save_config(ModeKeys.EVAL).save_interval == 3
assert hook.save_config.get_save_config(ModeKeys.PREDICT).save_interval == 1
assert hook.save_config.get_save_config(ModeKeys.GLOBAL).save_interval == 1
# Check include_collections
assert "weights" in hook.include_collections and "losses" in hook.include_collections
assert len(hook.include_collections) == 4
# Check collection configurations for losses
assert (
hook.collection_manager.collections["losses"]
.save_config.get_save_config(ModeKeys.TRAIN)
.save_interval
== 2
)
assert (
hook.collection_manager.collections["losses"]
.save_config.get_save_config(ModeKeys.EVAL)
assert hook.save_config.get_save_config(ModeKeys.PREDICT).save_interval == 1
assert hook.save_config.get_save_config(ModeKeys.GLOBAL).save_interval == 1
# Check include_collections
assert "weights" in hook.include_collections and "losses" in hook.include_collections
assert len(hook.include_collections) == 4
# Check collection configurations for losses
assert (
hook.collection_manager.collections["losses"]
.save_config.get_save_config(ModeKeys.TRAIN)
.save_interval
== 2
)
assert (
hook.collection_manager.collections["losses"]
.save_config.get_save_config(ModeKeys.EVAL)
.save_interval
== 4
)
assert (
hook.collection_manager.collections["losses"]
.save_config.get_save_config(ModeKeys.PREDICT)
.save_interval
== 1
)
assert (
hook.collection_manager.collections["losses"]
.save_config.get_save_config(ModeKeys.GLOBAL)
.save_interval
== 5
)
# Check collection configuration for weights
def _get_collections_to_save_for_step(self) -> Set["Collection"]:
if self._collections_to_save_for_step is None:
self._assert_prep()
self._collections_to_save_for_step = set()
for coll in self._get_all_collections_to_save():
if self.mode in [ModeKeys.EVAL, ModeKeys.PREDICT]:
if coll.name in [CollectionKeys.GRADIENTS, CollectionKeys.OPTIMIZER_VARIABLES]:
continue
if coll.save_config.should_save_step(self.mode, self.mode_steps[self.mode]):
self._collections_to_save_for_step.add(coll)
if self._collections_to_save_for_step:
if self.mode == ModeKeys.GLOBAL:
step_str = f"for step {self.step}"
else:
step_str = f"for step {self.mode_steps[self.mode]} of mode {self.mode.name}"
self.logger.debug(
f"Saving the collections "
f"{', '.join([x.name for x in self._collections_to_save_for_step])} {step_str}"
)
return self._collections_to_save_for_step
def get_keras_mode(mode):
# Should never be called in TF 1.13 where this is not available
from tensorflow.python.keras.utils.mode_keys import ModeKeys as KerasModeKeys
if mode == ModeKeys.TRAIN:
return KerasModeKeys.TRAIN
elif mode == ModeKeys.EVAL:
return KerasModeKeys.TEST
elif mode == ModeKeys.PREDICT:
return KerasModeKeys.PREDICT
def _get_exec_function(self, mode):
if self.distribution_strategy in [
TFDistributionStrategy.NONE,
TFDistributionStrategy.HOROVOD,
]:
if mode == ModeKeys.TRAIN:
x = self.model.train_function
elif mode == ModeKeys.EVAL:
x = self.model.test_function
elif mode == ModeKeys.PREDICT:
x = self.model.predict_function
else:
raise NotImplementedError
else:
x = self._get_distributed_model(mode)._distributed_function
return x