Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions hindsight-api/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,13 @@ class RecallResponse(BaseModel):
chunks: dict[str, ChunkData] | None = Field(default=None, description="Chunks for facts, keyed by chunk_id")


class EntityInput(BaseModel):
"""Entity to associate with retained content."""

text: str = Field(description="The entity name/text")
type: str | None = Field(default=None, description="Optional entity type (e.g., 'PERSON', 'ORG', 'CONCEPT')")


class MemoryItem(BaseModel):
"""Single memory item for retain."""

Expand All @@ -292,6 +299,7 @@ class MemoryItem(BaseModel):
"context": "team meeting",
"metadata": {"source": "slack", "channel": "engineering"},
"document_id": "meeting_notes_2024_01_15",
"entities": [{"text": "Alice"}, {"text": "ML model", "type": "CONCEPT"}],
}
},
)
Expand All @@ -301,6 +309,10 @@ class MemoryItem(BaseModel):
context: str | None = None
metadata: dict[str, str] | None = None
document_id: str | None = Field(default=None, description="Optional document ID for this memory item.")
entities: list[EntityInput] | None = Field(
default=None,
description="Optional entities to combine with auto-extracted entities.",
)

@field_validator("timestamp", mode="before")
@classmethod
Expand Down Expand Up @@ -1986,6 +1998,10 @@ async def api_retain(
content_dict["metadata"] = item.metadata
if item.document_id:
content_dict["document_id"] = item.document_id
if item.entities:
content_dict["entities"] = [
{"text": e.text, "type": e.type or "CONCEPT"} for e in item.entities
]
contents.append(content_dict)

if request.async_:
Expand Down
43 changes: 35 additions & 8 deletions hindsight-api/hindsight_api/engine/retain/entity_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@


async def process_entities_batch(
entity_resolver, conn, bank_id: str, unit_ids: list[str], facts: list[ProcessedFact], log_buffer: list[str] = None
entity_resolver,
conn,
bank_id: str,
unit_ids: list[str],
facts: list[ProcessedFact],
log_buffer: list[str] = None,
user_entities_per_content: dict[int, list[dict]] = None,
) -> list[EntityLink]:
"""
Process entities for all facts and create entity links.

This function:
1. Extracts entity mentions from fact texts
2. Resolves entity names to canonical entities
3. Creates entity records in the database
4. Returns entity links ready for insertion
2. Merges user-provided entities with LLM-extracted entities
3. Resolves entity names to canonical entities
4. Creates entity records in the database
5. Returns entity links ready for insertion

Args:
entity_resolver: EntityResolver instance for entity resolution
Expand All @@ -31,6 +38,7 @@ async def process_entities_batch(
unit_ids: List of unit IDs (same length as facts)
facts: List of ProcessedFact objects
log_buffer: Optional buffer for detailed logging
user_entities_per_content: Dict mapping content_index to list of user-provided entities

Returns:
List of EntityLink objects for batch insertion
Expand All @@ -41,14 +49,33 @@ async def process_entities_batch(
if len(unit_ids) != len(facts):
raise ValueError(f"Mismatch between unit_ids ({len(unit_ids)}) and facts ({len(facts)})")

user_entities_per_content = user_entities_per_content or {}

# Extract data for link_utils function
fact_texts = [fact.fact_text for fact in facts]
# Use occurred_start if available, otherwise use mentioned_at for entity timestamps
fact_dates = [fact.occurred_start if fact.occurred_start is not None else fact.mentioned_at for fact in facts]
# Convert EntityRef objects to dict format expected by link_utils
entities_per_fact = [
[{"text": entity.name, "type": "CONCEPT"} for entity in (fact.entities or [])] for fact in facts
]

# Convert EntityRef objects to dict format and merge with user-provided entities
entities_per_fact = []
for fact in facts:
# Start with LLM-extracted entities
llm_entities = [{"text": entity.name, "type": "CONCEPT"} for entity in (fact.entities or [])]

# Get user entities for this content (use content_index from fact)
user_entities = user_entities_per_content.get(fact.content_index, [])

# Merge with case-insensitive deduplication
seen_texts = {e["text"].lower() for e in llm_entities}
for user_entity in user_entities:
if user_entity["text"].lower() not in seen_texts:
llm_entities.append({
"text": user_entity["text"],
"type": user_entity.get("type", "CONCEPT"),
})
seen_texts.add(user_entity["text"].lower())

entities_per_fact.append(llm_entities)

# Use existing link_utils function for entity processing
entity_links = await link_utils.extract_entities_batch_optimized(
Expand Down
13 changes: 12 additions & 1 deletion hindsight-api/hindsight_api/engine/retain/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def retain_batch(
context=item.get("context", ""),
event_date=item.get("event_date") or utcnow(),
metadata=item.get("metadata", {}),
entities=item.get("entities", []),
)
contents.append(content)

Expand Down Expand Up @@ -352,8 +353,18 @@ async def retain_batch(

# Process entities
step_start = time.time()
# Build map of content_index -> user entities for merging
user_entities_per_content = {
idx: content.entities for idx, content in enumerate(contents) if content.entities
}
entity_links = await entity_processing.process_entities_batch(
entity_resolver, conn, bank_id, unit_ids, non_duplicate_facts, log_buffer
entity_resolver,
conn,
bank_id,
unit_ids,
non_duplicate_facts,
log_buffer,
user_entities_per_content=user_entities_per_content,
)
log_buffer.append(f"[6] Process entities: {len(entity_links)} links in {time.time() - step_start:.3f}s")

Expand Down
7 changes: 7 additions & 0 deletions hindsight-api/hindsight_api/engine/retain/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ class RetainContentDict(TypedDict, total=False):
event_date: When the content occurred (optional, defaults to now)
metadata: Custom key-value metadata (optional)
document_id: Document ID for this content item (optional)
entities: User-provided entities to merge with extracted entities (optional)
"""

content: str # Required
context: str
event_date: datetime
metadata: dict[str, str]
document_id: str
entities: list[dict[str, str]] # [{"text": "...", "type": "..."}]


def _now_utc() -> datetime:
Expand All @@ -46,6 +48,7 @@ class RetainContent:
context: str = ""
event_date: datetime = field(default_factory=_now_utc)
metadata: dict[str, str] = field(default_factory=dict)
entities: list[dict[str, str]] = field(default_factory=list) # User-provided entities


@dataclass
Expand Down Expand Up @@ -152,6 +155,9 @@ class ProcessedFact:
# DB fields (set after insertion)
unit_id: UUID | None = None

# Track which content this fact came from (for user entity merging)
content_index: int = 0

@property
def is_duplicate(self) -> bool:
"""Check if this fact was marked as a duplicate."""
Expand Down Expand Up @@ -194,6 +200,7 @@ def from_extracted_fact(
entities=entities,
causal_relations=extracted_fact.causal_relations,
chunk_id=chunk_id,
content_index=extracted_fact.content_index,
)


Expand Down
72 changes: 48 additions & 24 deletions hindsight-clients/python/hindsight_client/hindsight_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ def retain(
context: Optional[str] = None,
document_id: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
entities: Optional[List[Dict[str, str]]] = None,
) -> RetainResponse:
"""
Store a single memory (simplified interface).
Expand All @@ -123,13 +124,14 @@ def retain(
context: Optional context description
document_id: Optional document ID for grouping
metadata: Optional user-defined metadata
entities: Optional list of entities [{"text": "...", "type": "..."}]

Returns:
RetainResponse with success status
"""
return self.retain_batch(
bank_id=bank_id,
items=[{"content": content, "timestamp": timestamp, "context": context, "metadata": metadata}],
items=[{"content": content, "timestamp": timestamp, "context": context, "metadata": metadata, "entities": entities}],
document_id=document_id,
)

Expand All @@ -145,24 +147,34 @@ def retain_batch(

Args:
bank_id: The memory bank ID
items: List of memory items with 'content' and optional 'timestamp', 'context', 'metadata', 'document_id'
items: List of memory items with 'content' and optional 'timestamp', 'context', 'metadata', 'document_id', 'entities'
document_id: Optional document ID for grouping memories (applied to items that don't have their own)
retain_async: If True, process asynchronously in background (default: False)

Returns:
RetainResponse with success status and item count
"""
memory_items = [
memory_item.MemoryItem(
content=item["content"],
timestamp=item.get("timestamp"),
context=item.get("context"),
metadata=item.get("metadata"),
# Use item's document_id if provided, otherwise fall back to batch-level document_id
document_id=item.get("document_id") or document_id,
from hindsight_client_api.models.entity_input import EntityInput

memory_items = []
for item in items:
entities = None
if item.get("entities"):
entities = [
EntityInput(text=e["text"], type=e.get("type"))
for e in item["entities"]
]
memory_items.append(
memory_item.MemoryItem(
content=item["content"],
timestamp=item.get("timestamp"),
context=item.get("context"),
metadata=item.get("metadata"),
# Use item's document_id if provided, otherwise fall back to batch-level document_id
document_id=item.get("document_id") or document_id,
entities=entities,
)
)
for item in items
]

request_obj = retain_request.RetainRequest(
items=memory_items,
Expand Down Expand Up @@ -312,24 +324,34 @@ async def aretain_batch(

Args:
bank_id: The memory bank ID
items: List of memory items with 'content' and optional 'timestamp', 'context', 'metadata', 'document_id'
items: List of memory items with 'content' and optional 'timestamp', 'context', 'metadata', 'document_id', 'entities'
document_id: Optional document ID for grouping memories (applied to items that don't have their own)
retain_async: If True, process asynchronously in background (default: False)

Returns:
RetainResponse with success status and item count
"""
memory_items = [
memory_item.MemoryItem(
content=item["content"],
timestamp=item.get("timestamp"),
context=item.get("context"),
metadata=item.get("metadata"),
# Use item's document_id if provided, otherwise fall back to batch-level document_id
document_id=item.get("document_id") or document_id,
from hindsight_client_api.models.entity_input import EntityInput

memory_items = []
for item in items:
entities = None
if item.get("entities"):
entities = [
EntityInput(text=e["text"], type=e.get("type"))
for e in item["entities"]
]
memory_items.append(
memory_item.MemoryItem(
content=item["content"],
timestamp=item.get("timestamp"),
context=item.get("context"),
metadata=item.get("metadata"),
# Use item's document_id if provided, otherwise fall back to batch-level document_id
document_id=item.get("document_id") or document_id,
entities=entities,
)
)
for item in items
]

request_obj = retain_request.RetainRequest(
items=memory_items,
Expand All @@ -346,6 +368,7 @@ async def aretain(
context: Optional[str] = None,
document_id: Optional[str] = None,
metadata: Optional[Dict[str, str]] = None,
entities: Optional[List[Dict[str, str]]] = None,
) -> RetainResponse:
"""
Store a single memory (async).
Expand All @@ -357,13 +380,14 @@ async def aretain(
context: Optional context description
document_id: Optional document ID for grouping
metadata: Optional user-defined metadata
entities: Optional list of entities [{"text": "...", "type": "..."}]

Returns:
RetainResponse with success status
"""
return await self.aretain_batch(
bank_id=bank_id,
items=[{"content": content, "timestamp": timestamp, "context": context, "metadata": metadata}],
items=[{"content": content, "timestamp": timestamp, "context": context, "metadata": metadata, "entities": entities}],
document_id=document_id,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from hindsight_client_api.models.document_response import DocumentResponse
from hindsight_client_api.models.entity_detail_response import EntityDetailResponse
from hindsight_client_api.models.entity_include_options import EntityIncludeOptions
from hindsight_client_api.models.entity_input import EntityInput
from hindsight_client_api.models.entity_list_item import EntityListItem
from hindsight_client_api.models.entity_list_response import EntityListResponse
from hindsight_client_api.models.entity_observation_response import EntityObservationResponse
Expand Down
Loading
Loading