Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions build/lib/safetytooling/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .inference.api import InferenceAPI

__all__ = ["InferenceAPI"]
248 changes: 248 additions & 0 deletions build/lib/safetytooling/apis/batch_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,248 @@
import asyncio
import logging
import os
from pathlib import Path
from typing import Literal

from safetytooling.apis.inference.anthropic import ANTHROPIC_MODELS, AnthropicModelBatch
from safetytooling.apis.inference.cache_manager import get_cache_manager
from safetytooling.apis.inference.openai.batch_api import OpenAIModelBatch
from safetytooling.apis.inference.openai.utils import GPT_CHAT_MODELS
from safetytooling.data_models import LLMParams, LLMResponse, Prompt

LOGGER = logging.getLogger(__name__)


def chunk_prompts_for_anthropic(prompts: list[Prompt], chunk_size: int = 100_000) -> list[list[Prompt]]:
# Calculate size per prompt in MB
sizes = [len(str(prompt).encode("utf-8")) / (1024 * 1024) for prompt in prompts]

chunks = []
current_chunk = []
current_size = 0
current_count = 0

for prompt, size in zip(prompts, sizes):
# Check if adding this prompt would exceed either limit
if (
current_size + size > 250 or current_count + 1 > chunk_size
): # 256 MB is the max size for Anthropic, but 250 MB is a safe limit
if current_chunk:
chunks.append(current_chunk)
print(f"Chunked prompts for Anthropic: {len(current_chunk)} prompts, {current_size} MB")
current_chunk = [prompt]
current_size = size
current_count = 1
else:
current_chunk.append(prompt)
current_size += size
current_count += 1

if current_chunk:
chunks.append(current_chunk)
print(f"Total chunks: {len(chunks)}")
return chunks


class BatchInferenceAPI:
"""A wrapper around batch APIs that can do a bunch of model calls at once."""

def __init__(
self,
log_dir: Path | Literal["default"] = "default",
prompt_history_dir: Path | Literal["default"] | None = "default",
cache_dir: Path | Literal["default"] | None = "default",
use_redis: bool = False,
anthropic_api_key: str | None = None,
openai_api_key: str | None = None,
no_cache: bool = False,
):
"""Initialize the BatchInferenceAPI.

Args:
log_dir: Directory to store prompts before passing them to the batchapi. Set to "default" to use your prompt_history_dir as log_dir.
prompt_history_dir: Directory to store prompt history. Set to None to disable.
Set to "default" to use the default directory.
cache_dir: Directory to store cache. Set to None to disable.
Set to "default" to use the default directory.
use_redis: Whether to use Redis for caching.
anthropic_api_key: Optional API key for Anthropic. If not provided, will try to load from environment.
openai_api_key: Optional API key for OpenAI. If not provided, will try to load from environment.
no_cache: If True, disable caching regardless of other settings.
chunk: Maximum number of prompts to send in a single batch. If None, send all prompts in one batch.
"""

# Check NO_CACHE in os.environ
self.no_cache = no_cache or os.environ.get("NO_CACHE", "false").lower() == "true"

if log_dir == "default":
if "PROMPT_HISTORY_DIR" in os.environ:
self.log_dir = Path(os.environ["PROMPT_HISTORY_DIR"])
else:
self.log_dir = Path.home() / ".prompt_history"
else:
assert isinstance(log_dir, Path) or log_dir is None
self.log_dir = log_dir

if cache_dir == "default":
if "CACHE_DIR" in os.environ:
self.cache_dir = Path(os.environ["CACHE_DIR"])
else:
self.cache_dir = Path.home() / ".cache"
else:
assert isinstance(cache_dir, Path) or cache_dir is None
self.cache_dir = cache_dir

# Check REDIS_CACHE in os.environ
self.use_redis = use_redis or os.environ.get("REDIS_CACHE", "false").lower() == "true"
self.cache_manager = None
if self.cache_dir is not None and not self.no_cache:
self.cache_manager = get_cache_manager(self.cache_dir, self.use_redis)
print(f"{self.cache_manager=}")

self._anthropic_batch = AnthropicModelBatch(anthropic_api_key=anthropic_api_key)
self._openai_batch = OpenAIModelBatch(openai_api_key=openai_api_key)

async def __call__(
self,
model_id: str,
prompts: list[Prompt],
log_dir: Path | None = None,
use_cache: bool = True,
max_tokens: int | None = None,
chunk: int | None = None,
**kwargs,
) -> tuple[list[LLMResponse], str]:
"""Make batch API requests for the specified model and prompts.

Args:
model_id: The model to call. Must be a Claude or OpenAI model.
prompts: List of prompts to send to the model.
log_dir: Directory to store logs. If None, will use prompt_history_dir.
use_cache: Whether to use caching. Ignored if no_cache was set during initialization.
max_tokens: Maximum number of tokens to generate per prompt. Setting different max_tokens for each prompt is not yet supported here, although it is in principle supported in the underlying API.
**kwargs: Additional arguments to pass to the model. "seed" should be used for preventing cached responses from being used.

Returns:
Tuple of (list of responses, batch ID)
"""
if self.no_cache:
use_cache = False

if not prompts:
raise ValueError("No prompts provided")

if log_dir is None:
log_dir = self.log_dir

if log_dir is not None:
log_dir.mkdir(parents=True, exist_ok=True)

if model_id in ANTHROPIC_MODELS:
assert max_tokens is not None, "max_tokens must be provided for Anthropic models"
kwargs["max_tokens"] = max_tokens

llm_params = LLMParams(model=model_id, **kwargs)
cached_responses = []
cached_indices = []
uncached_prompts = []
uncached_indices = []

# Check cache for each prompt only if caching is enabled
if self.cache_manager is not None and use_cache:
if "n" in kwargs:
assert kwargs["n"] == 1, "n must be 1 for batch inference + cache"
for i, prompt in enumerate(prompts):
cached_result = self.cache_manager.maybe_load_cache(prompt=prompt, params=llm_params)
if cached_result is not None:
cached_responses.append(cached_result.responses[0]) # Assuming single response per prompt
cached_indices.append(i)
else:
uncached_prompts.append(prompt)
uncached_indices.append(i)
print(f"Retrieving {len(cached_responses)}/{len(prompts)} from cache")
else:
uncached_prompts = prompts
uncached_indices = list(range(len(prompts)))

# If all prompts are cached, return cached responses
if not uncached_prompts:
return cached_responses, "cached"

if chunk is None:
if model_id in ANTHROPIC_MODELS:
chunk = 100_000
elif model_id in GPT_CHAT_MODELS:
chunk = 50_000
else:
raise ValueError(f"Model {model_id} is not supported. Only Claude and OpenAI models are supported.")

if model_id in ANTHROPIC_MODELS:
chunks = chunk_prompts_for_anthropic(uncached_prompts, chunk_size=chunk)
else:
chunks = [uncached_prompts[i : i + chunk] for i in range(0, len(uncached_prompts), chunk)]

async def process_chunk(chunk_prompts) -> tuple[list[LLMResponse], str]:
print(f"Batch kwargs: {kwargs}")
if model_id in ANTHROPIC_MODELS:
return await self._anthropic_batch(
model_id=model_id,
prompts=chunk_prompts,
max_tokens=kwargs.get("max_tokens", None),
**{k: v for k, v in kwargs.items() if k != "seed" and k != "max_tokens"},
)
elif model_id in GPT_CHAT_MODELS:
return await self._openai_batch(
model_id=model_id,
prompts=chunk_prompts,
log_dir=log_dir,
max_tokens=kwargs.get("max_tokens", None),
**{k: v for k, v in kwargs.items() if k != "max_tokens"},
)
else:
raise ValueError(f"Model {model_id} is not supported. Only Claude and OpenAI models are supported.")

# Process uncached prompts in chunks if specified
if len(chunks) > 1:
all_responses = []
all_batch_ids = []

results = await asyncio.gather(*[process_chunk(chunk) for chunk in chunks])

for chunk_prompts, (responses, batch_id) in zip(chunks, results):
all_responses.extend(responses)
all_batch_ids.append(batch_id)

# Save responses to cache
if self.cache_manager is not None and use_cache:
for prompt, response in zip(chunk_prompts, responses):
if response is not None:
self.cache_manager.save_cache(prompt=prompt, params=llm_params, responses=[response])

else:
# Process all uncached prompts in a single batch
responses, batch_id = await process_chunk(uncached_prompts)
all_responses = responses
all_batch_ids = [batch_id]

# Save responses to cache
if self.cache_manager is not None and use_cache:
for prompt, response in zip(uncached_prompts, responses):
if response is not None:
self.cache_manager.save_cache(prompt=prompt, params=llm_params, responses=[response])

# Merge cached and new responses
final_responses = [None] * len(prompts)
for cached_response, idx in zip(cached_responses, cached_indices):
final_responses[idx] = cached_response
for idx, response in zip(uncached_indices, all_responses):
final_responses[idx] = response

if use_cache:
print(
f"{len(cached_responses)}/{len(prompts)} retrieved from cache, "
f"{sum(response is None for response in final_responses)}/{len(prompts)} failed, "
f"{sum(response is not None for response in final_responses)}/{len(prompts)} succeeded in batch API requests {all_batch_ids}"
)

return final_responses, ",".join(all_batch_ids)
Empty file.
Loading
Loading