Skip to content

Commit c9ada4a

Browse files
committed
Format
Signed-off-by: Ann Kuruvilla <[email protected]>
1 parent 9a980c3 commit c9ada4a

File tree

4 files changed

+100
-21
lines changed

4 files changed

+100
-21
lines changed

QEfficient/finetune/utils/train_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,6 @@ def train(
358358
logger.log_rank_zero(
359359
f"Epoch {epoch + 1}: Train epoch loss: {train_epoch_loss:.4f}, Train metric: {train_epoch_metric:.4f}, Epoch time {epoch_end_time:.2f} sec"
360360
)
361-
breakpoint()
362361
# Saving the results every epoch to plot later
363362
if train_config.save_metrics:
364363
save_to_json(

QEfficient/utils/constants.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
# Minimum value for causal mask
3232
MIN_MASKED_ATTENTION_VALUE = float("-inf")
3333

34+
# Finetuning
35+
LOSS_ATOL = 1e-3
36+
METRIC_ATOL = 1e-3
37+
3438

3539
# Store the qeff_models inside the ~/.cache directory or over-ride with an env variable.
3640
def get_models_dir():

tests/finetune/reference_data.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Scenario 1: Single-device llama training on Alpaca dataset.
44
"llama_config_alpaca_single_device": {
55
"description": "Baseline for Llama on Alpaca single-device",
6-
"train_step_losses": [
6+
"train_step_losses": [
77
1.5112206935882568,
88
1.2211230993270874,
99
1.9942185878753662,
@@ -106,7 +106,7 @@
106106
1.4072850942611694,
107107
1.374159812927246,
108108
],
109-
"train_step_metrics": [
109+
"train_step_metrics": [
110110
9.490362167358398,
111111
10.207969665527344,
112112
6.944809913635254,
@@ -140,16 +140,62 @@
140140
3.951754093170166,
141141
],
142142
},
143-
144143
# Scenario 3: Single-device Bert training on IMDB dataset.
145144
"bert_config_imdb_single_device": {
146145
"description": "Baseline for BERT on IMDB single-device",
147146
"train_step_losses": [
148-
0.390625, 0.51220703125, 0.9208984375, 0.4052734375, 1.1640625, 0.6533203125, 0.5087890625, 0.76171875, 0.63525390625, 0.50146484375, 0.5439453125, 0.947265625, 0.89013671875, 0.80419921875, 0.6533203125, 0.4580078125, 0.92041015625, 0.7412109375, 0.7197265625
147+
0.390625,
148+
0.51220703125,
149+
0.9208984375,
150+
0.4052734375,
151+
1.1640625,
152+
0.6533203125,
153+
0.5087890625,
154+
0.76171875,
155+
0.63525390625,
156+
0.50146484375,
157+
0.5439453125,
158+
0.947265625,
159+
0.89013671875,
160+
0.80419921875,
161+
0.6533203125,
162+
0.4580078125,
163+
0.92041015625,
164+
0.7412109375,
165+
0.7197265625,
166+
],
167+
"eval_step_losses": [
168+
0.55126953125,
169+
0.7421875,
170+
0.86572265625,
171+
0.64501953125,
172+
0.65234375,
173+
0.60302734375,
174+
0.638671875,
175+
0.8232421875,
176+
0.6611328125,
177+
0.6240234375,
149178
],
150-
"eval_step_losses": [0.55126953125, 0.7421875, 0.86572265625, 0.64501953125, 0.65234375, 0.60302734375, 0.638671875, 0.8232421875, 0.6611328125, 0.6240234375],
151179
"train_step_metrics": [
152-
1.0, 1.0, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.625, 0.625, 0.625, 0.5999755859375, 0.58331298828125, 0.5714111328125, 0.5714111328125, 0.5714111328125, 0.5625, 0.5555419921875, 0.5054931640625
180+
1.0,
181+
1.0,
182+
0.5,
183+
0.5,
184+
0.5,
185+
0.5,
186+
0.5,
187+
0.5,
188+
0.625,
189+
0.625,
190+
0.625,
191+
0.5999755859375,
192+
0.58331298828125,
193+
0.5714111328125,
194+
0.5714111328125,
195+
0.5714111328125,
196+
0.5625,
197+
0.5555419921875,
198+
0.5054931640625,
153199
],
154200
"eval_step_metrics": [1.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0],
155201
},
@@ -159,23 +205,15 @@
159205
"world_size": 2,
160206
"rank_data": {
161207
0: { # Data for Rank 0
162-
"train_step_losses": [
163-
164-
],
208+
"train_step_losses": [],
165209
"eval_step_losses": [],
166-
"train_step_metrics": [
167-
168-
],
210+
"train_step_metrics": [],
169211
"eval_step_metrics": [],
170212
},
171213
1: { # Data for Rank 1
172-
"train_step_losses": [
173-
174-
],
214+
"train_step_losses": [],
175215
"eval_step_losses": [],
176-
"train_step_metrics": [
177-
178-
],
216+
"train_step_metrics": [],
179217
"eval_step_metrics": [],
180218
},
181219
},

tests/finetune/test_finetune.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616

1717
import QEfficient
1818
import QEfficient.cloud.finetune
19-
from . import reference_data as ref_data
2019
from QEfficient.cloud.finetune import main as finetune
2120
from QEfficient.finetune.utils.helper import Device, Task_Mode
21+
from QEfficient.utils import constants as constant
2222

23-
LOSS_ATOL = 0.02
23+
from . import reference_data as ref_data
2424

2525
alpaca_json_path = os.path.join(os.getcwd(), "alpaca_data.json")
2626

@@ -210,6 +210,44 @@ def test_finetune(
210210
# "Eval metric is not matching."
211211
# )
212212

213+
# Assertions for step-level values using the helper function
214+
assert_list_close(
215+
ref_train_losses,
216+
results["train_step_loss"],
217+
constant.LOSS_ATOL,
218+
"Train Step Losses",
219+
scenario_key,
220+
current_world_size,
221+
current_rank,
222+
)
223+
assert_list_close(
224+
ref_eval_losses,
225+
results["eval_step_loss"],
226+
constant.LOSS_ATOL,
227+
"Eval Step Losses",
228+
scenario_key,
229+
current_world_size,
230+
current_rank,
231+
)
232+
assert_list_close(
233+
ref_train_metrics,
234+
results["train_step_metric"],
235+
constant.METRIC_ATOL,
236+
"Train Step Metrics",
237+
scenario_key,
238+
current_world_size,
239+
current_rank,
240+
)
241+
assert_list_close(
242+
ref_eval_metrics,
243+
results["eval_step_metric"],
244+
constant.METRIC_ATOL,
245+
"Eval Step Metrics",
246+
scenario_key,
247+
current_world_size,
248+
current_rank,
249+
)
250+
213251
assert results["avg_epoch_time"] < 60, "Training should complete within 60 seconds."
214252

215253
train_config_spy.assert_called_once()

0 commit comments

Comments
 (0)