Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions litellm/proxy/management_endpoints/key_management_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,9 +640,9 @@ async def _common_key_generation_helper( # noqa: PLR0915
request_type="key", **data_json, table_name="key"
)

response["soft_budget"] = (
data.soft_budget
) # include the user-input soft budget in the response
response[
"soft_budget"
] = data.soft_budget # include the user-input soft budget in the response

response = GenerateKeyResponse(**response)

Expand Down Expand Up @@ -1299,7 +1299,6 @@ async def prepare_key_update_data(
data: Union[UpdateKeyRequest, RegenerateKeyRequest],
existing_key_row: LiteLLM_VerificationToken,
):

data_json: dict = data.model_dump(exclude_unset=True)
data_json.pop("key", None)
data_json.pop("new_key", None)
Expand Down Expand Up @@ -2063,7 +2062,8 @@ async def generate_key_helper_fn( # noqa: PLR0915
if duration is None: # allow tokens that never expire
expires = None
else:
expires = get_budget_reset_time(budget_duration=duration)
duration_s = duration_in_seconds(duration=duration)
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)

if key_budget_duration is None: # one-time budget
key_reset_at = None
Expand Down Expand Up @@ -2357,10 +2357,10 @@ async def delete_verification_tokens(
try:
if prisma_client:
tokens = [_hash_token_if_needed(token=key) for key in tokens]
_keys_being_deleted: List[LiteLLM_VerificationToken] = (
await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)
_keys_being_deleted: List[
LiteLLM_VerificationToken
] = await prisma_client.db.litellm_verificationtoken.find_many(
where={"token": {"in": tokens}}
)

if len(_keys_being_deleted) == 0:
Expand Down Expand Up @@ -2468,9 +2468,9 @@ async def _rotate_master_key(
from litellm.proxy.proxy_server import proxy_config

try:
models: Optional[List] = (
await prisma_client.db.litellm_proxymodeltable.find_many()
)
models: Optional[
List
] = await prisma_client.db.litellm_proxymodeltable.find_many()
except Exception:
models = None
# 2. process model table
Expand Down Expand Up @@ -2781,11 +2781,11 @@ async def validate_key_list_check(
param="user_id",
code=status.HTTP_403_FORBIDDEN,
)
complete_user_info_db_obj: Optional[BaseModel] = (
await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)
complete_user_info_db_obj: Optional[
BaseModel
] = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_api_key_dict.user_id},
include={"organization_memberships": True},
)

if complete_user_info_db_obj is None:
Expand Down Expand Up @@ -2871,10 +2871,10 @@ async def get_admin_team_ids(
if complete_user_info is None:
return []
# Get all teams that user is an admin of
teams: Optional[List[BaseModel]] = (
await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
teams: Optional[
List[BaseModel]
] = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": complete_user_info.teams}}
)
if teams is None:
return []
Expand Down Expand Up @@ -3512,7 +3512,7 @@ async def key_health(
Checks:
- If key based logging is configured correctly - sends a test log

Usage
Usage

Pass the key in the request header

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,10 @@ async def test_key_token_handling(monkeypatch):
@pytest.mark.asyncio
async def test_budget_reset_and_expires_at_first_of_month(monkeypatch):
"""
Test that when budget_duration, duration, and key_budget_duration are "1mo", budget_reset_at and expires are set to first of next month
Test that budget reset fields are standardized to 1st of next month.
"""
from datetime import datetime, timezone

mock_prisma_client = AsyncMock()
mock_insert_data = AsyncMock(
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
Expand All @@ -297,52 +299,37 @@ async def test_budget_reset_and_expires_at_first_of_month(monkeypatch):
return_value=MagicMock(token="hashed_token_123", litellm_budget_table=None)
)

from datetime import datetime, timezone

import pytest

from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)
from litellm.proxy.proxy_server import prisma_client

# Use monkeypatch to set the prisma_client
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)

# Test key generation with budget_duration="1mo", duration="1mo", key_budget_duration="1mo"
test_start = datetime.now(timezone.utc)

# Generate key with monthly budget duration
response = await generate_key_helper_fn(
request_type="user",
budget_duration="1mo",
duration="1mo",
key_budget_duration="1mo",
user_id="test_user",
)

print(f"response: {response}\n")
# Get the current date
now = datetime.now(timezone.utc)
# Verify budget_reset_at is standardized to 1st of next month at midnight
budget_reset_at = response.get("budget_reset_at")
assert budget_reset_at is not None
assert budget_reset_at.day == 1, "budget_reset_at should be on 1st of month"
assert budget_reset_at.hour == 0, "budget_reset_at should be at midnight"
assert budget_reset_at.minute == 0, "budget_reset_at should be at midnight"

# Calculate expected reset date (first of next month)
if now.month == 12:
expected_month = 1
expected_year = now.year + 1
# Verify it's next month
if test_start.month == 12:
assert budget_reset_at.month == 1
assert budget_reset_at.year == test_start.year + 1
else:
expected_month = now.month + 1
expected_year = now.year

# Verify budget_reset_at, expires is set to first of next month
for key in ["budget_reset_at", "expires"]:
response_date = response.get(key)
assert response_date is not None, f"{key} not found in response"
assert (
response_date.year == expected_year
), f"Expected year {expected_year}, got {response_date.year} for {key}"
assert (
response_date.month == expected_month
), f"Expected month {expected_month}, got {response_date.month} for {key}"
assert (
response_date.day == 1
), f"Expected day 1, got {response_date.day} for {key}"
assert budget_reset_at.month == test_start.month + 1
assert budget_reset_at.year == test_start.year


@pytest.mark.asyncio
Expand Down Expand Up @@ -1317,6 +1304,213 @@ async def test_update_key_fn_auto_rotate_disable():


@pytest.mark.asyncio
async def test_key_expiration_calculated_from_current_time(monkeypatch):
"""
Test that key expiration is calculated as duration from current time.

For duration="1mo" on Oct 15th:
- expires: Should be Nov 15th (1 month from creation)
- budget_reset_at: Should be Nov 1st (standardized monthly reset)
"""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock

from litellm.litellm_core_utils.duration_parser import duration_in_seconds
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)

# Set up mock prisma client
mock_prisma_client = AsyncMock()
mock_insert_data = AsyncMock(
return_value=MagicMock(
token="hashed_token_123",
litellm_budget_table=None,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
)
mock_prisma_client.insert_data = mock_insert_data
mock_prisma_client.jsonify_object = lambda data: data

monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)

test_start_time = datetime.now(timezone.utc)

# Generate key with monthly duration
response = await generate_key_helper_fn(
request_type="user",
duration="1mo",
budget_duration="1mo",
user_id="test_user",
)

print(f"\nTest time: {test_start_time}")
print(f"expires: {response.get('expires')}")
print(f"budget_reset_at: {response.get('budget_reset_at')}")

# Calculate expected values
duration_seconds = duration_in_seconds("1mo")
expected_expires = test_start_time + timedelta(seconds=duration_seconds)

if test_start_time.month == 12:
expected_budget_reset = datetime(
test_start_time.year + 1, 1, 1, 0, 0, 0, tzinfo=timezone.utc
)
else:
expected_budget_reset = datetime(
test_start_time.year,
test_start_time.month + 1,
1,
0,
0,
0,
tzinfo=timezone.utc,
)

# Verify budget_reset_at is standardized to 1st of next month
budget_reset_at = response.get("budget_reset_at")
assert budget_reset_at is not None
assert budget_reset_at.day == 1, "budget_reset_at should be 1st of month"
assert budget_reset_at.hour == 0

# Verify expires is calculated from current time
expires = response.get("expires")
assert expires is not None

time_diff = abs((expires - expected_expires).total_seconds())
assert (
time_diff < 5
), f"expires should be 1 month from creation time. Expected: {expected_expires}, Got: {expires}"

# expires and budget_reset_at should differ when test runs on non-1st day
if test_start_time.day != 1:
assert (
expires != budget_reset_at
), "expires and budget_reset_at should have different values"
assert (
expires.day != 1
), f"expires should not be on 1st when created on day {test_start_time.day}"


@pytest.mark.asyncio
async def test_key_expiration_with_various_durations(monkeypatch):
"""
Test key expiration calculation for various duration units.

Verify that expires is always calculated as current_time + duration.
"""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock

from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)

# Set up mock prisma client
mock_prisma_client = AsyncMock()
mock_insert_data = AsyncMock(
return_value=MagicMock(
token="hashed_token_123",
litellm_budget_table=None,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
)
mock_prisma_client.insert_data = mock_insert_data
mock_prisma_client.jsonify_object = lambda data: data

monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)

# Test cases: (duration_string, expected_seconds_from_now)
test_cases = [
("30s", 30),
("5m", 300),
("2h", 7200),
("7d", 604800),
]

for duration_str, expected_seconds in test_cases:
test_start = datetime.now(timezone.utc)

response = await generate_key_helper_fn(
request_type="key",
duration=duration_str,
user_id="test_user",
)

expires = response.get("expires")
assert expires is not None, f"expires should be set for duration={duration_str}"

# Calculate expected expiration
expected_expires = test_start + timedelta(seconds=expected_seconds)

# Verify within 2 seconds tolerance
time_diff = abs((expires - expected_expires).total_seconds())
assert (
time_diff < 2
), f"duration={duration_str}: Expected {expected_seconds}s from now, got diff of {time_diff}s"


@pytest.mark.asyncio
async def test_key_budget_reset_uses_standardized_time(monkeypatch):
"""
Test that budget resets are standardized to predictable intervals.

For budget_duration="1mo", budget should reset on 1st of next month.
"""
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock

from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)

# Set up mock prisma client
mock_prisma_client = AsyncMock()
mock_insert_data = AsyncMock(
return_value=MagicMock(
token="hashed_token_123",
litellm_budget_table=None,
created_at=datetime.now(timezone.utc),
updated_at=datetime.now(timezone.utc),
)
)
mock_prisma_client.insert_data = mock_insert_data
mock_prisma_client.jsonify_object = lambda data: data

monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
monkeypatch.setattr("litellm.proxy.proxy_server.premium_user", False)

test_start = datetime.now(timezone.utc)

# Generate key with monthly budget duration
response = await generate_key_helper_fn(
request_type="user",
budget_duration="1mo",
user_id="test_user",
)

budget_reset_at = response.get("budget_reset_at")
assert budget_reset_at is not None

# Verify standardized reset: 1st of next month at midnight
assert budget_reset_at.day == 1
assert budget_reset_at.hour == 0
assert budget_reset_at.minute == 0

# Verify it's next month
if test_start.month == 12:
assert budget_reset_at.month == 1
assert budget_reset_at.year == test_start.year + 1
else:
assert budget_reset_at.month == test_start.month + 1
assert budget_reset_at.year == test_start.year


@pytest.mark.asyncio
async def test_check_team_key_limits_no_existing_keys():
"""
Test _check_team_key_limits when team has no existing keys.
Expand Down
Loading