diff --git a/services/budapp/budapp/credential_ops/services.py b/services/budapp/budapp/credential_ops/services.py index af64e480e..51efd6860 100644 --- a/services/budapp/budapp/credential_ops/services.py +++ b/services/budapp/budapp/credential_ops/services.py @@ -254,7 +254,11 @@ async def add_or_generate_credential(self, request: CredentialRequest, user_id: return db_credential async def update_proxy_cache( - self, project_id: UUID, api_key: Optional[str] = None, expiry: Optional[datetime] = None + self, + project_id: UUID, + api_key: Optional[str] = None, + expiry: Optional[datetime] = None, + evaluation_id: Optional[UUID] = None, ): """Update the proxy cache in Redis with the latest endpoints and adapters for a given project. @@ -263,8 +267,12 @@ async def update_proxy_cache( the Redis cache with this information. Now includes authentication metadata for API usage tracking. Args: - api_key (str): The API key to associate with the project and its models. project_id (UUID): The unique identifier of the project whose endpoints and adapters are to be cached. + api_key (str): The API key to associate with the project and its models. + expiry (datetime): Optional expiry time for the API key. + evaluation_id (UUID): Optional evaluation ID to associate with requests made using this API key. + When provided, all inference requests using this API key will be tagged with this evaluation_id + for tracking evaluation metrics (tokens used, request count, etc.). Returns: None @@ -363,6 +371,7 @@ async def update_proxy_cache( "api_key_id": str(key_info["credential_id"]) if key_info.get("credential_id") else None, "user_id": str(key_info["user_id"]) if key_info.get("user_id") else None, "api_key_project_id": str(project_id), # project_id is always available + "evaluation_id": str(evaluation_id) if evaluation_id else None, } ttl = None diff --git a/services/budapp/budapp/eval_ops/services.py b/services/budapp/budapp/eval_ops/services.py index 70e7c4d44..0b6d7e9bc 100644 --- a/services/budapp/budapp/eval_ops/services.py +++ b/services/budapp/budapp/eval_ops/services.py @@ -2426,14 +2426,19 @@ async def _generate_temporary_evaluation_key( self, project_id: uuid.UUID, experiment_id: uuid.UUID, + evaluation_id: uuid.UUID, ) -> str: """Generate temporary API key for evaluation (no DB storage). The key is only cached in Redis with 24-hour TTL for automatic cleanup. + The evaluation_id is stored in the API key metadata so that all inference + requests made using this key are tagged with the evaluation_id for tracking + metrics like total tokens used, request count, etc. Args: project_id: Project ID to associate the credential with experiment_id: Experiment ID for logging purposes + evaluation_id: Evaluation ID to tag all inference requests made with this key Returns: The generated API key string @@ -2451,14 +2456,17 @@ async def _generate_temporary_evaluation_key( expiry = datetime.now() + timedelta(hours=24) # Update Redis cache directly (no DB storage) + # Include evaluation_id in metadata for tracking inference requests await CredentialService(self.session).update_proxy_cache( project_id=project_id, api_key=api_key, expiry=expiry, + evaluation_id=evaluation_id, ) logger.info( - f"Generated temporary evaluation key for experiment {experiment_id}, valid until {expiry.isoformat()}" + f"Generated temporary evaluation key for experiment {experiment_id}, " + f"evaluation {evaluation_id}, valid until {expiry.isoformat()}" ) return api_key @@ -4939,9 +4947,11 @@ async def _trigger_evaluations_for_experiment_and_get_response( experiment_service = ExperimentService(self.session) project_id = await experiment_service.get_first_active_project() - # Generate temporary evaluation credential + # Generate temporary evaluation credential with evaluation_id for tracking _api_key = await experiment_service._generate_temporary_evaluation_key( - project_id=project_id, experiment_id=experiment_id + project_id=project_id, + experiment_id=experiment_id, + evaluation_id=evaluation_id, ) # Build evaluation request with dynamic values diff --git a/services/budgateway/tensorzero-internal/src/analytics.rs b/services/budgateway/tensorzero-internal/src/analytics.rs index 3b87a0810..bf2e59d29 100644 --- a/services/budgateway/tensorzero-internal/src/analytics.rs +++ b/services/budgateway/tensorzero-internal/src/analytics.rs @@ -58,6 +58,7 @@ pub struct GatewayAnalyticsDatabaseInsert { pub user_id: Option, pub project_id: Option, pub endpoint_id: Option, + pub evaluation_id: Option, // Evaluation ID for tracking evaluation-related requests /// Performance metrics #[serde(serialize_with = "serialize_datetime")] @@ -123,6 +124,7 @@ impl GatewayAnalyticsDatabaseInsert { user_id: None, project_id: None, endpoint_id: None, + evaluation_id: None, request_timestamp: now, response_timestamp: now, gateway_processing_ms: 0, diff --git a/services/budgateway/tensorzero-internal/src/analytics_middleware.rs b/services/budgateway/tensorzero-internal/src/analytics_middleware.rs index f6e282b59..7f8aa1ca3 100644 --- a/services/budgateway/tensorzero-internal/src/analytics_middleware.rs +++ b/services/budgateway/tensorzero-internal/src/analytics_middleware.rs @@ -63,6 +63,9 @@ pub async fn analytics_middleware( // Extract selected request headers record.request_headers = extract_important_headers(&headers); + // Note: Auth metadata headers are extracted from response headers (after endpoint handlers copy them) + // because the auth middleware runs after this middleware on the request path + // Store analytics in request extensions let analytics_arc = Arc::new(tokio::sync::Mutex::new(analytics)); request.extensions_mut().insert(analytics_arc.clone()); @@ -130,6 +133,9 @@ pub async fn analytics_middleware( // Extract selected response headers analytics.record.response_headers = extract_important_headers(response.headers()); + // Extract auth metadata from response headers (copied by endpoint handlers) + extract_auth_metadata(response.headers(), &mut analytics.record); + // Check if request was blocked if response.status() == StatusCode::FORBIDDEN || response.status() == StatusCode::TOO_MANY_REQUESTS @@ -443,6 +449,55 @@ fn extract_important_headers(headers: &HeaderMap) -> HashMap { result } +/// Extract auth metadata from request headers (injected by auth middleware) +fn extract_auth_metadata(headers: &HeaderMap, record: &mut GatewayAnalyticsDatabaseInsert) { + // Extract api_key_id + if let Some(api_key_id) = headers.get("x-tensorzero-api-key-id") { + if let Ok(id_str) = api_key_id.to_str() { + record.api_key_id = Some(id_str.to_string()); + tracing::debug!("Captured api_key_id for analytics: {}", id_str); + } + } + + // Extract user_id + if let Some(user_id) = headers.get("x-tensorzero-user-id") { + if let Ok(id_str) = user_id.to_str() { + record.user_id = Some(id_str.to_string()); + tracing::debug!("Captured user_id for analytics: {}", id_str); + } + } + + // Extract project_id + if let Some(project_id) = headers.get("x-tensorzero-project-id") { + if let Ok(id_str) = project_id.to_str() { + if let Ok(uuid) = Uuid::parse_str(id_str) { + record.project_id = Some(uuid); + tracing::debug!("Captured project_id for analytics: {}", id_str); + } + } + } + + // Extract endpoint_id + if let Some(endpoint_id) = headers.get("x-tensorzero-endpoint-id") { + if let Ok(id_str) = endpoint_id.to_str() { + if let Ok(uuid) = Uuid::parse_str(id_str) { + record.endpoint_id = Some(uuid); + tracing::debug!("Captured endpoint_id for analytics: {}", id_str); + } + } + } + + // Extract evaluation_id + if let Some(evaluation_id) = headers.get("x-tensorzero-evaluation-id") { + if let Ok(id_str) = evaluation_id.to_str() { + if let Ok(uuid) = Uuid::parse_str(id_str) { + record.evaluation_id = Some(uuid); + tracing::debug!("Captured evaluation_id for analytics: {}", id_str); + } + } + } +} + /// Write analytics record to ClickHouse async fn write_analytics_to_clickhouse( clickhouse: &ClickHouseConnectionInfo, diff --git a/services/budgateway/tensorzero-internal/src/auth.rs b/services/budgateway/tensorzero-internal/src/auth.rs index 2fdba9a33..0471e9a68 100644 --- a/services/budgateway/tensorzero-internal/src/auth.rs +++ b/services/budgateway/tensorzero-internal/src/auth.rs @@ -27,6 +27,7 @@ pub struct AuthMetadata { pub api_key_id: Option, pub user_id: Option, pub api_key_project_id: Option, + pub evaluation_id: Option, } pub type APIConfig = HashMap; @@ -342,6 +343,13 @@ pub async fn require_api_key( .insert("x-tensorzero-api-key-project-id", header_value); } } + if let Some(evaluation_id) = auth_meta.evaluation_id { + if let Ok(header_value) = evaluation_id.parse() { + request + .headers_mut() + .insert("x-tensorzero-evaluation-id", header_value); + } + } } } diff --git a/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/migrations/migration_0036.rs b/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/migrations/migration_0036.rs new file mode 100644 index 000000000..691628199 --- /dev/null +++ b/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/migrations/migration_0036.rs @@ -0,0 +1,123 @@ +use crate::clickhouse::migration_manager::migration_trait::Migration; +use crate::clickhouse::ClickHouseConnectionInfo; +use crate::error::{Error, ErrorDetails}; +use async_trait::async_trait; + +use super::{check_column_exists, check_table_exists}; + +/// This migration adds evaluation_id column to the ModelInferenceDetails and GatewayAnalytics tables +/// to track inferences associated with specific evaluations. +pub struct Migration0036<'a> { + pub clickhouse: &'a ClickHouseConnectionInfo, +} + +const MIGRATION_ID: &str = "0036"; + +#[async_trait] +impl Migration for Migration0036<'_> { + async fn can_apply(&self) -> Result<(), Error> { + // Check if ModelInferenceDetails table exists + let model_inference_details_exists = + check_table_exists(self.clickhouse, "ModelInferenceDetails", MIGRATION_ID).await?; + + if !model_inference_details_exists { + return Err(Error::new(ErrorDetails::ClickHouseMigration { + id: MIGRATION_ID.to_string(), + message: "ModelInferenceDetails table does not exist".to_string(), + })); + } + + Ok(()) + } + + async fn should_apply(&self) -> Result { + // Check if evaluation_id column already exists in ModelInferenceDetails + let evaluation_id_exists_mid = check_column_exists( + self.clickhouse, + "ModelInferenceDetails", + "evaluation_id", + MIGRATION_ID, + ) + .await?; + + // Check if evaluation_id column already exists in GatewayAnalytics + let gateway_analytics_exists = + check_table_exists(self.clickhouse, "GatewayAnalytics", MIGRATION_ID).await?; + + let evaluation_id_exists_ga = if gateway_analytics_exists { + check_column_exists( + self.clickhouse, + "GatewayAnalytics", + "evaluation_id", + MIGRATION_ID, + ) + .await? + } else { + true // If table doesn't exist, consider it as "already done" for this column + }; + + // Apply if any column is missing + Ok(!evaluation_id_exists_mid || !evaluation_id_exists_ga) + } + + async fn apply(&self, _clean_start: bool) -> Result<(), Error> { + // Add evaluation_id column to ModelInferenceDetails + let query = r#" + ALTER TABLE ModelInferenceDetails + ADD COLUMN IF NOT EXISTS evaluation_id Nullable(UUID) + "#; + self.clickhouse + .run_query_synchronous(query.to_string(), None) + .await?; + + // Add index for evaluation_id in ModelInferenceDetails + let index_query = r#" + ALTER TABLE ModelInferenceDetails + ADD INDEX IF NOT EXISTS idx_evaluation_id evaluation_id TYPE bloom_filter(0.01) GRANULARITY 4 + "#; + self.clickhouse + .run_query_synchronous(index_query.to_string(), None) + .await?; + + // Check if GatewayAnalytics table exists before adding column + let gateway_analytics_exists = + check_table_exists(self.clickhouse, "GatewayAnalytics", MIGRATION_ID).await?; + + if gateway_analytics_exists { + // Add evaluation_id column to GatewayAnalytics + let ga_query = r#" + ALTER TABLE GatewayAnalytics + ADD COLUMN IF NOT EXISTS evaluation_id Nullable(UUID) + "#; + self.clickhouse + .run_query_synchronous(ga_query.to_string(), None) + .await?; + + // Add index for evaluation_id in GatewayAnalytics + let ga_index_query = r#" + ALTER TABLE GatewayAnalytics + ADD INDEX IF NOT EXISTS idx_evaluation_id evaluation_id TYPE bloom_filter(0.01) GRANULARITY 4 + "#; + self.clickhouse + .run_query_synchronous(ga_index_query.to_string(), None) + .await?; + } + + Ok(()) + } + + fn rollback_instructions(&self) -> String { + r#" + ALTER TABLE ModelInferenceDetails DROP INDEX IF EXISTS idx_evaluation_id; + ALTER TABLE ModelInferenceDetails DROP COLUMN IF EXISTS evaluation_id; + ALTER TABLE GatewayAnalytics DROP INDEX IF EXISTS idx_evaluation_id; + ALTER TABLE GatewayAnalytics DROP COLUMN IF EXISTS evaluation_id; + "# + .to_string() + } + + async fn has_succeeded(&self) -> Result { + let should_apply = self.should_apply().await?; + Ok(!should_apply) + } +} diff --git a/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/migrations/mod.rs b/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/migrations/mod.rs index 8cf5b5a1f..5870e37c7 100644 --- a/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/migrations/mod.rs +++ b/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/migrations/mod.rs @@ -33,6 +33,7 @@ pub mod migration_0032; pub mod migration_0033; pub mod migration_0034; pub mod migration_0035; +pub mod migration_0036; /// Returns true if the table exists, false if it does not /// Errors if the query fails diff --git a/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/mod.rs b/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/mod.rs index 31ca705f3..11a5851d1 100644 --- a/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/mod.rs +++ b/services/budgateway/tensorzero-internal/src/clickhouse/migration_manager/mod.rs @@ -33,11 +33,12 @@ use migrations::migration_0032::Migration0032; use migrations::migration_0033::Migration0033; use migrations::migration_0034::Migration0034; use migrations::migration_0035::Migration0035; +use migrations::migration_0036::Migration0036; /// This must match the number of migrations returned by `make_all_migrations` - the tests /// will panic if they don't match. -/// Note: We have 36 total migrations (0-35), but 7 are banned (0001, 0007, 0010, 0012, 0013, 0014, 0023) -pub const NUM_MIGRATIONS: usize = 29; +/// Note: We have 37 total migrations (0-36), but 7 are banned (0001, 0007, 0010, 0012, 0013, 0014, 0023) +pub const NUM_MIGRATIONS: usize = 30; /// Constructs (but does not run) a vector of all our database migrations. /// This is the single source of truth for all migration - it's used during startup to migrate @@ -88,6 +89,7 @@ pub fn make_all_migrations<'a>( Box::new(Migration0033 { clickhouse }), Box::new(Migration0034 { clickhouse }), Box::new(Migration0035 { clickhouse }), + Box::new(Migration0036 { clickhouse }), ]; assert_eq!( migrations.len(), diff --git a/services/budgateway/tensorzero-internal/src/clickhouse/mod.rs b/services/budgateway/tensorzero-internal/src/clickhouse/mod.rs index 65b77c750..826edf6e5 100644 --- a/services/budgateway/tensorzero-internal/src/clickhouse/mod.rs +++ b/services/budgateway/tensorzero-internal/src/clickhouse/mod.rs @@ -640,6 +640,7 @@ pub struct ModelInferenceDetailsInsert { pub api_key_id: Option, pub user_id: Option, pub api_key_project_id: Option, + pub evaluation_id: Option, pub error_code: Option, pub error_message: Option, pub error_type: Option, diff --git a/services/budgateway/tensorzero-internal/src/endpoints/inference.rs b/services/budgateway/tensorzero-internal/src/endpoints/inference.rs index 06865cd51..227316238 100644 --- a/services/budgateway/tensorzero-internal/src/endpoints/inference.rs +++ b/services/budgateway/tensorzero-internal/src/endpoints/inference.rs @@ -127,6 +127,8 @@ pub struct ObservabilityMetadata { pub api_key_id: Option, pub user_id: Option, pub api_key_project_id: Option, + // Evaluation tracking metadata + pub evaluation_id: Option, } #[derive(Clone, Debug)] @@ -198,6 +200,11 @@ pub async fn inference_handler( .get("x-tensorzero-api-key-project-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); + // Extract evaluation tracking metadata from headers + let evaluation_id = headers + .get("x-tensorzero-evaluation-id") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); Some(ObservabilityMetadata { project_id: project_id.to_string(), @@ -206,6 +213,7 @@ pub async fn inference_handler( api_key_id, user_id, api_key_project_id, + evaluation_id, }) } else { None @@ -320,14 +328,22 @@ pub async fn inference_handler( model_latency_ms.to_string().parse().unwrap(), ); + // Add auth metadata headers to response for analytics middleware + super::add_auth_metadata_to_response(&mut http_response, &headers); + Ok(http_response) } Ok(InferenceOutput::Streaming(stream)) => { let event_stream = prepare_serialized_events(stream); - Ok(Sse::new(event_stream) + let mut response = Sse::new(event_stream) .keep_alive(axum::response::sse::KeepAlive::new()) - .into_response()) + .into_response(); + + // Add auth metadata headers to response for analytics middleware + super::add_auth_metadata_to_response(&mut response, &headers); + + Ok(response) } Err(error) => { // The inference function already sends failure events internally for AllVariantsFailed @@ -346,6 +362,9 @@ pub async fn inference_handler( } } + // Add auth metadata headers to response for analytics middleware + super::add_auth_metadata_to_response(&mut error_response, &headers); + Ok(error_response) } } @@ -1367,7 +1386,7 @@ pub async fn write_inference( }; // Use observability metadata if available, otherwise fall back to function/variant names - let (project_id, endpoint_id, obs_model_id, api_key_id, user_id, api_key_project_id) = + let (project_id, endpoint_id, obs_model_id, api_key_id, user_id, api_key_project_id, evaluation_id) = if let Some(obs_metadata) = observability_metadata { ( obs_metadata.project_id.clone(), @@ -1376,6 +1395,7 @@ pub async fn write_inference( obs_metadata.api_key_id.clone(), obs_metadata.user_id.clone(), obs_metadata.api_key_project_id.clone(), + obs_metadata.evaluation_id.clone(), ) } else { ( @@ -1385,6 +1405,7 @@ pub async fn write_inference( None, None, None, + None, ) }; @@ -1402,6 +1423,7 @@ pub async fn write_inference( api_key_id, user_id, api_key_project_id, + evaluation_id, error_code: None, // No error for successful inferences error_message: None, error_type: None, @@ -1493,7 +1515,7 @@ async fn send_failure_event( }; // Use observability metadata if available, otherwise use function/model names - let (project_id, endpoint_id, model_id, api_key_id, user_id, api_key_project_id) = + let (project_id, endpoint_id, model_id, api_key_id, user_id, api_key_project_id, evaluation_id) = if let Some(obs_metadata) = &observability_metadata { ( obs_metadata.project_id.clone(), @@ -1502,6 +1524,7 @@ async fn send_failure_event( obs_metadata.api_key_id.clone(), obs_metadata.user_id.clone(), obs_metadata.api_key_project_id.clone(), + obs_metadata.evaluation_id.clone(), ) } else { // Fallback to function/model names @@ -1514,6 +1537,7 @@ async fn send_failure_event( None, None, None, + None, ) }; @@ -1531,6 +1555,7 @@ async fn send_failure_event( api_key_id: api_key_id.clone(), user_id: user_id.clone(), api_key_project_id: api_key_project_id.clone(), + evaluation_id: evaluation_id.clone(), error_code: Some(format!("{:?}", status_code)), error_message: Some(error_message.clone()), error_type: Some(error_type.to_string()), @@ -1615,6 +1640,7 @@ async fn send_failure_event( api_key_id: api_key_id.and_then(|id| uuid::Uuid::parse_str(&id).ok()), user_id: user_id.and_then(|id| uuid::Uuid::parse_str(&id).ok()), api_key_project_id: parsed_api_key_project_id, + evaluation_id: evaluation_id.and_then(|id| uuid::Uuid::parse_str(&id).ok()), error_code: Some(format!("{:?}", status_code)), error_message: Some(error_message), error_type: Some(error_type.to_string()), diff --git a/services/budgateway/tensorzero-internal/src/endpoints/mod.rs b/services/budgateway/tensorzero-internal/src/endpoints/mod.rs index 252fbe8d9..2cd8943c3 100644 --- a/services/budgateway/tensorzero-internal/src/endpoints/mod.rs +++ b/services/budgateway/tensorzero-internal/src/endpoints/mod.rs @@ -1,5 +1,7 @@ use std::collections::HashMap; +use axum::http::{HeaderMap, HeaderValue, Response}; +use axum::body::Body; use crate::error::{Error, ErrorDetails}; pub mod batch_inference; @@ -14,6 +16,29 @@ pub mod object_storage; pub mod openai_compatible; pub mod status; +/// Add auth metadata headers from request to response for analytics tracking. +/// This copies relevant auth headers (evaluation_id, api_key_id, user_id, project_id, endpoint_id) +/// from the request headers to the response headers so the analytics middleware can capture them. +pub fn add_auth_metadata_to_response(response: &mut Response, request_headers: &HeaderMap) { + // List of auth metadata headers to copy from request to response + let auth_headers = [ + "x-tensorzero-evaluation-id", + "x-tensorzero-api-key-id", + "x-tensorzero-user-id", + "x-tensorzero-project-id", + "x-tensorzero-endpoint-id", + "x-tensorzero-api-key-project-id", + ]; + + for header_name in &auth_headers { + if let Some(header_value) = request_headers.get(*header_name) { + if let Ok(cloned_value) = HeaderValue::from_bytes(header_value.as_bytes()) { + response.headers_mut().insert(*header_name, cloned_value); + } + } + } +} + pub fn validate_tags(tags: &HashMap, internal: bool) -> Result<(), Error> { if internal { return Ok(()); diff --git a/services/budgateway/tensorzero-internal/src/endpoints/openai_compatible.rs b/services/budgateway/tensorzero-internal/src/endpoints/openai_compatible.rs index 2d4e54072..a23be71c0 100644 --- a/services/budgateway/tensorzero-internal/src/endpoints/openai_compatible.rs +++ b/services/budgateway/tensorzero-internal/src/endpoints/openai_compatible.rs @@ -284,6 +284,10 @@ pub async fn inference_handler( .get("x-tensorzero-api-key-project-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); + let evaluation_id = headers + .get("x-tensorzero-evaluation-id") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); params.observability_metadata = Some(super::inference::ObservabilityMetadata { project_id: project_id.to_string(), @@ -292,6 +296,7 @@ pub async fn inference_handler( api_key_id, user_id, api_key_project_id, + evaluation_id, }); } @@ -1006,6 +1011,9 @@ pub async fn inference_handler( model_latency_ms.to_string().parse().unwrap(), ); + // Add auth metadata headers to response for analytics middleware + super::add_auth_metadata_to_response(&mut http_response, &headers); + Ok(http_response) } InferenceOutput::Streaming(mut stream) => { @@ -1063,6 +1071,9 @@ pub async fn inference_handler( ); } + // Add auth metadata headers to response for analytics middleware + super::add_auth_metadata_to_response(&mut response, &headers); + Ok(response) } } @@ -3094,6 +3105,10 @@ pub async fn embedding_handler( .get("x-tensorzero-api-key-project-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); + let evaluation_id = headers + .get("x-tensorzero-evaluation-id") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); Some(super::inference::ObservabilityMetadata { project_id: project_id.to_string(), @@ -3102,6 +3117,7 @@ pub async fn embedding_handler( api_key_id, user_id, api_key_project_id, + evaluation_id, }) } else { None @@ -3682,6 +3698,10 @@ pub async fn moderation_handler( .get("x-tensorzero-api-key-project-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); + let evaluation_id = headers + .get("x-tensorzero-evaluation-id") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); Some(super::inference::ObservabilityMetadata { project_id: project_id.to_string(), @@ -3690,6 +3710,7 @@ pub async fn moderation_handler( api_key_id, user_id, api_key_project_id, + evaluation_id, }) } else { None @@ -4987,6 +5008,10 @@ pub async fn image_generation_handler( .get("x-tensorzero-api-key-project-id") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); + let evaluation_id = headers + .get("x-tensorzero-evaluation-id") + .and_then(|v| v.to_str().ok()) + .map(|s| s.to_string()); Some(super::inference::ObservabilityMetadata { project_id: project_id.to_string(), @@ -4995,6 +5020,7 @@ pub async fn image_generation_handler( api_key_id, user_id, api_key_project_id, + evaluation_id, }) } else { None diff --git a/services/budgateway/tensorzero-internal/src/kafka/buffer.rs b/services/budgateway/tensorzero-internal/src/kafka/buffer.rs index d1c1fa748..3e3722d0f 100644 --- a/services/budgateway/tensorzero-internal/src/kafka/buffer.rs +++ b/services/budgateway/tensorzero-internal/src/kafka/buffer.rs @@ -151,6 +151,7 @@ mod tests { api_key_id: None, api_key_project_id: None, user_id: None, + evaluation_id: None, error_code: None, error_message: None, error_type: None, diff --git a/services/budgateway/tensorzero-internal/src/kafka/cloudevents.rs b/services/budgateway/tensorzero-internal/src/kafka/cloudevents.rs index 8214c7a43..adf9cd286 100644 --- a/services/budgateway/tensorzero-internal/src/kafka/cloudevents.rs +++ b/services/budgateway/tensorzero-internal/src/kafka/cloudevents.rs @@ -64,6 +64,10 @@ pub struct ObservabilityEvent { #[serde(skip_serializing_if = "Option::is_none")] pub api_key_project_id: Option, + // Evaluation tracking metadata + #[serde(skip_serializing_if = "Option::is_none")] + pub evaluation_id: Option, + // Error information for failed inferences #[serde(skip_serializing_if = "Option::is_none")] pub error_code: Option, @@ -168,6 +172,7 @@ mod tests { api_key_id: None, user_id: None, api_key_project_id: None, + evaluation_id: None, error_code: None, error_message: None, error_type: None, @@ -200,6 +205,7 @@ mod tests { api_key_id: None, user_id: None, api_key_project_id: None, + evaluation_id: None, error_code: None, error_message: None, error_type: None, diff --git a/services/budgateway/tensorzero-internal/src/kafka/tests.rs b/services/budgateway/tensorzero-internal/src/kafka/tests.rs index f8072afde..e544b4533 100644 --- a/services/budgateway/tensorzero-internal/src/kafka/tests.rs +++ b/services/budgateway/tensorzero-internal/src/kafka/tests.rs @@ -26,6 +26,7 @@ async fn test_mock_kafka_observability_event() { api_key_id: None, api_key_project_id: None, user_id: None, + evaluation_id: None, error_code: None, error_message: None, error_type: None, @@ -82,6 +83,7 @@ async fn test_mock_kafka_buffer_and_batch() { api_key_id: None, api_key_project_id: None, user_id: None, + evaluation_id: None, error_code: None, error_message: None, error_type: None, @@ -125,6 +127,7 @@ async fn test_kafka_disabled() { api_key_id: None, api_key_project_id: None, user_id: None, + evaluation_id: None, error_code: None, error_message: None, error_type: None, @@ -158,6 +161,7 @@ async fn test_kafka_event_validation() { api_key_id: None, api_key_project_id: None, user_id: None, + evaluation_id: None, error_code: None, error_message: None, error_type: None, diff --git a/services/budmetrics/scripts/migrate_clickhouse.py b/services/budmetrics/scripts/migrate_clickhouse.py index fcc313b29..ce3b04074 100755 --- a/services/budmetrics/scripts/migrate_clickhouse.py +++ b/services/budmetrics/scripts/migrate_clickhouse.py @@ -205,6 +205,7 @@ async def create_model_inference_details_table(self): api_key_id Nullable(UUID), user_id Nullable(UUID), api_key_project_id Nullable(UUID), + evaluation_id Nullable(UUID), error_code Nullable(String), error_message Nullable(String), error_type Nullable(String), @@ -230,6 +231,7 @@ async def create_model_inference_details_table(self): "ALTER TABLE ModelInferenceDetails ADD INDEX IF NOT EXISTS idx_api_key_id (api_key_id) TYPE minmax GRANULARITY 1", "ALTER TABLE ModelInferenceDetails ADD INDEX IF NOT EXISTS idx_user_id (user_id) TYPE minmax GRANULARITY 1", "ALTER TABLE ModelInferenceDetails ADD INDEX IF NOT EXISTS idx_api_key_project_id (api_key_project_id) TYPE minmax GRANULARITY 1", + "ALTER TABLE ModelInferenceDetails ADD INDEX IF NOT EXISTS idx_evaluation_id (evaluation_id) TYPE bloom_filter(0.01) GRANULARITY 4", "ALTER TABLE ModelInferenceDetails ADD INDEX IF NOT EXISTS idx_error_type (error_type) TYPE minmax GRANULARITY 1", "ALTER TABLE ModelInferenceDetails ADD INDEX IF NOT EXISTS idx_status_code (status_code) TYPE minmax GRANULARITY 1", ] @@ -1036,6 +1038,64 @@ async def add_error_tracking_columns(self): logger.info("Error tracking columns migration completed successfully") + async def add_evaluation_id_column(self): + """Add evaluation_id column to ModelInferenceDetails table for evaluation tracking. + + This migration adds evaluation_id column to track inferences associated with specific evaluations. + """ + logger.info("Adding evaluation_id column to ModelInferenceDetails table...") + + # Check if the table exists first + try: + table_exists = await self.client.execute_query("EXISTS TABLE ModelInferenceDetails") + if not table_exists or not table_exists[0][0]: + logger.warning("ModelInferenceDetails table does not exist. Skipping evaluation_id migration.") + return + except Exception as e: + logger.error(f"Error checking if ModelInferenceDetails table exists: {e}") + return + + # Check if evaluation_id column already exists + try: + check_column_query = """ + SELECT COUNT(*) + FROM system.columns + WHERE table = 'ModelInferenceDetails' + AND database = currentDatabase() + AND name = 'evaluation_id' + """ + result = await self.client.execute_query(check_column_query) + column_exists = result[0][0] > 0 if result else False + + if not column_exists: + # Add the column + alter_query = """ + ALTER TABLE ModelInferenceDetails + ADD COLUMN IF NOT EXISTS evaluation_id Nullable(UUID) + """ + await self.client.execute_query(alter_query) + logger.info("Added evaluation_id column to ModelInferenceDetails table") + else: + logger.info("evaluation_id column already exists in ModelInferenceDetails table") + + except Exception as e: + if "already exists" in str(e).lower(): + logger.info("evaluation_id column already exists") + else: + logger.error(f"Error adding evaluation_id column: {e}") + return + + # Add index for evaluation_id + try: + index_query = "ALTER TABLE ModelInferenceDetails ADD INDEX IF NOT EXISTS idx_evaluation_id (evaluation_id) TYPE bloom_filter(0.01) GRANULARITY 4" + await self.client.execute_query(index_query) + logger.info("Index idx_evaluation_id created or already exists") + except Exception as e: + if "already exists" not in str(e).lower(): + logger.warning(f"Index creation warning for idx_evaluation_id: {e}") + + logger.info("Evaluation ID column migration completed successfully") + async def setup_cluster_metrics_materialized_views(self): """Set up materialized views for cluster metrics. @@ -1183,6 +1243,7 @@ async def run_migration(self): await self.add_auth_metadata_columns() # Add auth metadata columns migration await self.update_api_key_project_id() # Update api_key_project_id where null await self.add_error_tracking_columns() # Add error tracking columns for failed inferences + await self.add_evaluation_id_column() # Add evaluation_id column for evaluation tracking await self.verify_tables() logger.info("Migration completed successfully!")