Skip to content

Commit 5359996

Browse files
authored
[WIP] Arbor Interface Update (#8837)
* Add GPU sharing * Pass through GPUConfig * Fix typecheck of typeddict * Fix URL join * Remove Arbor GRPO single GPU config * Update arbor endpoints * Update Arbor GRPO interface * Update Arbor docs * Fix import * Fix typechecking * Revert import * Fix ruff errors
1 parent 50185ca commit 5359996

File tree

7 files changed

+68
-71
lines changed

7 files changed

+68
-71
lines changed

docs/docs/tutorials/rl_multihop/index.ipynb

Lines changed: 14 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,8 @@
1111
"For this tutorial, you will also need DSPy's Arbor RL server.\n",
1212
"\n",
1313
"```bash\n",
14-
"> pip install arbor-ai\n",
15-
"> python -m arbor.cli serve --arbor-config arbor.yaml\n",
16-
"```\n",
17-
"\n",
18-
"where you create `arbor.yaml` in your directory, containing a plan like:\n",
19-
"\n",
20-
"```text\n",
21-
"inference:\n",
22-
" gpu_ids: '0'\n",
23-
"\n",
24-
"training:\n",
25-
" gpu_ids: '1, 2'\n",
26-
"```\n",
27-
"\n",
28-
"which assigns GPU 0 for inference and GPUs 1 and 2 for training."
14+
"> pip install -U arbor-ai\n",
15+
"```"
2916
]
3017
},
3118
{
@@ -37,14 +24,16 @@
3724
"import dspy\n",
3825
"from dspy.clients.lm_local_arbor import ArborProvider\n",
3926
"\n",
27+
"import arbor\n",
28+
"arbor_server_info = arbor.init() # Initialize the Arbor server in the background\n",
29+
"\n",
4030
"port = 7453\n",
4131
"local_lm_name = \"Qwen/Qwen2.5-7B-Instruct\"\n",
4232
"local_lm = dspy.LM(\n",
4333
" model=f\"openai/arbor:{local_lm_name}\",\n",
4434
" provider=ArborProvider(),\n",
4535
" temperature=0.7,\n",
46-
" api_base=f\"http://localhost:{port}/v1/\",\n",
47-
" api_key=\"arbor\",\n",
36+
" api_base=arbor_server_info[\"api_base\"],\n",
4837
")\n",
4938
"\n",
5039
"dspy.configure(lm=local_lm)\n",
@@ -238,17 +227,18 @@
238227
"outputs": [],
239228
"source": [
240229
"from dspy.teleprompt.grpo import GRPO\n",
230+
"from dspy.clients.utils_finetune import MultiGPUConfig\n",
241231
"\n",
242232
"program = ResearchHop(num_docs=4, num_hops=2)\n",
243233
"program.set_lm(local_lm)\n",
244234
"\n",
245235
"# NOTE: Training on 6 GPUs.\n",
246236
"train_kwargs = {\n",
247237
" \"per_device_train_batch_size\": 2,\n",
248-
" \"gradient_accumulation_steps\": 4,\n",
249-
" \"temperature\": 0.7,\n",
238+
" \"gradient_accumulation_steps\": 8,\n",
239+
" \"temperature\": 1.0,\n",
250240
" \"beta\": 0.04,\n",
251-
" \"learning_rate\": 2e-5,\n",
241+
" \"learning_rate\": 1e-5,\n",
252242
" \"gradient_checkpointing\": True,\n",
253243
" \"gradient_checkpointing_kwargs\": {\"use_reentrant\": False},\n",
254244
" \"bf16\": True,\n",
@@ -262,16 +252,16 @@
262252
"\n",
263253
"compiler = GRPO(\n",
264254
" metric=recall,\n",
265-
" multitask=True,\n",
266255
" num_dspy_examples_per_grpo_step=6,\n",
267-
" num_samples_per_input=8,\n",
256+
" num_rollouts_per_grpo_step=4,\n",
268257
" exclude_demos=True,\n",
269-
" num_train_steps=500,\n",
270-
" num_threads=24,\n",
258+
" num_train_steps=100,\n",
259+
" num_threads=16,\n",
271260
" use_train_as_val=False,\n",
272261
" num_steps_for_val=10,\n",
273262
" train_kwargs=train_kwargs,\n",
274263
" report_train_scores=False,\n",
264+
" gpu_config=MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1),\n",
275265
")\n",
276266
"\n",
277267
"optimized_program = compiler.compile(\n",
@@ -304,11 +294,6 @@
304294
"source": [
305295
"In our preliminary experiments, training above for about 18 hours boosts the recall (devset) from 61.8% to 66.2%. This is _typically_ worse on cost/quality basis than you'd get from running prompt optimizers dspy.MIPROv2 or dspy.SIMBA, but it's still a very solid start for online RL over arbitrary LM programs for small LMs."
306296
]
307-
},
308-
{
309-
"cell_type": "markdown",
310-
"metadata": {},
311-
"source": []
312297
}
313298
],
314299
"metadata": {

docs/docs/tutorials/rl_papillon/index.ipynb

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,8 @@
1515
"For this tutorial, you will also need the Arbor RL server.\n",
1616
"\n",
1717
"```bash\n",
18-
"> pip install arbor-ai\n",
19-
"> python -m arbor.cli serve --arbor-config arbor.yaml\n",
20-
"```\n",
21-
"\n",
22-
"where you create `arbor.yaml` in your directory, containing a plan like:\n",
23-
"\n",
24-
"```text\n",
25-
"inference:\n",
26-
" gpu_ids: '0'\n",
27-
"\n",
28-
"training:\n",
29-
" gpu_ids: '1, 2'\n",
30-
"```\n",
31-
"\n",
32-
"which assigns GPU 0 for inference and GPUs 1 and 2 for training."
18+
"> pip install -U arbor-ai\n",
19+
"```"
3320
]
3421
},
3522
{
@@ -41,14 +28,16 @@
4128
"import dspy\n",
4229
"from dspy.clients.lm_local_arbor import ArborProvider\n",
4330
"\n",
31+
"import arbor\n",
32+
"arbor_server_info = arbor.init() # Initialize the Arbor server in the background\n",
33+
"\n",
4434
"port = 7453\n",
45-
"local_lm_name = \"Qwen/Qwen3-1.7B\"\n",
35+
"local_lm_name = \"Qwen/Qwen2.5-7B-Instruct\"\n",
4636
"local_lm = dspy.LM(\n",
4737
" model=f\"openai/arbor:{local_lm_name}\",\n",
4838
" provider=ArborProvider(),\n",
4939
" temperature=0.7,\n",
50-
" api_base=f\"http://localhost:{port}/v1/\",\n",
51-
" api_key=\"arbor\",\n",
40+
" api_base=arbor_server_info[\"api_base\"],\n",
5241
")\n",
5342
"\n",
5443
"dspy.configure(lm=local_lm)\n",
@@ -267,6 +256,7 @@
267256
"outputs": [],
268257
"source": [
269258
"from dspy.teleprompt.grpo import GRPO\n",
259+
"from dspy.clients.utils_finetune import MultiGPUConfig\n",
270260
"\n",
271261
"papillon = PAPILLON(untrusted_model=openai_lm)\n",
272262
"papillon.set_lm(local_lm)\n",
@@ -275,7 +265,7 @@
275265
"train_kwargs = {\n",
276266
" \"per_device_train_batch_size\": 8,\n",
277267
" \"gradient_accumulation_steps\": 4,\n",
278-
" \"temperature\": 0.7,\n",
268+
" \"temperature\": 1.0,\n",
279269
" \"beta\": 0.04,\n",
280270
" \"learning_rate\": 2e-6,\n",
281271
" \"gradient_checkpointing\": True,\n",
@@ -301,6 +291,7 @@
301291
" num_steps_for_val=10,\n",
302292
" train_kwargs=train_kwargs,\n",
303293
" report_train_scores=False,\n",
294+
" gpu_config=MultiGPUConfig(num_inference_gpus=2, num_training_gpus=2),\n",
304295
")\n",
305296
"\n",
306297
"optimized_papillon = compiler.compile(\n",

dspy/clients/lm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dspy.clients.cache import request_cache
1414
from dspy.clients.openai import OpenAIProvider
1515
from dspy.clients.provider import Provider, ReinforceJob, TrainingJob
16-
from dspy.clients.utils_finetune import TrainDataFormat
16+
from dspy.clients.utils_finetune import MultiGPUConfig, TrainDataFormat
1717
from dspy.dsp.utils.settings import settings
1818
from dspy.utils.callback import BaseCallback
1919

@@ -237,14 +237,14 @@ def thread_function_wrapper():
237237

238238
return job
239239

240-
def reinforce(self, train_kwargs) -> ReinforceJob:
240+
def reinforce(self, train_kwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)) -> ReinforceJob:
241241
# TODO(GRPO Team): Should we return an initialized job here?
242242
from dspy import settings as settings
243243

244244
err = f"Provider {self.provider} does not implement the reinforcement learning interface."
245245
assert self.provider.reinforceable, err
246246

247-
job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs)
247+
job = self.provider.ReinforceJob(lm=self, train_kwargs=train_kwargs, gpu_config=gpu_config)
248248
job.initialize()
249249
return job
250250

dspy/clients/lm_local_arbor.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import time
22
from datetime import datetime
33
from typing import TYPE_CHECKING, Any, TypedDict
4+
from urllib.parse import urljoin
45

56
import openai
67
import requests
78

89
import dspy
910
from dspy.clients.provider import Provider, ReinforceJob, TrainingJob
10-
from dspy.clients.utils_finetune import GRPOGroup, TrainDataFormat, TrainingStatus, save_data
11+
from dspy.clients.utils_finetune import GRPOGroup, MultiGPUConfig, TrainDataFormat, TrainingStatus, save_data
1112

1213
if TYPE_CHECKING:
1314
from dspy.clients.lm import LM
@@ -70,7 +71,7 @@ class ArborReinforceJob(ReinforceJob):
7071
"lora": False,
7172
}
7273

73-
def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs):
74+
def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)):
7475
# The teleprompter must ensure that this is set
7576
if "num_generations" not in train_kwargs:
7677
raise ValueError("num_generations must be set in the training kwargs")
@@ -80,6 +81,7 @@ def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs):
8081
self.provider_job_id = None
8182
self.checkpoints = {}
8283
self.last_checkpoint = None
84+
self.gpu_config = gpu_config
8385

8486
def initialize(self):
8587
# TODO(GRPO Team): Set provider job ID
@@ -118,6 +120,8 @@ def initialize(self):
118120
api_base = self.lm.kwargs["api_base"]
119121

120122
finetune_model = ArborProvider._remove_provider_prefix(self.lm.model)
123+
# Only multi-GPU is supported for now
124+
gpu_config_type = "multi"
121125
data = {
122126
"model": finetune_model,
123127
"num_generations": num_generations,
@@ -140,8 +144,12 @@ def initialize(self):
140144
"logging_steps": logging_steps,
141145
"max_context_length": max_context_length,
142146
"lora": lora,
147+
"gpu_config": {
148+
"type": gpu_config_type,
149+
gpu_config_type: self.gpu_config,
150+
},
143151
}
144-
url = f"{api_base}fine_tuning/grpo/initialize"
152+
url = urljoin(api_base, "fine_tuning/grpo/initialize")
145153
headers = {"Content-Type": "application/json"}
146154
response = requests.post(url=url, headers=headers, json=data)
147155
assert response.status_code == 200, f"Failed to initialize GRPO: {response}"
@@ -158,7 +166,7 @@ def _run_grpo_step_one_group(
158166

159167
finetune_model = ArborProvider._remove_provider_prefix(self.lm.model)
160168
data = {"job_id": self.provider_job_id, "model": finetune_model, "batch": train_group}
161-
url = f"{api_base}fine_tuning/grpo/step"
169+
url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/step")
162170
headers = {"Content-Type": "application/json"}
163171
response = requests.post(url, headers=headers, json=data)
164172
assert response.status_code == 200, f"Failed to run a GRPO step: {response.text}"
@@ -184,7 +192,7 @@ def step(self, train_data: list[GRPOGroup], train_data_format: TrainDataFormat |
184192

185193
def save_checkpoint(self, checkpoint_name: str, score: float | None = None):
186194
api_base = self.lm.kwargs["api_base"]
187-
url = f"{api_base}fine_tuning/grpo/checkpoint"
195+
url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/checkpoint")
188196
headers = {"Content-Type": "application/json"}
189197
body = {"job_id": self.provider_job_id, "checkpoint_name": checkpoint_name}
190198
response = requests.post(url, headers=headers, json=body)
@@ -203,7 +211,7 @@ def save_checkpoint(self, checkpoint_name: str, score: float | None = None):
203211
def terminate(self):
204212
api_base = self.lm.kwargs["api_base"]
205213

206-
url = f"{api_base}fine_tuning/grpo/terminate"
214+
url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/terminate")
207215
headers = {"Content-Type": "application/json"}
208216
body = {"job_id": self.provider_job_id}
209217
response = requests.post(url, headers=headers, json=body)
@@ -214,14 +222,15 @@ def terminate(self):
214222
self.lm.model = ArborProvider._add_provider_prefix(current_model)
215223

216224
def cancel(self):
217-
if ArborProvider.does_job_exist(self.provider_job_id):
218-
status = self.status()
219-
if ArborProvider.is_terminal_training_status(status):
220-
err_msg = "Jobs that are complete cannot be canceled."
221-
err_msg += f" Job with ID {self.provider_job_id} is done."
222-
raise Exception(err_msg)
223-
openai.fine_tuning.jobs.cancel(self.provider_job_id)
224-
self.provider_job_id = None
225+
if self.provider_job_id:
226+
api_base = self.lm.kwargs["api_base"]
227+
url = urljoin(api_base, f"fine_tuning/grpo/{self.provider_job_id}/cancel")
228+
headers = {"Content-Type": "application/json"}
229+
response = requests.post(url, headers=headers)
230+
if response.status_code == 200:
231+
self.provider_job_id = None
232+
else:
233+
raise Exception(f"Failed to cancel GRPO job: {response.text}")
225234

226235
def status(self) -> TrainingStatus:
227236
status = ArborProvider.get_training_status(self.provider_job_id)
@@ -245,7 +254,7 @@ def launch(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
245254
launch_kwargs = launch_kwargs or lm.launch_kwargs
246255

247256
# Make request to launch endpoint
248-
response = requests.post(f"{api_base}chat/launch", json={"model": model, "launch_kwargs": launch_kwargs})
257+
response = requests.post(urljoin(api_base, "chat/launch"), json={"model": model, "launch_kwargs": launch_kwargs})
249258

250259
if response.status_code != 200:
251260
raise Exception(f"Failed to launch model. Status code: {response.status_code}, Response: {response.text}")
@@ -257,7 +266,7 @@ def kill(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
257266
api_base = lm.kwargs["api_base"]
258267

259268
response = requests.post(
260-
f"{api_base}chat/kill",
269+
urljoin(api_base, "chat/kill"),
261270
)
262271

263272
if response.status_code != 200:

dspy/clients/provider.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from threading import Thread
44
from typing import TYPE_CHECKING, Any
55

6-
from dspy.clients.utils_finetune import TrainDataFormat
6+
from dspy.clients.utils_finetune import MultiGPUConfig, TrainDataFormat
77

88
if TYPE_CHECKING:
99
from dspy.clients.lm import LM
@@ -36,11 +36,14 @@ def status(self):
3636

3737

3838
class ReinforceJob:
39-
def __init__(self, lm: "LM", train_kwargs: dict[str, Any] | None = None):
39+
def __init__(self, lm: "LM", train_kwargs: dict[str, Any] | None = None, gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1)):
4040
self.lm = lm
4141
self.train_kwargs = train_kwargs or {}
42+
self.gpu_config = gpu_config
4243
self.checkpoints = {}
4344
self.last_checkpoint = None
45+
self.gpu_config = gpu_config
46+
4447

4548
@abstractmethod
4649
def initialize(self):

dspy/clients/utils_finetune.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ class GRPOChatData(TypedDict):
4343
GRPOGroup = list[GRPOChatData]
4444

4545

46+
class MultiGPUConfig(TypedDict):
47+
# Number of GPUs to use for inference
48+
num_inference_gpus: int
49+
# Number of GPUs to use for training
50+
num_training_gpus: int
51+
52+
4653
def infer_data_format(adapter: Adapter) -> str:
4754
if isinstance(adapter, dspy.ChatAdapter):
4855
return TrainDataFormat.CHAT

dspy/teleprompt/grpo.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dspy.adapters.base import Adapter
77
from dspy.adapters.chat_adapter import ChatAdapter
88
from dspy.clients.lm import LM
9-
from dspy.clients.utils_finetune import GRPOGroup, TrainDataFormat
9+
from dspy.clients.utils_finetune import GRPOGroup, MultiGPUConfig, TrainDataFormat
1010
from dspy.dsp.utils.settings import settings
1111
from dspy.evaluate.evaluate import Evaluate
1212
from dspy.primitives.example import Example
@@ -41,6 +41,7 @@ def __init__(
4141
format_failure_score: float = -1,
4242
variably_invoked_predictor_grouping_mode: Literal["truncate"] | Literal["fill"] | Literal["ragged"] = "truncate",
4343
variably_invoked_predictor_fill_strategy: Literal["randint"] | Literal["max"] | None = None,
44+
gpu_config: MultiGPUConfig = MultiGPUConfig(num_inference_gpus=1, num_training_gpus=1),
4445
):
4546
super().__init__(train_kwargs=train_kwargs)
4647
self.metric = metric
@@ -57,6 +58,7 @@ def __init__(
5758
self.report_train_scores = report_train_scores
5859
self.failure_score = failure_score
5960
self.format_failure_score = format_failure_score
61+
self.gpu_config = gpu_config
6062

6163
assert failure_score > format_failure_score, "failure_score must be greater than format_failure_score since the range [format_failure_score, failure_score] is used to provide dspy formatting rewards"
6264

@@ -332,7 +334,7 @@ def compile(
332334
job_key = (pred.lm, data_key)
333335
if job_key not in grpo_training_jobs:
334336
train_kwargs = self.train_kwargs[pred.lm]
335-
job = pred.lm.reinforce(train_kwargs=train_kwargs)
337+
job = pred.lm.reinforce(train_kwargs=train_kwargs, gpu_config=self.gpu_config)
336338
grpo_training_jobs[job_key] = job
337339

338340
self.report_validation_metrics(

0 commit comments

Comments
 (0)