Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 23, 2025

Stacked PRs:


[moe fp8 training] fused reduction kernel along dim1 for 3d expert weights in backward

Summary

  • This PR adds an experimental new kernel for quantizing 3d expert weights (non transposed) for RHS operand in backward pass GEMM of grad_input = grad_output @ weight using 1 single fused kernel with a reduction based approach to calculate scales, rather than 2 stage approach with atomics.
  • Currently the atomics approach is still superior, but I think this is worth because if there are any tricks to optimize this fused kernel further, a single kernel should theoretically be better than 2 separate dispatches.
  • I used TRITON_PRINT_AUTOTUNING=1 with a diverse autotuner config, then removed the configs that were not selected, for fast compile times.

Benchmarks

input_shape          power_of_2_scales      torch_time_us    triton_atomic_time_us    triton_reduction_time_us    torch_mem_bw_gbps    triton_atomic_mem_bw_gbps    triton_reduction_mem_bw_gbps  triton_atomic_speedup    triton_reduction_speedup
-------------------  -------------------  ---------------  -----------------------  --------------------------  -------------------  ---------------------------  ------------------------------  -----------------------  --------------------------
(1, (8192, 5120))    True                         113.152                  118.784                     454.752              1853.39                      1765.52                         461.164  0.95x                    0.25x
(1, (5120, 8192))    True                         149.248                  118.176                     293.552              1405.15                      1774.6                          714.406  1.26x                    0.51x
(16, (8192, 5120))   True                        2160.62                  1654.64                     1914.64               1553                         2027.9                         1752.52   1.31x                    1.13x
(16, (5120, 8192))   True                        2073.41                  1647.2                      1825.01               1618.32                      2037.06                        1838.59   1.26x                    1.14x
(128, (8192, 5120))  True                       21529.4                  13617.8                     14840.8                1246.83                      1971.22                        1808.77   1.58x                    1.45x
(128, (5120, 8192))  True                       17985.5                  13570.6                     15235.6                1492.51                      1978.07                        1761.89   1.33x                    1.18x

Copy link

pytorch-bot bot commented Aug 23, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2865

Note: Links to docs will display an error until the docs builds have been completed.

✅ You can merge normally! (3 Unrelated Failures)

As of commit 3beaedf with merge base 253d65a (image):

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Aug 23, 2025
…ights in backward

stack-info: PR: #2865, branch: danielvegamyhre/stack/59
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/58 branch from 3848c56 to cf93326 Compare August 23, 2025 23:42
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/59 branch from 2cdfdfe to 1418745 Compare August 23, 2025 23:42
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 23, 2025
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/58 to main August 24, 2025 00:13
danielvegamyhre added a commit that referenced this pull request Aug 24, 2025
…ights in backward

stack-info: PR: #2865, branch: danielvegamyhre/stack/59
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/59 branch from 1418745 to 15fd988 Compare August 24, 2025 00:13
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/58 August 24, 2025 00:14
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 24, 2025
@danielvegamyhre danielvegamyhre changed the title [moe fp8 training] fused reduction kernel along dim1 for 3d expert weights in backward [moe fp8 training] single fused reduction based kernel for dim1 cast of 3d expert weights in backward Aug 24, 2025
…ights in backward

stack-info: PR: #2865, branch: danielvegamyhre/stack/59
@danielvegamyhre danielvegamyhre changed the base branch from danielvegamyhre/stack/58 to main August 24, 2025 01:08
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/59 branch from 15fd988 to 3beaedf Compare August 24, 2025 01:08
@danielvegamyhre danielvegamyhre changed the title [moe fp8 training] single fused reduction based kernel for dim1 cast of 3d expert weights in backward [moe fp8 training] fused reduction kernel along dim1 for 3d expert weights in backward Aug 24, 2025
@danielvegamyhre danielvegamyhre changed the base branch from main to danielvegamyhre/stack/58 August 24, 2025 01:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant