Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add _resolve_ flag to instantiate. #2269

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 58 additions & 13 deletions hydra/_internal/instantiate/_instantiate2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class _Keys(str, Enum):
RECURSIVE = "_recursive_"
ARGS = "_args_"
PARTIAL = "_partial_"
RESOLVE = "_resolve_"


def _is_target(x: Any) -> bool:
Expand Down Expand Up @@ -164,6 +165,8 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
a trace of OmegaConf containers
_partial_: If True, return functools.partial wrapped method or object
False by default. Configure per target.
_resolve_: If to resolve the OmegaConf configuration (bool)
True by default.
:param args: Optional positional parameters pass-through
:param kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
Expand Down Expand Up @@ -213,14 +216,22 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
if kwargs:
config = OmegaConf.merge(config, kwargs)

OmegaConf.resolve(config)
_resolve_ = config.pop(_Keys.RESOLVE, True)

if _resolve_:
OmegaConf.resolve(config)

_recursive_ = config.pop(_Keys.RECURSIVE, True)
_convert_ = config.pop(_Keys.CONVERT, ConvertMode.NONE)
_partial_ = config.pop(_Keys.PARTIAL, False)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config,
*args,
recursive=_recursive_,
convert=_convert_,
partial=_partial_,
resolve=_resolve_,
)
elif OmegaConf.is_list(config):
# Finalize config (convert targets to strings, merge with kwargs)
Expand All @@ -236,14 +247,20 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
_recursive_ = kwargs.pop(_Keys.RECURSIVE, True)
_convert_ = kwargs.pop(_Keys.CONVERT, ConvertMode.NONE)
_partial_ = kwargs.pop(_Keys.PARTIAL, False)
_resolve_ = True

if _partial_:
raise InstantiationException(
"The _partial_ keyword is not compatible with top-level list instantiation"
)

return instantiate_node(
config, *args, recursive=_recursive_, convert=_convert_, partial=_partial_
config,
*args,
recursive=_recursive_,
convert=_convert_,
partial=_partial_,
resolve=_resolve_,
)
else:
raise InstantiationException(
Expand All @@ -256,13 +273,17 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
)


def _convert_node(node: Any, convert: Union[ConvertMode, str]) -> Any:
def _convert_node(
node: Any,
convert: Union[ConvertMode, str],
resolve: bool = True,
) -> Any:
if OmegaConf.is_config(node):
if convert == ConvertMode.ALL:
node = OmegaConf.to_container(node, resolve=True)
node = OmegaConf.to_container(node, resolve=resolve)
elif convert == ConvertMode.PARTIAL:
node = OmegaConf.to_container(
node, resolve=True, structured_config_mode=SCMode.DICT_CONFIG
node, resolve=resolve, structured_config_mode=SCMode.DICT_CONFIG
)
return node

Expand All @@ -273,6 +294,7 @@ def instantiate_node(
convert: Union[str, ConvertMode] = ConvertMode.NONE,
recursive: bool = True,
partial: bool = False,
resolve: bool = True,
) -> Any:
# Return None if config is None
if node is None or (OmegaConf.is_config(node) and node._is_none()):
Expand All @@ -288,6 +310,7 @@ def instantiate_node(
convert = node[_Keys.CONVERT] if _Keys.CONVERT in node else convert
recursive = node[_Keys.RECURSIVE] if _Keys.RECURSIVE in node else recursive
partial = node[_Keys.PARTIAL] if _Keys.PARTIAL in node else partial
convert = node[_Keys.RESOLVE] if _Keys.CONVERT in node else resolve

full_key = node._get_full_key(None)

Expand All @@ -297,6 +320,12 @@ def instantiate_node(
msg += f"\nfull_key: {full_key}"
raise TypeError(msg)

if not isinstance(resolve, bool):
msg = f"Instantiation: _resolve_ flag must be a bool, got {type(recursive)}"
if full_key:
msg += f"\nfull_key: {full_key}"
raise TypeError(msg)

if not isinstance(partial, bool):
msg = f"Instantiation: _partial_ flag must be a bool, got {type( partial )}"
if node and full_key:
Expand All @@ -306,8 +335,10 @@ def instantiate_node(
# If OmegaConf list, create new list of instances if recursive
if OmegaConf.is_list(node):
items = [
instantiate_node(item, convert=convert, recursive=recursive)
for item in node._iter_ex(resolve=True)
instantiate_node(
item, convert=convert, recursive=recursive, resolve=resolve
)
for item in node._iter_ex(resolve=resolve)
]

if convert in (ConvertMode.ALL, ConvertMode.PARTIAL):
Expand All @@ -320,21 +351,29 @@ def instantiate_node(
return lst

elif OmegaConf.is_dict(node):
exclude_keys = set({"_target_", "_convert_", "_recursive_", "_partial_"})
exclude_keys = set(
{"_target_", "_convert_", "_recursive_", "_partial_", "_resolve_"}
)
if _is_target(node):
_target_ = _resolve_target(node.get(_Keys.TARGET), full_key)
kwargs = {}
is_partial = node.get("_partial_", False) or partial
is_resolve = (node.get("_resolve_", None) is None and resolve) or node.get(
"_resolve_", False
)
for key in node.keys():
if key not in exclude_keys:
if OmegaConf.is_missing(node, key) and is_partial:
continue
value = node[key]
if recursive:
value = instantiate_node(
value, convert=convert, recursive=recursive
value,
convert=convert,
recursive=recursive,
resolve=is_resolve,
)
kwargs[key] = _convert_node(value, convert)
kwargs[key] = _convert_node(value, convert, resolve=is_resolve)

return _call_target(_target_, partial, args, kwargs, full_key)
else:
Expand All @@ -347,15 +386,21 @@ def instantiate_node(
for key, value in node.items():
# list items inherits recursive flag from the containing dict.
dict_items[key] = instantiate_node(
value, convert=convert, recursive=recursive
value,
convert=convert,
recursive=recursive,
resolve=resolve,
)
return dict_items
else:
# Otherwise use DictConfig and resolve interpolations lazily.
cfg = OmegaConf.create({}, flags={"allow_objects": True})
for key, value in node.items():
cfg[key] = instantiate_node(
value, convert=convert, recursive=recursive
value,
convert=convert,
recursive=recursive,
resolve=resolve,
)
cfg._set_parent(node)
cfg._metadata.object_type = node._metadata.object_type
Expand Down
13 changes: 13 additions & 0 deletions tests/instantiate/test_instantiate.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,19 @@ def config(request: Any, src: Any) -> Any:
AClass(10, 20, 30, 40),
id="class",
),
param(
{
"_target_": "tests.instantiate.AClass",
"a": "${somethingthatisjustpassed}",
"b": 20,
"c": 30,
"d": 40,
"_resolve_": False,
},
{},
AClass("somethingthatisjustpassed", 20, 30, 40),
id="class+not_resolve",
),
param(
{
"_target_": "tests.instantiate.AClass",
Expand Down
5 changes: 5 additions & 0 deletions website/docs/advanced/instantiate_objects/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def instantiate(config: Any, *args: Any, **kwargs: Any) -> Any:
a trace of OmegaConf containers
_partial_: If True, return functools.partial wrapped method or object
False by default. Configure per target.
_resolve_: If to resolve the OmegaConf configuration (bool)
True by default.
:param args: Optional positional parameters pass-through
:param kwargs: Optional named parameters to override
parameters in the config object. Parameters not present
Expand Down Expand Up @@ -369,6 +371,9 @@ assert bar1.foo is bar2.foo # the `Foo` instance is re-used here
This does not apply if `_partial_=False`,
in which case a new `Foo` instance would be created with each call to `instantiate`.

### OmegaConf resolution
By default, the configuration is resolved via `OmegaConf.resolve`. This can be turned off, using `_resolve_=False`.


### Instantiation of builtins

Expand Down