Skip to content
38 changes: 36 additions & 2 deletions torchao/quantization/pt2e/reference_representation_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
from torch.fx import GraphModule
from torch.fx.passes.utils.matcher_with_name_node_map_utils import InternalMatch
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
from torch.fx.subgraph_rewriter import ReplacedPatterns, replace_pattern_with_filters

from torchao.quantization.pt2e.export_utils import WrapperModule
from torchao.quantization.pt2e.utils import (
Expand Down Expand Up @@ -455,6 +455,34 @@ def _filter_fn_for_dynamic_quantized_linear_4bit_groupwise(
return weight_is_int4 and act_quant_is_int8


def _port_metadata_for_dynamic_quantized_linear_4bit_groupwise(
replacement_pattern: ReplacedPatterns,
):
"""
Port metadata for dynamically quantized linear 4-bit groupwise operation.
It custom_op node's metadata with corresponding linear node's metadata.
"""
from torch.fx.traceback import NodeSource, NodeSourceAction

linear_node = None
int4_custom_op_node = None
for _, g_n in replacement_pattern.nodes_map.items():
if g_n.target == torch.ops.aten.linear.default:
linear_node = g_n
break
if len(replacement_pattern.replacements) > 0:
int4_custom_op_node = replacement_pattern.replacements[-1]
if linear_node is not None and int4_custom_op_node is not None:
int4_custom_op_node.meta = linear_node.meta.copy()
int4_custom_op_node.meta["from_node"] = [
NodeSource(
linear_node,
"ReplaceInt4DynamicQuantWithCustomOp",
NodeSourceAction.REPLACE,
)
]


def _qdq_quantized_conv2d(
x_i8,
x_scale,
Expand Down Expand Up @@ -883,6 +911,7 @@ class _RewriteInfo:
list[Callable[["InternalMatch", torch.fx.Graph, torch.fx.Graph], bool]]
] = None
ignore_literals: bool = False
port_metadata_fn: Optional[Callable[["ReplacedPatterns"], None]] = None


def reference_representation_rewrite(model: GraphModule) -> GraphModule:
Expand Down Expand Up @@ -1053,6 +1082,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
),
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
ignore_literals=True,
port_metadata_fn=_port_metadata_for_dynamic_quantized_linear_4bit_groupwise,
),
_RewriteInfo(
_DYNAMIC_QUANTIZED_LINEAR_4BIT_GROUPWISE_EXAMPLE_INPUTS_2,
Expand All @@ -1074,6 +1104,7 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
),
filter_fn=[_filter_fn_for_dynamic_quantized_linear_4bit_groupwise],
ignore_literals=True,
port_metadata_fn=_port_metadata_for_dynamic_quantized_linear_4bit_groupwise,
),
_RewriteInfo(
_QUANTIZED_LINEAR_EXAMPLE_INPUTS,
Expand Down Expand Up @@ -1153,12 +1184,15 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
replacement = replacement_post_trans(replacement)
pattern.recompile() # type: ignore[attr-defined]
replacement.recompile() # type: ignore[attr-defined]
replace_pattern_with_filters(
matches = replace_pattern_with_filters(
model,
pattern,
replacement,
match_filters=rewrite_info.filter_fn,
ignore_literals=rewrite_info.ignore_literals,
) # type: ignore[arg-type]
if rewrite_info.port_metadata_fn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reference rewrite is an option pass. Or at least thats how I was using it in my debug. In reference rewrite the rules for porting metadata is not as generic, so you have to let each pattern dictate what they want to do

for m in matches:
rewrite_info.port_metadata_fn(m) # type: ignore[arg-type]

return model
Loading