Skip to content

Commit 36e6bc7

Browse files
[mxfp8 moe training] refactor all var names with suffix _mx to _fp8 for clarity
1 parent 0dd17b5 commit 36e6bc7

File tree

1 file changed

+56
-56
lines changed

1 file changed

+56
-56
lines changed

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 56 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -291,13 +291,13 @@ def forward(
291291
ctx.out_dtype = out_dtype
292292
ctx.emulated = emulated
293293

294-
# A_mx shape: (M, K)
294+
# A_fp8 shape: (M, K)
295295
# A_scale shape: (M, K//block_size)
296-
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
296+
A_scale, A_fp8 = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
297297

298-
# B_mx shape: (E, N, K)
298+
# B_fp8 shape: (E, N, K)
299299
# B_scale shape: (E, N, K//block_size)
300-
B_scales, B_mx = to_mx(
300+
B_scales, B_fp8 = to_mx(
301301
B_t.transpose(-2, -1),
302302
elem_dtype=torch.float8_e4m3fn,
303303
block_size=block_size,
@@ -311,9 +311,9 @@ def forward(
311311
else fbgemm_mxfp8_grouped_mm_2d_3d
312312
)
313313
out = mxfp8_2d_3d_grouped_mm(
314-
A_mx,
314+
A_fp8,
315315
A_scale,
316-
B_mx,
316+
B_fp8,
317317
B_scales,
318318
offs=offs,
319319
block_size=block_size,
@@ -328,15 +328,15 @@ def backward(ctx, grad_out: torch.Tensor):
328328
out_dtype = ctx.out_dtype
329329
emulated = ctx.emulated
330330

331-
# grad_out_mx shape: (M, N)
331+
# grad_out_fp8 shape: (M, N)
332332
# grad_out_scale shape: (M, N//block_size)
333-
grad_out_scale, grad_out_mx = to_mx(
333+
grad_out_scale, grad_out_fp8 = to_mx(
334334
grad_out, elem_dtype=torch.float8_e4m3fn, block_size=block_size
335335
)
336336

337-
# B_mx shape: (E, K, N)
337+
# B_fp8 shape: (E, K, N)
338338
# B_scale shape: (E, K, N//block_size)
339-
B_scales, B_mx = to_mx(
339+
B_scales, B_fp8 = to_mx(
340340
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
341341
B_t.contiguous(),
342342
elem_dtype=torch.float8_e4m3fn,
@@ -350,43 +350,43 @@ def backward(ctx, grad_out: torch.Tensor):
350350
else fbgemm_mxfp8_grouped_mm_2d_3d
351351
)
352352
grad_A = mxfp8_2d_3d_grouped_mm(
353-
grad_out_mx,
353+
grad_out_fp8,
354354
grad_out_scale,
355-
B_mx,
355+
B_fp8,
356356
B_scales,
357357
offs=offs,
358358
out_dtype=out_dtype,
359359
)
360360

361-
# grad_out_t_mx shape: (N, M)
361+
# grad_out_t_fp8 shape: (N, M)
362362
# grad_out_t_scales shape: (N, M//block_size)
363-
grad_out_t_scales, grad_out_t_mx = to_mx(
363+
grad_out_t_scales, grad_out_t_fp8 = to_mx(
364364
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
365365
grad_out.transpose(-2, -1).contiguous(),
366366
elem_dtype=torch.float8_e4m3fn,
367367
block_size=block_size,
368368
)
369369

370370
# Transpose A so we can scale along the M dimension, then un-transpose.
371-
# A_t_mx shape: (K, M)
371+
# A_t_fp8 shape: (K, M)
372372
# A_t_scales shape: (K, M//block_size)
373-
A_t_scales, A_t_mx = to_mx(
373+
A_t_scales, A_t_fp8 = to_mx(
374374
A.transpose(-2, -1).contiguous(),
375375
elem_dtype=torch.float8_e4m3fn,
376376
block_size=block_size,
377377
)
378378

379-
# A_mx shape = (M, K)
380-
A_mx = A_t_mx.transpose(-2, -1)
379+
# A_fp8 shape = (M, K)
380+
A_fp8 = A_t_fp8.transpose(-2, -1)
381381

382382
# A_scales shape = (M//block_size, K)
383383
A_scales = A_t_scales.transpose(-2, -1)
384384

385385
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
386386
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
387-
grad_out_t_mx,
387+
grad_out_t_fp8,
388388
grad_out_t_scales,
389-
A_mx,
389+
A_fp8,
390390
A_scales,
391391
offs=offs,
392392
)
@@ -398,64 +398,64 @@ def backward(ctx, grad_out: torch.Tensor):
398398

399399

400400
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
401-
A_mx: torch.Tensor,
401+
A_fp8: torch.Tensor,
402402
A_scale: torch.Tensor,
403-
B_mx: torch.Tensor,
403+
B_fp8: torch.Tensor,
404404
B_scale: torch.Tensor,
405405
offs: Optional[torch.Tensor] = None,
406406
out_dtype: Optional[torch.dtype] = torch.bfloat16,
407407
block_size: int = 32,
408408
) -> torch.Tensor:
409-
assert A_mx.ndim == 2, f"A must be 2D, got {A_mx.ndim}"
410-
assert B_mx.ndim == 3, f"B must be 3D, got {B_mx.ndim}"
411-
assert A_scale.shape[0] == A_mx.shape[0], (
412-
f"A_scale must have same M dim as A_mx, got A={A_mx.shape} and A_scale={A_scale.shape}"
409+
assert A_fp8.ndim == 2, f"A must be 2D, got {A_fp8.ndim}"
410+
assert B_fp8.ndim == 3, f"B must be 3D, got {B_fp8.ndim}"
411+
assert A_scale.shape[0] == A_fp8.shape[0], (
412+
f"A_scale must have same M dim as A_fp8, got A={A_fp8.shape} and A_scale={A_scale.shape}"
413413
)
414-
assert A_scale.shape[1] == A_mx.shape[1] // block_size, (
415-
f"A_scale dim1 should be size K//block_size, got A={A_mx.shape} and A_scale={A_scale.shape}"
414+
assert A_scale.shape[1] == A_fp8.shape[1] // block_size, (
415+
f"A_scale dim1 should be size K//block_size, got A={A_fp8.shape} and A_scale={A_scale.shape}"
416416
)
417-
assert B_scale.shape[0] == B_mx.shape[0], (
418-
f"B_scale must have same E dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
417+
assert B_scale.shape[0] == B_fp8.shape[0], (
418+
f"B_scale must have same E dim as B_fp8, got B={B_fp8.shape} and B_scale={B_scale.shape}"
419419
)
420-
assert B_scale.shape[1] == B_mx.shape[1], (
421-
f"B_scale must have same N dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
420+
assert B_scale.shape[1] == B_fp8.shape[1], (
421+
f"B_scale must have same N dim as B_fp8, got B={B_fp8.shape} and B_scale={B_scale.shape}"
422422
)
423-
assert B_scale.shape[2] == B_mx.shape[2] // block_size, (
424-
f"B_scale dim2 should be size K//block_size, got B={B_mx.shape} and B_scale={B_scale.shape}"
423+
assert B_scale.shape[2] == B_fp8.shape[2] // block_size, (
424+
f"B_scale dim2 should be size K//block_size, got B={B_fp8.shape} and B_scale={B_scale.shape}"
425425
)
426426

427427
# Dequantize input
428-
# A_mx shape: (M, K)
428+
# A_fp8 shape: (M, K)
429429
# A_scale shape: (M, K//block_size)
430-
A_orig_shape = A_mx.shape
430+
A_orig_shape = A_fp8.shape
431431

432432
# Reshape to be able to do per-scaling group multiplication
433-
# A_mx shape: (M, K//block_size, block_size)
433+
# A_fp8 shape: (M, K//block_size, block_size)
434434
# A_scale shape: (M, K//block_size, 1)
435-
A_mx = A_mx.reshape(*A_mx.shape[:-1], A_mx.shape[-1] // block_size, block_size)
435+
A_fp8 = A_fp8.reshape(*A_fp8.shape[:-1], A_fp8.shape[-1] // block_size, block_size)
436436
A_scale = A_scale.unsqueeze(-1)
437437

438438
# Rescale and cast to bfloat16
439-
A = A_mx.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
439+
A = A_fp8.to(torch.bfloat16) * A_scale.to(torch.bfloat16)
440440

441441
# Reshape back to original shape
442442
# A shape: (M, K)
443443
A = A.reshape(A_orig_shape)
444444

445445
# Dequantize weights
446446
# Tranpose to get block_size on rightmost dim
447-
# B_mx shape: (E, N, K)
447+
# B_fp8 shape: (E, N, K)
448448
# B_scale shape: (E, N, K//block_size)
449-
E, N, K = B_mx.shape
449+
E, N, K = B_fp8.shape
450450

451451
# Reshape to be able to do per-scaling group multiplication
452-
# B_mx shape: (E, N, K//block_size, block_size)
452+
# B_fp8 shape: (E, N, K//block_size, block_size)
453453
# B_scale shape: (E, N, K//block_size, 1)
454-
B_mx = B_mx.reshape(*B_mx.shape[:-1], B_mx.shape[-1] // block_size, block_size)
454+
B_fp8 = B_fp8.reshape(*B_fp8.shape[:-1], B_fp8.shape[-1] // block_size, block_size)
455455
B_scale = B_scale.unsqueeze(-1)
456456

457457
# Rescale and cast to bfloat16
458-
B = B_mx.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
458+
B = B_fp8.to(torch.bfloat16) * B_scale.to(torch.bfloat16)
459459

460460
# Reshape back to original shape
461461
# B shape: (E, K, N)
@@ -467,27 +467,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
467467

468468

469469
def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
470-
A_mx: torch.Tensor, # (M, K)
470+
A_fp8: torch.Tensor, # (M, K)
471471
A_scale: torch.Tensor, # (M, K//block_size)
472-
B_mx: torch.Tensor, # (K, N)
472+
B_fp8: torch.Tensor, # (K, N)
473473
B_scale: torch.Tensor, # (K//block_size, N)
474474
offs: torch.Tensor,
475475
out_dtype: Optional[torch.dtype] = torch.bfloat16,
476476
block_size: int = 32,
477477
) -> torch.Tensor:
478-
assert A_mx.ndim == 2, "A must be 2D"
479-
assert B_mx.ndim == 2, "B must be 2D"
478+
assert A_fp8.ndim == 2, "A must be 2D"
479+
assert B_fp8.ndim == 2, "B must be 2D"
480480
A = torch.zeros(
481-
A_mx.shape,
481+
A_fp8.shape,
482482
dtype=torch.bfloat16,
483-
device=A_mx.device,
484-
requires_grad=A_mx.requires_grad,
483+
device=A_fp8.device,
484+
requires_grad=A_fp8.requires_grad,
485485
)
486486
B = torch.zeros(
487-
B_mx.shape,
487+
B_fp8.shape,
488488
dtype=torch.bfloat16,
489-
device=B_mx.device,
490-
requires_grad=B_mx.requires_grad,
489+
device=B_fp8.device,
490+
requires_grad=B_fp8.requires_grad,
491491
)
492492

493493
# Dequantize input per each scaling group
@@ -503,7 +503,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
503503
# -- Dequantize A tensor
504504
# A_group shape: (M, group_size)
505505
# A_scale shape: (M, group_size//block_size)
506-
A_group = A_mx[:, group_start_idx:group_end_idx]
506+
A_group = A_fp8[:, group_start_idx:group_end_idx]
507507
A_group_shape = A_group.shape
508508

509509
# Get scales for this group.
@@ -528,7 +528,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
528528

529529
# -- Dequantize B tensor
530530
# B_group shape is (group_size, N)
531-
B_group = B_mx[group_start_idx:group_end_idx, :]
531+
B_group = B_fp8[group_start_idx:group_end_idx, :]
532532
B_group_shape = B_group.shape
533533

534534
# Scales shape is (group_size//block_size, N)

0 commit comments

Comments
 (0)