Skip to content

Commit b62a4f1

Browse files
authored
fix: Adding a TypeAdapter cache to fix a performance regression with … (#212)
fix: Adding a TypeAdapter cache to fix a performance regression with larger templates. Signed-off-by: Brian Axelson <[email protected]>
1 parent bc43391 commit b62a4f1

File tree

7 files changed

+190
-7
lines changed

7 files changed

+190
-7
lines changed

DEVELOPMENT.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ from any directory of this repository:
3131
* `hatch build` - To build the installable Python wheel and sdist packages into the `dist/` directory.
3232
* `hatch run test` - To run the PyTest unit tests found in the `test/` directory. See [Testing](#testing).
3333
* `hatch run all:test` - To run the PyTest unit tests against all available supported versions of Python.
34+
* `hatch run benchmark` - To run performance benchmark tests found in the `test/openjd/model/benchmark/` directory.
3435
* `hatch run lint` - To check that the package's formatting adheres to our standards.
3536
* `hatch run fmt` - To automatically reformat all code to adhere to our formatting standards.
3637
* `hatch shell` - Enter a shell environment where you can run the `deadline` command-line directly as it is implemented in your
@@ -94,6 +95,17 @@ You can run tests with:
9495
Any arguments that you add to these commands are passed through to PyTest. So, if you want to, say, run the
9596
[Python debugger](https://docs.python.org/3/library/pdb.html) to investigate a test failure then you can run: `hatch run test --pdb`
9697

98+
### Running Benchmarks
99+
100+
Performance benchmark tests are kept separate from the regular test suite and can be run with:
101+
102+
* `hatch run benchmark` - To run all benchmark tests.
103+
* `hatch run benchmark -k <benchmark>` - To run a specific benchmark test.
104+
105+
Benchmarks are designed to measure performance characteristics of the library and may take longer to run than regular tests.
106+
107+
Benchmarks may include log output which can be enabled following the instructions in the test output section below.
108+
97109
### Super verbose test output
98110

99111
If you find that you need much more information from a failing test (say you're debugging a

hatch.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ pre-install-commands = [
55

66
[envs.default.scripts]
77
sync = "pip install -r requirements-testing.txt"
8-
test = "pytest --cov-config pyproject.toml {args:test}"
8+
test = "pytest test/ --cov-config pyproject.toml --ignore=test/openjd/model/benchmark {args}"
9+
benchmark = "pytest test/openjd/model/benchmark --no-cov {args}"
910
typing = "mypy {args:src test}"
1011
style = [
1112
"ruff check {args:src test}",

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@ addopts = [
130130
"--timeout=30"
131131
]
132132

133-
134133
[tool.coverage.run]
135134
branch = true
136135
parallel = true

src/openjd/model/_internal/_create_job.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22

33
from contextlib import contextmanager
4-
from typing import Annotated, Any, Union
4+
from typing import Annotated, Any, Union, Dict
55

66
from pydantic import ValidationError
77
from pydantic import TypeAdapter
@@ -13,6 +13,27 @@
1313

1414
__all__ = ("instantiate_model",)
1515

16+
# Cache for TypeAdapter instances
17+
_type_adapter_cache: Dict[int, TypeAdapter] = {}
18+
19+
20+
def get_type_adapter(field_type: Any) -> TypeAdapter:
21+
"""Get a TypeAdapter for the given field type, using a cache for efficiency.
22+
Assumes field_type refers to shared type definitions from the model so two field_types
23+
referring to the same underlying model type will have the same id and share an adapter.
24+
25+
Args:
26+
field_type: The type to adapt.
27+
28+
Returns:
29+
A TypeAdapter for the given type.
30+
"""
31+
# Use id as cache key
32+
type_key = id(field_type)
33+
if type_key not in _type_adapter_cache:
34+
_type_adapter_cache[type_key] = TypeAdapter(field_type)
35+
return _type_adapter_cache[type_key]
36+
1637

1738
@contextmanager
1839
def capture_validation_errors(
@@ -114,9 +135,9 @@ def instantiate_model( # noqa: C901
114135
else:
115136
instantiated = _instantiate_noncollection_value(field_value, symtab, needs_resolve)
116137

117-
# Validate as the target field type
118-
type_adaptor: Any = TypeAdapter(target_field_type)
119-
instantiated = type_adaptor.validate_python(instantiated)
138+
# Validate as the target field type using cached TypeAdapter
139+
type_adapter = get_type_adapter(target_field_type)
140+
instantiated = type_adapter.validate_python(instantiated)
120141
instantiated_fields[target_field_name] = instantiated
121142

122143
if not errors:

test/openjd/model/_internal/test_create_job.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
from openjd.model import SymbolTable
1111
from openjd.model._format_strings import FormatString
1212
from openjd.model import SpecificationRevision
13-
from openjd.model._internal._create_job import instantiate_model
13+
from openjd.model._internal._create_job import (
14+
instantiate_model,
15+
get_type_adapter,
16+
_type_adapter_cache,
17+
)
1418
from openjd.model._types import (
1519
JobCreateAsMetadata,
1620
JobCreationMetadata,
@@ -19,6 +23,36 @@
1923
from openjd.model.v2023_09 import ModelParsingContext as ModelParsingContext_v2023_09
2024

2125

26+
class TestGetTypeAdapter:
27+
"""Tests for the get_type_adapter function."""
28+
29+
def setup_method(self):
30+
"""Clear the type adapter cache before each test."""
31+
_type_adapter_cache.clear()
32+
33+
def test_get_type_adapter_caching(self):
34+
"""Test that get_type_adapter caches TypeAdapters for repeated calls with the same type."""
35+
# First call should create a new TypeAdapter
36+
str_type = str
37+
str_type_id = id(str_type)
38+
adapter1 = get_type_adapter(str_type)
39+
40+
# Second call should return the cached TypeAdapter
41+
adapter2 = get_type_adapter(str_type)
42+
43+
# The adapters should be the same object (not just equal)
44+
assert adapter1 is adapter2
45+
46+
# The cache should contain the entry
47+
assert str_type_id in _type_adapter_cache
48+
49+
# Different types should get different adapters
50+
int_type = int
51+
adapter_int = get_type_adapter(int_type)
52+
assert adapter_int is not adapter1
53+
assert id(int_type) in _type_adapter_cache
54+
55+
2256
class BaseModelForTesting(OpenJDModel):
2357
# Specific version doesn't matter for these tests
2458
revision = SpecificationRevision.UNDEFINED
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
3+
import time
4+
import cProfile
5+
import pstats
6+
import io
7+
import logging
8+
import pytest
9+
from openjd.model import create_job, decode_job_template
10+
11+
# Configure logging
12+
logging.basicConfig(
13+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
14+
)
15+
logger = logging.getLogger("openjd.model.benchmark")
16+
17+
18+
class TestBenchmarkStepEnvironmentsPerformance:
19+
"""Benchmark class to verify performance with large numbers of step environments."""
20+
21+
def test_job_template_with_many_total_step_environments(self):
22+
"""
23+
Benchmark that a job template with many total step environments across multiple steps is processed efficiently.
24+
25+
This test creates steps with many environments each and verifies the processing time.
26+
"""
27+
# Create a job template with multiple steps, each with step environments
28+
num_steps = 100 # Create 100 steps
29+
num_step_envs_per_step = 200 # 200 environments per step
30+
31+
logger.info(
32+
f"CREATING JOB TEMPLATE WITH {num_steps} STEPS AND {num_step_envs_per_step} ENVIRONMENTS PER STEP"
33+
)
34+
35+
steps = []
36+
for step_num in range(num_steps):
37+
steps.append(
38+
{
39+
"name": f"TestStep{step_num}",
40+
"script": {
41+
"actions": {"onRun": {"command": "echo", "args": [f"Step {step_num}"]}}
42+
},
43+
"stepEnvironments": [
44+
{"name": f"stepEnv{step_num}_{i}", "variables": {"key": f"value{i}"}}
45+
for i in range(num_step_envs_per_step)
46+
],
47+
}
48+
)
49+
50+
job_template_with_many_total_envs = {
51+
"specificationVersion": "jobtemplate-2023-09",
52+
"name": "Test Job with Many Total Step Environments",
53+
"steps": steps,
54+
}
55+
56+
logger.info("STARTING JOB TEMPLATE PROCESSING")
57+
58+
# Set up profiler
59+
profiler = cProfile.Profile()
60+
profiler.enable()
61+
62+
start_time = time.time()
63+
64+
try:
65+
# Create a proper JobTemplate object from the dictionary using decode_job_template
66+
job_template = decode_job_template(template=job_template_with_many_total_envs)
67+
68+
# Call create_job with the JobTemplate object
69+
_ = create_job(job_template=job_template, job_parameter_values={})
70+
71+
elapsed_time = time.time() - start_time
72+
logger.info(f"PERFORMANCE RESULT: create_job completed in {elapsed_time:.2f} seconds")
73+
74+
# Disable profiler and print results
75+
profiler.disable()
76+
77+
# Log the top 20 functions by cumulative time
78+
s = io.StringIO()
79+
ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative")
80+
ps.print_stats(20)
81+
logger.info("TOP 20 FUNCTIONS BY CUMULATIVE TIME:")
82+
for line in s.getvalue().splitlines():
83+
logger.info(line)
84+
85+
# Log the top 20 functions by total time
86+
s = io.StringIO()
87+
ps = pstats.Stats(profiler, stream=s).sort_stats("time")
88+
ps.print_stats(20)
89+
logger.info("TOP 20 FUNCTIONS BY TOTAL TIME:")
90+
for line in s.getvalue().splitlines():
91+
logger.info(line)
92+
93+
# Verify that the operation completed within a reasonable time
94+
assert (
95+
elapsed_time < 10
96+
), f"Operation took {elapsed_time:.2f} seconds, which exceeds the 10 second threshold"
97+
98+
except Exception as e:
99+
# Disable profiler in case of exception
100+
profiler.disable()
101+
102+
elapsed_time = time.time() - start_time
103+
logger.error(
104+
f"ERROR: create_job failed in {elapsed_time:.2f} seconds with error: {str(e)}"
105+
)
106+
107+
# Log profiling information even in case of failure
108+
s = io.StringIO()
109+
ps = pstats.Stats(profiler, stream=s).sort_stats("cumulative")
110+
ps.print_stats(20)
111+
logger.info("TOP 20 FUNCTIONS BY CUMULATIVE TIME (BEFORE ERROR):")
112+
for line in s.getvalue().splitlines():
113+
logger.info(line)
114+
115+
pytest.fail(f"create_job failed with error: {str(e)}")

0 commit comments

Comments
 (0)