Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

from torchao.prototype.moe_training.kernels.float8_rowwise import (
triton_fp8_rowwise_3d_transpose_rhs,
triton_fp8_rowwise_3d_transpose_rhs_fused_reduction,
)
from torchao.prototype.moe_training.utils import (
torch_to_3d_rowwise_float8_transpose_rhs,
Expand All @@ -37,9 +38,11 @@ class ExperimentConfig:
@dataclass(frozen=True)
class ExperimentResult:
torch_time_us: float
triton_time_us: float
triton_atomic_time_us: float
triton_reduction_time_us: float
torch_mem_bw_gbps: float
triton_mem_bw_gbps: float
triton_atomic_mem_bw_gbps: float
triton_reduction_mem_bw_gbps: float


@dataclass(frozen=True)
Expand All @@ -59,7 +62,7 @@ def get_configs() -> List[ExperimentConfig]:
(128, 5120, 8192), # w2
]
high_precision_dtypes = [torch.bfloat16]
power_of_2_scales = [True, False]
power_of_2_scales = [True]
configs = []
for input_shape, high_precision_dtype, power_of_2_scale in itertools.product(
input_shapes, high_precision_dtypes, power_of_2_scales
Expand Down Expand Up @@ -94,14 +97,22 @@ def run_torch(input_tensor: torch.Tensor):
)
return out

def run_triton(input_tensor: torch.Tensor):
def run_triton_atomic(input_tensor: torch.Tensor):
out = triton_fp8_rowwise_3d_transpose_rhs(
input_tensor,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=config.power_of_2_scales,
)
return out

def run_triton_reduction(input_tensor: torch.Tensor):
out = triton_fp8_rowwise_3d_transpose_rhs_fused_reduction(
input_tensor,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=config.power_of_2_scales,
)
return out

# bench torch
compiled_run_torch = torch.compile(run_torch)
warmup(run_torch, input_tensor)
Expand All @@ -110,10 +121,19 @@ def run_triton(input_tensor: torch.Tensor):
input_tensor,
)

# bench triton
warmup(run_triton, input_tensor)
triton_time_us = benchmark_cuda_function_in_microseconds(
run_triton,
# bench triton atomic method
run_triton_atomic_c = torch.compile(run_triton_atomic)
warmup(run_triton_atomic_c, input_tensor)
triton_atomic_time_us = benchmark_cuda_function_in_microseconds(
run_triton_atomic_c,
input_tensor,
)

# bench triton reduction method
run_triton_reduction_c = torch.compile(run_triton_reduction)
warmup(run_triton_reduction_c, input_tensor)
triton_reduction_time_us = benchmark_cuda_function_in_microseconds(
run_triton_reduction_c,
input_tensor,
)

Expand All @@ -129,13 +149,20 @@ def run_triton(input_tensor: torch.Tensor):
# Both torch.compile codegen and the triton kernel read the input tensor twice
# (once for scale calculations, once for scaling + casting).
torch_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (torch_time_us / 1e6)
triton_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (triton_time_us / 1e6)
triton_atomic_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (
triton_atomic_time_us / 1e6
)
triton_reduction_mem_bw_gbps = ((read_bytes * 2 + write_bytes) / 1e9) / (
triton_reduction_time_us / 1e6
)

return ExperimentResult(
torch_time_us=torch_time_us,
triton_time_us=triton_time_us,
triton_atomic_time_us=triton_atomic_time_us,
triton_reduction_time_us=triton_reduction_time_us,
torch_mem_bw_gbps=torch_mem_bw_gbps,
triton_mem_bw_gbps=triton_mem_bw_gbps,
triton_atomic_mem_bw_gbps=triton_atomic_mem_bw_gbps,
triton_reduction_mem_bw_gbps=triton_reduction_mem_bw_gbps,
)


Expand All @@ -144,10 +171,13 @@ def print_results(experiments: List[Experiment]):
"input_shape",
"power_of_2_scales",
"torch_time_us",
"triton_time_us",
"triton_atomic_time_us",
"triton_reduction_time_us",
"torch_mem_bw_gbps",
"triton_mem_bw_gbps",
"triton_speedup",
"triton_atomic_mem_bw_gbps",
"triton_reduction_mem_bw_gbps",
"triton_atomic_speedup",
"triton_reduction_speedup",
]
rows = []
for experiment in experiments:
Expand All @@ -157,10 +187,13 @@ def print_results(experiments: List[Experiment]):
input_shape,
experiment.config.power_of_2_scales,
experiment.result.torch_time_us,
experiment.result.triton_time_us,
experiment.result.triton_atomic_time_us,
experiment.result.triton_reduction_time_us,
round(experiment.result.torch_mem_bw_gbps, 3),
round(experiment.result.triton_mem_bw_gbps, 3),
f"{experiment.result.torch_time_us / experiment.result.triton_time_us:.2f}x",
round(experiment.result.triton_atomic_mem_bw_gbps, 3),
round(experiment.result.triton_reduction_mem_bw_gbps, 3),
f"{experiment.result.torch_time_us / experiment.result.triton_atomic_time_us:.2f}x",
f"{experiment.result.torch_time_us / experiment.result.triton_reduction_time_us:.2f}x",
]
)
print(tabulate(rows, headers=headers))
Expand Down
38 changes: 37 additions & 1 deletion test/prototype/moe_training/test_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from torchao.prototype.moe_training.kernels.float8_rowwise import (
triton_fp8_rowwise_3d_transpose_rhs,
triton_fp8_rowwise_3d_transpose_rhs_fused_reduction,
)
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
triton_fp8_per_group_colwise_scales,
Expand Down Expand Up @@ -128,7 +129,7 @@ def test_column_major_with_jagged_colwise_scales(round_scales_to_power_of_2: boo

@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
def test_fp8_rowwise_3d_transpose_rhs_atomic(round_scales_to_power_of_2: bool):
device = "cuda"
experts, n, k = 8, 4 * 5120, 5120

Expand Down Expand Up @@ -159,3 +160,38 @@ def test_fp8_rowwise_3d_transpose_rhs(round_scales_to_power_of_2: bool):
assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"


@skip_if_rocm("ROCm not supported")
@pytest.mark.parametrize("round_scales_to_power_of_2", [True, False])
def test_fp8_rowwise_3d_transpose_rhs_reduction(round_scales_to_power_of_2: bool):
device = "cuda"
experts, n, k = 8, 4 * 5120, 5120

# Example expert weights as it comes into forward transposed
torch.manual_seed(0)
x = torch.randn((experts, n, k), dtype=torch.bfloat16, device=device).transpose(
-2, -1
)

# Compute reference with torch impl
ref_fp8, ref_scales = torch_to_3d_rowwise_float8_transpose_rhs(
x,
target_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
# Torch impl keeps empty scaled dim, so we squeeze it out to be consistent with triton impl
ref_scales = ref_scales.squeeze(1)

triton_fp8, triton_scales = triton_fp8_rowwise_3d_transpose_rhs_fused_reduction(
x,
output_dtype=torch.float8_e4m3fn,
round_scales_to_power_of_2=round_scales_to_power_of_2,
)
assert ref_scales.shape == triton_scales.shape, "scale shapes not equal"
assert ref_scales.stride() == triton_scales.stride(), "scale strides not equal"
assert torch.allclose(ref_scales, triton_scales, rtol=0, atol=0), "scales not equal"

assert ref_fp8.shape == triton_fp8.shape, "output shapes not equal"
assert ref_fp8.stride() == triton_fp8.stride(), "output strides not equal"
assert torch.allclose(ref_fp8, triton_fp8, rtol=0, atol=0), "fp8 data not equal"
Loading
Loading