diff --git a/DEVELOPMENT.md b/DEVELOPMENT.md index 641b21f..c9efd04 100644 --- a/DEVELOPMENT.md +++ b/DEVELOPMENT.md @@ -31,6 +31,7 @@ from any directory of this repository: * `hatch build` - To build the installable Python wheel and sdist packages into the `dist/` directory. * `hatch run test` - To run the PyTest unit tests found in the `test/` directory. See [Testing](#testing). * `hatch run all:test` - To run the PyTest unit tests against all available supported versions of Python. +* `hatch run benchmark` - To run performance benchmark tests found in the `test/openjd/model/benchmark/` directory. * `hatch run lint` - To check that the package's formatting adheres to our standards. * `hatch run fmt` - To automatically reformat all code to adhere to our formatting standards. * `hatch shell` - Enter a shell environment where you can run the `deadline` command-line directly as it is implemented in your @@ -94,6 +95,17 @@ You can run tests with: Any arguments that you add to these commands are passed through to PyTest. So, if you want to, say, run the [Python debugger](https://docs.python.org/3/library/pdb.html) to investigate a test failure then you can run: `hatch run test --pdb` +### Running Benchmarks + +Performance benchmark tests are kept separate from the regular test suite and can be run with: + +* `hatch run benchmark` - To run all benchmark tests. +* `hatch run benchmark -k ` - To run a specific benchmark test. + +Benchmarks are designed to measure performance characteristics of the library and may take longer to run than regular tests. + +Benchmarks may include log output which can be enabled following the instructions in the test output section below. + ### Super verbose test output If you find that you need much more information from a failing test (say you're debugging a diff --git a/hatch.toml b/hatch.toml index 0a2b20b..a701765 100644 --- a/hatch.toml +++ b/hatch.toml @@ -5,7 +5,8 @@ pre-install-commands = [ [envs.default.scripts] sync = "pip install -r requirements-testing.txt" -test = "pytest --cov-config pyproject.toml {args:test}" +test = "pytest test/ --cov-config pyproject.toml --ignore=test/openjd/model/benchmark {args}" +benchmark = "pytest test/openjd/model/benchmark --no-cov {args}" typing = "mypy {args:src test}" style = [ "ruff check {args:src test}", diff --git a/pyproject.toml b/pyproject.toml index 12824f6..6a3822a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -130,7 +130,6 @@ addopts = [ "--timeout=30" ] - [tool.coverage.run] branch = true parallel = true diff --git a/src/openjd/model/_internal/_create_job.py b/src/openjd/model/_internal/_create_job.py index 40f73c5..49e1573 100644 --- a/src/openjd/model/_internal/_create_job.py +++ b/src/openjd/model/_internal/_create_job.py @@ -1,7 +1,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. from contextlib import contextmanager -from typing import Annotated, Any, Union +from typing import Annotated, Any, Union, Dict from pydantic import ValidationError from pydantic import TypeAdapter @@ -13,6 +13,27 @@ __all__ = ("instantiate_model",) +# Cache for TypeAdapter instances +_type_adapter_cache: Dict[int, TypeAdapter] = {} + + +def get_type_adapter(field_type: Any) -> TypeAdapter: + """Get a TypeAdapter for the given field type, using a cache for efficiency. + Assumes field_type refers to shared type definitions from the model so two field_types + referring to the same underlying model type will have the same id and share an adapter. + + Args: + field_type: The type to adapt. + + Returns: + A TypeAdapter for the given type. + """ + # Use id as cache key + type_key = id(field_type) + if type_key not in _type_adapter_cache: + _type_adapter_cache[type_key] = TypeAdapter(field_type) + return _type_adapter_cache[type_key] + @contextmanager def capture_validation_errors( @@ -114,9 +135,9 @@ def instantiate_model( # noqa: C901 else: instantiated = _instantiate_noncollection_value(field_value, symtab, needs_resolve) - # Validate as the target field type - type_adaptor: Any = TypeAdapter(target_field_type) - instantiated = type_adaptor.validate_python(instantiated) + # Validate as the target field type using cached TypeAdapter + type_adapter = get_type_adapter(target_field_type) + instantiated = type_adapter.validate_python(instantiated) instantiated_fields[target_field_name] = instantiated if not errors: diff --git a/test/openjd/model/_internal/test_create_job.py b/test/openjd/model/_internal/test_create_job.py index 442c9d2..b1e9f47 100644 --- a/test/openjd/model/_internal/test_create_job.py +++ b/test/openjd/model/_internal/test_create_job.py @@ -10,7 +10,11 @@ from openjd.model import SymbolTable from openjd.model._format_strings import FormatString from openjd.model import SpecificationRevision -from openjd.model._internal._create_job import instantiate_model +from openjd.model._internal._create_job import ( + instantiate_model, + get_type_adapter, + _type_adapter_cache, +) from openjd.model._types import ( JobCreateAsMetadata, JobCreationMetadata, @@ -19,6 +23,36 @@ from openjd.model.v2023_09 import ModelParsingContext as ModelParsingContext_v2023_09 +class TestGetTypeAdapter: + """Tests for the get_type_adapter function.""" + + def setup_method(self): + """Clear the type adapter cache before each test.""" + _type_adapter_cache.clear() + + def test_get_type_adapter_caching(self): + """Test that get_type_adapter caches TypeAdapters for repeated calls with the same type.""" + # First call should create a new TypeAdapter + str_type = str + str_type_id = id(str_type) + adapter1 = get_type_adapter(str_type) + + # Second call should return the cached TypeAdapter + adapter2 = get_type_adapter(str_type) + + # The adapters should be the same object (not just equal) + assert adapter1 is adapter2 + + # The cache should contain the entry + assert str_type_id in _type_adapter_cache + + # Different types should get different adapters + int_type = int + adapter_int = get_type_adapter(int_type) + assert adapter_int is not adapter1 + assert id(int_type) in _type_adapter_cache + + class BaseModelForTesting(OpenJDModel): # Specific version doesn't matter for these tests revision = SpecificationRevision.UNDEFINED diff --git a/test/openjd/model/benchmark/__init__.py b/test/openjd/model/benchmark/__init__.py new file mode 100644 index 0000000..8d929cc --- /dev/null +++ b/test/openjd/model/benchmark/__init__.py @@ -0,0 +1 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. diff --git a/test/openjd/model/benchmark/test_benchmark_step_environments.py b/test/openjd/model/benchmark/test_benchmark_step_environments.py new file mode 100644 index 0000000..8f3abe0 --- /dev/null +++ b/test/openjd/model/benchmark/test_benchmark_step_environments.py @@ -0,0 +1,115 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + +import time +import cProfile +import pstats +import io +import logging +import pytest +from openjd.model import create_job, decode_job_template + +# Configure logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger("openjd.model.benchmark") + + +class TestBenchmarkStepEnvironmentsPerformance: + """Benchmark class to verify performance with large numbers of step environments.""" + + def test_job_template_with_many_total_step_environments(self): + """ + Benchmark that a job template with many total step environments across multiple steps is processed efficiently. + + This test creates steps with many environments each and verifies the processing time. + """ + # Create a job template with multiple steps, each with step environments + num_steps = 100 # Create 100 steps + num_step_envs_per_step = 200 # 200 environments per step + + logger.info( + f"CREATING JOB TEMPLATE WITH {num_steps} STEPS AND {num_step_envs_per_step} ENVIRONMENTS PER STEP" + ) + + steps = [] + for step_num in range(num_steps): + steps.append( + { + "name": f"TestStep{step_num}", + "script": { + "actions": {"onRun": {"command": "echo", "args": [f"Step {step_num}"]}} + }, + "stepEnvironments": [ + {"name": f"stepEnv{step_num}_{i}", "variables": {"key": f"value{i}"}} + for i in range(num_step_envs_per_step) + ], + } + ) + + job_template_with_many_total_envs = { + "specificationVersion": "jobtemplate-2023-09", + "name": "Test Job with Many Total Step Environments", + "steps": steps, + } + + logger.info("STARTING JOB TEMPLATE PROCESSING") + + # Set up profiler + profiler = cProfile.Profile() + profiler.enable() + + start_time = time.time() + + try: + # Create a proper JobTemplate object from the dictionary using decode_job_template + job_template = decode_job_template(template=job_template_with_many_total_envs) + + # Call create_job with the JobTemplate object + _ = create_job(job_template=job_template, job_parameter_values={}) + + elapsed_time = time.time() - start_time + logger.info(f"PERFORMANCE RESULT: create_job completed in {elapsed_time:.2f} seconds") + + # Disable profiler and print results + profiler.disable() + + # Log the top 20 functions by cumulative time + s = io.StringIO() + ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative") + ps.print_stats(20) + logger.info("TOP 20 FUNCTIONS BY CUMULATIVE TIME:") + for line in s.getvalue().splitlines(): + logger.info(line) + + # Log the top 20 functions by total time + s = io.StringIO() + ps = pstats.Stats(profiler, stream=s).sort_stats("time") + ps.print_stats(20) + logger.info("TOP 20 FUNCTIONS BY TOTAL TIME:") + for line in s.getvalue().splitlines(): + logger.info(line) + + # Verify that the operation completed within a reasonable time + assert ( + elapsed_time < 10 + ), f"Operation took {elapsed_time:.2f} seconds, which exceeds the 10 second threshold" + + except Exception as e: + # Disable profiler in case of exception + profiler.disable() + + elapsed_time = time.time() - start_time + logger.error( + f"ERROR: create_job failed in {elapsed_time:.2f} seconds with error: {str(e)}" + ) + + # Log profiling information even in case of failure + s = io.StringIO() + ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative") + ps.print_stats(20) + logger.info("TOP 20 FUNCTIONS BY CUMULATIVE TIME (BEFORE ERROR):") + for line in s.getvalue().splitlines(): + logger.info(line) + + pytest.fail(f"create_job failed with error: {str(e)}")