Skip to content

Commit f02354d

Browse files
authored
Add tracking for new tensors, AQT and layouts (#2895)
Summary: Add api logging for these things to understand the model checkpoints that's using these APIs Test Plan: internal queries Reviewers: Subscribers: Tasks: Tags:
1 parent 4236656 commit f02354d

File tree

10 files changed

+15
-1
lines changed

10 files changed

+15
-1
lines changed

torchao/dtypes/affine_quantized_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(
116116
dtype=None,
117117
strides=None,
118118
):
119+
torch._C._log_api_usage_once(str(type(self)))
119120
self.tensor_impl = tensor_impl
120121
self.block_size = block_size
121122
self.quant_min = quant_min

torchao/dtypes/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def __repr__(self):
6868
def extra_repr(self) -> str:
6969
return ""
7070

71+
def __post_init__(self):
72+
torch._C._log_api_usage_once(str(type(self)))
73+
7174

7275
@dataclass(frozen=True)
7376
class PlainLayout(Layout):

torchao/quantization/quantize_/workflows/float8/float8_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def __init__(
136136
kernel_preference: KernelPreference = KernelPreference.AUTO,
137137
dtype: Optional[torch.dtype] = None,
138138
):
139+
super().__init__()
139140
self.qdata = qdata
140141
self.scale = scale
141142
self.block_size = block_size

torchao/quantization/quantize_/workflows/int4/int4_marlin_sparse_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __new__(cls, qdata, scale, zero_point, meta, block_size, num_bits, shape):
3535
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
3636

3737
def __init__(self, qdata, scale, zero_point, meta, block_size, num_bits, shape):
38+
super().__init__()
3839
self.qdata = qdata
3940
self.scale = scale
4041
self.zero_point = zero_point

torchao/quantization/quantize_/workflows/int4/int4_opaque_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def __init__(
7070
block_size: List[int],
7171
shape: torch.Size,
7272
):
73+
super().__init__()
7374
self.qdata = qdata
7475
self.scale_and_zero = scale_and_zero
7576
self.block_size = block_size

torchao/quantization/quantize_/workflows/int4/int4_preshuffled_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
group_zero: Optional[torch.Tensor] = None,
103103
row_scale: Optional[torch.Tensor] = None,
104104
):
105+
super().__init__()
105106
# one and only one of group_scale and group_zero should be None
106107
assert group_zero is None or row_scale is None
107108
assert not (group_zero is not None and row_scale is not None)

torchao/quantization/quantize_/workflows/int4/int4_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ def __init__(
7575
shape: torch.Size,
7676
act_pre_scale: Optional[torch.Tensor] = None,
7777
):
78+
super().__init__()
7879
self.qdata = qdata
7980
self.scale = scale
8081
self.zero_point = zero_point

torchao/quantization/quantize_/workflows/intx/intx_opaque_tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ComputeTarget(enum.Enum):
3737
ATEN = "aten"
3838

3939
"""
40-
This packs the tensor for TorchAO CPU kernels by selecting the best available kernel
40+
This packs the tensor for TorchAO CPU kernels by selecting the best available kernel
4141
based on the quantization scheme, either using KlediAI kernels or lowbit kernels.
4242
It requires TorchAO C++ kernels to be installed.
4343
"""
@@ -112,6 +112,7 @@ def __init__(
112112
packed_weights_has_bias,
113113
compute_target,
114114
):
115+
super().__init__()
115116
assert packed_weights.device == torch.device("cpu")
116117
self.packed_weights = packed_weights
117118
self.bit_width = bit_width

torchao/quantization/quantize_/workflows/intx/intx_unpacked_to_int8_tensor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def __init__(
9393
dtype,
9494
apply_int8_act_asym_per_token_quant,
9595
):
96+
super().__init__()
9697
assert qdata.dtype == torch.int8, (
9798
f"qdata dtype must be int8, but got {qdata.dtype}"
9899
)

torchao/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,9 @@ def __init_subclass__(cls, **kwargs):
860860
get_tensor_impl_constructor = classmethod(_get_tensor_impl_constructor)
861861
_get_to_kwargs = _get_to_kwargs
862862

863+
def __init__(self, *args, **kwargs):
864+
torch._C._log_api_usage_once(str(type(self)))
865+
863866
def __tensor_flatten__(self):
864867
if hasattr(self, "tensor_data_names") and hasattr(
865868
self, "tensor_attribute_names"

0 commit comments

Comments
 (0)