Skip to content

chore(trace ai queries): Parallelize attribute values requests #95013

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 51 additions & 68 deletions src/sentry/api/endpoints/seer_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import hmac
import logging
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any

import sentry_sdk
Expand Down Expand Up @@ -280,68 +281,6 @@ def get_attribute_names(*, org_id: int, project_ids: list[int], stats_period: st
return {"fields": fields}


def get_attribute_values(
*,
fields: list[str],
org_id: int,
project_ids: list[int],
stats_period: str,
limit: int = 100,
sampled: bool = True,
) -> dict:
period = parse_stats_period(stats_period)
if period is None:
period = datetime.timedelta(days=7)

end = datetime.datetime.now()
start = end - period

start_time_proto = ProtobufTimestamp()
start_time_proto.FromDatetime(start)
end_time_proto = ProtobufTimestamp()
end_time_proto.FromDatetime(end)

sampling_mode = (
DownsampledStorageConfig.MODE_NORMAL
if sampled
else DownsampledStorageConfig.MODE_HIGHEST_ACCURACY
)

values = {}
resolver = SearchResolver(
params=SnubaParams(
start=start,
end=end,
),
config=SearchResolverConfig(),
definitions=SPAN_DEFINITIONS,
)

for field in fields:
resolved_field, _ = resolver.resolve_attribute(field)
if resolved_field.proto_definition.type == AttributeKey.Type.TYPE_STRING:

req = TraceItemAttributeValuesRequest(
meta=RequestMeta(
organization_id=org_id,
cogs_category="events_analytics_platform",
referrer=Referrer.SEER_RPC.value,
project_ids=project_ids,
start_timestamp=start_time_proto,
end_timestamp=end_time_proto,
trace_item_type=TraceItemType.TRACE_ITEM_TYPE_SPAN,
downsampled_storage_config=DownsampledStorageConfig(mode=sampling_mode),
),
key=resolved_field.proto_definition,
limit=limit,
)

values_response = snuba_rpc.attribute_values_rpc(req)
values[field] = [value for value in values_response.values]

return {"values": values}


def get_attribute_values_with_substring(
*,
org_id: int,
Expand All @@ -356,9 +295,14 @@ def get_attribute_values_with_substring(
Note: The RPC is guaranteed to not return duplicate values for the same field.
ie: if span.description is requested with both null and "payment" substrings,
the RPC will return the set of values for span.description to avoid duplicates.

TODO: Replace with batch attribute values RPC once available
"""
values: dict[str, set[str]] = {}

if not fields_with_substrings:
return {"values": values}

period = parse_stats_period(stats_period)
if period is None:
period = datetime.timedelta(days=7)
Expand Down Expand Up @@ -386,7 +330,10 @@ def get_attribute_values_with_substring(
definitions=SPAN_DEFINITIONS,
)

for field_with_substring in fields_with_substrings:
def process_field_with_substring(
field_with_substring: dict[str, str],
) -> tuple[str, set[str]] | None:
"""Helper function to process a single field_with_substring request."""
field = field_with_substring["field"]
substring = field_with_substring["substring"]

Expand All @@ -409,10 +356,47 @@ def get_attribute_values_with_substring(
)

values_response = snuba_rpc.attribute_values_rpc(req)
if field in values:
values[field].update({value for value in values_response.values if value})
else:
values[field] = {value for value in values_response.values if value}
return field, {value for value in values_response.values if value}
return None

timeout_seconds = 1.0

with ThreadPoolExecutor(max_workers=min(len(fields_with_substrings), 10)) as executor:
future_to_field = {
executor.submit(
process_field_with_substring, field_with_substring
): field_with_substring
for field_with_substring in fields_with_substrings
}

try:
for future in as_completed(future_to_field, timeout=timeout_seconds):
field_with_substring = future_to_field[future]

try:
result = future.result()
if result is not None:
field, field_values = result
if field in values:
values[field].update(field_values)
else:
values[field] = field_values
except TimeoutError:
logger.warning(
"RPC call timed out after %s seconds for field %s, skipping",
timeout_seconds,
field_with_substring.get("field", "unknown"),
)
except Exception as e:
logger.warning(
"RPC call failed for field %s: %s",
field_with_substring.get("field", "unknown"),
str(e),
)
except TimeoutError:
for future in future_to_field:
future.cancel()
logger.warning("Overall timeout exceeded, cancelled remaining RPC calls")

return {"values": values}

Expand Down Expand Up @@ -496,7 +480,6 @@ def get_attributes_and_values(
"get_error_event_details": get_error_event_details,
"get_profile_details": get_profile_details,
"get_attribute_names": get_attribute_names,
"get_attribute_values": get_attribute_values,
"get_attribute_values_with_substring": get_attribute_values_with_substring,
"get_attributes_and_values": get_attributes_and_values,
}
Expand Down
213 changes: 166 additions & 47 deletions tests/snuba/api/endpoints/test_seer_attributes.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from concurrent.futures import TimeoutError
from unittest.mock import Mock, patch
from uuid import uuid4

from sentry.api.endpoints.seer_rpc import (
get_attribute_names,
get_attribute_values,
get_attribute_values_with_substring,
get_attributes_and_values,
)
Expand Down Expand Up @@ -47,52 +48,6 @@ def test_get_attribute_names(self):
},
}

def test_get_attribute_values(self):
for transaction in ["foo", "bar", "baz"]:
self.store_segment(
self.project.id,
uuid4().hex,
uuid4().hex,
span_id=uuid4().hex[:16],
organization_id=self.organization.id,
parent_span_id=None,
timestamp=before_now(days=0, minutes=10).replace(microsecond=0),
transaction=transaction,
duration=100,
exclusive_time=100,
is_eap=True,
)

attribute_names = get_attribute_names(
org_id=self.organization.id,
project_ids=[self.project.id],
stats_period="7d",
)

result = get_attribute_values(
fields=attribute_names["fields"]["string"],
org_id=self.organization.id,
project_ids=[self.project.id],
stats_period="7d",
sampled=False,
)

assert result == {
"values": {
"span.description": [
"bar",
"baz",
"foo",
],
"transaction": [
"bar",
"baz",
"foo",
],
"project": [],
}
}

def test_get_attribute_values_with_substring(self):
for transaction in ["foo", "bar", "baz"]:
self.store_segment(
Expand Down Expand Up @@ -185,3 +140,167 @@ def test_get_attributes_and_values(self):
},
]
}

def test_get_attribute_values_with_substring_empty_field_list(self):
"""Test handling of empty fields_with_substrings list"""
result = get_attribute_values_with_substring(
org_id=self.organization.id,
project_ids=[self.project.id],
stats_period="7d",
fields_with_substrings=[],
)

expected: dict = {"values": {}}
assert result == expected

def test_get_attribute_values_with_substring_async_success_and_partial_failures(self):
"""Test concurrent execution with successful results, timeouts, and exceptions"""
for transaction in ["foo", "bar"]:
self.store_segment(
self.project.id,
uuid4().hex,
uuid4().hex,
span_id=uuid4().hex[:16],
organization_id=self.organization.id,
parent_span_id=None,
timestamp=before_now(days=0, minutes=10).replace(microsecond=0),
transaction=transaction,
duration=100,
exclusive_time=100,
is_eap=True,
)

with patch("sentry.api.endpoints.seer_rpc.ThreadPoolExecutor") as mock_executor:
mock_executor_instance = Mock()
mock_executor.return_value.__enter__.return_value = mock_executor_instance

mock_future_success = Mock()
mock_future_timeout = Mock()
mock_future_exception = Mock()

mock_future_success.result.return_value = ("transaction", {"foo", "bar"})
mock_future_timeout.result.side_effect = TimeoutError("Individual timeout")
mock_future_exception.result.side_effect = Exception("RPC failed")

mock_executor_instance.submit.side_effect = [
mock_future_success,
mock_future_timeout,
mock_future_exception,
]

fields_with_substrings = [
{"field": "transaction", "substring": "fo"},
{"field": "span.description", "substring": "timeout_field"},
{"field": "span.status", "substring": "error_field"},
]

with patch("sentry.api.endpoints.seer_rpc.as_completed") as mock_as_completed:

def as_completed_side_effect(future_to_field_dict, timeout):
return [mock_future_success, mock_future_timeout, mock_future_exception]

mock_as_completed.side_effect = as_completed_side_effect

result = get_attribute_values_with_substring(
org_id=self.organization.id,
project_ids=[self.project.id],
stats_period="7d",
fields_with_substrings=fields_with_substrings,
sampled=False,
)

assert result == {
"values": {
"transaction": {"foo", "bar"},
}
}

assert mock_executor_instance.submit.call_count == 3
mock_as_completed.assert_called_once()

def test_get_attribute_values_with_substring_overall_timeout(self):
"""Test overall timeout handling with future cancellation"""
self.store_segment(
self.project.id,
uuid4().hex,
uuid4().hex,
span_id=uuid4().hex[:16],
organization_id=self.organization.id,
parent_span_id=None,
timestamp=before_now(days=0, minutes=10).replace(microsecond=0),
transaction="foo",
duration=100,
exclusive_time=100,
is_eap=True,
)

with patch("sentry.api.endpoints.seer_rpc.as_completed") as mock_as_completed:
mock_as_completed.side_effect = TimeoutError("Overall timeout")

with patch("sentry.api.endpoints.seer_rpc.ThreadPoolExecutor") as mock_executor:
mock_executor_instance = Mock()
mock_executor.return_value.__enter__.return_value = mock_executor_instance

mock_future1 = Mock()
mock_future2 = Mock()
mock_executor_instance.submit.side_effect = [mock_future1, mock_future2]

result = get_attribute_values_with_substring(
org_id=self.organization.id,
project_ids=[self.project.id],
stats_period="7d",
fields_with_substrings=[
{"field": "transaction", "substring": "fo"},
{"field": "span.description", "substring": "desc"},
],
sampled=False,
)

assert result == {"values": {}}

mock_future1.cancel.assert_called_once()
mock_future2.cancel.assert_called_once()

def test_get_attribute_values_with_substring_max_workers_limit(self):
"""Test that ThreadPoolExecutor is limited to max 10 workers even with more fields"""
self.store_segment(
self.project.id,
uuid4().hex,
uuid4().hex,
span_id=uuid4().hex[:16],
organization_id=self.organization.id,
parent_span_id=None,
timestamp=before_now(days=0, minutes=10).replace(microsecond=0),
transaction="foo",
duration=100,
exclusive_time=100,
is_eap=True,
)

fields_with_substrings = [
{"field": "transaction", "substring": f"field_{i}"} for i in range(15)
]

with patch("sentry.api.endpoints.seer_rpc.ThreadPoolExecutor") as mock_executor:
mock_executor_instance = Mock()
mock_executor.return_value.__enter__.return_value = mock_executor_instance

mock_futures = [Mock() for _ in range(15)]
for i, future in enumerate(mock_futures):
future.result.return_value = (f"transaction_{i}", {f"value_{i}"})

mock_executor_instance.submit.side_effect = mock_futures

with patch("sentry.api.endpoints.seer_rpc.as_completed") as mock_as_completed:
mock_as_completed.return_value = mock_futures

get_attribute_values_with_substring(
org_id=self.organization.id,
project_ids=[self.project.id],
stats_period="7d",
fields_with_substrings=fields_with_substrings,
sampled=False,
)

mock_executor.assert_called_once_with(max_workers=10)
assert mock_executor_instance.submit.call_count == 15
Loading