⚡️ Speed up function _find_validator by 44%
#142
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 44% (0.44x) speedup for
_find_validatorinmlflow/server/auth/__init__.py⏱️ Runtime :
1.08 milliseconds→752 microseconds(best of47runs)📝 Explanation and details
The optimization achieves a 44% speedup by eliminating expensive operations in the logged model validation path, which dominated the original execution time at 82.5% of total runtime.
Key optimizations:
Local variable bindings for global lookups: Caching
LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS,req.path, andreq.methodin local variables eliminates repeated global dictionary and attribute lookups during iteration.Replaced expensive
next()generator with explicit loop: The original usednext()with a generator expression that created overhead. The optimized version uses a directforloop over.items(), which is more efficient for iteration patterns.Method-first short-circuiting: Reordering the condition to check
m == methodbeforepat.fullmatch(path)avoids expensive regex matching when methods don't match. Since method comparison is O(1) while regex matching is costly, this provides significant savings.Performance impact by test category:
The optimization is particularly effective because logged model validation was the performance bottleneck, spending most time on regex operations that can now be avoided through better short-circuiting. The changes maintain identical behavior while dramatically reducing computational overhead in the critical path.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
import re
from typing import Callable
imports
import pytest
from mlflow.server.auth.init import _find_validator
--- Dummy classes and handlers to simulate the environment ---
class DummyRequest:
"""A minimal replacement for flask.Request for testing."""
def init(self, path: str, method: str):
self.path = path
self.method = method
Dummy handler functions to use as validators
def handler_a(): return "A"
def handler_b(): return "B"
def handler_c(): return "C"
def handler_d(): return "D"
def handler_e(): return "E"
Helper to simulate _re_compile_path (for logged model paths)
def _re_compile_path(pattern: str):
"""Simulates Flask's path regex compilation for logged model paths."""
# Convert Flask-style to regex
regex = re.sub(r"<[^>]+>", r"[^/]+", pattern)
return re.compile(f"^{regex}$")
--- Simulated validator dictionaries ---
BEFORE_REQUEST_VALIDATORS = {
# Basic routes
("/api/experiment", "GET"): handler_a,
("/api/experiment", "POST"): handler_b,
("/api/experiment/42", "DELETE"): handler_c,
# Edge: similar but not exact
("/api/experiment/42", "GET"): handler_d,
# Edge: empty path
("", "GET"): handler_e,
}
LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS = {
# Pattern for logged models with parameterized path
(_re_compile_path("/mlflow/logged-models/<model_id>"), "GET"): handler_a,
(_re_compile_path("/mlflow/logged-models/<model_id>/tags"), "POST"): handler_b,
(_re_compile_path("/mlflow/logged-models/<model_id>/tags/<tag_id>"), "DELETE"): handler_c,
(_re_compile_path("/mlflow/logged-models/<model_id>/finalize"), "PUT"): handler_d,
# Edge: pattern with no parameter
(_re_compile_path("/mlflow/logged-models/"), "GET"): handler_e,
}
from mlflow.server.auth.init import _find_validator
--- Unit tests ---
1. Basic Test Cases
def test_basic_exact_match():
"""Should return correct handler for exact path/method match (experiment GET)."""
req = DummyRequest("/api/experiment", "GET")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.14μs -> 951ns (19.6% faster)
def test_basic_post_match():
"""Should return correct handler for exact path/method match (experiment POST)."""
req = DummyRequest("/api/experiment", "POST")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.10μs -> 1.01μs (9.22% faster)
def test_basic_delete_match():
"""Should return correct handler for exact path/method match (experiment/42 DELETE)."""
req = DummyRequest("/api/experiment/42", "DELETE")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.08μs -> 1.07μs (1.59% faster)
def test_basic_logged_model_get():
"""Should match logged model GET with parameterized path."""
req = DummyRequest("/mlflow/logged-models/abc123", "GET")
codeflash_output = _find_validator(req); validator = codeflash_output # 38.7μs -> 22.9μs (69.1% faster)
def test_basic_logged_model_post_tags():
"""Should match logged model POST for tags with parameterized path."""
req = DummyRequest("/mlflow/logged-models/xyz/tags", "POST")
codeflash_output = _find_validator(req); validator = codeflash_output # 36.5μs -> 25.6μs (42.6% faster)
2. Edge Test Cases
def test_edge_no_match_wrong_method():
"""Should return None if method does not match (experiment/42 PUT)."""
req = DummyRequest("/api/experiment/42", "PUT")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.11μs -> 1.01μs (10.1% faster)
def test_edge_no_match_wrong_path():
"""Should return None if path does not match (unknown path)."""
req = DummyRequest("/api/unknown", "GET")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.15μs -> 956ns (20.6% faster)
def test_edge_similar_path_not_match():
"""Should not match if path is similar but not exact."""
req = DummyRequest("/api/experiment/43", "DELETE")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.08μs -> 972ns (10.7% faster)
def test_edge_logged_model_tag_delete():
"""Should match logged model tag DELETE with parameterized path."""
req = DummyRequest("/mlflow/logged-models/abc/tags/123", "DELETE")
codeflash_output = _find_validator(req); validator = codeflash_output # 36.6μs -> 18.3μs (99.7% faster)
def test_edge_logged_model_finalize_put():
"""Should match logged model finalize PUT with parameterized path."""
req = DummyRequest("/mlflow/logged-models/abc/finalize", "PUT")
codeflash_output = _find_validator(req); validator = codeflash_output # 37.5μs -> 14.2μs (164% faster)
def test_edge_logged_model_no_param():
"""Should match logged model GET for path with no param."""
req = DummyRequest("/mlflow/logged-models/", "GET")
codeflash_output = _find_validator(req); validator = codeflash_output # 36.2μs -> 22.3μs (62.2% faster)
def test_edge_empty_path():
"""Should match empty path if registered."""
req = DummyRequest("", "GET")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.15μs -> 897ns (28.5% faster)
def test_edge_case_sensitive_method():
"""Should not match if method case is wrong."""
req = DummyRequest("/api/experiment", "get")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.09μs -> 1.01μs (8.12% faster)
def test_edge_case_sensitive_path():
"""Should not match if path case is wrong."""
req = DummyRequest("/API/EXPERIMENT", "GET")
codeflash_output = _find_validator(req); validator = codeflash_output # 1.10μs -> 980ns (12.0% faster)
def test_edge_logged_model_extra_path_segment():
"""Should not match if extra path segment is present."""
req = DummyRequest("/mlflow/logged-models/abc123/extra", "GET")
codeflash_output = _find_validator(req); validator = codeflash_output # 36.8μs -> 22.7μs (61.9% faster)
def test_edge_logged_model_missing_param():
"""Should not match if parameter is missing."""
req = DummyRequest("/mlflow/logged-models/", "POST")
codeflash_output = _find_validator(req); validator = codeflash_output # 36.0μs -> 25.6μs (40.3% faster)
def test_edge_logged_model_partial_match():
"""Should not match if only part of the path matches."""
req = DummyRequest("/mlflow/logged-models/abc", "POST")
codeflash_output = _find_validator(req); validator = codeflash_output # 36.1μs -> 24.9μs (44.7% faster)
3. Large Scale Test Cases
def test_large_scale_many_before_request_validators():
"""Test scalability with a large BEFORE_REQUEST_VALIDATORS dict."""
# Create 1000 unique paths
large_dict = {(f"/api/exp/{i}", "GET"): (lambda i=i: i) for i in range(1000)}
global BEFORE_REQUEST_VALIDATORS
old_dict = BEFORE_REQUEST_VALIDATORS
BEFORE_REQUEST_VALIDATORS = large_dict
def test_large_scale_many_logged_model_validators():
"""Test scalability with a large LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS dict."""
# Create 1000 regex patterns for logged models
large_dict = {
(_re_compile_path(f"/mlflow/logged-models/{i}"), "GET"): (lambda i=i: i)
for i in range(1000)
}
global LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS
old_dict = LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS
LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS = large_dict
def test_large_scale_logged_model_tag_post():
"""Test performance with many tag POST patterns."""
# Create 500 tag POST patterns
large_dict = {
(_re_compile_path(f"/mlflow/logged-models/{i}/tags"), "POST"): (lambda i=i: i)
for i in range(500)
}
global LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS
old_dict = LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS
LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS = large_dict
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
import re
from typing import Callable
imports
import pytest
from mlflow.server.auth.init import _find_validator
--- Minimal stubs and test setup ---
Dummy validator functions for testing
def validator_a(): return "A"
def validator_b(): return "B"
def validator_c(): return "C"
def validator_d(): return "D"
def validator_e(): return "E"
Dummy Request class for testing
class DummyRequest:
def init(self, path: str, method: str):
self.path = path
self.method = method
Helper to simulate _re_compile_path (returns a compiled regex)
def _re_compile_path(path: str):
# Convert Flask-style path '/mlflow/logged-models/<model_id>' to regex
# Replace with [^/]+
regex = re.sub(r"<[^>]+>", r"[^/]+", path)
return re.compile(f"^{regex}$")
--- Simulated global validator mappings ---
BEFORE_REQUEST_VALIDATORS: {(path, method): validator}
BEFORE_REQUEST_VALIDATORS = {
("/api/experiments/get", "GET"): validator_a,
("/api/experiments/delete", "POST"): validator_b,
("/api/model/rename", "PUT"): validator_c,
("/api/model/delete", "DELETE"): validator_d,
# For edge cases
("", "GET"): validator_e,
}
LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS: {(compiled_regex, method): validator}
LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS = {
(_re_compile_path("/mlflow/logged-models/<model_id>"), "GET"): validator_a,
(_re_compile_path("/mlflow/logged-models/<model_id>/tags"), "POST"): validator_b,
(_re_compile_path("/mlflow/logged-models/<model_id>/delete"), "DELETE"): validator_c,
(_re_compile_path("/mlflow/logged-models/<model_id>/finalize"), "PUT"): validator_d,
# For edge cases
(_re_compile_path("/mlflow/logged-models/"), "GET"): validator_e,
}
from mlflow.server.auth.init import _find_validator
--- Unit tests ---
1. Basic Test Cases
def test_basic_exact_match():
# Should find validator_a for GET /api/experiments/get
req = DummyRequest("/api/experiments/get", "GET")
codeflash_output = _find_validator(req) # 1.12μs -> 1.04μs (7.28% faster)
def test_basic_different_method():
# Should find validator_b for POST /api/experiments/delete
req = DummyRequest("/api/experiments/delete", "POST")
codeflash_output = _find_validator(req) # 1.18μs -> 1.05μs (12.6% faster)
def test_basic_logged_model_regex_match():
# Should match regex and find validator_a for GET /mlflow/logged-models/123
req = DummyRequest("/mlflow/logged-models/123", "GET")
codeflash_output = _find_validator(req) # 34.8μs -> 21.9μs (58.8% faster)
def test_basic_logged_model_different_method():
# Should match regex and find validator_b for POST /mlflow/logged-models/456/tags
req = DummyRequest("/mlflow/logged-models/456/tags", "POST")
codeflash_output = _find_validator(req) # 36.0μs -> 25.2μs (43.0% faster)
def test_basic_logged_model_delete():
# Should match regex and find validator_c for DELETE /mlflow/logged-models/789/delete
req = DummyRequest("/mlflow/logged-models/789/delete", "DELETE")
codeflash_output = _find_validator(req) # 35.6μs -> 18.9μs (87.9% faster)
def test_basic_logged_model_finalize():
# Should match regex and find validator_d for PUT /mlflow/logged-models/789/finalize
req = DummyRequest("/mlflow/logged-models/789/finalize", "PUT")
codeflash_output = _find_validator(req) # 36.0μs -> 14.3μs (151% faster)
def test_basic_no_match_returns_none():
# Should return None for unmatched path/method
req = DummyRequest("/unknown/path", "GET")
codeflash_output = _find_validator(req) # 1.15μs -> 1.01μs (14.3% faster)
def test_basic_logged_model_no_match_returns_none():
# Should return None for unmatched logged model path/method
req = DummyRequest("/mlflow/logged-models/123", "POST") # POST not mapped for this path
codeflash_output = _find_validator(req) # 35.3μs -> 23.8μs (48.5% faster)
2. Edge Test Cases
def test_edge_empty_path():
# Should find validator_e for empty path and GET method
req = DummyRequest("", "GET")
codeflash_output = _find_validator(req) # 1.09μs -> 966ns (12.7% faster)
def test_edge_logged_model_empty_id():
# Should match regex for /mlflow/logged-models/ and GET
req = DummyRequest("/mlflow/logged-models/", "GET")
codeflash_output = _find_validator(req) # 33.5μs -> 21.9μs (52.7% faster)
def test_edge_case_sensitive_method():
# Should not match if method is wrong case
req = DummyRequest("/api/experiments/get", "get") # Lowercase
codeflash_output = _find_validator(req) # 1.07μs -> 953ns (12.4% faster)
def test_edge_case_sensitive_path():
# Should not match if path is wrong case
req = DummyRequest("/API/EXPERIMENTS/GET", "GET")
codeflash_output = _find_validator(req) # 1.07μs -> 1.02μs (5.20% faster)
def test_edge_partial_path_match():
# Should not match if path is only partially correct
req = DummyRequest("/api/experiments/get/extra", "GET")
codeflash_output = _find_validator(req) # 1.23μs -> 1.13μs (8.79% faster)
def test_edge_logged_model_extra_slash():
# Should not match if extra slash at end
req = DummyRequest("/mlflow/logged-models/123/", "GET")
codeflash_output = _find_validator(req) # 35.2μs -> 22.2μs (58.3% faster)
def test_edge_logged_model_id_with_special_chars():
# Should match even if id contains special chars (except '/')
req = DummyRequest("/mlflow/logged-models/model_!@#", "GET")
codeflash_output = _find_validator(req) # 35.7μs -> 22.7μs (56.9% faster)
def test_edge_logged_model_id_with_slash_should_not_match():
# Should not match if id contains a slash
req = DummyRequest("/mlflow/logged-models/12/34", "GET")
codeflash_output = _find_validator(req) # 35.0μs -> 22.3μs (57.2% faster)
def test_edge_method_is_none():
# Should not match if method is None
req = DummyRequest("/api/experiments/get", None)
codeflash_output = _find_validator(req) # 1.07μs -> 1.08μs (1.02% slower)
def test_edge_method_is_empty_string():
# Should not match if method is empty string
req = DummyRequest("/api/experiments/get", "")
codeflash_output = _find_validator(req) # 980ns -> 1.05μs (6.67% slower)
def test_edge_logged_model_method_is_empty_string():
# Should not match if method is empty string
req = DummyRequest("/mlflow/logged-models/123", "")
codeflash_output = _find_validator(req) # 33.2μs -> 12.7μs (162% faster)
3. Large Scale Test Cases
def test_large_scale_many_before_request_validators():
# Create 1000 entries in BEFORE_REQUEST_VALIDATORS
large_validators = {}
for i in range(1000):
path = f"/api/experiments/{i}"
method = "GET"
large_validators[(path, method)] = lambda i=i: i
def test_large_scale_many_logged_model_validators():
# Create 1000 regex entries in LOGGED_MODEL_BEFORE_REQUEST_VALIDATORS
large_logged_model_validators = {}
for i in range(1000):
path = f"/mlflow/logged-models/{i}/action"
pat = re_compile_path(path)
method = "POST"
large_logged_model_validators[(pat, method)] = lambda i=i: f"logged{i}"
def test_large_scale_performance():
# Test that function remains performant with large number of validators
import time
large_validators = {}
for i in range(1000):
path = f"/api/model/{i}"
method = "PUT"
large_validators[(path, method)] = lambda i=i: i
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-_find_validator-mhuq6g9kand push.