Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions DEVELOPMENT.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ from any directory of this repository:
* `hatch build` - To build the installable Python wheel and sdist packages into the `dist/` directory.
* `hatch run test` - To run the PyTest unit tests found in the `test/` directory. See [Testing](#testing).
* `hatch run all:test` - To run the PyTest unit tests against all available supported versions of Python.
* `hatch run benchmark` - To run performance benchmark tests found in the `test/openjd/model/benchmark/` directory.
* `hatch run lint` - To check that the package's formatting adheres to our standards.
* `hatch run fmt` - To automatically reformat all code to adhere to our formatting standards.
* `hatch shell` - Enter a shell environment where you can run the `deadline` command-line directly as it is implemented in your
Expand Down Expand Up @@ -94,6 +95,17 @@ You can run tests with:
Any arguments that you add to these commands are passed through to PyTest. So, if you want to, say, run the
[Python debugger](https://docs.python.org/3/library/pdb.html) to investigate a test failure then you can run: `hatch run test --pdb`

### Running Benchmarks

Performance benchmark tests are kept separate from the regular test suite and can be run with:

* `hatch run benchmark` - To run all benchmark tests.
* `hatch run benchmark -k <benchmark>` - To run a specific benchmark test.

Benchmarks are designed to measure performance characteristics of the library and may take longer to run than regular tests.

Benchmarks may include log output which can be enabled following the instructions in the test output section below.

### Super verbose test output

If you find that you need much more information from a failing test (say you're debugging a
Expand Down
3 changes: 2 additions & 1 deletion hatch.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ pre-install-commands = [

[envs.default.scripts]
sync = "pip install -r requirements-testing.txt"
test = "pytest --cov-config pyproject.toml {args:test}"
test = "pytest test/ --cov-config pyproject.toml --ignore=test/openjd/model/benchmark {args}"
benchmark = "pytest test/openjd/model/benchmark --no-cov {args}"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we be running these in GitHub actions to prevent regression before merge?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at that - I'd like to, but it looks like the commands run live in a shared config currently. If there's a simple way to add that I'm happy to do it though!

typing = "mypy {args:src test}"
style = [
"ruff check {args:src test}",
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ addopts = [
"--timeout=30"
]


[tool.coverage.run]
branch = true
parallel = true
Expand Down
29 changes: 25 additions & 4 deletions src/openjd/model/_internal/_create_job.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

from contextlib import contextmanager
from typing import Annotated, Any, Union
from typing import Annotated, Any, Union, Dict

from pydantic import ValidationError
from pydantic import TypeAdapter
Expand All @@ -13,6 +13,27 @@

__all__ = ("instantiate_model",)

# Cache for TypeAdapter instances
_type_adapter_cache: Dict[int, TypeAdapter] = {}


def get_type_adapter(field_type: Any) -> TypeAdapter:
"""Get a TypeAdapter for the given field type, using a cache for efficiency.
Assumes field_type refers to shared type definitions from the model so two field_types
referring to the same underlying model type will have the same id and share an adapter.

Args:
field_type: The type to adapt.

Returns:
A TypeAdapter for the given type.
"""
# Use id as cache key
type_key = id(field_type)
if type_key not in _type_adapter_cache:
_type_adapter_cache[type_key] = TypeAdapter(field_type)
return _type_adapter_cache[type_key]


@contextmanager
def capture_validation_errors(
Expand Down Expand Up @@ -114,9 +135,9 @@ def instantiate_model( # noqa: C901
else:
instantiated = _instantiate_noncollection_value(field_value, symtab, needs_resolve)

# Validate as the target field type
type_adaptor: Any = TypeAdapter(target_field_type)
instantiated = type_adaptor.validate_python(instantiated)
# Validate as the target field type using cached TypeAdapter
type_adapter = get_type_adapter(target_field_type)
instantiated = type_adapter.validate_python(instantiated)
instantiated_fields[target_field_name] = instantiated

if not errors:
Expand Down
36 changes: 35 additions & 1 deletion test/openjd/model/_internal/test_create_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,11 @@
from openjd.model import SymbolTable
from openjd.model._format_strings import FormatString
from openjd.model import SpecificationRevision
from openjd.model._internal._create_job import instantiate_model
from openjd.model._internal._create_job import (
instantiate_model,
get_type_adapter,
_type_adapter_cache,
)
from openjd.model._types import (
JobCreateAsMetadata,
JobCreationMetadata,
Expand All @@ -19,6 +23,36 @@
from openjd.model.v2023_09 import ModelParsingContext as ModelParsingContext_v2023_09


class TestGetTypeAdapter:
"""Tests for the get_type_adapter function."""

def setup_method(self):
"""Clear the type adapter cache before each test."""
_type_adapter_cache.clear()

def test_get_type_adapter_caching(self):
"""Test that get_type_adapter caches TypeAdapters for repeated calls with the same type."""
# First call should create a new TypeAdapter
str_type = str
str_type_id = id(str_type)
adapter1 = get_type_adapter(str_type)

# Second call should return the cached TypeAdapter
adapter2 = get_type_adapter(str_type)

# The adapters should be the same object (not just equal)
assert adapter1 is adapter2

# The cache should contain the entry
assert str_type_id in _type_adapter_cache

# Different types should get different adapters
int_type = int
adapter_int = get_type_adapter(int_type)
assert adapter_int is not adapter1
assert id(int_type) in _type_adapter_cache


class BaseModelForTesting(OpenJDModel):
# Specific version doesn't matter for these tests
revision = SpecificationRevision.UNDEFINED
Expand Down
1 change: 1 addition & 0 deletions test/openjd/model/benchmark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
115 changes: 115 additions & 0 deletions test/openjd/model/benchmark/test_benchmark_step_environments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.

import time
import cProfile
import pstats
import io
import logging
import pytest
from openjd.model import create_job, decode_job_template

# Configure logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger("openjd.model.benchmark")


class TestBenchmarkStepEnvironmentsPerformance:
"""Benchmark class to verify performance with large numbers of step environments."""

def test_job_template_with_many_total_step_environments(self):
"""
Benchmark that a job template with many total step environments across multiple steps is processed efficiently.
This test creates steps with many environments each and verifies the processing time.
"""
# Create a job template with multiple steps, each with step environments
num_steps = 100 # Create 100 steps
num_step_envs_per_step = 200 # 200 environments per step

logger.info(
f"CREATING JOB TEMPLATE WITH {num_steps} STEPS AND {num_step_envs_per_step} ENVIRONMENTS PER STEP"
)

steps = []
for step_num in range(num_steps):
steps.append(
{
"name": f"TestStep{step_num}",
"script": {
"actions": {"onRun": {"command": "echo", "args": [f"Step {step_num}"]}}
},
"stepEnvironments": [
{"name": f"stepEnv{step_num}_{i}", "variables": {"key": f"value{i}"}}
for i in range(num_step_envs_per_step)
],
}
)

job_template_with_many_total_envs = {
"specificationVersion": "jobtemplate-2023-09",
"name": "Test Job with Many Total Step Environments",
"steps": steps,
}

logger.info("STARTING JOB TEMPLATE PROCESSING")

# Set up profiler
profiler = cProfile.Profile()
profiler.enable()

start_time = time.time()

try:
# Create a proper JobTemplate object from the dictionary using decode_job_template
job_template = decode_job_template(template=job_template_with_many_total_envs)

# Call create_job with the JobTemplate object
_ = create_job(job_template=job_template, job_parameter_values={})

elapsed_time = time.time() - start_time
logger.info(f"PERFORMANCE RESULT: create_job completed in {elapsed_time:.2f} seconds")

# Disable profiler and print results
profiler.disable()

# Log the top 20 functions by cumulative time
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative")
ps.print_stats(20)
logger.info("TOP 20 FUNCTIONS BY CUMULATIVE TIME:")
for line in s.getvalue().splitlines():
logger.info(line)

# Log the top 20 functions by total time
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats("time")
ps.print_stats(20)
logger.info("TOP 20 FUNCTIONS BY TOTAL TIME:")
for line in s.getvalue().splitlines():
logger.info(line)

# Verify that the operation completed within a reasonable time
assert (
elapsed_time < 10
), f"Operation took {elapsed_time:.2f} seconds, which exceeds the 10 second threshold"
Comment on lines +93 to +96
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How close are we to this threshold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's going to depend on the host which is part of why I've relegated it to "benchmark", but I'm running it in about 4 seconds currently


except Exception as e:
# Disable profiler in case of exception
profiler.disable()

elapsed_time = time.time() - start_time
logger.error(
f"ERROR: create_job failed in {elapsed_time:.2f} seconds with error: {str(e)}"
)

# Log profiling information even in case of failure
s = io.StringIO()
ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative")
ps.print_stats(20)
logger.info("TOP 20 FUNCTIONS BY CUMULATIVE TIME (BEFORE ERROR):")
for line in s.getvalue().splitlines():
logger.info(line)

pytest.fail(f"create_job failed with error: {str(e)}")