@@ -291,13 +291,13 @@ def forward(
291
291
ctx .out_dtype = out_dtype
292
292
ctx .emulated = emulated
293
293
294
- # A_mx shape: (M, K)
294
+ # A_fp8 shape: (M, K)
295
295
# A_scale shape: (M, K//block_size)
296
- A_scale , A_mx = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
296
+ A_scale , A_fp8 = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
297
297
298
- # B_mx shape: (E, N, K)
298
+ # B_fp8 shape: (E, N, K)
299
299
# B_scale shape: (E, N, K//block_size)
300
- B_scales , B_mx = to_mx (
300
+ B_scales , B_fp8 = to_mx (
301
301
B_t .transpose (- 2 , - 1 ),
302
302
elem_dtype = torch .float8_e4m3fn ,
303
303
block_size = block_size ,
@@ -311,9 +311,9 @@ def forward(
311
311
else fbgemm_mxfp8_grouped_mm_2d_3d
312
312
)
313
313
out = mxfp8_2d_3d_grouped_mm (
314
- A_mx ,
314
+ A_fp8 ,
315
315
A_scale ,
316
- B_mx ,
316
+ B_fp8 ,
317
317
B_scales ,
318
318
offs = offs ,
319
319
block_size = block_size ,
@@ -328,15 +328,15 @@ def backward(ctx, grad_out: torch.Tensor):
328
328
out_dtype = ctx .out_dtype
329
329
emulated = ctx .emulated
330
330
331
- # grad_out_mx shape: (M, N)
331
+ # grad_out_fp8 shape: (M, N)
332
332
# grad_out_scale shape: (M, N//block_size)
333
- grad_out_scale , grad_out_mx = to_mx (
333
+ grad_out_scale , grad_out_fp8 = to_mx (
334
334
grad_out , elem_dtype = torch .float8_e4m3fn , block_size = block_size
335
335
)
336
336
337
- # B_mx shape: (E, K, N)
337
+ # B_fp8 shape: (E, K, N)
338
338
# B_scale shape: (E, K, N//block_size)
339
- B_scales , B_mx = to_mx (
339
+ B_scales , B_fp8 = to_mx (
340
340
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
341
341
B_t .contiguous (),
342
342
elem_dtype = torch .float8_e4m3fn ,
@@ -350,43 +350,43 @@ def backward(ctx, grad_out: torch.Tensor):
350
350
else fbgemm_mxfp8_grouped_mm_2d_3d
351
351
)
352
352
grad_A = mxfp8_2d_3d_grouped_mm (
353
- grad_out_mx ,
353
+ grad_out_fp8 ,
354
354
grad_out_scale ,
355
- B_mx ,
355
+ B_fp8 ,
356
356
B_scales ,
357
357
offs = offs ,
358
358
out_dtype = out_dtype ,
359
359
)
360
360
361
- # grad_out_t_mx shape: (N, M)
361
+ # grad_out_t_fp8 shape: (N, M)
362
362
# grad_out_t_scales shape: (N, M//block_size)
363
- grad_out_t_scales , grad_out_t_mx = to_mx (
363
+ grad_out_t_scales , grad_out_t_fp8 = to_mx (
364
364
# TODO: can we support non-contiguous input tensor in to_mx to eliminate this inefficiency?
365
365
grad_out .transpose (- 2 , - 1 ).contiguous (),
366
366
elem_dtype = torch .float8_e4m3fn ,
367
367
block_size = block_size ,
368
368
)
369
369
370
370
# Transpose A so we can scale along the M dimension, then un-transpose.
371
- # A_t_mx shape: (K, M)
371
+ # A_t_fp8 shape: (K, M)
372
372
# A_t_scales shape: (K, M//block_size)
373
- A_t_scales , A_t_mx = to_mx (
373
+ A_t_scales , A_t_fp8 = to_mx (
374
374
A .transpose (- 2 , - 1 ).contiguous (),
375
375
elem_dtype = torch .float8_e4m3fn ,
376
376
block_size = block_size ,
377
377
)
378
378
379
- # A_mx shape = (M, K)
380
- A_mx = A_t_mx .transpose (- 2 , - 1 )
379
+ # A_fp8 shape = (M, K)
380
+ A_fp8 = A_t_fp8 .transpose (- 2 , - 1 )
381
381
382
382
# A_scales shape = (M//block_size, K)
383
383
A_scales = A_t_scales .transpose (- 2 , - 1 )
384
384
385
385
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
386
386
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d (
387
- grad_out_t_mx ,
387
+ grad_out_t_fp8 ,
388
388
grad_out_t_scales ,
389
- A_mx ,
389
+ A_fp8 ,
390
390
A_scales ,
391
391
offs = offs ,
392
392
)
@@ -398,64 +398,64 @@ def backward(ctx, grad_out: torch.Tensor):
398
398
399
399
400
400
def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
401
- A_mx : torch .Tensor ,
401
+ A_fp8 : torch .Tensor ,
402
402
A_scale : torch .Tensor ,
403
- B_mx : torch .Tensor ,
403
+ B_fp8 : torch .Tensor ,
404
404
B_scale : torch .Tensor ,
405
405
offs : Optional [torch .Tensor ] = None ,
406
406
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
407
407
block_size : int = 32 ,
408
408
) -> torch .Tensor :
409
- assert A_mx .ndim == 2 , f"A must be 2D, got { A_mx .ndim } "
410
- assert B_mx .ndim == 3 , f"B must be 3D, got { B_mx .ndim } "
411
- assert A_scale .shape [0 ] == A_mx .shape [0 ], (
412
- f"A_scale must have same M dim as A_mx , got A={ A_mx .shape } and A_scale={ A_scale .shape } "
409
+ assert A_fp8 .ndim == 2 , f"A must be 2D, got { A_fp8 .ndim } "
410
+ assert B_fp8 .ndim == 3 , f"B must be 3D, got { B_fp8 .ndim } "
411
+ assert A_scale .shape [0 ] == A_fp8 .shape [0 ], (
412
+ f"A_scale must have same M dim as A_fp8 , got A={ A_fp8 .shape } and A_scale={ A_scale .shape } "
413
413
)
414
- assert A_scale .shape [1 ] == A_mx .shape [1 ] // block_size , (
415
- f"A_scale dim1 should be size K//block_size, got A={ A_mx .shape } and A_scale={ A_scale .shape } "
414
+ assert A_scale .shape [1 ] == A_fp8 .shape [1 ] // block_size , (
415
+ f"A_scale dim1 should be size K//block_size, got A={ A_fp8 .shape } and A_scale={ A_scale .shape } "
416
416
)
417
- assert B_scale .shape [0 ] == B_mx .shape [0 ], (
418
- f"B_scale must have same E dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
417
+ assert B_scale .shape [0 ] == B_fp8 .shape [0 ], (
418
+ f"B_scale must have same E dim as B_fp8 , got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
419
419
)
420
- assert B_scale .shape [1 ] == B_mx .shape [1 ], (
421
- f"B_scale must have same N dim as B_mx , got B={ B_mx .shape } and B_scale={ B_scale .shape } "
420
+ assert B_scale .shape [1 ] == B_fp8 .shape [1 ], (
421
+ f"B_scale must have same N dim as B_fp8 , got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
422
422
)
423
- assert B_scale .shape [2 ] == B_mx .shape [2 ] // block_size , (
424
- f"B_scale dim2 should be size K//block_size, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
423
+ assert B_scale .shape [2 ] == B_fp8 .shape [2 ] // block_size , (
424
+ f"B_scale dim2 should be size K//block_size, got B={ B_fp8 .shape } and B_scale={ B_scale .shape } "
425
425
)
426
426
427
427
# Dequantize input
428
- # A_mx shape: (M, K)
428
+ # A_fp8 shape: (M, K)
429
429
# A_scale shape: (M, K//block_size)
430
- A_orig_shape = A_mx .shape
430
+ A_orig_shape = A_fp8 .shape
431
431
432
432
# Reshape to be able to do per-scaling group multiplication
433
- # A_mx shape: (M, K//block_size, block_size)
433
+ # A_fp8 shape: (M, K//block_size, block_size)
434
434
# A_scale shape: (M, K//block_size, 1)
435
- A_mx = A_mx .reshape (* A_mx .shape [:- 1 ], A_mx .shape [- 1 ] // block_size , block_size )
435
+ A_fp8 = A_fp8 .reshape (* A_fp8 .shape [:- 1 ], A_fp8 .shape [- 1 ] // block_size , block_size )
436
436
A_scale = A_scale .unsqueeze (- 1 )
437
437
438
438
# Rescale and cast to bfloat16
439
- A = A_mx .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
439
+ A = A_fp8 .to (torch .bfloat16 ) * A_scale .to (torch .bfloat16 )
440
440
441
441
# Reshape back to original shape
442
442
# A shape: (M, K)
443
443
A = A .reshape (A_orig_shape )
444
444
445
445
# Dequantize weights
446
446
# Tranpose to get block_size on rightmost dim
447
- # B_mx shape: (E, N, K)
447
+ # B_fp8 shape: (E, N, K)
448
448
# B_scale shape: (E, N, K//block_size)
449
- E , N , K = B_mx .shape
449
+ E , N , K = B_fp8 .shape
450
450
451
451
# Reshape to be able to do per-scaling group multiplication
452
- # B_mx shape: (E, N, K//block_size, block_size)
452
+ # B_fp8 shape: (E, N, K//block_size, block_size)
453
453
# B_scale shape: (E, N, K//block_size, 1)
454
- B_mx = B_mx .reshape (* B_mx .shape [:- 1 ], B_mx .shape [- 1 ] // block_size , block_size )
454
+ B_fp8 = B_fp8 .reshape (* B_fp8 .shape [:- 1 ], B_fp8 .shape [- 1 ] // block_size , block_size )
455
455
B_scale = B_scale .unsqueeze (- 1 )
456
456
457
457
# Rescale and cast to bfloat16
458
- B = B_mx .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
458
+ B = B_fp8 .to (torch .bfloat16 ) * B_scale .to (torch .bfloat16 )
459
459
460
460
# Reshape back to original shape
461
461
# B shape: (E, K, N)
@@ -467,27 +467,27 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
467
467
468
468
469
469
def _emulated_mxfp8_scaled_grouped_mm_2d_2d (
470
- A_mx : torch .Tensor , # (M, K)
470
+ A_fp8 : torch .Tensor , # (M, K)
471
471
A_scale : torch .Tensor , # (M, K//block_size)
472
- B_mx : torch .Tensor , # (K, N)
472
+ B_fp8 : torch .Tensor , # (K, N)
473
473
B_scale : torch .Tensor , # (K//block_size, N)
474
474
offs : torch .Tensor ,
475
475
out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
476
476
block_size : int = 32 ,
477
477
) -> torch .Tensor :
478
- assert A_mx .ndim == 2 , "A must be 2D"
479
- assert B_mx .ndim == 2 , "B must be 2D"
478
+ assert A_fp8 .ndim == 2 , "A must be 2D"
479
+ assert B_fp8 .ndim == 2 , "B must be 2D"
480
480
A = torch .zeros (
481
- A_mx .shape ,
481
+ A_fp8 .shape ,
482
482
dtype = torch .bfloat16 ,
483
- device = A_mx .device ,
484
- requires_grad = A_mx .requires_grad ,
483
+ device = A_fp8 .device ,
484
+ requires_grad = A_fp8 .requires_grad ,
485
485
)
486
486
B = torch .zeros (
487
- B_mx .shape ,
487
+ B_fp8 .shape ,
488
488
dtype = torch .bfloat16 ,
489
- device = B_mx .device ,
490
- requires_grad = B_mx .requires_grad ,
489
+ device = B_fp8 .device ,
490
+ requires_grad = B_fp8 .requires_grad ,
491
491
)
492
492
493
493
# Dequantize input per each scaling group
@@ -503,7 +503,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
503
503
# -- Dequantize A tensor
504
504
# A_group shape: (M, group_size)
505
505
# A_scale shape: (M, group_size//block_size)
506
- A_group = A_mx [:, group_start_idx :group_end_idx ]
506
+ A_group = A_fp8 [:, group_start_idx :group_end_idx ]
507
507
A_group_shape = A_group .shape
508
508
509
509
# Get scales for this group.
@@ -528,7 +528,7 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_2d(
528
528
529
529
# -- Dequantize B tensor
530
530
# B_group shape is (group_size, N)
531
- B_group = B_mx [group_start_idx :group_end_idx , :]
531
+ B_group = B_fp8 [group_start_idx :group_end_idx , :]
532
532
B_group_shape = B_group .shape
533
533
534
534
# Scales shape is (group_size//block_size, N)
0 commit comments