diff --git a/nemo_run/core/execution/launcher.py b/nemo_run/core/execution/launcher.py index f40d52cf..980ad87c 100644 --- a/nemo_run/core/execution/launcher.py +++ b/nemo_run/core/execution/launcher.py @@ -14,6 +14,15 @@ class Launcher(ConfigurableMixin): nsys_profile: bool = False nsys_folder: str = "nsys_profile" nsys_trace: list[str] = field(default_factory=lambda: ["nvtx", "cuda"]) + nsys_extra_args: list[str] = field( + default_factory=lambda: [ + "--force-overwrite=true", + "--capture-range=cudaProfilerApi", + "--capture-range-end=stop", + "--cuda-graph-trace=node", + "--cuda-event-trace=false", + ] + ) def get_nsys_prefix(self, profile_dir: str) -> Optional[list[str]]: """Make a command prefix for nsys profiling""" @@ -27,12 +36,7 @@ def get_nsys_prefix(self, profile_dir: str) -> Optional[list[str]]: ",".join(self.nsys_trace), "-o", f"{profile_out_path}/profile_%p", - "--force-overwrite", - "true", - "--capture-range=cudaProfilerApi", - "--capture-range-end=stop", - "--cuda-graph-trace=node", - ] + ] + self.nsys_extra_args return args def transform(self, cmd: list[str]) -> Optional[Script]: ... diff --git a/nemo_run/package_info.py b/nemo_run/package_info.py index 6316ccf3..0682451f 100644 --- a/nemo_run/package_info.py +++ b/nemo_run/package_info.py @@ -13,7 +13,7 @@ # limitations under the License. from packaging.version import Version -__version__ = '0.5.0rc0.dev0' +__version__ = "0.5.0rc0.dev0" MAJOR = Version(__version__).major MINOR = Version(__version__).minor diff --git a/test/core/execution/test_slurm_templates.py b/test/core/execution/test_slurm_templates.py index 3faa7e32..dbb33690 100644 --- a/test/core/execution/test_slurm_templates.py +++ b/test/core/execution/test_slurm_templates.py @@ -470,11 +470,11 @@ def test_dummy_batch_request_nsys( "nvtx,cuda", "-o", "/nemo_run/nsys_profile/profile_%p", - "--force-overwrite", - "true", + "--force-overwrite=true", "--capture-range=cudaProfilerApi", "--capture-range-end=stop", "--cuda-graph-trace=node", + "--cuda-event-trace=false", ] def test_dummy_batch_request_warn(