Skip to content

Update "GPU Quantization with TorchAO" #3439

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 10, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 12 additions & 12 deletions prototype_source/gpu_quantization_torchao_tutorial.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
# > conda create -n myenv python=3.10
# > pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu121
# > pip install git+https://github.com/facebookresearch/segment-anything.git
# > pip install git+https://github.com/pytorch-labs/ao.git
# > pip install git+https://github.com/pytorch/ao.git
#
# Segment Anything Model checkpoint setup:
#
Expand All @@ -44,7 +44,7 @@
#

import torch
from torchao.quantization.quant_api import quantize_, int8_dynamic_activation_int8_weight
from torchao.quantization.quant_api import quantize_, Int8DynamicActivationInt8WeightConfig
from torchao.utils import unwrap_tensor_subclass, TORCH_VERSION_AT_LEAST_2_5
from segment_anything import sam_model_registry
from torch.utils.benchmark import Timer
Expand Down Expand Up @@ -143,7 +143,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
# for improvements.
#
# Next, let's apply quantization. Quantization for GPUs comes in three main forms
# in `torchao <https://github.com/pytorch-labs/ao>`_ which is just native
# in `torchao <https://github.com/pytorch/ao>`_ which is just native
# pytorch+python code. This includes:
#
# * int8 dynamic quantization
Expand All @@ -157,9 +157,9 @@ def get_sam_model(only_one_block=False, batchsize=1):
# in memory bound situations where the benefit comes from loading less
# weight data, rather than doing less computation. The torchao APIs:
#
# ``int8_dynamic_activation_int8_weight()``,
# ``int8_weight_only()`` or
# ``int4_weight_only()``
# ``Int8DynamicActivationInt8WeightConfig()``,
# ``Int8WeightOnlyConfig()`` or
# ``Int4WeightOnlyConfig()``
#
# can be used to easily apply the desired quantization technique and then
# once the model is compiled with ``torch.compile`` with ``max-autotune``, quantization is
Expand All @@ -171,7 +171,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
# ``apply_weight_only_int8_quant`` instead as drop in replacement for the two
# above (no replacement for int4).
#
# The difference between the two APIs is that ``int8_dynamic_activation`` API
# The difference between the two APIs is that the ``Int8DynamicActivationInt8WeightConfig`` API
# alters the weight tensor of the linear module so instead of doing a
# normal linear, it does a quantized operation. This is helpful when you
# have non-standard linear ops that do more than one thing. The ``apply``
Expand All @@ -186,7 +186,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
model, image = get_sam_model(only_one_block, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
quantize_(model, int8_dynamic_activation_int8_weight())
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
Expand Down Expand Up @@ -224,7 +224,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
torch._inductor.config.force_fuse_int_mm_with_mul = True
quantize_(model, int8_dynamic_activation_int8_weight())
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
Expand Down Expand Up @@ -258,7 +258,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.coordinate_descent_check_all_directions = True
torch._inductor.config.force_fuse_int_mm_with_mul = True
quantize_(model, int8_dynamic_activation_int8_weight())
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
Expand Down Expand Up @@ -290,7 +290,7 @@ def get_sam_model(only_one_block=False, batchsize=1):
model, image = get_sam_model(False, batchsize)
model = model.to(torch.bfloat16)
image = image.to(torch.bfloat16)
quantize_(model, int8_dynamic_activation_int8_weight())
quantize_(model, Int8DynamicActivationInt8WeightConfig())
if not TORCH_VERSION_AT_LEAST_2_5:
# needed for subclass + compile to work on older versions of pytorch
unwrap_tensor_subclass(model)
Expand All @@ -315,6 +315,6 @@ def get_sam_model(only_one_block=False, batchsize=1):
# the model. For example, this can be done with some form of flash attention.
#
# For more information visit
# `torchao <https://github.com/pytorch-labs/ao>`_ and try it on your own
# `torchao <https://github.com/pytorch/ao>`_ and try it on your own
# models.
#