@@ -370,7 +370,6 @@ def clean_output(self, output: str, prompt: str) -> str:
370370
371371
372372class  InstructConfig (InferenceConfig ):
373- 
374373    def  __init__ (self , prompted  : bool  =  False , instruction_tag  : str  =  "### Instruction" , response_tag  : str  =  "### Response" ):
375374        super ().__init__ (prompted = prompted )
376375        self .instruction_tag  =  instruction_tag 
@@ -401,6 +400,63 @@ def format_prompt(self, prompt : str) -> str:
401400    def  clean_output (self , output : str , prompt : str ) ->  str :
402401        return  clean_instruct_output (output , prompt , self .response_tag )
403402
403+ class  QwenConfig (InferenceConfig ):
404+     def  __init__ (self , prompted  : bool  =  False ):
405+         super ().__init__ (prompted = prompted )
406+ 
407+     def  get_dtype (self ):
408+         return  torch .float16 
409+ 
410+     def  init_padding (self , tokenizer ):
411+         tokenizer .pad_token_id  =  tokenizer .eos_token_id   # for batching 
412+         tokenizer .padding_side  =  "left"    # for decoder-only models 
413+ 
414+     def  get_pad_token_id (self , tokenizer ) ->  int :
415+         return  tokenizer .eos_token_id 
416+ 
417+     def  get_eos_token_id (self , tokenizer ) ->  int :
418+         return  None 
419+     
420+     def  trust_remote_code (self ) ->  bool :
421+         return  False 
422+ 
423+     def  format_prompt (self , prompt  : str ) ->  str :
424+         if  self .prompted :
425+             return  f"// filename: solutions/solution_1.cpp\n // here is the correct implementation of the coding exercise\n \n { prompt }  
426+         return  prompt .strip ()
427+ 
428+     def  clean_output (self , output : str , prompt : str ) ->  str :
429+         return  clean_output (output , prompt )
430+ 
431+ class  ChatMLConfig (InferenceConfig ):
432+     def  __init__ (self , prompted  : bool  =  False ):
433+         super ().__init__ (prompted = prompted )
434+ 
435+     def  get_dtype (self ):
436+         return  torch .bfloat16 
437+ 
438+     def  init_padding (self , tokenizer ):
439+         tokenizer .pad_token_id  =  tokenizer .eos_token_id   # for batching 
440+         tokenizer .padding_side  =  "left"    # for decoder-only models 
441+ 
442+     def  get_pad_token_id (self , tokenizer ) ->  int :
443+         return  tokenizer .pad_token_id 
444+ 
445+     def  get_eos_token_id (self , tokenizer ) ->  int :
446+         return  tokenizer .eos_token_id 
447+     
448+     def  trust_remote_code (self ) ->  bool :
449+         return  False 
450+ 
451+     def  format_prompt (self , prompt  : str ) ->  str :
452+         function_name  =  get_function_name (prompt , "cuda"  if  "__global__"  in  prompt  else  "serial" )
453+         prompt  =  f"Complete the following c++ function.\n ```c++{ prompt .strip ()} \n Write only the function { function_name }  
454+         prompt  =  f"<|im_start|>system\n You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.<|im_end|>\n <|im_start|>user\n { prompt } \n <|im_start|>assistant\n " 
455+         return  prompt 
456+ 
457+     def  clean_output (self , output : str , prompt : str ) ->  str :
458+         return  clean_instruct_output (output , prompt ,"<|im_start|>assistant\n " )
459+ 
404460def  get_inference_config (model_name  : str , ** kwargs ) ->  InferenceConfig :
405461    if  model_name  ==  "bigcode/starcoderbase" :
406462        return  StarCoderConfig (** kwargs )
@@ -422,6 +478,12 @@ def get_inference_config(model_name : str, **kwargs) -> InferenceConfig:
422478        return  InstructConfig (instruction_tag = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n \n ### Instruction:' , response_tag = '### Response:' , ** kwargs )
423479    elif  model_name .startswith ('hpcgroup/rlpf' ):
424480        return  InstructConfig (instruction_tag = '### Instruction' , response_tag = '### Response' , ** kwargs )
481+     elif  model_name .startswith ('Qwen/Qwen2.5' ) and  'Instruct'  in  model_name :
482+         return  ChatMLConfig (** kwargs )
483+     elif  model_name .startswith ('Qwen/Qwen3' ):
484+         return  ChatMLConfig (** kwargs )
485+     elif  model_name .startswith ('Qwen/Qwen2.5' ):
486+         return  QwenConfig (** kwargs )
425487    else :
426488        raise  ValueError (f"Unknown model name: { model_name }  )
427489
0 commit comments