Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Environment files
.env
.env.*

# Local caches and artifacts
.cache/
.prompt_history/
**/bin*.json
**/bin*.json.lock
*.log

# Python stuff
*.egg-info
.ruff_cache
Expand Down
2 changes: 1 addition & 1 deletion safetytooling/apis/finetuning/openai/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def openai_check_finetuning_data(
+ f"which is over the max context length of {context_length}."
+ "It will be truncated during fine-tuning"
)
prompts[0].pretty_print([], print_fn=lambda x, *_, **__: logger.warn(x))
prompts[0].pretty_print([], print_fn=lambda x, *_, **__: logger.warning(x))

# batch_size set to 0.2% of dataset size by default
# https://community.openai.com/t/why-is-the-default-batch-size-set-to-1-for-fine-tuning-the-chatgpt-turbo-model/513129
Expand Down
6 changes: 5 additions & 1 deletion safetytooling/apis/inference/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(
num_threads: int,
prompt_history_dir: Path | None = None,
anthropic_api_key: str | None = None,
progress_monitor: object | None = None,
):
self.num_threads = num_threads
self.prompt_history_dir = prompt_history_dir
Expand All @@ -50,6 +51,7 @@ def __init__(

self.available_requests = asyncio.BoundedSemaphore(int(self.num_threads))
self.kwarg_change_name = {"stop": "stop_sequences"}
self.progress_monitor = progress_monitor

async def __call__(
self,
Expand Down Expand Up @@ -107,7 +109,7 @@ async def __call__(
error_info = (
f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}"
)
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
error_list.append(error_info)
api_duration = time.time() - api_start
await asyncio.sleep(1.5**i)
Expand Down Expand Up @@ -208,6 +210,8 @@ async def __call__(
if print_prompt_and_response:
prompt.pretty_print(responses)

# Progress monitoring is handled centrally in InferenceAPI to avoid double counting.

return responses

def make_stream_api_call(
Expand Down
61 changes: 57 additions & 4 deletions safetytooling/apis/inference/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from .anthropic import ANTHROPIC_MODELS, AnthropicChatModel
from .cache_manager import BaseCacheManager, get_cache_manager
from .openai.base import OpenAIModel
from .progress_monitor import RateLimitProgressMonitor
from .gemini.genai import GeminiModel
from .gemini.vertexai import GeminiVertexAIModel
from .gray_swan import GRAYSWAN_MODELS, GraySwanChatModel
Expand Down Expand Up @@ -90,6 +92,7 @@ def __init__(
vllm_base_url: str = "http://localhost:8000/v1/chat/completions",
no_cache: bool = False,
oai_embedding_batch_size: int = 2048,
show_progress: bool = False,
):
"""
Set prompt_history_dir to None to disable saving prompt history.
Expand Down Expand Up @@ -153,18 +156,22 @@ def __init__(
self.cache_manager = get_cache_manager(self.cache_dir, self.use_redis)
print(f"{self.cache_manager=}")

self.progress_monitor = RateLimitProgressMonitor(disable=not show_progress)

self._openai_completion = OpenAICompletionModel(
frac_rate_limit=self.openai_fraction_rate_limit,
prompt_history_dir=self.prompt_history_dir,
base_url=self.openai_base_url,
openai_api_key=openai_api_key,
progress_monitor=self.progress_monitor,
)

self._openai_chat = OpenAIChatModel(
frac_rate_limit=self.openai_fraction_rate_limit,
prompt_history_dir=self.prompt_history_dir,
base_url=self.openai_base_url,
openai_api_key=openai_api_key,
progress_monitor=self.progress_monitor,
)

self._openai_moderation = OpenAIModerationModel()
Expand All @@ -176,30 +183,35 @@ def __init__(
num_threads=self.anthropic_num_threads,
prompt_history_dir=self.prompt_history_dir,
anthropic_api_key=anthropic_api_key,
progress_monitor=self.progress_monitor,
)

self._huggingface = HuggingFaceModel(
num_threads=self.huggingface_num_threads,
prompt_history_dir=self.prompt_history_dir,
token=os.environ.get("HF_TOKEN", None),
progress_monitor=self.progress_monitor,
)

self._gray_swan = GraySwanChatModel(
num_threads=self.gray_swan_num_threads,
prompt_history_dir=self.prompt_history_dir,
api_key=os.environ.get("GRAYSWAN_API_KEY", None),
api_key=(os.environ.get("GRAYSWAN_API_KEY") or None),
progress_monitor=self.progress_monitor,
)

self._together = TogetherChatModel(
num_threads=self.together_num_threads,
prompt_history_dir=self.prompt_history_dir,
api_key=os.environ.get("TOGETHER_API_KEY", None),
api_key=(os.environ.get("TOGETHER_API_KEY") or None),
progress_monitor=self.progress_monitor,
)

self._openrouter = OpenRouterChatModel(
num_threads=self.openrouter_num_threads,
prompt_history_dir=self.prompt_history_dir,
api_key=os.environ.get("OPENROUTER_API_KEY", None),
api_key=(os.environ.get("OPENROUTER_API_KEY") or None),
progress_monitor=self.progress_monitor,
)

self._gemini_vertex = GeminiVertexAIModel(prompt_history_dir=self.prompt_history_dir)
Expand All @@ -214,7 +226,8 @@ def __init__(
num_threads=vllm_num_threads,
prompt_history_dir=self.prompt_history_dir,
vllm_base_url=self.vllm_base_url,
runpod_api_key=os.environ.get("RUNPOD_API_KEY", None),
runpod_api_key=(os.environ.get("RUNPOD_API_KEY") or None),
progress_monitor=self.progress_monitor,
)

# DeepSeek uses the OpenAI API
Expand All @@ -223,6 +236,7 @@ def __init__(
prompt_history_dir=self.prompt_history_dir,
base_url=DEEPSEEK_BASE_URL,
openai_api_key=os.environ.get("DEEPSEEK_API_KEY", None),
progress_monitor=self.progress_monitor,
)

self._batch_models = {}
Expand Down Expand Up @@ -643,6 +657,45 @@ async def __call__(
self.model_timings.setdefault(response.model_id, []).append(response.api_duration)
self.model_wait_times.setdefault(response.model_id, []).append(response.duration - response.api_duration)

# Update progress monitor with usage info
if hasattr(self, "progress_monitor") and self.progress_monitor is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is the reason more of this can't go in the class and happen when you run self.progress_monitor.update_openai_usage in each model class?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, It’s centralized now, and the progress update happens only in InferenceAPI.call

try:
total_in_tokens = 0
total_out_tokens = 0
for response in candidate_responses:
usage = getattr(response, "usage", None)
if usage is not None:
total_in_tokens += int(getattr(usage, "input_tokens", 0) or 0)
total_out_tokens += int(getattr(usage, "output_tokens", 0) or 0)

# Register caps if OpenAI and resources are available
if isinstance(model_class, OpenAIModel):
req_res = getattr(model_class, "request_capacity", {}).get(model_id, None)
tok_res = getattr(model_class, "token_capacity", {}).get(model_id, None)
if req_res is not None and tok_res is not None:
self.progress_monitor.register_openai_model(model_id, req_res, tok_res)
else:
# requests-only when no resources are present
self.progress_monitor.register_generic_model(model_id, show_token_bars=False)
else:
# Only requests unless exact token usage is known
show_token_bars = total_in_tokens > 0 or total_out_tokens > 0
self.progress_monitor.register_generic_model(model_id, show_token_bars=show_token_bars)

# Increment requests by the number of real API calls made in this branch
request_increment = 1
if isinstance(model_class, AnthropicChatModel) or isinstance(model_class, HuggingFaceModel) or isinstance(model_class, GeminiModel) or isinstance(model_class, GeminiVertexAIModel):
request_increment = num_candidates

await self.progress_monitor.update_usage(
model_id=model_id,
input_tokens=(total_in_tokens if (total_in_tokens > 0) else None),
output_tokens=(total_out_tokens if (total_out_tokens > 0) else None),
request_increment=request_increment,
)
except Exception:
pass

return responses

async def ask_single_question(
Expand Down
2 changes: 1 addition & 1 deletion safetytooling/apis/inference/gemini/genai.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ async def attempt_api_call(model_id):
raise e
except Exception as e:
error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}"
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying in {1.5**i} seconds. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying in {1.5**i} seconds. (Attempt {i})")
await asyncio.sleep(1.5**i)
else:
break
Expand Down
2 changes: 1 addition & 1 deletion safetytooling/apis/inference/gemini/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ async def attempt_api_call(model_id):

except Exception as e:
error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}"
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
await asyncio.sleep(1.5**i)
else:
break
Expand Down
4 changes: 3 additions & 1 deletion safetytooling/apis/inference/gray_swan.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
num_threads: int,
prompt_history_dir: Path | None = None,
api_key: str | None = None,
progress_monitor: object | None = None,
):
self.num_threads = num_threads
self.prompt_history_dir = prompt_history_dir
Expand All @@ -32,6 +33,7 @@ def __init__(
else:
self.aclient = None
self.available_requests = asyncio.BoundedSemaphore(int(self.num_threads))
self.progress_monitor = progress_monitor

async def __call__(
self,
Expand Down Expand Up @@ -71,7 +73,7 @@ async def __call__(
raise e
except Exception as e:
error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}"
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
error_list.append(error_info)
api_duration = time.time() - api_start
await asyncio.sleep(1.5**i)
Expand Down
8 changes: 6 additions & 2 deletions safetytooling/apis/inference/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(
num_threads: int,
token: str,
prompt_history_dir: Path | None = None,
progress_monitor: object | None = None,
):
self.num_threads = num_threads
self.prompt_history_dir = prompt_history_dir
Expand All @@ -56,6 +57,7 @@ def __init__(
}

self.tokenizers = {}
self.progress_monitor = progress_monitor

def count_tokens(self, text, model_name):
if model_name not in self.tokenizers:
Expand Down Expand Up @@ -169,10 +171,10 @@ async def __call__(
raise e

if "503 Service Unavailable" in str(e):
LOGGER.warn(f"503 Service Unavailable error. Waiting 60 seconds before retrying. (Attempt {i})")
LOGGER.warning(f"503 Service Unavailable error. Waiting 60 seconds before retrying. (Attempt {i})")
await asyncio.sleep(60)
else:
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
await asyncio.sleep(1.5**i)
else:
break
Expand Down Expand Up @@ -208,4 +210,6 @@ async def __call__(
if print_prompt_and_response:
prompt.pretty_print(responses)

# Progress monitoring is handled centrally in InferenceAPI to avoid double counting.

return responses
15 changes: 14 additions & 1 deletion safetytooling/apis/inference/openai/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,14 @@ def __init__(
prompt_history_dir: Path | None = None,
base_url: str | None = None,
openai_api_key: str | None = None,
progress_monitor: object | None = None,
):
self.frac_rate_limit = frac_rate_limit
self.prompt_history_dir = prompt_history_dir
self.model_ids = set()
self.base_url = base_url
self.openai_api_key = openai_api_key
self.progress_monitor = progress_monitor

if openai_api_key:
self.aclient = openai.AsyncClient(api_key=openai_api_key, base_url=self.base_url)
Expand Down Expand Up @@ -125,6 +127,17 @@ async def add_model_id(self, model_id: str):
self.token_capacity[model_id] = token_capacity
self.request_capacity[model_id] = request_capacity

# Register model with progress monitor
if self.progress_monitor is not None:
try:
self.progress_monitor.register_openai_model(
model_id=model_id,
request_resource=request_capacity,
token_resource=token_capacity,
)
except Exception:
pass

async def __call__(
self,
model_id: str,
Expand Down Expand Up @@ -168,7 +181,7 @@ async def attempt_api_call():
raise e
except Exception as e:
error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}"
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
await asyncio.sleep(1.5**i)
else:
break
Expand Down
2 changes: 1 addition & 1 deletion safetytooling/apis/inference/openai/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ async def embed(
error_info = (
f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {traceback.format_exc()}"
)
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
time.sleep(1.5**i)

raise RuntimeError(f"Failed to get a response from the API after {max_attempts} attempts.")
2 changes: 1 addition & 1 deletion safetytooling/apis/inference/openai/moderation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def _single_moderation_request(
return response
except Exception as e:
error_info = f"Exception Type: {type(e).__name__}, Error Details: {str(e)}, Traceback: {format_exc()}"
LOGGER.warn(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
LOGGER.warning(f"Encountered API error: {error_info}.\nRetrying now. (Attempt {i})")
time.sleep(1.5**i)

raise RuntimeError(f"Failed to get a response from the API after {max_attempts} attempts.")
Expand Down
Loading