Skip to content

Use Q_ANNOTATION_KEY #12728

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions backends/arm/quantizer/arm_quantizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,21 @@
from torch.fx import GraphModule, Node

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


def is_annotated(node: Node) -> bool:
"""Given a node return whether the node is annotated."""
return (
"quantization_annotation" in node.meta
and cast(
QuantizationAnnotation, node.meta["quantization_annotation"]
)._annotated
Q_ANNOTATION_KEY in node.meta
and cast(QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY])._annotated
)


def is_output_annotated(node: Node) -> bool:
"""Given a node, return whether the output of the node is annotated."""
if "quantization_annotation" in node.meta:
annotation = cast(QuantizationAnnotation, node.meta["quantization_annotation"])
if Q_ANNOTATION_KEY in node.meta:
annotation = cast(QuantizationAnnotation, node.meta[Q_ANNOTATION_KEY])
return annotation._annotated and annotation.output_qspec is not None
else:
return False
Expand All @@ -43,9 +42,9 @@ def mark_node_as_annotated(node: Node) -> None:
"""Marks node as annotated. If needed, an empty QuantizationAnnotation is added
to the quantization_annotation node meta entry.
"""
if "quantization_annotation" not in node.meta:
node.meta["quantization_annotation"] = QuantizationAnnotation()
node.meta["quantization_annotation"]._annotated = True
if Q_ANNOTATION_KEY not in node.meta:
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation()
node.meta[Q_ANNOTATION_KEY]._annotated = True


def is_ok_for_quantization(node: Node, gm: GraphModule):
Expand Down
11 changes: 6 additions & 5 deletions backends/cadence/aot/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
QuantizationSpec,
Quantizer,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


act_qspec_asym8s = QuantizationSpec(
Expand Down Expand Up @@ -127,7 +128,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

for output, *custom_spec in anchors.output:
# pyre-ignore[16]: no attribute
output.meta["quantization_annotation"] = QuantizationAnnotation(
output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
_annotated=True,
Expand All @@ -143,7 +144,7 @@ def annotate_inputs(
for node, idx, *custom_spec in inputs:
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
"quantization_annotation",
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
arg = (
Expand All @@ -157,21 +158,21 @@ def annotate_inputs(
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation
node.meta[Q_ANNOTATION_KEY] = annotation

def annotate_weights_or_biases(
weights_or_biases: List[Tuple[fx.Node, int]],
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in weights_or_biases:
annotation = node.meta.get(
"quantization_annotation",
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
node.meta["quantization_annotation"] = annotation
node.meta[Q_ANNOTATION_KEY] = annotation

# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.inputs, input_act_qspec)
Expand Down
4 changes: 2 additions & 2 deletions backends/cadence/aot/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
SourcePartition,
)
from torchao.quantization.pt2e import ObserverOrFakeQuantize
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


def quantize_tensor_multiplier(
Expand Down Expand Up @@ -88,8 +89,7 @@ def is_annotated(nodes: List[fx.Node]) -> bool:
annotated = False
for node in nodes:
annotated = annotated or (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated
)
return annotated

Expand Down
8 changes: 3 additions & 5 deletions backends/cortex_m/test/test_replace_quant_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
QuantizationSpec,
Quantizer,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


@dataclass(eq=True, frozen=True)
Expand Down Expand Up @@ -67,18 +68,15 @@ def annotate(self, model: GraphModule):
]:
continue

if (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
):
if Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated:
continue

input_qspec_map = {
node.args[0]: config.input_activation,
node.args[1]: config.input_activation,
}

node.meta["quantization_annotation"] = QuantizationAnnotation(
node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=config.output_activation,
_annotated=True,
Expand Down
7 changes: 4 additions & 3 deletions backends/example/example_operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,12 @@
# LICENSE file in the root directory of this source tree.

from torchao.quantization.pt2e.quantizer import QuantizationAnnotation
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


def _nodes_are_annotated(node_list):
for node in node_list:
quantization_annotation = node.meta.get("quantization_annotation", None)
quantization_annotation = node.meta.get(Q_ANNOTATION_KEY, None)
if not quantization_annotation:
return False
if quantization_annotation._annotated:
Expand All @@ -23,11 +24,11 @@ def _annotate_nodes(node_tuples, quant_spec, input_node=False):
for node_tuple in node_tuples:
node = node_tuple[0]
quant_annotation = node.meta.get(
"quantization_annotation", QuantizationAnnotation(_annotated=True)
Q_ANNOTATION_KEY, QuantizationAnnotation(_annotated=True)
)
if input_node:
input_node = node_tuple[1]
quant_annotation.input_qspec_map[input_node] = quant_spec
else:
quant_annotation.output_qspec = quant_spec
node.meta["quantization_annotation"] = quant_annotation
node.meta[Q_ANNOTATION_KEY] = quant_annotation
5 changes: 3 additions & 2 deletions backends/mediatek/quantizer/annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
annotate_output_qspec as _annotate_output_qspec,
QuantizationAnnotation,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY

from .qconfig import QuantizationConfig

Expand Down Expand Up @@ -57,12 +58,12 @@ def _is_annotated(node: Node):
return True if any of the node
is annotated, otherwise return False
"""
KEY = "quantization_annotation"
KEY = Q_ANNOTATION_KEY
return KEY in node.meta and node.meta[KEY]._annotated


def _mark_as_annotated(nodes: List[Node]):
KEY = "quantization_annotation"
KEY = Q_ANNOTATION_KEY
for node in nodes:
if KEY not in node.meta:
node.meta[KEY] = QuantizationAnnotation()
Expand Down
11 changes: 6 additions & 5 deletions backends/nxp/quantizer/neutron_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
QuantizationSpec,
Quantizer,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


class NeutronAtenQuantizer(Quantizer):
Expand Down Expand Up @@ -92,7 +93,7 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:

for output, *custom_spec in anchors.output:
# pyre-ignore[16]: no attribute
output.meta["quantization_annotation"] = QuantizationAnnotation(
output.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(
# pyre-ignore[6]: incompatible parameter type
output_qspec=(custom_spec[0] if custom_spec else output_act_qspec),
_annotated=True,
Expand All @@ -108,7 +109,7 @@ def annotate_inputs(
for node, idx, *custom_spec in inputs:
# pyre-ignore[16]: no attribute
annotation = node.meta.get(
"quantization_annotation",
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
arg = (
Expand All @@ -122,21 +123,21 @@ def annotate_inputs(
custom_spec[0] if custom_spec else spec
)
# pyre-ignore[16]: no attribute
node.meta["quantization_annotation"] = annotation
node.meta[Q_ANNOTATION_KEY] = annotation

def annotate_weights_or_biases(
weights_or_biases: List[Tuple[fx.Node, int]],
spec: Optional[QuantizationSpec],
) -> None:
for node, idx, *custom_spec in weights_or_biases:
annotation = node.meta.get(
"quantization_annotation",
Q_ANNOTATION_KEY,
QuantizationAnnotation(_annotated=True),
)
annotation.input_qspec_map[node.args[idx]] = (
custom_spec[0] if custom_spec else spec
)
node.meta["quantization_annotation"] = annotation
node.meta[Q_ANNOTATION_KEY] = annotation

# pyre-ignore[6]: incompatible parameter type
annotate_inputs(anchors.inputs, input_act_qspec)
Expand Down
3 changes: 2 additions & 1 deletion backends/nxp/quantizer/patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
FixedQParamsQuantizationSpec,
SharedQuantizationSpec,
)
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


@dataclass
Expand Down Expand Up @@ -90,7 +91,7 @@ def get_anchors(
prev_node = fused_partition[0].input_nodes[0]

# Previous node was not quantized => we are not able to share q-params
if "quantization_annotation" not in prev_node.meta:
if Q_ANNOTATION_KEY not in prev_node.meta:
return None

qspec = SharedQuantizationSpec(prev_node)
Expand Down
4 changes: 2 additions & 2 deletions backends/nxp/quantizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
SourcePartition,
)
from torchao.quantization.pt2e import ObserverOrFakeQuantize
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


def is_annotated(nodes: List[fx.Node]) -> bool:
annotated = False
for node in nodes:
annotated = annotated or (
"quantization_annotation" in node.meta
and node.meta["quantization_annotation"]._annotated
Q_ANNOTATION_KEY in node.meta and node.meta[Q_ANNOTATION_KEY]._annotated
)
return annotated

Expand Down
7 changes: 3 additions & 4 deletions backends/openvino/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@
Quantizer,
SharedQuantizationSpec,
)

QUANT_ANNOTATION_KEY = "quantization_annotation"
from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY


class QuantizationMode(Enum):
Expand Down Expand Up @@ -174,8 +173,8 @@ def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
self._fill_torch_ao_annotation(edge_or_node, qspec, annotation)

for node, annotation in node_vs_torch_annotation.items():
assert QUANT_ANNOTATION_KEY not in node.meta
node.meta[QUANT_ANNOTATION_KEY] = annotation
assert Q_ANNOTATION_KEY not in node.meta
node.meta[Q_ANNOTATION_KEY] = annotation
return model

@staticmethod
Expand Down
Loading
Loading