Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_tensorboard_writer(caplog):
"""Unit test of log_writer."""
caplog.set_level(logging.INFO)
emmental.Meta.reset()
emmental.init()
log_writer = TensorBoardWriter()
log_writer.add_config(emmental.Meta.config)
log_writer.add_scalar(name="step 1", value=0.1, step=1)
log_writer.add_scalar(name="step 2", value=0.2, step=2)
config_filename = "config.yaml"
log_writer.write_config(config_filename)
# Test config
with open(os.path.join(emmental.Meta.log_path, config_filename), "r") as f:
config = yaml.load(f, Loader=yaml.FullLoader)
assert config["meta_config"]["verbose"] is True
assert config["logging_config"]["counter_unit"] == "epoch"
assert config["logging_config"]["checkpointing"] is False
log_writer.write_log()
emmental.Meta.update_config(
path="tests/shared", filename="emmental-test-config.yaml"
)
assert Meta.config["meta_config"] == {
"seed": 1,
"verbose": False,
"log_path": "tests",
"use_exact_log_path": False,
}
# Test unable to find config file
Meta.reset()
emmental.init(dirpath)
emmental.Meta.update_config(path=os.path.dirname(__file__))
assert Meta.config["meta_config"] == {
"seed": None,
"verbose": True,
"log_path": "logs",
"use_exact_log_path": False,
}
# Remove the temp folder
shutil.rmtree(dirpath)
def test_meta(caplog):
"""Unit test of meta."""
caplog.set_level(logging.INFO)
dirpath = "temp_test_meta_log_folder"
Meta.reset()
emmental.init(dirpath)
# Check the log folder is created correctly
assert os.path.isdir(dirpath) is True
assert Meta.log_path.startswith(dirpath) is True
# Check the config is created
assert isinstance(Meta.config, dict) is True
assert Meta.config["meta_config"] == {
"seed": None,
"verbose": True,
"log_path": "logs",
"use_exact_log_path": False,
}
emmental.Meta.update_config(
path="tests/shared", filename="emmental-test-config.yaml"
)
assert Meta.config["meta_config"] == {
"seed": 1,
"verbose": False,
"log_path": "tests",
"use_exact_log_path": False,
}
"checkpoint_all": True,
"clear_intermediate_checkpoints": True,
"clear_all_checkpoints": False,
},
},
}
# Test default and default args are the same
dirpath = "temp_parse_args"
Meta.reset()
emmental.init(dirpath)
parser = parse_args()
args = parser.parse_args([])
config1 = parse_args_to_config(args)
config2 = emmental.Meta.config
del config2["learner_config"]["global_evaluation_metric_dict"]
del config2["learner_config"]["optimizer_config"]["parameters"]
assert config1 == config2
shutil.rmtree(dirpath)
# Test different checkpoint_metric
dirpath = "temp_parse_args"
Meta.reset()
emmental.init(
log_dir=dirpath,
config={
"logging_config": {
"checkpointer_config": {
"checkpoint_metric": {"model/valid/all/accuracy": "max"}
}
}
},
)
assert emmental.Meta.config == {
"meta_config": {
"seed": None,
"verbose": True,
"log_path": "logs",
"use_exact_log_path": False,
},
"data_config": {"min_data_len": 0, "max_data_len": 0},
"model_config": {"model_path": None, "device": 0, "dataparallel": True},
"learner_config": {
"fp16": False,
"n_epochs": 1,
"train_split": ["train"],
"valid_split": ["valid"],
"test_split": ["test"],
"ignore_index": None,
"global_evaluation_metric_dict": None,
def _set_warmup_scheduler(self, model: EmmentalModel) -> None:
"""Set warmup learning rate scheduler for learning process.
Args:
model: The model to set up warmup scheduler.
"""
self.warmup_steps = 0
if Meta.config["learner_config"]["lr_scheduler_config"]["warmup_steps"]:
warmup_steps = Meta.config["learner_config"]["lr_scheduler_config"][
"warmup_steps"
]
if warmup_steps < 0:
raise ValueError("warmup_steps much greater or equal than 0.")
warmup_unit = Meta.config["learner_config"]["lr_scheduler_config"][
"warmup_unit"
]
if warmup_unit == "epoch":
self.warmup_steps = int(warmup_steps * self.n_batches_per_epoch)
elif warmup_unit == "batch":
self.warmup_steps = int(warmup_steps)
else:
raise ValueError(
f"warmup_unit must be 'batch' or 'epoch', but {warmup_unit} found."
)
if self.counter_unit not in ["sample", "batch", "epoch"]:
raise ValueError(f"Unrecognized unit: {self.counter_unit}")
# Set up evaluation frequency
self.evaluation_freq = Meta.config["logging_config"]["evaluation_freq"]
if Meta.config["meta_config"]["verbose"]:
logger.info(f"Evaluating every {self.evaluation_freq} {self.counter_unit}.")
if Meta.config["logging_config"]["checkpointing"]:
self.checkpointing = True
# Set up checkpointing frequency
self.checkpointing_freq = int(
Meta.config["logging_config"]["checkpointer_config"]["checkpoint_freq"]
)
if Meta.config["meta_config"]["verbose"]:
logger.info(
f"Checkpointing every "
f"{self.checkpointing_freq * self.evaluation_freq} "
f"{self.counter_unit}."
)
# Set up checkpointer
self.checkpointer = Checkpointer()
else:
self.checkpointing = False
if Meta.config["meta_config"]["verbose"]:
logger.info("No checkpointing.")
# Set up number of samples passed since last evaluation/checkpointing and
# total number of samples passed since learning process
self.sample_count: int = 0