diff --git a/torchao/prototype/moe_training/scaled_grouped_mm.py b/torchao/prototype/moe_training/scaled_grouped_mm.py index fb76821601..225ee57842 100644 --- a/torchao/prototype/moe_training/scaled_grouped_mm.py +++ b/torchao/prototype/moe_training/scaled_grouped_mm.py @@ -122,7 +122,7 @@ def forward( round_scales_to_power_of_2=True, ) A_scaled = A.to(torch.float32) * A_scales - A_fp8_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) + A_data_row_major = to_fp8_saturated(A_scaled, torch.float8_e4m3fn) # Convert B to float8, column-major for right operand of grouped GEMM. # B_t shape: (E, K, N) @@ -136,7 +136,7 @@ def forward( round_scales_to_power_of_2=True, ) B_t_scaled = B_t.to(torch.float32) * B_t_scales - B_t_fp8_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) + B_t_data_col_major = to_fp8_saturated(B_t_scaled, torch.float8_e4m3fn) # Store what we need for backward. ctx.save_for_backward(A, B_t, offs) @@ -144,10 +144,10 @@ def forward( # Perform scaled grouped GEMM and return result. # output shape: scaled grouped mm of (M,K) @ (B,K,N) = (M,N) - assert not _is_column_major(A_fp8_row_major), ( + assert not _is_column_major(A_data_row_major), ( "A must be row-major for output = A @ B" ) - assert _is_column_major(B_t_fp8_col_major), ( + assert _is_column_major(B_t_data_col_major), ( "B must be column-major for output = A @ B" ) @@ -157,8 +157,8 @@ def forward( A_scales = A_scales.squeeze(-1) B_t_scales = B_t_scales.squeeze(1) return torch._scaled_grouped_mm( - A_fp8_row_major, - B_t_fp8_col_major, + A_data_row_major, + B_t_data_col_major, A_scales.reciprocal(), # Reciprocals are needed for rescaling the output. B_t_scales.reciprocal(), offs, @@ -184,13 +184,13 @@ def backward(ctx, grad_output: torch.Tensor): round_scales_to_power_of_2=True, ) grad_output_scaled = grad_output.to(torch.float32) * grad_output_scales - grad_output_fp8_row_major = to_fp8_saturated( + grad_output_data_row_major = to_fp8_saturated( grad_output_scaled, torch.float8_e4m3fn ) # Compute B fp8 column-major for right operand of grouped GEMM: # grad_A = grad_output @ B. - B_fp8_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( + B_data_col_major, B_scales = triton_fp8_rowwise_3d_transpose_rhs( B_t._data if hasattr(B_t, "_data") else B_t, output_dtype=torch.float8_e4m3fn, round_scales_to_power_of_2=True, @@ -199,10 +199,10 @@ def backward(ctx, grad_output: torch.Tensor): # Compute grad_A. # grad_A = grad_output @ B # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K) - assert not _is_column_major(grad_output_fp8_row_major), ( + assert not _is_column_major(grad_output_data_row_major), ( "grad_output must be row-major for grad_A = grad_output @ B" ) - assert _is_column_major(B_fp8_col_major), ( + assert _is_column_major(B_data_col_major), ( "B must be column-major for grad_A = grad_output @ B" ) @@ -212,8 +212,8 @@ def backward(ctx, grad_output: torch.Tensor): grad_output_scales = grad_output_scales.squeeze(-1) B_scales = B_scales.squeeze(1) grad_A = torch._scaled_grouped_mm( - grad_output_fp8_row_major, - B_fp8_col_major, + grad_output_data_row_major, + B_data_col_major, grad_output_scales.reciprocal(), B_scales.reciprocal(), offs, @@ -227,7 +227,7 @@ def backward(ctx, grad_output: torch.Tensor): # Convert transpose of grad_output to float8, row-major for left operand of grouped GEMM # needed for grad_B: grad_output_t @ A # Use transpose method to avoid uncoalesced memory accesses. - grad_out_fp8_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales( + grad_out_data_colwise, grad_out_scales = triton_fp8_per_group_colwise_scales( grad_output.t() .contiguous() .t(), # Quantization is over 2x faster when input is col major, even with this transformation @@ -235,10 +235,10 @@ def backward(ctx, grad_output: torch.Tensor): torch.float8_e4m3fn, round_scales_to_power_of_2=True, ) - grad_output_t_fp8_row_major = grad_out_fp8_colwise.t() + grad_output_t_data_row_major = grad_out_data_colwise.t() grad_output_t_scales = grad_out_scales.t() - A_fp8_col_major, A_scales = triton_fp8_per_group_colwise_scales( + A_data_col_major, A_scales = triton_fp8_per_group_colwise_scales( A.t() .contiguous() .t(), # Quantization is over 2x faster when input is col major, even with this transformation @@ -249,10 +249,10 @@ def backward(ctx, grad_output: torch.Tensor): # Compute grad_B = grad_output_t @ A. # grad_B = grad_output_t @ A - assert not _is_column_major(grad_output_t_fp8_row_major), ( + assert not _is_column_major(grad_output_t_data_row_major), ( "grad_output_t must be row-major for grad_B = grad_output_t @ A" ) - assert _is_column_major(A_fp8_col_major), ( + assert _is_column_major(A_data_col_major), ( "A must be column-major for grad_B = grad_output_t @ A" ) @@ -260,8 +260,8 @@ def backward(ctx, grad_output: torch.Tensor): # the empty dim like the scales computed via tensor_to_scale, so we need # don't need to squeeze here. grad_B = torch._scaled_grouped_mm( - grad_output_t_fp8_row_major, - A_fp8_col_major, + grad_output_t_data_row_major, + A_data_col_major, grad_output_t_scales.reciprocal(), A_scales.reciprocal(), offs, @@ -295,13 +295,15 @@ def forward( ctx.out_dtype = out_dtype ctx.emulated = emulated - # A_mx shape: (M, K) + # A_data shape: (M, K) # A_scale shape: (M, K//block_size) - A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size) + A_scale, A_data = to_mx( + A, elem_dtype=torch.float8_e4m3fn, block_size=block_size + ) - # B_mx shape: (E, N, K) + # B_data shape: (E, N, K) # B_scale shape: (E, N, K//block_size) - B_scales, B_mx = to_mx( + B_scales, B_data = to_mx( B_t.transpose(-2, -1), elem_dtype=torch.float8_e4m3fn, block_size=block_size, @@ -315,9 +317,9 @@ def forward( else fbgemm_mxfp8_grouped_mm_2d_3d ) out = mxfp8_2d_3d_grouped_mm( - A_mx, + A_data, A_scale, - B_mx, + B_data, B_scales, offs=offs, block_size=block_size, @@ -332,15 +334,15 @@ def backward(ctx, grad_out: torch.Tensor): out_dtype = ctx.out_dtype emulated = ctx.emulated - # grad_out_mx shape: (M, N) + # grad_out_data shape: (M, N) # grad_out_scale shape: (M, N//block_size) - grad_out_scale, grad_out_mx = to_mx( + grad_out_scale, grad_out_data = to_mx( grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size ) - # B_mx shape: (E, K, N) + # B_data shape: (E, K, N) # B_scale shape: (E, K, N//block_size) - B_scales, B_mx = to_mx( + B_scales, B_data = to_mx( # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency? B_t.contiguous(), elem_dtype=torch.float8_e4m3fn, @@ -354,17 +356,17 @@ def backward(ctx, grad_out: torch.Tensor): else fbgemm_mxfp8_grouped_mm_2d_3d ) grad_A = mxfp8_2d_3d_grouped_mm( - grad_out_mx, + grad_out_data, grad_out_scale, - B_mx, + B_data, B_scales, offs=offs, out_dtype=out_dtype, ) - # grad_out_t_mx shape: (N, M) + # grad_out_t_data shape: (N, M) # grad_out_t_scales shape: (N, M//block_size) - grad_out_t_scales, grad_out_t_mx = to_mx( + grad_out_t_scales, grad_out_t_data = to_mx( # TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency? grad_out.transpose(-2, -1).contiguous(), elem_dtype=torch.float8_e4m3fn, @@ -372,25 +374,25 @@ def backward(ctx, grad_out: torch.Tensor): ) # Transpose A so we can scale along the M dimension, then un-transpose. - # A_t_mx shape: (K, M) + # A_t_data shape: (K, M) # A_t_scales shape: (K, M//block_size) - A_t_scales, A_t_mx = to_mx( + A_t_scales, A_t_data = to_mx( A.transpose(-2, -1).contiguous(), elem_dtype=torch.float8_e4m3fn, block_size=block_size, ) - # A_mx shape = (M, K) - A_mx = A_t_mx.transpose(-2, -1) + # A_data shape = (M, K) + A_data = A_t_data.transpose(-2, -1) # A_scales shape = (M//block_size, K) A_scales = A_t_scales.transpose(-2, -1) # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K) grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d( - grad_out_t_mx, + grad_out_t_data, grad_out_t_scales, - A_mx, + A_data, A_scales, offs=offs, ) @@ -402,45 +404,47 @@ def backward(ctx, grad_out: torch.Tensor): def _emulated_mxfp8_scaled_grouped_mm_2d_3d( - A_mx: torch.Tensor, + A_data: torch.Tensor, A_scale: torch.Tensor, - B_mx: torch.Tensor, + B_data: torch.Tensor, B_scale: torch.Tensor, offs: Optional[torch.Tensor] = None, out_dtype: Optional[torch.dtype] = torch.bfloat16, block_size: int = 32, ) -> torch.Tensor: - assert A_mx.ndim == 2, f"A must be 2D, got {A_mx.ndim}" - assert B_mx.ndim == 3, f"B must be 3D, got {B_mx.ndim}" - assert A_scale.shape[0] == A_mx.shape[0], ( - f"A_scale must have same M dim as A_mx, got A={A_mx.shape} and A_scale={A_scale.shape}" + assert A_data.ndim == 2, f"A must be 2D, got {A_data.ndim}" + assert B_data.ndim == 3, f"B must be 3D, got {B_data.ndim}" + assert A_scale.shape[0] == A_data.shape[0], ( + f"A_scale must have same M dim as A_data, got A={A_data.shape} and A_scale={A_scale.shape}" ) - assert A_scale.shape[1] == A_mx.shape[1] // block_size, ( - f"A_scale dim1 should be size K//block_size, got A={A_mx.shape} and A_scale={A_scale.shape}" + assert A_scale.shape[1] == A_data.shape[1] // block_size, ( + f"A_scale dim1 should be size K//block_size, got A={A_data.shape} and A_scale={A_scale.shape}" ) - assert B_scale.shape[0] == B_mx.shape[0], ( - f"B_scale must have same E dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}" + assert B_scale.shape[0] == B_data.shape[0], ( + f"B_scale must have same E dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" ) - assert B_scale.shape[1] == B_mx.shape[1], ( - f"B_scale must have same N dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}" + assert B_scale.shape[1] == B_data.shape[1], ( + f"B_scale must have same N dim as B_data, got B={B_data.shape} and B_scale={B_scale.shape}" ) - assert B_scale.shape[2] == B_mx.shape[2] // block_size, ( - f"B_scale dim2 should be size K//block_size, got B={B_mx.shape} and B_scale={B_scale.shape}" + assert B_scale.shape[2] == B_data.shape[2] // block_size, ( + f"B_scale dim2 should be size K//block_size, got B={B_data.shape} and B_scale={B_scale.shape}" ) # Dequantize input - # A_mx shape: (M, K) + # A_data shape: (M, K) # A_scale shape: (M, K//block_size) - A_orig_shape = A_mx.shape + A_orig_shape = A_data.shape # Reshape to be able to do per-scaling group multiplication - # A_mx shape: (M, K//block_size, block_size) + # A_data shape: (M, K//block_size, block_size) # A_scale shape: (M, K//block_size, 1) - A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size) + A_data = A_data.reshape( + *A_data.shape[:-1], A_data.shape[-1] // block_size, block_size + ) A_scale = A_scale.unsqueeze(-1) # Rescale and cast to bfloat16 - A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16) + A = A_data.to(torch.bfloat16) * A_scale.to(torch.bfloat16) # Reshape back to original shape # A shape: (M, K) @@ -448,18 +452,20 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d( # Dequantize weights # Tranpose to get block_size on rightmost dim - # B_mx shape: (E, N, K) + # B_data shape: (E, N, K) # B_scale shape: (E, N, K//block_size) - E, N, K = B_mx.shape + E, N, K = B_data.shape # Reshape to be able to do per-scaling group multiplication - # B_mx shape: (E, N, K//block_size, block_size) + # B_data shape: (E, N, K//block_size, block_size) # B_scale shape: (E, N, K//block_size, 1) - B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size) + B_data = B_data.reshape( + *B_data.shape[:-1], B_data.shape[-1] // block_size, block_size + ) B_scale = B_scale.unsqueeze(-1) # Rescale and cast to bfloat16 - B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16) + B = B_data.to(torch.bfloat16) * B_scale.to(torch.bfloat16) # Reshape back to original shape # B shape: (E, K, N) @@ -471,27 +477,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d( def _emulated_mxfp8_scaled_grouped_mm_2d_2d( - A_mx: torch.Tensor, # (M, K) + A_data: torch.Tensor, # (M, K) A_scale: torch.Tensor, # (M, K//block_size) - B_mx: torch.Tensor, # (K, N) + B_data: torch.Tensor, # (K, N) B_scale: torch.Tensor, # (K//block_size, N) offs: torch.Tensor, out_dtype: Optional[torch.dtype] = torch.bfloat16, block_size: int = 32, ) -> torch.Tensor: - assert A_mx.ndim == 2, "A must be 2D" - assert B_mx.ndim == 2, "B must be 2D" + assert A_data.ndim == 2, "A must be 2D" + assert B_data.ndim == 2, "B must be 2D" A = torch.zeros( - A_mx.shape, + A_data.shape, dtype=torch.bfloat16, - device=A_mx.device, - requires_grad=A_mx.requires_grad, + device=A_data.device, + requires_grad=A_data.requires_grad, ) B = torch.zeros( - B_mx.shape, + B_data.shape, dtype=torch.bfloat16, - device=B_mx.device, - requires_grad=B_mx.requires_grad, + device=B_data.device, + requires_grad=B_data.requires_grad, ) # Dequantize input per each scaling group @@ -507,7 +513,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d( # -- Dequantize A tensor # A_group shape: (M, group_size) # A_scale shape: (M, group_size//block_size) - A_group = A_mx[:, group_start_idx:group_end_idx] + A_group = A_data[:, group_start_idx:group_end_idx] A_group_shape = A_group.shape # Get scales for this group. @@ -532,7 +538,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d( # -- Dequantize B tensor # B_group shape is (group_size, N) - B_group = B_mx[group_start_idx:group_end_idx, :] + B_group = B_data[group_start_idx:group_end_idx, :] B_group_shape = B_group.shape # Scales shape is (group_size//block_size, N)