Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions services/budapp/budapp/credential_ops/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,11 @@ async def add_or_generate_credential(self, request: CredentialRequest, user_id:
return db_credential

async def update_proxy_cache(
self, project_id: UUID, api_key: Optional[str] = None, expiry: Optional[datetime] = None
self,
project_id: UUID,
api_key: Optional[str] = None,
expiry: Optional[datetime] = None,
evaluation_id: Optional[UUID] = None,
):
"""Update the proxy cache in Redis with the latest endpoints and adapters for a given project.

Expand All @@ -263,8 +267,12 @@ async def update_proxy_cache(
the Redis cache with this information. Now includes authentication metadata for API usage tracking.

Args:
api_key (str): The API key to associate with the project and its models.
project_id (UUID): The unique identifier of the project whose endpoints and adapters are to be cached.
api_key (str): The API key to associate with the project and its models.
expiry (datetime): Optional expiry time for the API key.
evaluation_id (UUID): Optional evaluation ID to associate with requests made using this API key.
When provided, all inference requests using this API key will be tagged with this evaluation_id
for tracking evaluation metrics (tokens used, request count, etc.).

Returns:
None
Expand Down Expand Up @@ -363,6 +371,7 @@ async def update_proxy_cache(
"api_key_id": str(key_info["credential_id"]) if key_info.get("credential_id") else None,
"user_id": str(key_info["user_id"]) if key_info.get("user_id") else None,
"api_key_project_id": str(project_id), # project_id is always available
"evaluation_id": str(evaluation_id) if evaluation_id else None,
}

ttl = None
Expand Down
16 changes: 13 additions & 3 deletions services/budapp/budapp/eval_ops/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -2426,14 +2426,19 @@ async def _generate_temporary_evaluation_key(
self,
project_id: uuid.UUID,
experiment_id: uuid.UUID,
evaluation_id: uuid.UUID,
) -> str:
"""Generate temporary API key for evaluation (no DB storage).

The key is only cached in Redis with 24-hour TTL for automatic cleanup.
The evaluation_id is stored in the API key metadata so that all inference
requests made using this key are tagged with the evaluation_id for tracking
metrics like total tokens used, request count, etc.

Args:
project_id: Project ID to associate the credential with
experiment_id: Experiment ID for logging purposes
evaluation_id: Evaluation ID to tag all inference requests made with this key

Returns:
The generated API key string
Expand All @@ -2451,14 +2456,17 @@ async def _generate_temporary_evaluation_key(
expiry = datetime.now() + timedelta(hours=24)

# Update Redis cache directly (no DB storage)
# Include evaluation_id in metadata for tracking inference requests
await CredentialService(self.session).update_proxy_cache(
project_id=project_id,
api_key=api_key,
expiry=expiry,
evaluation_id=evaluation_id,
)

logger.info(
f"Generated temporary evaluation key for experiment {experiment_id}, valid until {expiry.isoformat()}"
f"Generated temporary evaluation key for experiment {experiment_id}, "
f"evaluation {evaluation_id}, valid until {expiry.isoformat()}"
)

return api_key
Expand Down Expand Up @@ -4939,9 +4947,11 @@ async def _trigger_evaluations_for_experiment_and_get_response(
experiment_service = ExperimentService(self.session)
project_id = await experiment_service.get_first_active_project()

# Generate temporary evaluation credential
# Generate temporary evaluation credential with evaluation_id for tracking
_api_key = await experiment_service._generate_temporary_evaluation_key(
project_id=project_id, experiment_id=experiment_id
project_id=project_id,
experiment_id=experiment_id,
evaluation_id=evaluation_id,
)

# Build evaluation request with dynamic values
Expand Down
2 changes: 2 additions & 0 deletions services/budgateway/tensorzero-internal/src/analytics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ pub struct GatewayAnalyticsDatabaseInsert {
pub user_id: Option<String>,
pub project_id: Option<Uuid>,
pub endpoint_id: Option<Uuid>,
pub evaluation_id: Option<Uuid>, // Evaluation ID for tracking evaluation-related requests

/// Performance metrics
#[serde(serialize_with = "serialize_datetime")]
Expand Down Expand Up @@ -123,6 +124,7 @@ impl GatewayAnalyticsDatabaseInsert {
user_id: None,
project_id: None,
endpoint_id: None,
evaluation_id: None,
request_timestamp: now,
response_timestamp: now,
gateway_processing_ms: 0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ pub async fn analytics_middleware(
// Extract selected request headers
record.request_headers = extract_important_headers(&headers);

// Note: Auth metadata headers are extracted from response headers (after endpoint handlers copy them)
// because the auth middleware runs after this middleware on the request path

// Store analytics in request extensions
let analytics_arc = Arc::new(tokio::sync::Mutex::new(analytics));
request.extensions_mut().insert(analytics_arc.clone());
Expand Down Expand Up @@ -130,6 +133,9 @@ pub async fn analytics_middleware(
// Extract selected response headers
analytics.record.response_headers = extract_important_headers(response.headers());

// Extract auth metadata from response headers (copied by endpoint handlers)
extract_auth_metadata(response.headers(), &mut analytics.record);

// Check if request was blocked
if response.status() == StatusCode::FORBIDDEN
|| response.status() == StatusCode::TOO_MANY_REQUESTS
Expand Down Expand Up @@ -443,6 +449,55 @@ fn extract_important_headers(headers: &HeaderMap) -> HashMap<String, String> {
result
}

/// Extract auth metadata from request headers (injected by auth middleware)
fn extract_auth_metadata(headers: &HeaderMap, record: &mut GatewayAnalyticsDatabaseInsert) {
// Extract api_key_id
if let Some(api_key_id) = headers.get("x-tensorzero-api-key-id") {
if let Ok(id_str) = api_key_id.to_str() {
record.api_key_id = Some(id_str.to_string());
tracing::debug!("Captured api_key_id for analytics: {}", id_str);
}
}

// Extract user_id
if let Some(user_id) = headers.get("x-tensorzero-user-id") {
if let Ok(id_str) = user_id.to_str() {
record.user_id = Some(id_str.to_string());
tracing::debug!("Captured user_id for analytics: {}", id_str);
}
}

// Extract project_id
if let Some(project_id) = headers.get("x-tensorzero-project-id") {
if let Ok(id_str) = project_id.to_str() {
if let Ok(uuid) = Uuid::parse_str(id_str) {
record.project_id = Some(uuid);
tracing::debug!("Captured project_id for analytics: {}", id_str);
}
}
}

// Extract endpoint_id
if let Some(endpoint_id) = headers.get("x-tensorzero-endpoint-id") {
if let Ok(id_str) = endpoint_id.to_str() {
if let Ok(uuid) = Uuid::parse_str(id_str) {
record.endpoint_id = Some(uuid);
tracing::debug!("Captured endpoint_id for analytics: {}", id_str);
}
}
}

// Extract evaluation_id
if let Some(evaluation_id) = headers.get("x-tensorzero-evaluation-id") {
if let Ok(id_str) = evaluation_id.to_str() {
if let Ok(uuid) = Uuid::parse_str(id_str) {
record.evaluation_id = Some(uuid);
tracing::debug!("Captured evaluation_id for analytics: {}", id_str);
}
}
}
}
Comment on lines +453 to +499
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function contains a lot of repetitive code for extracting different metadata fields from the headers. This can be refactored to be more concise and maintainable by using and_then to chain the operations, which is more idiomatic in Rust and reduces nesting. This will make the function shorter and easier to read.

/// Extract auth metadata from request headers (injected by auth middleware)
fn extract_auth_metadata(headers: &HeaderMap, record: &mut GatewayAnalyticsDatabaseInsert) {
    if let Some(api_key_id) = headers.get("x-tensorzero-api-key-id").and_then(|v| v.to_str().ok()) {
        record.api_key_id = Some(api_key_id.to_string());
        tracing::debug!("Captured api_key_id for analytics: {}", api_key_id);
    }

    if let Some(user_id) = headers.get("x-tensorzero-user-id").and_then(|v| v.to_str().ok()) {
        record.user_id = Some(user_id.to_string());
        tracing::debug!("Captured user_id for analytics: {}", user_id);
    }

    if let Some(uuid) = headers.get("x-tensorzero-project-id").and_then(|v| v.to_str().ok()).and_then(|s| Uuid::parse_str(s).ok()) {
        record.project_id = Some(uuid);
        tracing::debug!("Captured project_id for analytics: {}", uuid);
    }

    if let Some(uuid) = headers.get("x-tensorzero-endpoint-id").and_then(|v| v.to_str().ok()).and_then(|s| Uuid::parse_str(s).ok()) {
        record.endpoint_id = Some(uuid);
        tracing::debug!("Captured endpoint_id for analytics: {}", uuid);
    }

    if let Some(uuid) = headers.get("x-tensorzero-evaluation-id").and_then(|v| v.to_str().ok()).and_then(|s| Uuid::parse_str(s).ok()) {
        record.evaluation_id = Some(uuid);
        tracing::debug!("Captured evaluation_id for analytics: {}", uuid);
    }
}


/// Write analytics record to ClickHouse
async fn write_analytics_to_clickhouse(
clickhouse: &ClickHouseConnectionInfo,
Expand Down
8 changes: 8 additions & 0 deletions services/budgateway/tensorzero-internal/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ pub struct AuthMetadata {
pub api_key_id: Option<String>,
pub user_id: Option<String>,
pub api_key_project_id: Option<String>,
pub evaluation_id: Option<String>,
}

pub type APIConfig = HashMap<String, ApiKeyMetadata>;
Expand Down Expand Up @@ -342,6 +343,13 @@ pub async fn require_api_key(
.insert("x-tensorzero-api-key-project-id", header_value);
}
}
if let Some(evaluation_id) = auth_meta.evaluation_id {
if let Ok(header_value) = evaluation_id.parse() {
request
.headers_mut()
.insert("x-tensorzero-evaluation-id", header_value);
}
}
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
use crate::clickhouse::migration_manager::migration_trait::Migration;
use crate::clickhouse::ClickHouseConnectionInfo;
use crate::error::{Error, ErrorDetails};
use async_trait::async_trait;

use super::{check_column_exists, check_table_exists};

/// This migration adds evaluation_id column to the ModelInferenceDetails and GatewayAnalytics tables
/// to track inferences associated with specific evaluations.
pub struct Migration0036<'a> {
pub clickhouse: &'a ClickHouseConnectionInfo,
}

const MIGRATION_ID: &str = "0036";

#[async_trait]
impl Migration for Migration0036<'_> {
async fn can_apply(&self) -> Result<(), Error> {
// Check if ModelInferenceDetails table exists
let model_inference_details_exists =
check_table_exists(self.clickhouse, "ModelInferenceDetails", MIGRATION_ID).await?;

if !model_inference_details_exists {
return Err(Error::new(ErrorDetails::ClickHouseMigration {
id: MIGRATION_ID.to_string(),
message: "ModelInferenceDetails table does not exist".to_string(),
}));
}

Ok(())
}

async fn should_apply(&self) -> Result<bool, Error> {
// Check if evaluation_id column already exists in ModelInferenceDetails
let evaluation_id_exists_mid = check_column_exists(
self.clickhouse,
"ModelInferenceDetails",
"evaluation_id",
MIGRATION_ID,
)
.await?;

// Check if evaluation_id column already exists in GatewayAnalytics
let gateway_analytics_exists =
check_table_exists(self.clickhouse, "GatewayAnalytics", MIGRATION_ID).await?;

let evaluation_id_exists_ga = if gateway_analytics_exists {
check_column_exists(
self.clickhouse,
"GatewayAnalytics",
"evaluation_id",
MIGRATION_ID,
)
.await?
} else {
true // If table doesn't exist, consider it as "already done" for this column
};

// Apply if any column is missing
Ok(!evaluation_id_exists_mid || !evaluation_id_exists_ga)
}

async fn apply(&self, _clean_start: bool) -> Result<(), Error> {
// Add evaluation_id column to ModelInferenceDetails
let query = r#"
ALTER TABLE ModelInferenceDetails
ADD COLUMN IF NOT EXISTS evaluation_id Nullable(UUID)
"#;
self.clickhouse
.run_query_synchronous(query.to_string(), None)
.await?;

// Add index for evaluation_id in ModelInferenceDetails
let index_query = r#"
ALTER TABLE ModelInferenceDetails
ADD INDEX IF NOT EXISTS idx_evaluation_id evaluation_id TYPE bloom_filter(0.01) GRANULARITY 4
"#;
self.clickhouse
.run_query_synchronous(index_query.to_string(), None)
.await?;

// Check if GatewayAnalytics table exists before adding column
let gateway_analytics_exists =
check_table_exists(self.clickhouse, "GatewayAnalytics", MIGRATION_ID).await?;

if gateway_analytics_exists {
// Add evaluation_id column to GatewayAnalytics
let ga_query = r#"
ALTER TABLE GatewayAnalytics
ADD COLUMN IF NOT EXISTS evaluation_id Nullable(UUID)
"#;
self.clickhouse
.run_query_synchronous(ga_query.to_string(), None)
.await?;

// Add index for evaluation_id in GatewayAnalytics
let ga_index_query = r#"
ALTER TABLE GatewayAnalytics
ADD INDEX IF NOT EXISTS idx_evaluation_id evaluation_id TYPE bloom_filter(0.01) GRANULARITY 4
"#;
self.clickhouse
.run_query_synchronous(ga_index_query.to_string(), None)
.await?;
}

Ok(())
}

fn rollback_instructions(&self) -> String {
r#"
ALTER TABLE ModelInferenceDetails DROP INDEX IF EXISTS idx_evaluation_id;
ALTER TABLE ModelInferenceDetails DROP COLUMN IF EXISTS evaluation_id;
ALTER TABLE GatewayAnalytics DROP INDEX IF EXISTS idx_evaluation_id;
ALTER TABLE GatewayAnalytics DROP COLUMN IF EXISTS evaluation_id;
"#
.to_string()
}

async fn has_succeeded(&self) -> Result<bool, Error> {
let should_apply = self.should_apply().await?;
Ok(!should_apply)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ pub mod migration_0032;
pub mod migration_0033;
pub mod migration_0034;
pub mod migration_0035;
pub mod migration_0036;

/// Returns true if the table exists, false if it does not
/// Errors if the query fails
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,12 @@ use migrations::migration_0032::Migration0032;
use migrations::migration_0033::Migration0033;
use migrations::migration_0034::Migration0034;
use migrations::migration_0035::Migration0035;
use migrations::migration_0036::Migration0036;

/// This must match the number of migrations returned by `make_all_migrations` - the tests
/// will panic if they don't match.
/// Note: We have 36 total migrations (0-35), but 7 are banned (0001, 0007, 0010, 0012, 0013, 0014, 0023)
pub const NUM_MIGRATIONS: usize = 29;
/// Note: We have 37 total migrations (0-36), but 7 are banned (0001, 0007, 0010, 0012, 0013, 0014, 0023)
pub const NUM_MIGRATIONS: usize = 30;

/// Constructs (but does not run) a vector of all our database migrations.
/// This is the single source of truth for all migration - it's used during startup to migrate
Expand Down Expand Up @@ -88,6 +89,7 @@ pub fn make_all_migrations<'a>(
Box::new(Migration0033 { clickhouse }),
Box::new(Migration0034 { clickhouse }),
Box::new(Migration0035 { clickhouse }),
Box::new(Migration0036 { clickhouse }),
];
assert_eq!(
migrations.len(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -640,6 +640,7 @@ pub struct ModelInferenceDetailsInsert {
pub api_key_id: Option<uuid::Uuid>,
pub user_id: Option<uuid::Uuid>,
pub api_key_project_id: Option<uuid::Uuid>,
pub evaluation_id: Option<uuid::Uuid>,
pub error_code: Option<String>,
pub error_message: Option<String>,
pub error_type: Option<String>,
Expand Down
Loading
Loading