# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import os
import warnings
from typing import IO, Any, Callable, Dict, List, Union
import yaml
from yacs.config import CfgNode as _CfgNode
__all__ = ["CfgNode"]
[docs]class CfgNode(_CfgNode):
"""
Our own extended version of :class:`yacs.config.CfgNode`.
It contains the following extra features:
1. The :meth:`merge_from_file` method supports the "_BASE_" key,
which allows the new CfgNode to inherit all the attributes from the
base configuration file.
2. Keys that start with "COMPUTED_" are treated as insertion-only
"computed" attributes. They can be inserted regardless of whether
the CfgNode is frozen or not.
3. With "allow_unsafe=True", it supports pyyaml tags that evaluate
expressions in config. See examples in
https://pyyaml.org/wiki/PyYAMLDocumentation#yaml-tags-and-python-types
Note that this may lead to arbitrary code execution: you must not
load a config file from untrusted sources before manually inspecting
the content of the file.
"""
@classmethod
def _open_cfg(cls, filename: str) -> Union[IO[str], IO[bytes]]:
"""
Defines how a config file is opened. May be overridden to support
different file schemas.
"""
return open(filename, "r")
[docs] @classmethod
def load_yaml_with_base(cls, filename: str, allow_unsafe: bool = False) -> None:
"""
Just like `yaml.load(open(filename))`, but inherit attributes from its
`_BASE_`.
Args:
filename (str or file-like object): the file name or file of the current config.
Will be used to find the base config file.
allow_unsafe (bool): whether to allow loading the config file with
`yaml.unsafe_load`.
Returns:
(dict): the loaded yaml
"""
with cls._open_cfg(filename) as f:
try:
cfg = yaml.safe_load(f)
except yaml.constructor.ConstructorError:
if not allow_unsafe:
raise
warnings.warn(
"Loading config {} with yaml.unsafe_load. Your machine may "
"be at risk if the file contains malicious content.".format(
filename, UserWarning
)
)
f.close()
with cls._open_cfg(filename) as f:
cfg = yaml.unsafe_load(f)
# pyre-ignore
def merge_a_into_b(a: Dict[Any, Any], b: Dict[Any, Any]) -> None:
# merge dict a into dict b. values in a will overwrite b.
for k, v in a.items():
if isinstance(v, dict) and k in b:
assert isinstance(
b[k], dict
), "Cannot inherit key '{}' from base!".format(k)
merge_a_into_b(v, b[k])
else:
b[k] = v
return cfg
[docs] def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = False) -> None:
"""
Merge configs from a given yaml file.
Args:
cfg_filename: the file name of the yaml config.
allow_unsafe: whether to allow loading the config file with
`yaml.unsafe_load`.
"""
loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe)
loaded_cfg = type(self)(loaded_cfg)
self.merge_from_other_cfg(loaded_cfg)
[docs] def merge_from_other_cfg(self, cfg_other: object) -> Callable[[], None]:
"""
Args:
cfg_other (CfgNode): configs to merge from.
"""
return super().merge_from_other_cfg(cfg_other)
[docs] def merge_from_list(self, cfg_list: List[object]) -> Callable[[], None]:
"""
Args:
cfg_list (list): list of configs to merge from.
"""
return super().merge_from_list(cfg_list)
[docs] def merge_from_dict(self, cfg_dict: Dict[object, object]) -> Callable[[], None]:
"""
Args:
cfg_dict (dict): dict of configs to merge from.
"""
cfg_dict.pop("self", None)
merge_dict(self, cfg_dict)
return self
def merge_dict(a, b):
"""merge dict instance b into a,
an extended version of `update` method.
Examples
--------
>>> a = {'x': 1, 'y': {'z': 2, 'v': 3}}
>>> b = {'y': {'z': 3}}
>>> a.update(ab)
{'x': 1, 'y': {'z': 3}}
>>> merge_dict(a,b)
{'x': 1, 'y': {'z': 3, 'v': 3}}
Parameters
----------
a : dict
b : dict
"""
for k, v in b.items():
if not isinstance(v, dict):
a[k] = v
else:
if k not in a:
a[k] = CfgNode()
merge_dict(a[k], b[k])