11import time
22from datetime import datetime
33from typing import TYPE_CHECKING , Any , TypedDict
4+ from urllib .parse import urljoin
45
56import openai
67import requests
78
89import dspy
910from dspy .clients .provider import Provider , ReinforceJob , TrainingJob
10- from dspy .clients .utils_finetune import GRPOGroup , TrainDataFormat , TrainingStatus , save_data
11+ from dspy .clients .utils_finetune import GRPOGroup , MultiGPUConfig , TrainDataFormat , TrainingStatus , save_data
1112
1213if TYPE_CHECKING :
1314 from dspy .clients .lm import LM
@@ -70,7 +71,7 @@ class ArborReinforceJob(ReinforceJob):
7071 "lora" : False ,
7172 }
7273
73- def __init__ (self , lm : "LM" , train_kwargs : GRPOTrainKwargs ):
74+ def __init__ (self , lm : "LM" , train_kwargs : GRPOTrainKwargs , gpu_config : MultiGPUConfig = MultiGPUConfig ( num_inference_gpus = 1 , num_training_gpus = 1 ) ):
7475 # The teleprompter must ensure that this is set
7576 if "num_generations" not in train_kwargs :
7677 raise ValueError ("num_generations must be set in the training kwargs" )
@@ -80,6 +81,7 @@ def __init__(self, lm: "LM", train_kwargs: GRPOTrainKwargs):
8081 self .provider_job_id = None
8182 self .checkpoints = {}
8283 self .last_checkpoint = None
84+ self .gpu_config = gpu_config
8385
8486 def initialize (self ):
8587 # TODO(GRPO Team): Set provider job ID
@@ -118,6 +120,8 @@ def initialize(self):
118120 api_base = self .lm .kwargs ["api_base" ]
119121
120122 finetune_model = ArborProvider ._remove_provider_prefix (self .lm .model )
123+ # Only multi-GPU is supported for now
124+ gpu_config_type = "multi"
121125 data = {
122126 "model" : finetune_model ,
123127 "num_generations" : num_generations ,
@@ -140,8 +144,12 @@ def initialize(self):
140144 "logging_steps" : logging_steps ,
141145 "max_context_length" : max_context_length ,
142146 "lora" : lora ,
147+ "gpu_config" : {
148+ "type" : gpu_config_type ,
149+ gpu_config_type : self .gpu_config ,
150+ },
143151 }
144- url = f" { api_base } fine_tuning/grpo/initialize"
152+ url = urljoin ( api_base , " fine_tuning/grpo/initialize")
145153 headers = {"Content-Type" : "application/json" }
146154 response = requests .post (url = url , headers = headers , json = data )
147155 assert response .status_code == 200 , f"Failed to initialize GRPO: { response } "
@@ -158,7 +166,7 @@ def _run_grpo_step_one_group(
158166
159167 finetune_model = ArborProvider ._remove_provider_prefix (self .lm .model )
160168 data = {"job_id" : self .provider_job_id , "model" : finetune_model , "batch" : train_group }
161- url = f" { api_base } fine_tuning/grpo/step"
169+ url = urljoin ( api_base , f" fine_tuning/grpo/{ self . provider_job_id } / step")
162170 headers = {"Content-Type" : "application/json" }
163171 response = requests .post (url , headers = headers , json = data )
164172 assert response .status_code == 200 , f"Failed to run a GRPO step: { response .text } "
@@ -184,7 +192,7 @@ def step(self, train_data: list[GRPOGroup], train_data_format: TrainDataFormat |
184192
185193 def save_checkpoint (self , checkpoint_name : str , score : float | None = None ):
186194 api_base = self .lm .kwargs ["api_base" ]
187- url = f" { api_base } fine_tuning/grpo/checkpoint"
195+ url = urljoin ( api_base , f" fine_tuning/grpo/{ self . provider_job_id } / checkpoint")
188196 headers = {"Content-Type" : "application/json" }
189197 body = {"job_id" : self .provider_job_id , "checkpoint_name" : checkpoint_name }
190198 response = requests .post (url , headers = headers , json = body )
@@ -203,7 +211,7 @@ def save_checkpoint(self, checkpoint_name: str, score: float | None = None):
203211 def terminate (self ):
204212 api_base = self .lm .kwargs ["api_base" ]
205213
206- url = f" { api_base } fine_tuning/grpo/terminate"
214+ url = urljoin ( api_base , f" fine_tuning/grpo/{ self . provider_job_id } / terminate")
207215 headers = {"Content-Type" : "application/json" }
208216 body = {"job_id" : self .provider_job_id }
209217 response = requests .post (url , headers = headers , json = body )
@@ -214,14 +222,15 @@ def terminate(self):
214222 self .lm .model = ArborProvider ._add_provider_prefix (current_model )
215223
216224 def cancel (self ):
217- if ArborProvider .does_job_exist (self .provider_job_id ):
218- status = self .status ()
219- if ArborProvider .is_terminal_training_status (status ):
220- err_msg = "Jobs that are complete cannot be canceled."
221- err_msg += f" Job with ID { self .provider_job_id } is done."
222- raise Exception (err_msg )
223- openai .fine_tuning .jobs .cancel (self .provider_job_id )
224- self .provider_job_id = None
225+ if self .provider_job_id :
226+ api_base = self .lm .kwargs ["api_base" ]
227+ url = urljoin (api_base , f"fine_tuning/grpo/{ self .provider_job_id } /cancel" )
228+ headers = {"Content-Type" : "application/json" }
229+ response = requests .post (url , headers = headers )
230+ if response .status_code == 200 :
231+ self .provider_job_id = None
232+ else :
233+ raise Exception (f"Failed to cancel GRPO job: { response .text } " )
225234
226235 def status (self ) -> TrainingStatus :
227236 status = ArborProvider .get_training_status (self .provider_job_id )
@@ -245,7 +254,7 @@ def launch(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
245254 launch_kwargs = launch_kwargs or lm .launch_kwargs
246255
247256 # Make request to launch endpoint
248- response = requests .post (f" { api_base } chat/launch" , json = {"model" : model , "launch_kwargs" : launch_kwargs })
257+ response = requests .post (urljoin ( api_base , " chat/launch") , json = {"model" : model , "launch_kwargs" : launch_kwargs })
249258
250259 if response .status_code != 200 :
251260 raise Exception (f"Failed to launch model. Status code: { response .status_code } , Response: { response .text } " )
@@ -257,7 +266,7 @@ def kill(lm: "LM", launch_kwargs: dict[str, Any] | None = None):
257266 api_base = lm .kwargs ["api_base" ]
258267
259268 response = requests .post (
260- f" { api_base } chat/kill" ,
269+ urljoin ( api_base , " chat/kill") ,
261270 )
262271
263272 if response .status_code != 200 :
0 commit comments