Skip to content

Conversation

namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Aug 11, 2025

Summary

Add SmoothQuantConfig as a base config and SmoothQuantObserver as a smoothing factor computation. Apply corresponding changes in other parts for the SmoothQuant API flows.

Test Plan

Unittest and real run (example.py) using example.py with Llama-2-7b-chat-hf for both quantization and model saving

Future Plan

Build a benchmark within the vLLM ecosystem for AWQ and SmoothQuant. See #2815 for more info

Summary:
- Added SmoothQuantConfig as a base config and made corresponding changes in other parts of the flow

Test Plan:
- Qwen 3-8B with example.py and unittest
- Additional test plans requirerd

ETC
- Fix typo in README.md for SmoothQuant
Copy link

pytorch-bot bot commented Aug 11, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2728

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 11, 2025
@namgyu-youn namgyu-youn marked this pull request as draft August 12, 2025 06:53
@namgyu-youn namgyu-youn marked this pull request as ready for review August 12, 2025 08:01
@namgyu-youn
Copy link
Contributor Author

@jerryzh168 Could you please look into this PR? It was inspired by #2659 (comment) for more generalized SmoothQuant API.

@jerryzh168
Copy link
Contributor

jerryzh168 commented Aug 15, 2025

Thanks @namgyu-youn this is a step towards that but not fully general yet, it seems to be a quick change to add it though, commented inline.

also it seems smoothquant is not very popular at the moment: https://huggingface.co/models?search=smoothquant, so I'd like to wait a bit before we invest more effort to it, let me know if you are interested to contribute more to torchao, we have many more higher priority issues that you can help with I think

@jerryzh168 jerryzh168 self-requested a review August 15, 2025 17:53
@namgyu-youn
Copy link
Contributor Author

Thanks @namgyu-youn this is a step towards that but not fully general yet, it seems to be a quick change to add it though, commented inline.

also it seems smoothquant is not very popular at the moment: https://huggingface.co/models?search=smoothquant, so I'd like to wait a bit before we invest more effort to it, let me know if you are interested to contribute more to torchao, we have many more higher priority issues that you can help with I think

Thanks for the kind info, and I truly love your team's work after reviewing TorchAO: CodeML @ ICML 2025.

The recently updated contribution guide could be a great choice for the next contribution, but personally I prefer the sparsity (pruning) module more. Unfortunately, I heard the main POC (@jcaip) is on vacation, making it hard for me to progress. The following are my recent activities related to the sparsity module:

  1. Since Wanda was already introduced, I recently introduced Wanda++ at feat: RGS for wanda++ #2537.
  2. Computation overhead was missing in your team's workshop (not certain because of my lack of knowledge), and opened issue at Missing benchmark for sparse24_sm90_sparsify overhead #2612
  3. Also interested in Activation compression Accelerate activation sparsity with activation compression #1920, but I have to learn more about it.

If there is no huge progress for the sparsity module, quantization (new APIs or primitive ops) might be a next step. Let me know if there is a good-second-issue about it.

p.s. Could you please check #2644 ? It hasn't merged yet after being approved (no CI broken). Also, #2660 has been waiting for review (I am fine to close this because it is low-priority).

@namgyu-youn namgyu-youn marked this pull request as draft August 16, 2025 15:27
@namgyu-youn namgyu-youn marked this pull request as ready for review August 16, 2025 18:31
@namgyu-youn
Copy link
Contributor Author

Test result (test_smoothquant.py):

$ python test/prototype/test_smoothquant.py
..............................................
----------------------------------------------------------------------
Ran 46 tests in 15.208s

OK

@namgyu-youn
Copy link
Contributor Author

@jerryzh168 Hi, I am happy to show you more generalized SmoothQuant API by using Quantization API (torchao/quantization/quant_api.py) at ba89d03. Could you review this PR?

"device", ["cpu"] + (["cuda"] if torch.cuda.is_available() else [])
"base_config",
[
int8_dynamic_activation_int8_weight(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: this API is deprecated, use Int8DynamicActivationInt8WeightConfig instead

insert_smooth_quant_observer_(model, alpha, quant_mode)
# Step 1: Insert observers to find average magnitude and calculate scales
config = SmoothQuantConfig(
base_config=int8_dynamic_activation_int8_weight(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can generalize the example API to take quant type configs now, see

help="Quantization method. Options are either awq-int4wo-<group_size>, or int4wo-<group_size>.",

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, but how about using Int8DynamicActivationInt8WeightConfig as a default in here and devide PR? It might require checking which APIs are compatiable with SmoothQuantConfig, and building unittest.

Even more, we can uniform commonly used utils functions in AWQ and SmoothQuant: get_calib_dataset, wiki2_eval, and quantize_and_eval.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah sure

@namgyu-youn namgyu-youn requested a review from jerryzh168 August 19, 2025 11:03
print(f"time for convert: {time.time() - t0:.02f} seconds")

# Set up config for loading
quant_config.step = SmoothQuantStep.PREPARE_FOR_LOADING
Copy link
Contributor

@jerryzh168 jerryzh168 Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work? you can check if it works by the following:

export MODEL=YOUR_SAVED_SMOOTHQUANT_MODEL
lm_eval --model hf --model_args pretrained=$MODEL --tasks $TASK --device cuda:0 --batch_size auto --limit 50

# vllm
export MODEL=YOUR_SAVED_SMOOTHQUANT_MODEL
python benchmarks/benchmark_latency.py --input-len 256 --output-len 256 --model $MODEL

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hoped so because it works similarly to AWQ, but just tested it with the following code for assurance and got the log message:

import tempfile
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from torchao.prototype.smoothquant import SmoothQuantConfig
from torchao.prototype.smoothquant.core import SmoothQuantStep
from torchao.prototype.smoothquant.example import quantize_and_eval
from torchao.quantization import quantize_
from torchao.quantization.quant_api import Int8DynamicActivationInt8WeightConfig

MODEL_NAME = "microsoft/DialoGPT-small"

# Step 1: Create quantized model
with tempfile.NamedTemporaryFile(suffix='.pt', delete=False) as f:
    model_path = f.name

quantize_and_eval(MODEL_NAME, 0.5, ['PPL'], 256, 5, 'cuda', torch.float32, False, model_path, None)

# Step 2: Test PREPARE_FOR_LOADING
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token

model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float32).cuda()
quantize_(model, SmoothQuantConfig(
    base_config=Int8DynamicActivationInt8WeightConfig(),
    step=SmoothQuantStep.PREPARE_FOR_LOADING,
    alpha=0.5,
))

# Test inference
test_input = tokenizer('Hello world', return_tensors='pt').to('cuda')
with torch.no_grad():
    output = model(**test_input)
    generated = model.generate(**test_input, max_length=20, do_sample=False)

print(f"✓ Inference: {output.logits.shape}")
print(f"✓ Generation: {tokenizer.decode(generated[0], skip_special_tokens=True)}")
Loading model on cuda...
Time to load model: 1.86 seconds
running SmoothQuant prepare and calibrate
Repo card metadata block was not found. Setting CardData to empty.
Token indices sequence length is longer than the specified maximum sequence length for this model (1443 > 1024). Running this sequence through the model will result in indexing errors
time for prepare and calibration: 5.20 seconds
running SmoothQuant convert
time for convert: 0.04 seconds
Saving model to /tmp/tmpqeme5s1r.pt
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
✓ Inference: torch.Size([1, 4, 50257])
✓ Generation: TorchAO TorchAO

For sure, we should benchmark them with your suggestion, but I want to carefully suggest dividing its PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK sounds good to divide the PR

insert_smooth_quant_observer_(model)
load_smooth_quant_recipe(model, "./smooth_quant_recipe.json")

# Step 3: Convert
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think ideally we can add a tutorial doc for how to save transformer models for vllm/lm-eval as well:

if quant.startswith("awq-int4wo"):
group_size = int(quant.split("-")[2])
print(f"running {quant} quantization with group size {group_size}")
# TODO: this is temporary, we'll be using Int4WeightOnlyConfig soon
from torchao.quantization import FbgemmConfig
# use_hqq = True
# base_config = Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)
base_config = FbgemmConfig(
input_dtype=torch.bfloat16,
weight_dtype=torch.int4,
output_dtype=torch.bfloat16,
block_size=[1, group_size],
preshuffle=False,
)
print(f"running {quant} prepare and calibrate")
t0 = time.time()
quant_config = AWQConfig(base_config, step="prepare")
quantize_(
model,
quant_config,
)
from torchao._models._eval import TransformerEvalWrapper
TransformerEvalWrapper(
model=model.to(device),
tokenizer=tokenizer,
max_seq_length=max_seq_length,
device=device,
).run_eval(
tasks=tasks,
limit=calibration_limit,
)
print(f"time for prepare and calibration: {time.time() - t0:.02f} seconds")
print(f"running {quant} convert")
t0 = time.time()
quant_config = AWQConfig(base_config, step="convert")
quantize_(model, quant_config)
print(f"time for convert: {time.time() - t0:.02f} seconds")
quant_config = AWQConfig(base_config, step="prepare_for_loading")
model.config.quantization_config = TorchAoConfig(quant_config)

basically: prepare, convert, and then manually set the config step to "prepare_for_loading"
and then upload the model.

after this the model should be able to be used with vllm and lm-eval.

Comment on lines 115 to 116
if config.set_inductor_config:
torchao.quantization.utils.recommended_inductor_config_setter()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be removed I think

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please remove this before landing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the reminder. I will remove it.


# Get quantization parameters
if all(x is not None for x in (config.smoothing_factor, config.wei_scales)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

when is this branch taken? how do people generate these parameters?

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we only use the smooth factor in here. wei_scales and act_scales should be totally removed.

Comment on lines 27 to 29
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
eps: Optional[float] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we remove these args as well, I think these are not needed

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes they can be removed with AffineQuantizedMinMaxObserver, thanks.

Comment on lines 53 to 58
self.act_ic_obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.int8,
PerAxis(-1),
eps=eps,
MappingType.SYMMETRIC, torch.int8, PerAxis(-1), eps=self.eps
)
self.wei_ic_obs = AffineQuantizedMinMaxObserver(
MappingType.SYMMETRIC,
torch.int8,
PerAxis(-1),
eps=eps,
MappingType.SYMMETRIC, torch.int8, PerAxis(-1), eps=self.eps
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these hardcoded to int8? what if we want to apply them to other types of quantization?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed updating it after testing Int8DynamicActivationInt8WeightConfig. It should be fixed.

Comment on lines 76 to 78
wei_min_per_ic = self.wei_ic_obs.min_val
wei_max_per_ic = self.wei_ic_obs.max_val
act_min_per_ic = self.act_ic_obs.min_val
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so we only need min_val/max_val right? might be easier to just do this ourselves instead of relying on AffineQuantizedMinMaxObserver? we can copy over some of the main logic to record min_val/max_val based on granularity of scale (PerAxis) as well

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I misunderstood it. What we only need here are min_val/max_val same as AWQ. The workflow should be updated based on it.

print("Loading dataset")
t0 = time.time()
# TODO: Uniform this with torchao/prototype/awq/example.py and expand more tasks
def benchmark(model, tokenizer, max_seq_length=512, tasks=["PPL"], device="cuda"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was hoping to remove these in the future and just rely on vllm for performance evaluation and lm-eval for model quality evaluation

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Filed an issue at #2815 for testing more quantization API and benchmarks. And these task sound good to me; let me engage it after this PR.

@namgyu-youn namgyu-youn requested a review from jerryzh168 August 20, 2025 06:57
)
weight = observed_linear.weight * smoothing_factor

# Create new linear layer
linear = torch.nn.Linear(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: one trick we can have when creating these linear weights is to create them in meta device: https://github.com/vllm-project/vllm/blob/c86af22f31838ee654c856279ac5110ae3fdb2cc/vllm/model_executor/layers/quantization/torchao.py#L159 to save memory I think, similar for awq

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for teaching me the Meta device. LGTM for memory saving when transformations are done before loading the actual data.



class SmoothQuantObserver(torch.nn.Module):
def __init__(
self,
weight: torch.Tensor,
bias: Optional[torch.Tensor],
base_config: AOBaseConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually seems like base_config is not used here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the observer only computes smoothing factor, base_config can be removed, thanks.

set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
"""

base_config: AOBaseConfig
step: SmoothQuantStep
alpha: Optional[float] = 0.5
smoothing_factor: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be removed as well? seems like not used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing it looks good because it can be computed without initialization, thanks.

smoothing_factor: Optional[torch.Tensor] = None
act_scales: Optional[torch.Tensor] = None
wei_scales: Optional[torch.Tensor] = None
set_inductor_config: bool = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also this flag, I don't think we need this

# Get quantization parameters
smoothing_factor = (
config.smoothing_factor
if config.smoothing_factor is not None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably just remove this arg, I don't see when we'll use it

self.obs = obs

def forward(self, input: torch.Tensor):
input = self.obs(input)
output = F.linear(input, self.weight, self.bias)
return output
return F.linear(input, self.weight, self.bias)

@classmethod
def from_float(cls, float_linear: torch.nn.Linear, obs: SmoothQuantObserver):
observed_linear = cls(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for teaching me meta devices again!

@@ -68,29 +64,15 @@ Running the example with `torch.compile` on a NVIDIA A10G GPU.
Perplexity
| Quant Method | alpha=0.25 | alpha=0.5 | alpha=0.75 | alpha=None* |
|-|-|-|-|-|
| Dynamic | 8.1872 | 7.4257 | 7.2518 | 7.5509 |
| Static | 43.8051 | 11.2984 | 7.5791 | 19.5050 |
| SmoothQuant | - | - | - | - |
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we'd need to run some initial testing to make sure the refactor works? otherwise we might be landing non-working code

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the omitted section and will be updated after the refactor is finished. I will only add int8 dynamic in this PR, but am definitely interested in expanding to vLLM benchmarks with more quantization APIs.

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you plan to update this before this PR is merged or later? I feel we should update this before the PR can be merged to make sure this is working

performance benchmark can come later, but we should have accuracy test to make sure smoothquant implementation is correct I think

Copy link
Contributor Author

@namgyu-youn namgyu-youn Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I didn't make it clear; it will be updated in this PR after refactoring is finished. Actually, I have been checking its performance using TinyLlama/TinyLlama-1.1B-Chat-v1.0 for each commit. The architecture is quite different, so I didn't mention them, although the accuracy (perplexity) was affordable. Following is one of the experiment results:

image

Similar to AWQ, it will be updated using the Llama-2-7b-chat-hf model. Please feel free to direct commit about it because my setup is quite different from your team's setup (1xA100 80GB SXM4 instance).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK please request review when you are ready

@namgyu-youn namgyu-youn requested a review from jerryzh168 August 21, 2025 15:17

Note*: Conventional quantization without SmoothQuant
Evaluation perplexity numbers were calculated using the script in `smoothquant/example.py`. For Llama-2-7b-chat-hf, performance benchmarks were calculated using the `torchao/_models/llama/generate.py` script and run on a 1xA100 48GB PCIe interface.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jerryzh168 Since my setup (1xA100 48GB PCIe) is not the same as your team (1xA100 80GB SXM4 instance), the result can be quite different. Please feel free to update it if needed.

Copy link
Contributor

@jerryzh168 jerryzh168 Aug 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's fine I feel, we are mainly interested in the perplexity changes with and without AWQ I think, could you show that? you can use https://github.com/pytorch/ao/blob/main/torchao/_models/llama/eval.py to get the perplexity I think, or try to add a non-awq version in example.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants