|
18 | 18 | min_learning_rate=1e-6, |
19 | 19 | full_training=FinetuneFullTrainingLimits( |
20 | 20 | max_batch_size=96, |
| 21 | + max_batch_size_dpo=48, |
21 | 22 | min_batch_size=8, |
22 | 23 | ), |
23 | 24 | lora_training=FinetuneLoraTrainingLimits( |
24 | 25 | max_batch_size=128, |
| 26 | + max_batch_size_dpo=64, |
25 | 27 | min_batch_size=8, |
26 | 28 | max_rank=64, |
27 | 29 | target_modules=["q", "k", "v", "o", "mlp"], |
@@ -83,6 +85,40 @@ def test_lora_request(): |
83 | 85 | assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size |
84 | 86 |
|
85 | 87 |
|
| 88 | +def test_dpo_request_lora(): |
| 89 | + request = create_finetune_request( |
| 90 | + model_limits=_MODEL_LIMITS, |
| 91 | + model=_MODEL_NAME, |
| 92 | + training_file=_TRAINING_FILE, |
| 93 | + training_method="dpo", |
| 94 | + lora=True, |
| 95 | + ) |
| 96 | + |
| 97 | + assert request.training_type.type == "Lora" |
| 98 | + assert request.training_type.lora_r == _MODEL_LIMITS.lora_training.max_rank |
| 99 | + assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2 |
| 100 | + assert request.training_type.lora_dropout == 0.0 |
| 101 | + assert request.training_type.lora_trainable_modules == "all-linear" |
| 102 | + assert request.batch_size == _MODEL_LIMITS.lora_training.max_batch_size_dpo |
| 103 | + |
| 104 | + |
| 105 | +def test_dpo_request(): |
| 106 | + request = create_finetune_request( |
| 107 | + model_limits=_MODEL_LIMITS, |
| 108 | + model=_MODEL_NAME, |
| 109 | + training_file=_TRAINING_FILE, |
| 110 | + training_method="dpo", |
| 111 | + lora=False, |
| 112 | + ) |
| 113 | + |
| 114 | + assert request.training_type.type == "Lora" |
| 115 | + assert request.training_type.lora_r == _MODEL_LIMITS.lora_training.max_rank |
| 116 | + assert request.training_type.lora_alpha == _MODEL_LIMITS.lora_training.max_rank * 2 |
| 117 | + assert request.training_type.lora_dropout == 0.0 |
| 118 | + assert request.training_type.lora_trainable_modules == "all-linear" |
| 119 | + assert request.batch_size == _MODEL_LIMITS.full_training.max_batch_size_dpo |
| 120 | + |
| 121 | + |
86 | 122 | def test_from_checkpoint_request(): |
87 | 123 | request = create_finetune_request( |
88 | 124 | model_limits=_MODEL_LIMITS, |
|
0 commit comments