⚡️ Speed up function get_canonical_stage by 9%
#132
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 9% (0.09x) speedup for
get_canonical_stageinmlflow/entities/model_registry/model_version_stages.py⏱️ Runtime :
4.80 milliseconds→4.40 milliseconds(best of30runs)📝 Explanation and details
The optimized code achieves a 9% speedup through two key performance optimizations:
1. EAFP (Easier to Ask for Forgiveness than Permission) Pattern
The original code uses
if key not in _CANONICAL_MAPPING:followed by a dictionary lookup, which performs the dictionary search twice for valid keys. The optimized version usestry/except KeyErrorwith direct dictionary access, eliminating the redundant lookup. This is a well-known Python optimization pattern that leverages the fact that dictionary lookups are fast, and exception handling is efficient for uncommon cases.2. f-string vs .format() Method
The optimized code replaces
.format()with f-string syntax (f"Invalid Model Version stage: {stage}..."), which is measurably faster in Python for string formatting operations.Performance Analysis from Line Profiler:
if key not in _CANONICAL_MAPPING:check (line 3 in original) took 11% of total timereturn _CANONICAL_MAPPING[key](line 7 in original) took another 4.3%return _CANONICAL_MAPPING[key]in the try block handles both operations more efficientlyTest Case Performance:
The optimization shows consistent improvements across all test scenarios:
The optimization is particularly effective for invalid inputs because it eliminates the dictionary membership check entirely, going straight to the lookup attempt. For valid inputs, it removes the double lookup penalty. This makes the function more efficient regardless of whether the input stage is valid or invalid.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import pytest # used for our unit tests
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage
function to test
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
STAGE_NONE = "None"
STAGE_STAGING = "Staging"
STAGE_PRODUCTION = "Production"
STAGE_ARCHIVED = "Archived"
ALL_STAGES = [STAGE_NONE, STAGE_STAGING, STAGE_PRODUCTION, STAGE_ARCHIVED]
_CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage
unit tests
-------------------- Basic Test Cases --------------------
@pytest.mark.parametrize(
"input_stage,expected",
[
# Test canonical stages in their canonical form
("None", "None"),
("Staging", "Staging"),
("Production", "Production"),
("Archived", "Archived"),
# Test canonical stages with different casing
("none", "None"),
("staging", "Staging"),
("production", "Production"),
("archived", "Archived"),
("NONE", "None"),
("STAGING", "Staging"),
("PRODUCTION", "Production"),
("ARCHIVED", "Archived"),
# Test mixed case
("nOnE", "None"),
("StAgInG", "Staging"),
("PrOdUcTiOn", "Production"),
("ArChIvEd", "Archived"),
]
)
def test_basic_canonical_and_case_insensitive(input_stage, expected):
# Should return the canonical stage name regardless of input case
codeflash_output = get_canonical_stage(input_stage) # 15.5μs -> 14.1μs (9.97% faster)
-------------------- Edge Test Cases --------------------
@pytest.mark.parametrize(
"input_stage",
[
"", # Empty string
" ", # Space
"none ", # Trailing space
" none", # Leading space
"No ne", # Internal space
"Staging!", # Special character
"Staging.", # Special character
"Stagings", # Plural
"Prod", # Abbreviation
"Stage", # Partial match
"archivedd", # Typo
"123", # Numeric
"NoneNone", # Repeated valid
"StagingProduction", # Concatenated valid
"NONEE", # Typo
None, # NoneType
123, # Integer
12.34, # Float
["Staging"], # List
{"stage": "Staging"}, # Dict
True, # Boolean
False, # Boolean
]
)
def test_edge_invalid_inputs(input_stage):
# Should raise MlflowException for invalid stage values
with pytest.raises(MlflowException) as excinfo:
get_canonical_stage(input_stage) # 135μs -> 118μs (14.6% faster)
# Check that the exception message contains the invalid value if possible
if isinstance(input_stage, str):
pass
def test_edge_leading_trailing_whitespace():
# Should not accept valid stage names with leading/trailing whitespace
for stage in ALL_STAGES:
with pytest.raises(MlflowException):
get_canonical_stage(" " + stage)
with pytest.raises(MlflowException):
get_canonical_stage(stage + " ")
with pytest.raises(MlflowException):
get_canonical_stage(" " + stage + " ")
def test_edge_non_string_types():
# Should raise for non-string input types
for value in [None, 123, 12.34, True, False, ["Staging"], {"stage": "Staging"}]:
with pytest.raises(Exception):
get_canonical_stage(value)
-------------------- Large Scale Test Cases --------------------
def test_large_scale_all_valid_stages():
# Test a large batch of valid stages with random casing
import random
import string
def test_large_scale_all_invalid_stages():
# Test a large batch of invalid stages
invalids = []
# Generate 1000 unique invalid strings
for i in range(1000):
# Use a mix of random letters, numbers, and valid stage names with typos
base = random.choice(["none", "staging", "production", "archived"])
typo = base + random.choice(["x", "y", "z", "1", "2", "!", " "])
invalids.append(typo + str(i))
# Add some completely random strings
for i in range(100):
invalids.append("".join(random.choices(string.ascii_letters + string.digits, k=10)))
def test_large_scale_performance_valid():
# Test that function performs efficiently for large input batches
import time
stages = ["None", "Staging", "Production", "Archived"] * 250 # 1000 elements
start = time.time()
results = [get_canonical_stage(s) for s in stages]
end = time.time()
def test_large_scale_performance_invalid():
# Test that function performs efficiently for large invalid input batches
import time
invalids = ["noone", "stagingg", "prod", "archivedd"] * 250 # 1000 elements
start = time.time()
for s in invalids:
with pytest.raises(MlflowException):
get_canonical_stage(s)
end = time.time()
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import pytest # used for our unit tests
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage
function to test
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
STAGE_NONE = "None"
STAGE_STAGING = "Staging"
STAGE_PRODUCTION = "Production"
STAGE_ARCHIVED = "Archived"
ALL_STAGES = [STAGE_NONE, STAGE_STAGING, STAGE_PRODUCTION, STAGE_ARCHIVED]
_CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage
unit tests
=========================
Basic Test Cases
=========================
@pytest.mark.parametrize(
"input_stage,expected",
[
# Test canonical stages in original case
("None", "None"),
("Staging", "Staging"),
("Production", "Production"),
("Archived", "Archived"),
# Test canonical stages in lower case
("none", "None"),
("staging", "Staging"),
("production", "Production"),
("archived", "Archived"),
# Test canonical stages in upper case
("NONE", "None"),
("STAGING", "Staging"),
("PRODUCTION", "Production"),
("ARCHIVED", "Archived"),
# Test canonical stages in mixed case
("NoNe", "None"),
("StaGinG", "Staging"),
("pROdUctiOn", "Production"),
("arCHived", "Archived"),
]
)
def test_get_canonical_stage_basic(input_stage, expected):
# Basic functionality: should return the canonical stage name for all valid inputs
codeflash_output = get_canonical_stage(input_stage) # 15.3μs -> 14.1μs (8.54% faster)
=========================
Edge Test Cases
=========================
@pytest.mark.parametrize(
"input_stage",
[
"", # Empty string
" ", # Single space
" none", # Leading spaces
"none ", # Trailing spaces
" staging ", # Leading/trailing spaces
"Staging\n", # Trailing newline
"Staging\t", # Trailing tab
"Stagi ng", # Internal space
"STAGINGS", # Plural
"STAGE", # Partial match
"archivedd", # Extra character
"prod", # Abbreviation
"NoneType", # Similar to valid
"none\n", # Newline
"None\t", # Tab
"none.", # Punctuation
"none!", # Punctuation
"none_", # Underscore
"none-", # Hyphen
"none123", # Numbers appended
"123none", # Numbers prepended
"NONE ", # Trailing space
" NONE", # Leading space
"NONE\n", # Trailing newline
"\nNONE", # Leading newline
"NoneNone", # Repeated
None, # NoneType input
123, # Integer input
1.23, # Float input
[], # List input
{}, # Dict input
True, # Boolean input
False, # Boolean input
object(), # Arbitrary object
]
)
def test_get_canonical_stage_invalid(input_stage):
# Edge case: invalid input should raise MlflowException
# For non-string input, should raise AttributeError before MlflowException
if not isinstance(input_stage, str):
with pytest.raises(AttributeError):
get_canonical_stage(input_stage)
else:
with pytest.raises(MlflowException) as excinfo:
get_canonical_stage(input_stage)
for stage in ALL_STAGES:
pass
@pytest.mark.parametrize(
"input_stage",
[
"none", # Lowercase valid
"NONE", # Uppercase valid
"None", # Canonical valid
]
)
def test_get_canonical_stage_return_type(input_stage):
# Should always return a string for valid input
codeflash_output = get_canonical_stage(input_stage); result = codeflash_output # 2.78μs -> 2.52μs (10.3% faster)
=========================
Large Scale Test Cases
=========================
def test_get_canonical_stage_large_valid():
# Test many valid inputs in a batch
inputs = []
expected = []
# Generate 250 of each valid stage in random case
for stage in ALL_STAGES:
for i in range(250):
# Randomize case (alternating upper/lower for each char)
s = "".join(
c.upper() if j % 2 == 0 else c.lower()
for j, c in enumerate(stage)
)
inputs.append(s)
expected.append(stage)
# Check all results
for inp, exp in zip(inputs, expected):
codeflash_output = get_canonical_stage(inp) # 198μs -> 184μs (7.71% faster)
def test_get_canonical_stage_large_invalid():
# Test 1000 invalid inputs (all should raise)
invalid_inputs = [
f"invalid_stage_{i}" for i in range(1000)
]
for inp in invalid_inputs:
with pytest.raises(MlflowException):
get_canonical_stage(inp)
def test_get_canonical_stage_performance():
# Performance: ensure function is fast for 1000 valid and 1000 invalid inputs
import time
valid_inputs = [stage for stage in ALL_STAGES for _ in range(250)]
invalid_inputs = [f"invalid_{i}" for i in range(1000)]
start = time.time()
for inp in valid_inputs:
codeflash_output = get_canonical_stage(inp) # 202μs -> 185μs (9.06% faster)
for inp in invalid_inputs:
with pytest.raises(MlflowException):
get_canonical_stage(inp)
elapsed = time.time() - start
=========================
Additional Edge Cases
=========================
def test_get_canonical_stage_case_insensitivity():
# Case insensitivity: all combinations of case for a valid stage
stage = "Production"
variants = [
"production", "PRODUCTION", "PrOdUcTiOn", "pRODUCTION", "PROductioN"
]
for variant in variants:
codeflash_output = get_canonical_stage(variant) # 2.01μs -> 1.82μs (10.5% faster)
def test_get_canonical_stage_unicode():
# Unicode characters: should raise for accented or non-ASCII
invalid_inputs = [
"Prødüction", # accented
"Stagingé", # accented
"Архивед", # Cyrillic
"生产", # Chinese
"None😊", # emoji
]
for inp in invalid_inputs:
with pytest.raises(MlflowException):
get_canonical_stage(inp)
def test_get_canonical_stage_whitespace_only():
# Only whitespace: should raise
for ws in [" ", "\t", "\n", " ", "\n\t"]:
with pytest.raises(MlflowException):
get_canonical_stage(ws)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-get_canonical_stage-mhul3nkeand push.