diff --git a/pyproject.toml b/pyproject.toml index 2c10323..fd849a2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,8 @@ classifiers = [ "Intended Audience :: End Users/Desktop" ] dependencies = [ - "openjd-sessions >= 0.10.1,< 0.11", - "openjd-model >= 0.7,< 0.9" + "openjd-sessions >= 0.10.3,< 0.11", + "openjd-model >= 0.8,< 0.9" ] [project.urls] diff --git a/src/openjd/cli/_common/__init__.py b/src/openjd/cli/_common/__init__.py index 5a51bee..be487bd 100644 --- a/src/openjd/cli/_common/__init__.py +++ b/src/openjd/cli/_common/__init__.py @@ -9,7 +9,7 @@ import yaml import os -from ._extensions import add_extensions_argument, process_extensions_argument +from ._extensions import add_extensions_argument, process_extensions_argument, SUPPORTED_EXTENSIONS from ._job_from_template import ( job_from_template, get_job_params, @@ -33,6 +33,7 @@ "read_job_template", "read_environment_template", "validate_task_parameters", + "SUPPORTED_EXTENSIONS", ] diff --git a/src/openjd/cli/_common/_extensions.py b/src/openjd/cli/_common/_extensions.py index bd1a8f8..02fee28 100644 --- a/src/openjd/cli/_common/_extensions.py +++ b/src/openjd/cli/_common/_extensions.py @@ -4,7 +4,7 @@ from typing import Optional # This is the list of Open Job Description extensions with implemented support -SUPPORTED_EXTENSIONS = ["TASK_CHUNKING"] +SUPPORTED_EXTENSIONS = ["TASK_CHUNKING", "REDACTED_ENV_VARS"] def add_extensions_argument(run_parser: ArgumentParser): diff --git a/src/openjd/cli/_run/_local_session/_session_manager.py b/src/openjd/cli/_run/_local_session/_session_manager.py index 3d847ba..2413505 100644 --- a/src/openjd/cli/_run/_local_session/_session_manager.py +++ b/src/openjd/cli/_run/_local_session/_session_manager.py @@ -17,10 +17,14 @@ EnvironmentType, ) from ._logs import LocalSessionLogHandler, LogEntry, LoggingTimestampFormat +from ..._common import SUPPORTED_EXTENSIONS + from openjd.model import ( IntRangeExpr, Job, JobParameterValues, + RevisionExtensions, + SpecificationRevision, Step, StepParameterSpaceIterator, TaskParameterSet, @@ -83,6 +87,9 @@ def __init__( environments: Optional[list[Any]] = None, should_print_logs: bool = True, retain_working_dir: bool = False, + revision_extensions: RevisionExtensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=SUPPORTED_EXTENSIONS + ), ): self.session_id = session_id self._action_ended = Event() @@ -98,6 +105,7 @@ def __init__( path_mapping_rules=self._path_mapping_rules, callback=self._action_callback, retain_working_dir=retain_working_dir, + revision_extensions=revision_extensions, ) self._should_print_logs = should_print_logs diff --git a/src/openjd/cli/_run/_run_command.py b/src/openjd/cli/_run/_run_command.py index a113d07..a2d09f1 100644 --- a/src/openjd/cli/_run/_run_command.py +++ b/src/openjd/cli/_run/_run_command.py @@ -34,6 +34,8 @@ StepParameterSpaceIterator, ParameterValue, ParameterValueType, + RevisionExtensions, + SpecificationRevision, TaskParameterSet, ) from openjd.sessions import PathMappingRule, LOG @@ -328,6 +330,9 @@ def _run_local_session( path_mapping_rules: Optional[list[PathMappingRule]], should_print_logs: bool = True, retain_working_dir: bool = False, + revision_extensions: RevisionExtensions = RevisionExtensions( + spec_rev=SpecificationRevision.v2023_09, supported_extensions=[] + ), ) -> OpenJDCliResult: """ Creates a Session object and listens for log messages to synchronously end the session. @@ -346,6 +351,7 @@ def _run_local_session( environments=[env.environment for env in environments] if environments else [], should_print_logs=should_print_logs, retain_working_dir=retain_working_dir, + revision_extensions=revision_extensions, ) as session: for dep_step in step_list: step_name = dep_step.name @@ -512,6 +518,12 @@ def do_run(args: Namespace) -> OpenJDCliResult: except RuntimeError as rte: return OpenJDCliResult(status="error", message=str(rte)) + # Create a RevisionExtensions object with the default specification version and enabled extensions + # We use the default v2023_09 since that's what we're currently supporting + revision_extensions = RevisionExtensions( + spec_rev=the_job.revision, supported_extensions=extensions + ) + return _run_local_session( job=the_job, job_parameter_values=job_parameter_values, @@ -524,4 +536,5 @@ def do_run(args: Namespace) -> OpenJDCliResult: path_mapping_rules=path_mapping_rules, should_print_logs=(args.output == "human-readable"), retain_working_dir=args.preserve, + revision_extensions=revision_extensions, ) diff --git a/test/openjd/cli/templates/redacted_env.yaml b/test/openjd/cli/templates/redacted_env.yaml new file mode 100644 index 0000000..1dde925 --- /dev/null +++ b/test/openjd/cli/templates/redacted_env.yaml @@ -0,0 +1,52 @@ +specificationVersion: "jobtemplate-2023-09" +extensions: + - REDACTED_ENV_VARS +name: Test Redacted Env +description: Test redacted environment variables + +jobEnvironments: + - name: RedactedEnv + script: + actions: + onEnter: + command: python + args: ["{{Env.File.Enter}}"] + onExit: + command: python + args: ["{{Env.File.Exit}}"] + embeddedFiles: + - name: Enter + type: TEXT + data: | + print("Setting redacted vars..") + print(f"openjd_redacted_env: SECRETVAR=SECRETVAL") + print(f"openjd_redacted_env: KEYSPACE =SECRETVAL") + print(f"openjd_redacted_env: VALSPACE= SPACEVAL") + print(f'openjd_redacted_env: "MULTILINE=first_line\\nsecond_line\\nthird_line"') + - name: Exit + type: TEXT + data: | + import os + print(f"SECRETVAR is {os.environ.get('SECRETVAR')}") + print(f"KEYSPACE is {os.environ.get('KEYSPACE')}") + print(f"VALSPACE is {os.environ.get('VALSPACE')}") + print(f"MULTILINE is {os.environ.get('VALSPACE')} END") + print("first_line") + print("second_line") + print("third_line") +steps: + - name: CheckVars + script: + actions: + onRun: + command: python + args: ["{{Task.File.Run}}"] + embeddedFiles: + - name: Run + type: TEXT + data: | + import os + print(f"SECRETVAR is {os.environ.get('SECRETVAR')}") + print(f"KEYSPACE is {os.environ.get('KEYSPACE')}") + print(f"VALSPACE is {os.environ.get('VALSPACE')}") + print(f"MULTILINE is {os.environ.get('VALSPACE')} END") diff --git a/test/openjd/cli/test_redacted_env.py b/test/openjd/cli/test_redacted_env.py new file mode 100644 index 0000000..2affa5c --- /dev/null +++ b/test/openjd/cli/test_redacted_env.py @@ -0,0 +1,45 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +from pathlib import Path +import re + +from . import run_openjd_cli_main, format_capsys_outerr + +TEMPLATE_DIR = Path(__file__).parent / "templates" + + +def test_run_job_with_redacted_env(capsys): + """Test that environment variables set with openjd_redacted_env are properly handled.""" + outerr = run_openjd_cli_main( + capsys, + args=[ + "run", + str(TEMPLATE_DIR / "redacted_env.yaml"), + ], + expected_exit_code=0, + ) + + # Verify the environment variables were set + for expected_message_regex in [ + "Setting redacted vars", + "SECRETVAR is \\*\\*\\*\\*\\*\\*\\*\\*", + "KEYSPACE is None", + "VALSPACE is \\*\\*\\*\\*\\*\\*\\*\\*", + "MULTILINE is \\*\\*\\*\\*\\*\\*\\*\\*", + ]: + assert re.search( + expected_message_regex, outerr.out + ), f"Regex r'{expected_message_regex}' not matched in:\n{format_capsys_outerr(outerr)}" + + # Verify the openjd_redacted_env lines are not in the output + for unexpected_message in [ + "openjd_redacted_env: SECRETVAR=SECRETVAL", + "openjd_redacted_env: KEYSPACE =SECRETVAL", + "openjd_redacted_env: VALSPACE= SPACEVAL", + "first_line", + "second_line", + "third_line", + ]: + assert ( + unexpected_message not in outerr.out + ), f"Found unexpected line in output:\n{format_capsys_outerr(outerr)}"