Skip to content

Commit 2196f40

Browse files
authored
chore: Centralize OmegaConf resolver registration (#1882)
Signed-off-by: ruit <ruit@nvidia.com>
1 parent ed718b4 commit 2196f40

9 files changed

Lines changed: 56 additions & 28 deletions

File tree

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,13 @@
4848
from nemo_rl.environments.utils import create_env
4949
from nemo_rl.experience.rollouts import run_async_nemo_gym_rollout
5050
from nemo_rl.models.generation import configure_generation_config
51-
from nemo_rl.utils.config import load_config, parse_hydra_overrides
51+
from nemo_rl.utils.config import (
52+
load_config,
53+
parse_hydra_overrides,
54+
register_omegaconf_resolvers,
55+
)
5256
from nemo_rl.utils.logger import get_next_experiment_dir
5357

54-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
55-
5658

5759
def parse_args() -> tuple[argparse.Namespace, list[str]]:
5860
"""Parse command line arguments."""
@@ -118,6 +120,7 @@ def collect_trajectories(
118120
def main() -> None:
119121
"""Main entry point."""
120122
# Parse arguments
123+
register_omegaconf_resolvers()
121124
args, overrides = parse_args()
122125

123126
if not args.config:

examples/run_distillation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
from nemo_rl.data.utils import setup_response_data
2323
from nemo_rl.distributed.virtual_cluster import init_ray
2424
from nemo_rl.models.generation import configure_generation_config
25-
from nemo_rl.utils.config import load_config, parse_hydra_overrides
25+
from nemo_rl.utils.config import (
26+
load_config,
27+
parse_hydra_overrides,
28+
register_omegaconf_resolvers,
29+
)
2630
from nemo_rl.utils.logger import get_next_experiment_dir
2731

28-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
29-
3032

3133
def parse_args() -> tuple[argparse.Namespace, list[str]]:
3234
"""Parse command line arguments."""
@@ -46,6 +48,7 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
4648
def main() -> None:
4749
"""Main entry point."""
4850
# Parse arguments
51+
register_omegaconf_resolvers()
4952
args, overrides = parse_args()
5053

5154
if not args.config:

examples/run_grpo.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
from nemo_rl.data.utils import setup_response_data
2424
from nemo_rl.distributed.virtual_cluster import init_ray
2525
from nemo_rl.models.generation import configure_generation_config
26-
from nemo_rl.utils.config import load_config, parse_hydra_overrides
26+
from nemo_rl.utils.config import (
27+
load_config,
28+
parse_hydra_overrides,
29+
register_omegaconf_resolvers,
30+
)
2731
from nemo_rl.utils.logger import get_next_experiment_dir
2832

29-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
30-
3133

3234
def parse_args() -> tuple[argparse.Namespace, list[str]]:
3335
"""Parse command line arguments."""
@@ -45,6 +47,7 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
4547
def main() -> None:
4648
"""Main entry point."""
4749
# Parse arguments
50+
register_omegaconf_resolvers()
4851
args, overrides = parse_args()
4952

5053
if not args.config:

examples/run_grpo_sliding_puzzle.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,13 @@
3434
SlidingPuzzleMetadata,
3535
)
3636
from nemo_rl.models.generation import configure_generation_config
37-
from nemo_rl.utils.config import load_config, parse_hydra_overrides
37+
from nemo_rl.utils.config import (
38+
load_config,
39+
parse_hydra_overrides,
40+
register_omegaconf_resolvers,
41+
)
3842
from nemo_rl.utils.logger import get_next_experiment_dir
3943

40-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
41-
4244

4345
def parse_args():
4446
"""Parse command line arguments."""
@@ -192,6 +194,7 @@ def setup_puzzle_data(
192194
def main():
193195
"""Main entry point."""
194196
# Parse arguments
197+
register_omegaconf_resolvers()
195198
args, overrides = parse_args()
196199

197200
if not args.config:

examples/run_sft.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@
3030
update_single_dataset_config,
3131
)
3232
from nemo_rl.distributed.virtual_cluster import init_ray
33-
from nemo_rl.utils.config import load_config, parse_hydra_overrides
33+
from nemo_rl.utils.config import (
34+
load_config,
35+
parse_hydra_overrides,
36+
register_omegaconf_resolvers,
37+
)
3438
from nemo_rl.utils.logger import get_next_experiment_dir
3539

36-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
37-
OmegaConf.register_new_resolver("max", lambda a, b: max(a, b))
38-
3940

4041
def parse_args():
4142
"""Parse command line arguments."""
@@ -148,6 +149,7 @@ def setup_data(tokenizer: AutoTokenizer, data_config: DataConfig):
148149
def main(is_vlm: bool = False):
149150
"""Main entry point."""
150151
# Parse arguments
152+
register_omegaconf_resolvers()
151153
args, overrides = parse_args()
152154

153155
if not args.config:

examples/run_vlm_grpo.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
from nemo_rl.data.utils import setup_response_data
2424
from nemo_rl.distributed.virtual_cluster import init_ray
2525
from nemo_rl.models.generation import configure_generation_config
26-
from nemo_rl.utils.config import load_config, parse_hydra_overrides
26+
from nemo_rl.utils.config import (
27+
load_config,
28+
parse_hydra_overrides,
29+
register_omegaconf_resolvers,
30+
)
2731
from nemo_rl.utils.logger import get_next_experiment_dir
2832

29-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
30-
3133

3234
def parse_args() -> tuple[argparse.Namespace, list[str]]:
3335
"""Parse command line arguments."""
@@ -42,6 +44,8 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
4244

4345
def main() -> None:
4446
"""Main entry point."""
47+
# Parse arguments
48+
register_omegaconf_resolvers()
4549
args, overrides = parse_args()
4650

4751
if not args.config:

nemo_rl/utils/config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,11 @@ def parse_hydra_overrides(cfg: DictConfig, overrides: list[str]) -> DictConfig:
186186
return cfg
187187
except Exception as e:
188188
raise OverridesError(f"Failed to parse Hydra overrides: {str(e)}") from e
189+
190+
191+
def register_omegaconf_resolvers() -> None:
192+
"""Register shared OmegaConf resolvers used in configs."""
193+
if not OmegaConf.has_resolver("mul"):
194+
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
195+
if not OmegaConf.has_resolver("max"):
196+
OmegaConf.register_new_resolver("max", lambda a, b: max(a, b))

research/template_project/single_update.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@
4141
from nemo_rl.models.generation import configure_generation_config
4242
from nemo_rl.models.generation.vllm import VllmGeneration
4343
from nemo_rl.models.policy.lm_policy import Policy
44-
from nemo_rl.utils.config import load_config, parse_hydra_overrides
45-
46-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
44+
from nemo_rl.utils.config import (
45+
load_config,
46+
parse_hydra_overrides,
47+
register_omegaconf_resolvers,
48+
)
4749

4850

4951
def main(config: MasterConfig) -> None:
@@ -178,6 +180,7 @@ def parse_args() -> tuple[argparse.Namespace, list[str]]:
178180

179181
if __name__ == "__main__":
180182
# Parse arguments
183+
register_omegaconf_resolvers()
181184
args, overrides = parse_args()
182185

183186
if not args.config:

tests/unit/test_config_validation.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,15 @@
2727
from nemo_rl.algorithms.rm import MasterConfig as RMMasterConfig
2828
from nemo_rl.algorithms.sft import MasterConfig as SFTMasterConfig
2929
from nemo_rl.evals.eval import MasterConfig as EvalMasterConfig
30-
from nemo_rl.utils.config import load_config_with_inheritance
30+
from nemo_rl.utils.config import (
31+
load_config_with_inheritance,
32+
register_omegaconf_resolvers,
33+
)
3134

3235
# All tests in this module should run first
3336
pytestmark = pytest.mark.run_first
3437

35-
if not OmegaConf.has_resolver("mul"):
36-
OmegaConf.register_new_resolver("mul", lambda a, b: a * b)
37-
38-
if not OmegaConf.has_resolver("max"):
39-
OmegaConf.register_new_resolver("max", lambda a, b: max(a, b))
38+
register_omegaconf_resolvers()
4039

4140

4241
def validate_config_section(

0 commit comments

Comments
 (0)