Skip to content

Commit 40bf04c

Browse files
committed
Add load and run tests for checkpoints that we want to have BC
Summary: Added load and run tests to make sure previously saved checkpoints can continue to load and run. includes FP8, INT4 and INT4 + preshuffled checkpoints since these might reach larger audience Test Plan: python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags: stack-info: PR: #2792, branch: jerryzh168/stack/28
1 parent 751d7f6 commit 40bf04c

File tree

2 files changed

+160
-77
lines changed

2 files changed

+160
-77
lines changed
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
import unittest
7+
import warnings
8+
9+
import torch
10+
from torch.testing._internal import common_utils
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
run_tests,
14+
)
15+
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
16+
17+
from torchao.utils import is_fbcode, is_sm_at_least_89, is_sm_at_least_90
18+
19+
# please check model card for how to generate these models
20+
21+
_DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [
22+
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev
23+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev"
24+
]
25+
26+
_DEPRECATED_MODEL_INFO = [
27+
# model card: https://huggingface.co/torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev
28+
(
29+
"torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev",
30+
1,
31+
"Float8DynamicActivationFloat8WeightConfig",
32+
),
33+
]
34+
35+
_SINGLE_LINEAR_MODEL_NAMES = [
36+
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev
37+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
38+
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev
39+
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
40+
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev
41+
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
42+
]
43+
44+
45+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
46+
@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+")
47+
@unittest.skipIf(
48+
is_fbcode(),
49+
"Skipping the test in fbcode for now, not sure how to download from transformers",
50+
)
51+
class TestLoadAndRunCheckpoint(TestCase):
52+
def _test_single_linear_helper(self, model_name):
53+
from huggingface_hub import hf_hub_download
54+
55+
downloaded_model = hf_hub_download(model_name, filename="model.pt")
56+
# Load model weights, example inputs and reference output,
57+
# run the loaded model and make sure the result matches reference output
58+
59+
with torch.device("meta"):
60+
# 32 and 256 are the args we used when we save the model, see
61+
# model card:
62+
# https://huggingface.co/torchao-testing/single-linear-FP8-v2-0.13-dev
63+
model = torch.nn.Sequential(
64+
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
65+
)
66+
with open(downloaded_model, "rb") as f:
67+
model.load_state_dict(torch.load(f), assign=True)
68+
69+
downloaded_example_inputs = hf_hub_download(
70+
model_name, filename="model_inputs.pt"
71+
)
72+
with open(downloaded_example_inputs, "rb") as f:
73+
example_inputs = torch.load(f)
74+
downloaded_output = hf_hub_download(model_name, filename="model_output.pt")
75+
with open(downloaded_output, "rb") as f:
76+
ref_output = torch.load(f)
77+
78+
output = model(*example_inputs)
79+
self.assertTrue(torch.equal(output, ref_output))
80+
81+
@common_utils.parametrize("model_name", _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES)
82+
def test_deprecated_single_linear(self, model_name):
83+
self._test_single_linear_helper(model_name)
84+
85+
@common_utils.parametrize("model_name", _SINGLE_LINEAR_MODEL_NAMES)
86+
def test_single_linear(self, model_name):
87+
"""Test that we can load and run the quantized linear checkpoint with saved sample input
88+
and match the saved output, to make sure there is no BC breaking changes
89+
when we make changes to tensor subclass implementations
90+
"""
91+
if (
92+
"Float8DynamicActivationFloat8WeightConfig" in model_name
93+
and not is_sm_at_least_90()
94+
):
95+
return unittest.skip("FP8 checkpoint is produced in SM90+")
96+
97+
self._test_single_linear_helper(model_name)
98+
99+
@common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO)
100+
def test_deprecated_hf_models(self, model_info):
101+
"""Test that we print correct warning message when loading a deprecated checkpoint
102+
and making sure the deprecated checkpoints can still be loaded
103+
"""
104+
# Load and quantize model
105+
model_name, version, config_name = model_info
106+
with warnings.catch_warnings(record=True) as caught_warnings:
107+
quantized_model = AutoModelForCausalLM.from_pretrained(
108+
model_name,
109+
torch_dtype="bfloat16",
110+
device_map="cuda:0",
111+
)
112+
assert any(
113+
"Stored version is not the same as current default version of the config"
114+
in str(w.message)
115+
for w in caught_warnings
116+
), "Didn't get expected warning message for version mismatch"
117+
118+
assert any(
119+
f"Models quantized with version 1 of {config_name} is deprecated"
120+
in str(w.message)
121+
for w in caught_warnings
122+
), "Didn't get expected warning message for deprecation"
123+
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
124+
assert (
125+
quantized_model.config.quantization_config.quant_type.version == version
126+
)
127+
128+
tokenizer = AutoTokenizer.from_pretrained(model_name)
129+
from huggingface_hub import hf_hub_download
130+
131+
downloaded_example_inputs = hf_hub_download(
132+
model_name, filename="model_prompt.pt"
133+
)
134+
with open(downloaded_example_inputs, "rb") as f:
135+
prompt = torch.load(f)
136+
137+
inputs = tokenizer(
138+
prompt,
139+
return_tensors="pt",
140+
).to("cuda")
141+
generated_ids = quantized_model.generate(
142+
**inputs, max_new_tokens=128, temperature=0
143+
)
144+
145+
downloaded_output = hf_hub_download(model_name, filename="model_output.pt")
146+
with open(downloaded_output, "rb") as f:
147+
ref_generated_ids = torch.load(f)
148+
149+
self.assertTrue(torch.equal(generated_ids, ref_generated_ids))
150+
151+
# make sure can successfully decode
152+
_ = tokenizer.batch_decode(
153+
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
154+
)
155+
156+
157+
common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint)
158+
159+
if __name__ == "__main__":
160+
run_tests()

test/integration/test_loading_deprecated_checkpoint.py

Lines changed: 0 additions & 77 deletions
This file was deleted.

0 commit comments

Comments
 (0)