Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hindsight-api/hindsight_api/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def create_app(
if mcp_api_enabled:
try:
from .mcp import create_mcp_app

mcp_app = create_mcp_app(memory=memory)
except ImportError as e:
logger.error(f"MCP server requested but dependencies not available: {e}")
Expand Down
64 changes: 58 additions & 6 deletions hindsight-api/hindsight_api/api/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from fastapi import Depends, FastAPI, Header, HTTPException, Query

from hindsight_api.extensions import AuthenticationError


def _parse_metadata(metadata: Any) -> dict[str, Any]:
"""Parse metadata that may be a dict, JSON string, or None."""
Expand All @@ -35,7 +37,7 @@ def _parse_metadata(metadata: Any) -> dict[str, Any]:
from hindsight_api.engine.db_utils import acquire_with_retry
from hindsight_api.engine.memory_engine import Budget, fq_table
from hindsight_api.engine.response_models import VALID_RECALL_FACT_TYPES
from hindsight_api.extensions import HttpExtension, load_extension
from hindsight_api.extensions import HttpExtension, OperationValidationError, load_extension
from hindsight_api.metrics import create_metrics_collector, get_metrics_collector, initialize_metrics
from hindsight_api.models import RequestContext

Expand Down Expand Up @@ -989,6 +991,16 @@ def get_request_context(authorization: str | None = Header(default=None)) -> Req
api_key = authorization.strip()
return RequestContext(api_key=api_key)

# Global exception handler for authentication errors
@app.exception_handler(AuthenticationError)
async def authentication_error_handler(request, exc: AuthenticationError):
from fastapi.responses import JSONResponse

return JSONResponse(
status_code=401,
content={"detail": str(exc)},
)

@app.get(
"/health",
summary="Health check endpoint",
Expand Down Expand Up @@ -1036,6 +1048,8 @@ async def api_graph(
try:
data = await app.state.memory.get_graph_data(bank_id, type, request_context=request_context)
return data
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1082,6 +1096,8 @@ async def api_list(
request_context=request_context,
)
return data
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1198,6 +1214,10 @@ async def api_recall(
)
except HTTPException:
raise
except OperationValidationError as e:
raise HTTPException(status_code=e.status_code, detail=e.reason)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1260,6 +1280,10 @@ async def api_reflect(
structured_output=core_result.structured_output,
)

except OperationValidationError as e:
raise HTTPException(status_code=e.status_code, detail=e.reason)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand All @@ -1280,6 +1304,8 @@ async def api_list_banks(request_context: RequestContext = Depends(get_request_c
try:
banks = await app.state.memory.list_banks(request_context=request_context)
return BankListResponse(banks=banks)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1403,6 +1429,8 @@ async def api_stats(bank_id: str):
failed_operations=failed_operations,
)

except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand All @@ -1427,6 +1455,8 @@ async def api_list_entities(
try:
entities = await app.state.memory.list_entities(bank_id, limit=limit, request_context=request_context)
return EntityListResponse(items=[EntityListItem(**e) for e in entities])
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1464,7 +1494,7 @@ async def api_get_entity(
for obs in entity["observations"]
],
)
except HTTPException:
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback
Expand Down Expand Up @@ -1517,7 +1547,7 @@ async def api_regenerate_entity_observations(
for obs in entity["observations"]
],
)
except HTTPException:
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback
Expand Down Expand Up @@ -1555,6 +1585,8 @@ async def api_list_documents(
bank_id=bank_id, search_query=q, limit=limit, offset=offset, request_context=request_context
)
return data
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1585,7 +1617,7 @@ async def api_get_document(
if not document:
raise HTTPException(status_code=404, detail="Document not found")
return document
except HTTPException:
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback
Expand Down Expand Up @@ -1614,7 +1646,7 @@ async def api_get_chunk(chunk_id: str, request_context: RequestContext = Depends
if not chunk:
raise HTTPException(status_code=404, detail="Chunk not found")
return chunk
except HTTPException:
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback
Expand Down Expand Up @@ -1658,7 +1690,7 @@ async def api_delete_document(
document_id=document_id,
memory_units_deleted=result["memory_units_deleted"],
)
except HTTPException:
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback
Expand All @@ -1683,6 +1715,8 @@ async def api_list_operations(bank_id: str, request_context: RequestContext = De
bank_id=bank_id,
operations=[OperationResponse(**op) for op in operations],
)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1713,6 +1747,8 @@ async def api_cancel_operation(
return CancelOperationResponse(**result)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1744,6 +1780,8 @@ async def api_get_bank_profile(bank_id: str, request_context: RequestContext = D
disposition=DispositionTraits(**disposition_dict),
background=profile["background"],
)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1782,6 +1820,8 @@ async def api_update_bank_disposition(
disposition=DispositionTraits(**disposition_dict),
background=profile["background"],
)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1811,6 +1851,8 @@ async def api_add_bank_background(
response.disposition = DispositionTraits(**result["disposition"])

return response
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1862,6 +1904,8 @@ async def api_create_or_update_bank(
disposition=DispositionTraits(**disposition_dict),
background=final_profile["background"],
)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1889,6 +1933,8 @@ async def api_delete_bank(bank_id: str, request_context: RequestContext = Depend
+ result.get("entities_deleted", 0)
+ result.get("documents_deleted", 0),
)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -1963,6 +2009,10 @@ async def api_retain(
return RetainResponse.model_validate(
{"success": True, "bank_id": bank_id, "items_count": len(contents), "async": False}
)
except OperationValidationError as e:
raise HTTPException(status_code=e.status_code, detail=e.reason)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down Expand Up @@ -2001,6 +2051,8 @@ async def api_clear_bank_memories(
await app.state.memory.delete_bank(bank_id, fact_type=type, request_context=request_context)

return DeleteResponse(success=True)
except (AuthenticationError, HTTPException):
raise
except Exception as e:
import traceback

Expand Down
4 changes: 3 additions & 1 deletion hindsight-api/hindsight_api/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,9 @@ def from_env(cls) -> "HindsightConfig":
lazy_reranker=os.getenv(ENV_LAZY_RERANKER, "false").lower() == "true",
# Observation thresholds
observation_min_facts=int(os.getenv(ENV_OBSERVATION_MIN_FACTS, str(DEFAULT_OBSERVATION_MIN_FACTS))),
observation_top_entities=int(os.getenv(ENV_OBSERVATION_TOP_ENTITIES, str(DEFAULT_OBSERVATION_TOP_ENTITIES))),
observation_top_entities=int(
os.getenv(ENV_OBSERVATION_TOP_ENTITIES, str(DEFAULT_OBSERVATION_TOP_ENTITIES))
),
)

def get_llm_base_url(self) -> str:
Expand Down
19 changes: 11 additions & 8 deletions hindsight-api/hindsight_api/engine/llm_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,7 @@ def __init__(
elif self.provider in ("ollama", "lmstudio"):
# Use dummy key if not provided for local
api_key = self.api_key or "local"
client_kwargs = {
"api_key": api_key,
"base_url": self.base_url,
"max_retries": 0
}
client_kwargs = {"api_key": api_key, "base_url": self.base_url, "max_retries": 0}
if self.timeout:
client_kwargs["timeout"] = self.timeout
self._client = AsyncOpenAI(**client_kwargs)
Expand Down Expand Up @@ -207,7 +203,14 @@ async def call(
# Handle Anthropic provider separately
if self.provider == "anthropic":
return await self._call_anthropic(
messages, response_format, max_completion_tokens, max_retries, initial_backoff, max_backoff, skip_validation, start_time
messages,
response_format,
max_completion_tokens,
max_retries,
initial_backoff,
max_backoff,
skip_validation,
start_time,
)

# Handle Ollama with native API for structured output (better schema enforcement)
Expand Down Expand Up @@ -297,8 +300,8 @@ async def call(
schema_msg + "\n\n" + call_params["messages"][0]["content"]
)
if self.provider not in ("lmstudio", "ollama"):
call_params["response_format"] = {"type": "json_object"}
call_params["response_format"] = {"type": "json_object"}

logger.debug(f"Sending request to {self.provider}/{self.model} (timeout={self.timeout})")
response = await self._client.chat.completions.create(**call_params)
logger.debug(f"Received response from {self.provider}/{self.model}")
Expand Down
30 changes: 23 additions & 7 deletions hindsight-api/hindsight_api/engine/memory_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ async def _validate_operation(self, validation_coro) -> None:

result = await validation_coro
if not result.allowed:
raise OperationValidationError(result.reason or "Operation not allowed")
raise OperationValidationError(result.reason or "Operation not allowed", result.status_code)

async def _authenticate_tenant(self, request_context: "RequestContext | None") -> str:
"""
Expand All @@ -401,7 +401,9 @@ async def _authenticate_tenant(self, request_context: "RequestContext | None") -
if request_context is None:
raise AuthenticationError("RequestContext is required when tenant extension is configured")

# Let AuthenticationError propagate - HTTP layer will convert to 401
tenant_context = await self._tenant_extension.authenticate(request_context)

_current_schema.set(tenant_context.schema_name)
return tenant_context.schema_name

Expand Down Expand Up @@ -2827,13 +2829,16 @@ async def _handle_form_opinion(self, task_dict: dict[str, Any]):
Handler for form opinion tasks.

Args:
task_dict: Dict with keys: 'bank_id', 'answer_text', 'query'
task_dict: Dict with keys: 'bank_id', 'answer_text', 'query', 'tenant_id'
"""
bank_id = task_dict["bank_id"]
answer_text = task_dict["answer_text"]
query = task_dict["query"]
tenant_id = task_dict.get("tenant_id")

await self._extract_and_store_opinions_async(bank_id=bank_id, answer_text=answer_text, query=query)
await self._extract_and_store_opinions_async(
bank_id=bank_id, answer_text=answer_text, query=query, tenant_id=tenant_id
)

async def _handle_reinforce_opinion(self, task_dict: dict[str, Any]):
"""
Expand Down Expand Up @@ -3222,8 +3227,15 @@ def model_json_schema(self):
answer_text = result.strip()

# Submit form_opinion task for background processing
# Pass tenant_id from request context for internal authentication in background task
await self._task_backend.submit_task(
{"type": "form_opinion", "bank_id": bank_id, "answer_text": answer_text, "query": query}
{
"type": "form_opinion",
"bank_id": bank_id,
"answer_text": answer_text,
"query": query,
"tenant_id": getattr(request_context, "tenant_id", None) if request_context else None,
}
)

total_time = time.time() - reflect_start
Expand Down Expand Up @@ -3261,7 +3273,9 @@ def model_json_schema(self):

return result

async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str, query: str):
async def _extract_and_store_opinions_async(
self, bank_id: str, answer_text: str, query: str, tenant_id: str | None = None
):
"""
Background task to extract and store opinions from think response.

Expand All @@ -3271,6 +3285,7 @@ async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str
bank_id: bank IDentifier
answer_text: The generated answer text
query: The original query
tenant_id: Tenant identifier for internal authentication
"""
try:
# Extract opinions from the answer
Expand All @@ -3281,10 +3296,11 @@ async def _extract_and_store_opinions_async(self, bank_id: str, answer_text: str
from datetime import datetime

current_time = datetime.now(UTC)
# Use internal request context for background tasks
# Use internal context with tenant_id for background authentication
# Extension can check internal=True to bypass normal auth
from hindsight_api.models import RequestContext

internal_context = RequestContext()
internal_context = RequestContext(tenant_id=tenant_id, internal=True)
for opinion in new_opinions:
await self.retain_async(
bank_id=bank_id,
Expand Down
Loading