Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
target/
__pycache__/
data/

.env
.python-version
20 changes: 20 additions & 0 deletions examples/rl/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
Make sure you have `sos serve` running!

Run training with:
```sh
uv run train --model <model-name>
```

Run benchmark with:
```sh
uv run benchmark
```

Define your model configs in `all_experiments.py`.
ENV vars:
```
EPHEMERAL=1 # Whether to delete and remove the sandbox after it's used. Set to 0 if you want to inspect them afterwards.
MAX_TURNS=30 # Max turns for a rollout
MAX_MODEL_TOKENS=32000
SHELLM=1 # Whether the model follows standard chat format or SHELLM.
```
54 changes: 54 additions & 0 deletions examples/rl/all_experiments.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import art
from project_types import RunConfig


models: dict[str, art.TrainableModel[RunConfig]] = {}

models["run_1"] = art.TrainableModel[RunConfig](
base_model="deathbyknowledge/Qwen2.5-7B-Shell-SFT",
project="sos-shellm-agent",
name="run_1",
config=RunConfig(
groups_per_step=4,
rollouts_per_group=8,
difficulty=1,
num_epochs=2,
learning_rate=1e-5,
validation_frequency=10,
validation_num_scenarios=20,
training_num_scenarios=508,
),
)

models["run_2"] = art.TrainableModel[RunConfig](
base_model="deathbyknowledge/Qwen2.5-7B-Shell-SFT",
project="sos-shellm-agent",
name="run_2",
config=RunConfig(
groups_per_step=8,
rollouts_per_group=8,
difficulty=1,
num_epochs=2,
learning_rate=1e-5,
validation_frequency=20,
validation_num_scenarios=100,
training_num_scenarios=508,
),
)

# with h200
models["run_3"] = art.TrainableModel[RunConfig](
base_model="deathbyknowledge/Qwen2.5-7B-Shell-SFT",
project="sos-shellm-agent",
name="run_3",
config=RunConfig(
groups_per_step=12,
rollouts_per_group=8,
difficulty=None,
num_epochs=2,
learning_rate=1e-5,
validation_frequency=50,
validation_num_scenarios=100,
training_num_scenarios=5000,
),
)
56 changes: 56 additions & 0 deletions examples/rl/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import art
import asyncio
from rollout import RolloutConfig, rollout_and_score, ShellTrajectory
from load_scenarios import load_scenarios
from tqdm.asyncio import tqdm
import os


async def benchmark(
model: art.Model, num_scenarios: int, difficulty: None | int = None
) -> tuple[list[ShellTrajectory], float]:
scenarios = load_scenarios(limit=num_scenarios, split="test", difficulty=difficulty)
results: list[ShellTrajectory] = await tqdm.gather(
*[
rollout_and_score(model, scenario, config=RolloutConfig(temperature=0.0))
for scenario in scenarios
],
desc=f"benchmarking {model.name}",
)
scores = [result.reward for result in results]
accuracy = sum([result.success_condition_passed for result in results]) / len(
results
)
return results, sum(scores) / len(scores) if scores else 0, accuracy


async def benchmark_all_models(
num_scenarios: int, difficulty: None | int = None
) -> dict[str, float]:
model_names = [
"deathbyknowledge/Qwen2.5-7B-Shell-SFT",
]

models = [
art.Model(
name=name,
project="shell-agent-test",
inference_api_key=os.getenv("INFERENCE_API_KEY", "FAKE_KEY"),
inference_base_url=os.getenv(
"INFERENCE_BASE_URL", "http://localhost:8000/v1"
),
)
for name in model_names
]
results = await asyncio.gather(
*[benchmark(model, num_scenarios, difficulty) for model in models]
)
return {
model.name: {"score": score, "accuracy": accuracy}
for model, (_results, score, accuracy) in zip(models, results)
}


if __name__ == "__main__":
results = asyncio.run(benchmark_all_models(num_scenarios=20))
print(results)
43 changes: 43 additions & 0 deletions examples/rl/load_scenarios.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from project_types import Scenario
from typing import Optional, Literal
from datasets import load_dataset, Dataset
import random

HF_REPO_ID = "deathbyknowledge/shell-tasks"

def load_scenarios(
split: Literal["train", "test"] = "train",
limit: Optional[int] = None,
shuffle: bool = False,
seed: Optional[int] = None,
difficulty: Optional[Literal[1, 2, 3, 4]] = None,
):
dataset: Dataset = load_dataset(HF_REPO_ID, split=split)

if difficulty is not None:
dataset = dataset.filter(lambda x: x["difficulty_level"] == difficulty)

if shuffle or (seed is not None):
if seed is not None:
dataset = dataset.shuffle(seed=seed)
else:
dataset = dataset.shuffle()

# Convert each row (dict) in the dataset to a Scenario object
scenarios = [
Scenario(id=row['task_id'], task=row['task'], # type: ignore
setup_commands=row['setup_commands'],success_condition=row['success_condition']) # type: ignore
for row in dataset # type: ignore
]

if shuffle:
if seed is not None:
rng = random.Random(seed)
rng.shuffle(scenarios)
else:
random.shuffle(scenarios)

if limit is not None:
return scenarios[:limit]
else:
return scenarios
23 changes: 23 additions & 0 deletions examples/rl/project_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import List, Optional, Literal
from pydantic import BaseModel, Field

class Message(BaseModel):
role: Literal['assistant', 'user', 'system']
content: str

class Scenario(BaseModel):
id: str
task: str
setup_commands: List[str]
success_condition: str


class RunConfig(BaseModel):
num_epochs: int = 2
groups_per_step: int = 12
validation_frequency: int = 10
validation_num_scenarios: int = 20
training_num_scenarios: int = 1000
rollouts_per_group: int = 8
learning_rate: float = 1e-5
difficulty: Optional[Literal[1, 2, 3, 4]] = None
19 changes: 19 additions & 0 deletions examples/rl/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
[project]
name = "rl"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"asyncio>=4.0.0",
"datasets>=4.0.0",
"httpx>=0.28.1",
"openai>=1.99.6",
"openpipe-art[backend]>=0.4.4",
"pydantic>=2.11.7",
"rich>=14.1.0",
"tenacity>=9.1.2",
"transformers<4.54.0",
"trl==0.19.1",
]

Loading