Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def save_load_from_file(conf: Any, resolve: bool, expected: Any) -> None:
if expected is None:
expected = conf
try:
with tempfile.NamedTemporaryFile(
mode="wt", delete=False, encoding="utf-8"
) as fp:
OmegaConf.save(conf, fp.file, resolve=resolve) # type: ignore
with io.open(os.path.abspath(fp.name), "rt", encoding="utf-8") as handle:
c2 = OmegaConf.load(handle)
assert c2 == expected
finally:
os.unlink(fp.name)
def test_load_empty_file(tmpdir: str) -> None:
empty = Path(tmpdir) / "test.yaml"
empty.touch()
assert OmegaConf.load(empty) == {}
def test_load_duplicate_keys_sub() -> None:
from yaml.constructor import ConstructorError
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
content = """
a:
b: 1
c: 2
b: 3
"""
fp.write(content.encode("utf-8"))
with pytest.raises(ConstructorError):
OmegaConf.load(fp.name)
finally:
os.unlink(fp.name)
def test_load_duplicate_keys_top() -> None:
from yaml.constructor import ConstructorError
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
content = """
a:
b: 1
a:
b: 2
"""
fp.write(content.encode("utf-8"))
with pytest.raises(ConstructorError):
OmegaConf.load(fp.name)
finally:
os.unlink(fp.name)
def save_load_from_filename(
conf: Any, resolve: bool, expected: Any, file_class: Type[Any]
) -> None:
if expected is None:
expected = conf
# note that delete=False here is a work around windows incompetence.
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
filepath = file_class(fp.name)
OmegaConf.save(conf, filepath, resolve=resolve)
c2 = OmegaConf.load(filepath)
assert c2 == expected
finally:
os.unlink(fp.name)
def save_load__from_filename_deprecated(conf: DictConfig) -> None:
# note that delete=False here is a work around windows incompetence.
try:
with tempfile.NamedTemporaryFile(delete=False) as fp:
conf.save(fp.name)
c2 = OmegaConf.load(fp.name)
assert conf == c2
finally:
os.unlink(fp.name)
def read_config_file(config_id: str = "sample_config") -> ConfigType:
"""Read a config file
Args:
config_id (str, optional): Id of the config file to read.
Defaults to "sample_config".
Returns:
ConfigType: Config object
"""
project_root = str(Path(__file__).resolve()).split("/codes")[0]
config_name = "{}.yaml".format(config_id)
config = OmegaConf.load(os.path.join(project_root, "config", config_name))
assert isinstance(config, DictConfig)
return config
if search_path is not None:
fullpath = "{}/{}".format(search_path.path, filename)
is_pkg = search_path.path.startswith("pkg://")
if is_pkg:
fullpath = fullpath[len("pkg://") :]
module_name, resource_name = ConfigLoader._split_module_and_resource(
fullpath
)
with resource_stream(module_name, resource_name) as stream:
loaded_cfg = OmegaConf.load(stream)
if record_load:
self.all_config_checked.append(
(filename, search_path.path, search_path.provider)
)
elif os.path.exists(fullpath):
loaded_cfg = OmegaConf.load(fullpath)
if record_load:
self.all_config_checked.append(
(filename, search_path.path, search_path.provider)
)
else:
# This should never happen because we just searched for it and found it
assert False, "'{}' not found".format(fullpath)
return loaded_cfg
def load_log(directory, trial_file=None):
if '.hydra' in os.listdir(directory):
full_conf = OmegaConf.load(f"{directory}/.hydra/config.yaml")
else:
full_conf = OmegaConf.load(f"{directory}/config.yaml")
trial_files = glob.glob(f"{directory}/trial_*.dat")
if len(trial_files) > 1:
if trial_file is not None:
last_trial_log = f"{directory}/{trial_file}"
else:
last_trial_log = max(trial_files, key=os.path.getctime)
vis_log = torch.load(last_trial_log)
logs[log_dir].append(vis_log)
configs[log_dir].append(full_conf)