-
Notifications
You must be signed in to change notification settings - Fork 51
[QEff. Finetuning]: Enhance test cases to match intermediate step level loss/metrics #531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
c9ada4a
to
084bb38
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, just address minor comments and we are good to merge.
d76ef7f
to
c31b88a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks Ann for making this change.
tests/finetune/test_finetune.py
Outdated
f"{name} length mismatch for scenario '{scenario_key}' (WS: {current_world_size}, Rank: {current_rank}). " | ||
f"Expected {len(ref_list)} elements, but got {len(actual_list)}." | ||
) | ||
max_diff = np.max(np.abs(np.array(ref_list) - np.array(actual_list))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In case of mismatch, it will report the max diff. It should instead report: 1) The step numbers at which deviation is happening, 2) diff in value at each of these steps. np.isclose() will help in getting the deviated indices. Before this, np.allclose() can be used to check if the assertion is passing or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added step wise details for deviation
REFERENCE_DATA = { | ||
# Scenario 1: Single-device llama 3.2-1B training on Alpaca dataset. | ||
"llama_config_alpaca_single_device": { | ||
"description": "Baseline for Llama on Alpaca single-device", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add the complete model ID here and in other configs as well.
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
Signed-off-by: Ann Kuruvilla <[email protected]>
4b6e21d
to
77106f6
Compare
Enable test cases for Intermediate step level loss/metric matching in single and DDP set up.
Nested dictionary structure for mapping the reference losses at different test scenarios. The test scenarios with the ref values are listed in a separate reference file.
The test scenarios at present include single device testing for below models:
Llama, Bert on Alpaca and GSM8k dataset.
REFERNCE DATA based on SDK - 1.21.0.23