Skip to content

Conversation

@codeflash-ai
Copy link

@codeflash-ai codeflash-ai bot commented Nov 13, 2025

📄 26% (0.26x) speedup for ExperimentalUIJWTToken.get_key_object_from_ui_hash_key in litellm/proxy/auth/auth_checks.py

⏱️ Runtime : 963 microseconds 766 microseconds (best of 169 runs)

📝 Explanation and details

The optimization applies LRU caching to the salt key retrieval using @lru_cache(maxsize=1) on a new _cached_get_salt_key() function that wraps the original _get_salt_key() call.

Key change: Instead of calling _get_salt_key() directly in decrypt_value_helper(), it now calls _cached_get_salt_key() which caches the result after the first call.

Why this provides a speedup: The line profiler shows that _get_salt_key() consumes 99.7% of the execution time in decrypt_value_helper() (4.36 seconds out of 4.37 seconds total). This function performs expensive operations like environment variable lookups and imports from litellm.proxy.proxy_server. Since the salt key is static within a process lifecycle, caching eliminates this repeated overhead.

Performance impact based on function references: The function is called from _user_api_key_auth_builder() via ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(), which is part of the authentication flow for UI login tokens. This authentication happens on every request that uses UI-based JWT tokens, making this a hot path where the caching provides significant value.

Test case performance: The optimization shows consistent 20-35% improvements across all test scenarios, with particularly strong gains in the large-scale test cases (34% for the loop test). This indicates the optimization scales well with increased usage patterns where the same process handles multiple decryption operations, which is typical in a proxy server handling multiple concurrent requests.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 128 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 60.0%
🌀 Generated Regression Tests and Runtime
import base64
import hashlib
import json
import os
import sys
# Patch nacl.secret.SecretBox in the decrypt_value function for testing
import types
from typing import Optional

# imports
import pytest
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken

# --- Minimal stubs and helpers to make the test self-contained ---

# Dummy UserAPIKeyAuth class for testing
class UserAPIKeyAuth:
    def __init__(self, **kwargs):
        self.__dict__.update(kwargs)

    def __eq__(self, other):
        if not isinstance(other, UserAPIKeyAuth):
            return False
        return self.__dict__ == other.__dict__

# Minimal nacl.secret.SecretBox stub using XOR for demonstration (not secure, just for test)
class DummySecretBox:
    def __init__(self, key):
        self.key = key
    def decrypt(self, value):
        # XOR with key for demonstration (not secure!)
        if not value:
            return b""
        key = self.key
        key_len = len(key)
        return bytes([b ^ key[i % key_len] for i, b in enumerate(value)])
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken

# --- Helper functions for test setup ---

def encrypt_and_b64encode(data: dict, signing_key: str) -> str:
    # Serialize to JSON
    json_str = json.dumps(data)
    json_bytes = json_str.encode("utf-8")
    # Use SHA256 for 32-byte key
    hash_object = hashlib.sha256(signing_key.encode())
    hash_bytes = hash_object.digest()
    # Use DummySecretBox for testing
    box = DummySecretBox(hash_bytes)
    # XOR encrypt (same as decrypt for this dummy)
    encrypted = box.decrypt(json_bytes)
    # base64 encode
    return base64.b64encode(encrypted).decode("utf-8")

# --- Unit tests ---

# 1. BASIC TEST CASES

def test_valid_token_returns_user_api_key_auth():
    """Basic: Valid token should return correct UserAPIKeyAuth object"""
    data = {"user_id": "abc123", "scopes": ["read", "write"]}
    hashed_token = encrypt_and_b64encode(data, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 13.5μs -> 10.5μs (28.5% faster)

def test_valid_token_with_extra_fields():
    """Basic: Valid token with extra fields"""
    data = {"user_id": "xyz", "scopes": ["admin"], "email": "[email protected]", "active": True}
    hashed_token = encrypt_and_b64encode(data, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 11.7μs -> 9.26μs (26.4% faster)

def test_valid_token_with_empty_scopes():
    """Basic: Valid token with empty list field"""
    data = {"user_id": "empty", "scopes": []}
    hashed_token = encrypt_and_b64encode(data, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 11.2μs -> 8.67μs (29.4% faster)

# 2. EDGE TEST CASES

def test_invalid_base64_returns_none():
    """Edge: Invalid base64 string returns None"""
    invalid_b64 = "!!!not_base64!!!"
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(invalid_b64); result = codeflash_output # 11.1μs -> 8.60μs (29.4% faster)

def test_wrong_key_returns_none():
    """Edge: Encrypted with wrong key returns None (decryption fails)"""
    data = {"user_id": "fail", "scopes": ["none"]}
    # Encrypt with a different key
    hashed_token = encrypt_and_b64encode(data, "wrong_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 11.4μs -> 9.34μs (21.6% faster)



def test_empty_string_returns_none():
    """Edge: Empty string input returns None"""
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(""); result = codeflash_output # 17.4μs -> 14.2μs (23.1% faster)


def test_unicode_characters_in_token():
    """Edge: Token with unicode characters in fields"""
    data = {"user_id": "用户", "scopes": ["读", "写"]}
    hashed_token = encrypt_and_b64encode(data, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 16.9μs -> 13.6μs (24.6% faster)

def test_token_with_nested_dict():
    """Edge: Token with nested dictionary fields"""
    data = {"user_id": "nested", "profile": {"age": 30, "country": "US"}}
    hashed_token = encrypt_and_b64encode(data, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 12.7μs -> 10.3μs (23.5% faster)

# 3. LARGE SCALE TEST CASES

def test_large_number_of_fields():
    """Large Scale: Token with many fields"""
    data = {f"field_{i}": i for i in range(500)}
    data["user_id"] = "bulk"
    hashed_token = encrypt_and_b64encode(data, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 30.1μs -> 27.1μs (11.3% faster)
    for i in range(500):
        pass

def test_large_string_field():
    """Large Scale: Token with a very large string field"""
    big_string = "x" * 5000
    data = {"user_id": "big", "blob": big_string}
    hashed_token = encrypt_and_b64encode(data, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 22.4μs -> 19.7μs (13.8% faster)

def test_many_tokens_in_a_loop():
    """Large Scale: Decrypt many tokens in a loop to check stability"""
    for i in range(100):
        data = {"user_id": f"user{i}", "scopes": [str(j) for j in range(10)]}
        hashed_token = encrypt_and_b64encode(data, "test_master_key")
        codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 489μs -> 365μs (34.1% faster)

def test_large_nested_structure():
    """Large Scale: Token with large nested structure"""
    nested = {"items": [{"id": i, "val": [j for j in range(10)]} for i in range(50)]}
    nested["user_id"] = "nested_bulk"
    hashed_token = encrypt_and_b64encode(nested, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 17.2μs -> 14.6μs (18.0% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import base64
import hashlib
import json
import os

# imports
import pytest
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken


# --- Minimal stub/mock for UserAPIKeyAuth ---
class UserAPIKeyAuth:
    def __init__(self, **kwargs):
        # Save all fields for comparison
        self.__dict__.update(kwargs)

    def __eq__(self, other):
        return isinstance(other, UserAPIKeyAuth) and self.__dict__ == other.__dict__

    def __repr__(self):
        return f"UserAPIKeyAuth({self.__dict__})"

def dummy_encrypt(plaintext: str, signing_key: str) -> str:
    # Simulate nacl.secret.SecretBox encryption
    # prepend b"nonce" to plaintext, then base64 encode
    b = b"nonce" + plaintext.encode("utf-8")
    return base64.b64encode(b).decode("utf-8")
from litellm.proxy.auth.auth_checks import ExperimentalUIJWTToken

# --- TESTS ---

# --- BASIC TEST CASES ---
def test_basic_valid_token_single_field():
    # Test with a single field
    payload = {"api_key": "abc123"}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 12.7μs -> 10.1μs (25.9% faster)

def test_basic_valid_token_multiple_fields():
    # Test with multiple fields
    payload = {"api_key": "xyz", "user_id": "user42", "scope": "read"}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 11.5μs -> 9.10μs (26.2% faster)

def test_basic_valid_token_numeric_and_bool_fields():
    # Test with numeric and boolean fields
    payload = {"api_key": "num", "active": True, "quota": 123}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 10.9μs -> 9.03μs (20.8% faster)

# --- EDGE TEST CASES ---
def test_edge_empty_token():
    # Empty token should return None
    hashed_token = dummy_encrypt("", "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output
    # json.loads("") raises, so should raise Exception
    with pytest.raises(Exception):
        ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token)

def test_edge_invalid_base64():
    # Not a valid base64 string
    invalid_b64 = "!!!not_base64!!!"
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(invalid_b64); result = codeflash_output # 16.0μs -> 12.2μs (31.4% faster)

def test_edge_decryption_fails():
    # Valid base64, but not valid encrypted value
    # For our dummy, must start with b"nonce"
    bad_encrypted = base64.b64encode(b"badstuff").decode("utf-8")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(bad_encrypted); result = codeflash_output # 13.7μs -> 11.5μs (19.1% faster)


def test_edge_missing_fields():
    # JSON with missing expected fields (should still succeed, but result is empty)
    payload = {}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 16.7μs -> 13.7μs (22.0% faster)

def test_edge_extra_fields():
    # JSON with extra fields
    payload = {"api_key": "abc", "extra1": "foo", "extra2": 42}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 12.6μs -> 10.3μs (22.9% faster)


def test_edge_token_wrong_key():
    # Encrypt with wrong key, decryption fails
    payload = {"api_key": "abc"}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "wrong_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 16.9μs -> 13.9μs (22.2% faster)

def test_edge_token_with_unicode():
    # Unicode characters in payload
    payload = {"api_key": "测试", "user_id": "ユーザー"}
    plaintext = json.dumps(payload, ensure_ascii=False)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 12.6μs -> 10.4μs (21.0% faster)

def test_edge_token_with_special_json_types():
    # JSON with null, array, object
    payload = {"api_key": None, "roles": ["admin", "user"], "meta": {"x": 1}}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 12.1μs -> 9.57μs (26.7% faster)

# --- LARGE SCALE TEST CASES ---
def test_large_token_many_fields():
    # Large payload with many fields
    payload = {f"field_{i}": f"value_{i}" for i in range(500)}
    # Add a known field for checking
    payload["api_key"] = "large_key"
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 36.1μs -> 33.0μs (9.38% faster)
    for i in range(500):
        pass

def test_large_token_long_strings():
    # Payload with very long string values
    long_str = "x" * 1000
    payload = {"api_key": long_str, "description": long_str}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 17.1μs -> 14.6μs (17.3% faster)

def test_large_token_large_json_array():
    # Payload with a large array
    payload = {"api_key": "array_key", "data": list(range(1000))}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 22.0μs -> 18.6μs (18.0% faster)

def test_large_token_large_nested_object():
    # Payload with large nested object
    nested = {f"sub_{i}": i for i in range(500)}
    payload = {"api_key": "nested_key", "nested": nested}
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 27.3μs -> 24.3μs (12.2% faster)

def test_large_token_performance():
    # Performance: test with 1000 fields, ensure it doesn't take too long
    payload = {f"field_{i}": f"value_{i}" for i in range(1000)}
    payload["api_key"] = "perf_key"
    plaintext = json.dumps(payload)
    hashed_token = dummy_encrypt(plaintext, "test_master_key")
    codeflash_output = ExperimentalUIJWTToken.get_key_object_from_ui_hash_key(hashed_token); result = codeflash_output # 60.0μs -> 55.6μs (7.95% faster)
    for i in range(1000):
        pass
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-ExperimentalUIJWTToken.get_key_object_from_ui_hash_key-mhwz9ol2 and push.

Codeflash Static Badge

The optimization applies **LRU caching to the salt key retrieval** using `@lru_cache(maxsize=1)` on a new `_cached_get_salt_key()` function that wraps the original `_get_salt_key()` call.

**Key change**: Instead of calling `_get_salt_key()` directly in `decrypt_value_helper()`, it now calls `_cached_get_salt_key()` which caches the result after the first call.

**Why this provides a speedup**: The line profiler shows that `_get_salt_key()` consumes 99.7% of the execution time in `decrypt_value_helper()` (4.36 seconds out of 4.37 seconds total). This function performs expensive operations like environment variable lookups and imports from `litellm.proxy.proxy_server`. Since the salt key is static within a process lifecycle, caching eliminates this repeated overhead.

**Performance impact based on function references**: The function is called from `_user_api_key_auth_builder()` via `ExperimentalUIJWTToken.get_key_object_from_ui_hash_key()`, which is part of the authentication flow for UI login tokens. This authentication happens on every request that uses UI-based JWT tokens, making this a hot path where the caching provides significant value.

**Test case performance**: The optimization shows consistent 20-35% improvements across all test scenarios, with particularly strong gains in the large-scale test cases (34% for the loop test). This indicates the optimization scales well with increased usage patterns where the same process handles multiple decryption operations, which is typical in a proxy server handling multiple concurrent requests.
@codeflash-ai codeflash-ai bot requested a review from mashraf-222 November 13, 2025 05:17
@codeflash-ai codeflash-ai bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Nov 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant