|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD 3-Clause license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +import unittest |
| 7 | +import warnings |
| 8 | + |
| 9 | +import torch |
| 10 | +from torch.testing._internal import common_utils |
| 11 | +from torch.testing._internal.common_utils import ( |
| 12 | + TestCase, |
| 13 | + run_tests, |
| 14 | +) |
| 15 | +from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig |
| 16 | + |
| 17 | +from torchao.utils import is_fbcode, is_sm_at_least_89, is_sm_at_least_90 |
| 18 | + |
| 19 | +# please check model card for how to generate these models |
| 20 | + |
| 21 | +_DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [ |
| 22 | + # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev |
| 23 | + "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev" |
| 24 | +] |
| 25 | + |
| 26 | +_DEPRECATED_MODEL_INFO = [ |
| 27 | + # model card: https://huggingface.co/torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev |
| 28 | + ( |
| 29 | + "torchao-testing/opt-125m-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev", |
| 30 | + 1, |
| 31 | + "Float8DynamicActivationFloat8WeightConfig", |
| 32 | + ), |
| 33 | +] |
| 34 | + |
| 35 | +_SINGLE_LINEAR_MODEL_NAMES = [ |
| 36 | + # model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev |
| 37 | + "torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev", |
| 38 | + # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev |
| 39 | + "torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev", |
| 40 | + # model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev |
| 41 | + "torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev", |
| 42 | +] |
| 43 | + |
| 44 | + |
| 45 | +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") |
| 46 | +@unittest.skipIf(not is_sm_at_least_89(), "Nedd sm89+") |
| 47 | +@unittest.skipIf( |
| 48 | + is_fbcode(), |
| 49 | + "Skipping the test in fbcode for now, not sure how to download from transformers", |
| 50 | +) |
| 51 | +class TestLoadAndRunCheckpoint(TestCase): |
| 52 | + def _test_single_linear_helper(self, model_name): |
| 53 | + from huggingface_hub import hf_hub_download |
| 54 | + |
| 55 | + downloaded_model = hf_hub_download(model_name, filename="model.pt") |
| 56 | + # Load model weights, example inputs and reference output, |
| 57 | + # run the loaded model and make sure the result matches reference output |
| 58 | + |
| 59 | + with torch.device("meta"): |
| 60 | + # 32 and 256 are the args we used when we save the model, see |
| 61 | + # model card: |
| 62 | + # https://huggingface.co/torchao-testing/single-linear-FP8-v2-0.13-dev |
| 63 | + model = torch.nn.Sequential( |
| 64 | + torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda") |
| 65 | + ) |
| 66 | + with open(downloaded_model, "rb") as f: |
| 67 | + model.load_state_dict(torch.load(f), assign=True) |
| 68 | + |
| 69 | + downloaded_example_inputs = hf_hub_download( |
| 70 | + model_name, filename="model_inputs.pt" |
| 71 | + ) |
| 72 | + with open(downloaded_example_inputs, "rb") as f: |
| 73 | + example_inputs = torch.load(f) |
| 74 | + downloaded_output = hf_hub_download(model_name, filename="model_output.pt") |
| 75 | + with open(downloaded_output, "rb") as f: |
| 76 | + ref_output = torch.load(f) |
| 77 | + |
| 78 | + output = model(*example_inputs) |
| 79 | + self.assertTrue(torch.equal(output, ref_output)) |
| 80 | + |
| 81 | + @common_utils.parametrize("model_name", _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES) |
| 82 | + def test_deprecated_single_linear(self, model_name): |
| 83 | + self._test_single_linear_helper(model_name) |
| 84 | + |
| 85 | + @common_utils.parametrize("model_name", _SINGLE_LINEAR_MODEL_NAMES) |
| 86 | + def test_single_linear(self, model_name): |
| 87 | + """Test that we can load and run the quantized linear checkpoint with saved sample input |
| 88 | + and match the saved output, to make sure there is no BC breaking changes |
| 89 | + when we make changes to tensor subclass implementations |
| 90 | + """ |
| 91 | + if ( |
| 92 | + "Float8DynamicActivationFloat8WeightConfig" in model_name |
| 93 | + and not is_sm_at_least_90() |
| 94 | + ): |
| 95 | + return unittest.skip("FP8 checkpoint is produced in SM90+") |
| 96 | + |
| 97 | + self._test_single_linear_helper(model_name) |
| 98 | + |
| 99 | + @common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO) |
| 100 | + def test_deprecated_hf_models(self, model_info): |
| 101 | + """Test that we print correct warning message when loading a deprecated checkpoint |
| 102 | + and making sure the deprecated checkpoints can still be loaded |
| 103 | + """ |
| 104 | + # Load and quantize model |
| 105 | + model_name, version, config_name = model_info |
| 106 | + with warnings.catch_warnings(record=True) as caught_warnings: |
| 107 | + quantized_model = AutoModelForCausalLM.from_pretrained( |
| 108 | + model_name, |
| 109 | + torch_dtype="bfloat16", |
| 110 | + device_map="cuda:0", |
| 111 | + ) |
| 112 | + assert any( |
| 113 | + "Stored version is not the same as current default version of the config" |
| 114 | + in str(w.message) |
| 115 | + for w in caught_warnings |
| 116 | + ), "Didn't get expected warning message for version mismatch" |
| 117 | + |
| 118 | + assert any( |
| 119 | + f"Models quantized with version 1 of {config_name} is deprecated" |
| 120 | + in str(w.message) |
| 121 | + for w in caught_warnings |
| 122 | + ), "Didn't get expected warning message for deprecation" |
| 123 | + assert isinstance(quantized_model.config.quantization_config, TorchAoConfig) |
| 124 | + assert ( |
| 125 | + quantized_model.config.quantization_config.quant_type.version == version |
| 126 | + ) |
| 127 | + |
| 128 | + tokenizer = AutoTokenizer.from_pretrained(model_name) |
| 129 | + from huggingface_hub import hf_hub_download |
| 130 | + |
| 131 | + downloaded_example_inputs = hf_hub_download( |
| 132 | + model_name, filename="model_prompt.pt" |
| 133 | + ) |
| 134 | + with open(downloaded_example_inputs, "rb") as f: |
| 135 | + prompt = torch.load(f) |
| 136 | + |
| 137 | + inputs = tokenizer( |
| 138 | + prompt, |
| 139 | + return_tensors="pt", |
| 140 | + ).to("cuda") |
| 141 | + generated_ids = quantized_model.generate( |
| 142 | + **inputs, max_new_tokens=128, temperature=0 |
| 143 | + ) |
| 144 | + |
| 145 | + downloaded_output = hf_hub_download(model_name, filename="model_output.pt") |
| 146 | + with open(downloaded_output, "rb") as f: |
| 147 | + ref_generated_ids = torch.load(f) |
| 148 | + |
| 149 | + self.assertTrue(torch.equal(generated_ids, ref_generated_ids)) |
| 150 | + |
| 151 | + # make sure can successfully decode |
| 152 | + _ = tokenizer.batch_decode( |
| 153 | + generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False |
| 154 | + ) |
| 155 | + |
| 156 | + |
| 157 | +common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint) |
| 158 | + |
| 159 | +if __name__ == "__main__": |
| 160 | + run_tests() |
0 commit comments