|
2 | 2 |
|
3 | 3 | from together.resources.finetune import create_finetune_request |
4 | 4 | from together.types.finetune import ( |
5 | | - FinetuneTrainingLimits, |
6 | 5 | FinetuneFullTrainingLimits, |
7 | 6 | FinetuneLoraTrainingLimits, |
| 7 | + FinetuneTrainingLimits, |
8 | 8 | ) |
9 | 9 |
|
10 | 10 |
|
@@ -117,50 +117,36 @@ def test_no_from_checkpoint_no_model_name(): |
117 | 117 | ) |
118 | 118 |
|
119 | 119 |
|
120 | | -def test_batch_size_limit(): |
121 | | - with pytest.raises( |
122 | | - ValueError, |
123 | | - match="Requested batch size is higher that the maximum allowed value", |
124 | | - ): |
125 | | - _ = create_finetune_request( |
126 | | - model_limits=_MODEL_LIMITS, |
127 | | - model=_MODEL_NAME, |
128 | | - training_file=_TRAINING_FILE, |
129 | | - batch_size=128, |
130 | | - ) |
131 | | - |
132 | | - with pytest.raises( |
133 | | - ValueError, match="Requested batch size is lower that the minimum allowed value" |
134 | | - ): |
135 | | - _ = create_finetune_request( |
136 | | - model_limits=_MODEL_LIMITS, |
137 | | - model=_MODEL_NAME, |
138 | | - training_file=_TRAINING_FILE, |
139 | | - batch_size=1, |
140 | | - ) |
141 | | - |
142 | | - with pytest.raises( |
143 | | - ValueError, |
144 | | - match="Requested batch size is higher that the maximum allowed value", |
145 | | - ): |
146 | | - _ = create_finetune_request( |
147 | | - model_limits=_MODEL_LIMITS, |
148 | | - model=_MODEL_NAME, |
149 | | - training_file=_TRAINING_FILE, |
150 | | - batch_size=256, |
151 | | - lora=True, |
152 | | - ) |
153 | | - |
154 | | - with pytest.raises( |
155 | | - ValueError, match="Requested batch size is lower that the minimum allowed value" |
156 | | - ): |
157 | | - _ = create_finetune_request( |
158 | | - model_limits=_MODEL_LIMITS, |
159 | | - model=_MODEL_NAME, |
160 | | - training_file=_TRAINING_FILE, |
161 | | - batch_size=1, |
162 | | - lora=True, |
163 | | - ) |
| 120 | +@pytest.mark.parametrize("batch_size", [256, 1]) |
| 121 | +@pytest.mark.parametrize("use_lora", [False, True]) |
| 122 | +def test_batch_size_limit(batch_size, use_lora): |
| 123 | + model_limits = ( |
| 124 | + _MODEL_LIMITS.full_training if not use_lora else _MODEL_LIMITS.lora_training |
| 125 | + ) |
| 126 | + max_batch_size = model_limits.max_batch_size |
| 127 | + min_batch_size = model_limits.min_batch_size |
| 128 | + |
| 129 | + if batch_size > max_batch_size: |
| 130 | + error_message = f"Requested batch size of {batch_size} is higher that the maximum allowed value of {max_batch_size}" |
| 131 | + with pytest.raises(ValueError, match=error_message): |
| 132 | + _ = create_finetune_request( |
| 133 | + model_limits=_MODEL_LIMITS, |
| 134 | + model=_MODEL_NAME, |
| 135 | + training_file=_TRAINING_FILE, |
| 136 | + batch_size=batch_size, |
| 137 | + lora=use_lora, |
| 138 | + ) |
| 139 | + |
| 140 | + if batch_size < min_batch_size: |
| 141 | + error_message = f"Requested batch size of {batch_size} is lower that the minimum allowed value of {min_batch_size}" |
| 142 | + with pytest.raises(ValueError, match=error_message): |
| 143 | + _ = create_finetune_request( |
| 144 | + model_limits=_MODEL_LIMITS, |
| 145 | + model=_MODEL_NAME, |
| 146 | + training_file=_TRAINING_FILE, |
| 147 | + batch_size=batch_size, |
| 148 | + lora=use_lora, |
| 149 | + ) |
164 | 150 |
|
165 | 151 |
|
166 | 152 | def test_non_lora_model(): |
|
0 commit comments