Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 11, 2025

📄 9% (0.09x) speedup for get_canonical_stage in mlflow/entities/model_registry/model_version_stages.py

⏱️ Runtime : 4.80 milliseconds 4.40 milliseconds (best of 30 runs)

📝 Explanation and details

The optimized code achieves a 9% speedup through two key performance optimizations:

1. EAFP (Easier to Ask for Forgiveness than Permission) Pattern
The original code uses if key not in _CANONICAL_MAPPING: followed by a dictionary lookup, which performs the dictionary search twice for valid keys. The optimized version uses try/except KeyError with direct dictionary access, eliminating the redundant lookup. This is a well-known Python optimization pattern that leverages the fact that dictionary lookups are fast, and exception handling is efficient for uncommon cases.

2. f-string vs .format() Method
The optimized code replaces .format() with f-string syntax (f"Invalid Model Version stage: {stage}..."), which is measurably faster in Python for string formatting operations.

Performance Analysis from Line Profiler:

  • The if key not in _CANONICAL_MAPPING: check (line 3 in original) took 11% of total time
  • The subsequent return _CANONICAL_MAPPING[key] (line 7 in original) took another 4.3%
  • In the optimized version, the single return _CANONICAL_MAPPING[key] in the try block handles both operations more efficiently

Test Case Performance:
The optimization shows consistent improvements across all test scenarios:

  • Basic valid inputs: ~8-10% faster (15.5μs → 14.1μs)
  • Invalid inputs: ~14.6% faster (135μs → 118μs)
  • Large-scale operations: ~7-10% faster (198μs → 184μs)

The optimization is particularly effective for invalid inputs because it eliminates the dictionary membership check entirely, going straight to the lookup attempt. For valid inputs, it removes the double lookup penalty. This makes the function more efficient regardless of whether the input stage is valid or invalid.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 6125 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Generated Regression Tests and Runtime

import pytest # used for our unit tests
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage

function to test

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE

STAGE_NONE = "None"
STAGE_STAGING = "Staging"
STAGE_PRODUCTION = "Production"
STAGE_ARCHIVED = "Archived"

ALL_STAGES = [STAGE_NONE, STAGE_STAGING, STAGE_PRODUCTION, STAGE_ARCHIVED]
_CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage

unit tests

-------------------- Basic Test Cases --------------------

@pytest.mark.parametrize(
"input_stage,expected",
[
# Test canonical stages in their canonical form
("None", "None"),
("Staging", "Staging"),
("Production", "Production"),
("Archived", "Archived"),
# Test canonical stages with different casing
("none", "None"),
("staging", "Staging"),
("production", "Production"),
("archived", "Archived"),
("NONE", "None"),
("STAGING", "Staging"),
("PRODUCTION", "Production"),
("ARCHIVED", "Archived"),
# Test mixed case
("nOnE", "None"),
("StAgInG", "Staging"),
("PrOdUcTiOn", "Production"),
("ArChIvEd", "Archived"),
]
)
def test_basic_canonical_and_case_insensitive(input_stage, expected):
# Should return the canonical stage name regardless of input case
codeflash_output = get_canonical_stage(input_stage) # 15.5μs -> 14.1μs (9.97% faster)

-------------------- Edge Test Cases --------------------

@pytest.mark.parametrize(
"input_stage",
[
"", # Empty string
" ", # Space
"none ", # Trailing space
" none", # Leading space
"No ne", # Internal space
"Staging!", # Special character
"Staging.", # Special character
"Stagings", # Plural
"Prod", # Abbreviation
"Stage", # Partial match
"archivedd", # Typo
"123", # Numeric
"NoneNone", # Repeated valid
"StagingProduction", # Concatenated valid
"NONEE", # Typo
None, # NoneType
123, # Integer
12.34, # Float
["Staging"], # List
{"stage": "Staging"}, # Dict
True, # Boolean
False, # Boolean
]
)
def test_edge_invalid_inputs(input_stage):
# Should raise MlflowException for invalid stage values
with pytest.raises(MlflowException) as excinfo:
get_canonical_stage(input_stage) # 135μs -> 118μs (14.6% faster)
# Check that the exception message contains the invalid value if possible
if isinstance(input_stage, str):
pass

def test_edge_leading_trailing_whitespace():
# Should not accept valid stage names with leading/trailing whitespace
for stage in ALL_STAGES:
with pytest.raises(MlflowException):
get_canonical_stage(" " + stage)
with pytest.raises(MlflowException):
get_canonical_stage(stage + " ")
with pytest.raises(MlflowException):
get_canonical_stage(" " + stage + " ")

def test_edge_non_string_types():
# Should raise for non-string input types
for value in [None, 123, 12.34, True, False, ["Staging"], {"stage": "Staging"}]:
with pytest.raises(Exception):
get_canonical_stage(value)

-------------------- Large Scale Test Cases --------------------

def test_large_scale_all_valid_stages():
# Test a large batch of valid stages with random casing
import random
import string

def random_case(s):
    # Randomly change the case of each character
    return "".join(random.choice([c.upper(), c.lower()]) for c in s)

stages = []
expected = []
for stage in ALL_STAGES:
    for _ in range(250):  # 4 stages x 250 = 1000
        rc_stage = random_case(stage)
        stages.append(rc_stage)
        expected.append(stage)

# Test all at once
for inp, exp in zip(stages, expected):
    codeflash_output = get_canonical_stage(inp) # 203μs -> 185μs (10.2% faster)

def test_large_scale_all_invalid_stages():
# Test a large batch of invalid stages
invalids = []
# Generate 1000 unique invalid strings
for i in range(1000):
# Use a mix of random letters, numbers, and valid stage names with typos
base = random.choice(["none", "staging", "production", "archived"])
typo = base + random.choice(["x", "y", "z", "1", "2", "!", " "])
invalids.append(typo + str(i))
# Add some completely random strings
for i in range(100):
invalids.append("".join(random.choices(string.ascii_letters + string.digits, k=10)))

# Test all invalids
for invalid in invalids:
    with pytest.raises(MlflowException):
        get_canonical_stage(invalid)

def test_large_scale_performance_valid():
# Test that function performs efficiently for large input batches
import time
stages = ["None", "Staging", "Production", "Archived"] * 250 # 1000 elements
start = time.time()
results = [get_canonical_stage(s) for s in stages]
end = time.time()

def test_large_scale_performance_invalid():
# Test that function performs efficiently for large invalid input batches
import time
invalids = ["noone", "stagingg", "prod", "archivedd"] * 250 # 1000 elements
start = time.time()
for s in invalids:
with pytest.raises(MlflowException):
get_canonical_stage(s)
end = time.time()

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

#------------------------------------------------
import pytest # used for our unit tests
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage

function to test

from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE

STAGE_NONE = "None"
STAGE_STAGING = "Staging"
STAGE_PRODUCTION = "Production"
STAGE_ARCHIVED = "Archived"

ALL_STAGES = [STAGE_NONE, STAGE_STAGING, STAGE_PRODUCTION, STAGE_ARCHIVED]
_CANONICAL_MAPPING = {stage.lower(): stage for stage in ALL_STAGES}
from mlflow.entities.model_registry.model_version_stages import
get_canonical_stage

unit tests

=========================

Basic Test Cases

=========================

@pytest.mark.parametrize(
"input_stage,expected",
[
# Test canonical stages in original case
("None", "None"),
("Staging", "Staging"),
("Production", "Production"),
("Archived", "Archived"),
# Test canonical stages in lower case
("none", "None"),
("staging", "Staging"),
("production", "Production"),
("archived", "Archived"),
# Test canonical stages in upper case
("NONE", "None"),
("STAGING", "Staging"),
("PRODUCTION", "Production"),
("ARCHIVED", "Archived"),
# Test canonical stages in mixed case
("NoNe", "None"),
("StaGinG", "Staging"),
("pROdUctiOn", "Production"),
("arCHived", "Archived"),
]
)
def test_get_canonical_stage_basic(input_stage, expected):
# Basic functionality: should return the canonical stage name for all valid inputs
codeflash_output = get_canonical_stage(input_stage) # 15.3μs -> 14.1μs (8.54% faster)

=========================

Edge Test Cases

=========================

@pytest.mark.parametrize(
"input_stage",
[
"", # Empty string
" ", # Single space
" none", # Leading spaces
"none ", # Trailing spaces
" staging ", # Leading/trailing spaces
"Staging\n", # Trailing newline
"Staging\t", # Trailing tab
"Stagi ng", # Internal space
"STAGINGS", # Plural
"STAGE", # Partial match
"archivedd", # Extra character
"prod", # Abbreviation
"NoneType", # Similar to valid
"none\n", # Newline
"None\t", # Tab
"none.", # Punctuation
"none!", # Punctuation
"none_", # Underscore
"none-", # Hyphen
"none123", # Numbers appended
"123none", # Numbers prepended
"NONE ", # Trailing space
" NONE", # Leading space
"NONE\n", # Trailing newline
"\nNONE", # Leading newline
"NoneNone", # Repeated
None, # NoneType input
123, # Integer input
1.23, # Float input
[], # List input
{}, # Dict input
True, # Boolean input
False, # Boolean input
object(), # Arbitrary object
]
)
def test_get_canonical_stage_invalid(input_stage):
# Edge case: invalid input should raise MlflowException
# For non-string input, should raise AttributeError before MlflowException
if not isinstance(input_stage, str):
with pytest.raises(AttributeError):
get_canonical_stage(input_stage)
else:
with pytest.raises(MlflowException) as excinfo:
get_canonical_stage(input_stage)
for stage in ALL_STAGES:
pass

@pytest.mark.parametrize(
"input_stage",
[
"none", # Lowercase valid
"NONE", # Uppercase valid
"None", # Canonical valid
]
)
def test_get_canonical_stage_return_type(input_stage):
# Should always return a string for valid input
codeflash_output = get_canonical_stage(input_stage); result = codeflash_output # 2.78μs -> 2.52μs (10.3% faster)

=========================

Large Scale Test Cases

=========================

def test_get_canonical_stage_large_valid():
# Test many valid inputs in a batch
inputs = []
expected = []
# Generate 250 of each valid stage in random case
for stage in ALL_STAGES:
for i in range(250):
# Randomize case (alternating upper/lower for each char)
s = "".join(
c.upper() if j % 2 == 0 else c.lower()
for j, c in enumerate(stage)
)
inputs.append(s)
expected.append(stage)
# Check all results
for inp, exp in zip(inputs, expected):
codeflash_output = get_canonical_stage(inp) # 198μs -> 184μs (7.71% faster)

def test_get_canonical_stage_large_invalid():
# Test 1000 invalid inputs (all should raise)
invalid_inputs = [
f"invalid_stage_{i}" for i in range(1000)
]
for inp in invalid_inputs:
with pytest.raises(MlflowException):
get_canonical_stage(inp)

def test_get_canonical_stage_performance():
# Performance: ensure function is fast for 1000 valid and 1000 invalid inputs
import time
valid_inputs = [stage for stage in ALL_STAGES for _ in range(250)]
invalid_inputs = [f"invalid_{i}" for i in range(1000)]
start = time.time()
for inp in valid_inputs:
codeflash_output = get_canonical_stage(inp) # 202μs -> 185μs (9.06% faster)
for inp in invalid_inputs:
with pytest.raises(MlflowException):
get_canonical_stage(inp)
elapsed = time.time() - start

=========================

Additional Edge Cases

=========================

def test_get_canonical_stage_case_insensitivity():
# Case insensitivity: all combinations of case for a valid stage
stage = "Production"
variants = [
"production", "PRODUCTION", "PrOdUcTiOn", "pRODUCTION", "PROductioN"
]
for variant in variants:
codeflash_output = get_canonical_stage(variant) # 2.01μs -> 1.82μs (10.5% faster)

def test_get_canonical_stage_unicode():
# Unicode characters: should raise for accented or non-ASCII
invalid_inputs = [
"Prødüction", # accented
"Stagingé", # accented
"Архивед", # Cyrillic
"生产", # Chinese
"None😊", # emoji
]
for inp in invalid_inputs:
with pytest.raises(MlflowException):
get_canonical_stage(inp)

def test_get_canonical_stage_whitespace_only():
# Only whitespace: should raise
for ws in [" ", "\t", "\n", " ", "\n\t"]:
with pytest.raises(MlflowException):
get_canonical_stage(ws)

codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-get_canonical_stage-mhul3nke and push.

Codeflash Static Badge

The optimized code achieves a **9% speedup** through two key performance optimizations:

**1. EAFP (Easier to Ask for Forgiveness than Permission) Pattern**
The original code uses `if key not in _CANONICAL_MAPPING:` followed by a dictionary lookup, which performs the dictionary search twice for valid keys. The optimized version uses `try/except KeyError` with direct dictionary access, eliminating the redundant lookup. This is a well-known Python optimization pattern that leverages the fact that dictionary lookups are fast, and exception handling is efficient for uncommon cases.

**2. f-string vs .format() Method**  
The optimized code replaces `.format()` with f-string syntax (`f"Invalid Model Version stage: {stage}..."`), which is measurably faster in Python for string formatting operations.

**Performance Analysis from Line Profiler:**
- The `if key not in _CANONICAL_MAPPING:` check (line 3 in original) took 11% of total time
- The subsequent `return _CANONICAL_MAPPING[key]` (line 7 in original) took another 4.3% 
- In the optimized version, the single `return _CANONICAL_MAPPING[key]` in the try block handles both operations more efficiently

**Test Case Performance:**
The optimization shows consistent improvements across all test scenarios:
- Basic valid inputs: ~8-10% faster (15.5μs → 14.1μs)
- Invalid inputs: ~14.6% faster (135μs → 118μs) 
- Large-scale operations: ~7-10% faster (198μs → 184μs)

The optimization is particularly effective for invalid inputs because it eliminates the dictionary membership check entirely, going straight to the lookup attempt. For valid inputs, it removes the double lookup penalty. This makes the function more efficient regardless of whether the input stage is valid or invalid.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 11, 2025 13:05
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash labels Nov 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: Medium Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant