Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 86 additions & 25 deletions src/transformers/models/granitemoe/modeling_granitemoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...utils import auto_docstring, is_torch_flex_attn_available, logging
from ...utils.deprecation import deprecate_kwarg
from .scattermoe import scattered_experts
from .configuration_granitemoe import GraniteMoeConfig


Expand Down Expand Up @@ -317,6 +318,17 @@ def forward(self, hidden_states):
return index_sorted_experts, batch_index, batch_gates, expert_size, logits


# TODO add support for combileable bincount in PyTorch directly
@torch.library.custom_op("transformers::bincount", mutates_args={})
def bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength).to(torch.uint32)


@bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, device=x.device, dtype=torch.uint32)


class GraniteMoeMoE(nn.Module):
"""
A Sparsely gated mixture of experts layer with 1-layer Feed-Forward networks as experts.
Expand All @@ -341,36 +353,85 @@ def __init__(self, config: GraniteMoeConfig):
top_k=config.num_experts_per_tok,
)

def forward(self, layer_input):
"""
Forward pass of the mixture of experts layer.

Args:
layer_input (Tensor):
Input tensor.

Returns:
Tensor:
Output tensor.
Tensor:
Router logits.
"""
bsz, length, emb_size = layer_input.size()
layer_input = layer_input.reshape(-1, emb_size)
_, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)
self.num_experts = config.num_local_experts
self.top_k = config.num_experts_per_tok

# def forward(self, layer_input):
# """
# Forward pass of the mixture of experts layer.

# Args:
# layer_input (Tensor):
# Input tensor.

# Returns:
# Tensor:
# Output tensor.
# Tensor:
# Router logits.
# """
# bsz, length, emb_size = layer_input.size()
# layer_input = layer_input.reshape(-1, emb_size)
# _, batch_index, batch_gates, expert_size, router_logits = self.router(layer_input)

# expert_inputs = layer_input[batch_index]
# hidden_states = self.input_linear(expert_inputs, expert_size)
# chunked_hidden_states = hidden_states.chunk(2, dim=-1)
# hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
# expert_outputs = self.output_linear(hidden_states, expert_size)

# expert_outputs = expert_outputs * batch_gates[:, None]

# zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
# layer_output = zeros.index_add(0, batch_index, expert_outputs)
# layer_output = layer_output.view(bsz, length, self.input_size)
# return layer_output, router_logits

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
original_shape = hidden_states.shape
hidden_states = hidden_states.view(-1, self.input_size)

router_logits = self.router.layer(hidden_states)
router_weights, selected_experts = router_logits.topk(self.top_k, dim=-1)

router_weights = F.softmax(router_weights.float(), dim=-1)
router_weights = router_weights.type_as(hidden_states)

with torch.no_grad():
sorted_expert_idxs, sorted_scattered_idxs = selected_experts.flatten().sort()
expert_frequency = bincount(x=sorted_expert_idxs, minlength=self.num_experts)
expert_offsets = expert_frequency.cumsum(-1)

hidden_states = scattered_experts(
inputs=hidden_states,
expert_weights=self.input_linear.weight.permute(0, 2, 1),
k=self.top_k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
expert_offsets=expert_offsets,
gates=None,
grouped_in=False,
grouped_out=True,
)

expert_inputs = layer_input[batch_index]
hidden_states = self.input_linear(expert_inputs, expert_size)
chunked_hidden_states = hidden_states.chunk(2, dim=-1)
hidden_states = self.activation(chunked_hidden_states[0]) * chunked_hidden_states[1]
expert_outputs = self.output_linear(hidden_states, expert_size)

expert_outputs = expert_outputs * batch_gates[:, None]
hidden_states = scattered_experts(
inputs=hidden_states,
expert_weights=self.output_linear.weight.permute(0, 2, 1),
k=1,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
expert_offsets=expert_offsets,
gates=router_weights,
grouped_in=True,
grouped_out=False,
)

hidden_states = hidden_states.view(original_shape)

zeros = torch.zeros((bsz * length, self.input_size), dtype=expert_outputs.dtype, device=expert_outputs.device)
layer_output = zeros.index_add(0, batch_index, expert_outputs)
layer_output = layer_output.view(bsz, length, self.input_size)
return layer_output, router_logits
return hidden_states, router_logits


# Copied from transformers.models.granite.modeling_granite.repeat_kv with Granite->GraniteMoe
Expand Down
185 changes: 185 additions & 0 deletions src/transformers/models/granitemoe/scattermoe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# **************************************************
# Copyright (c) 2025, Mayank Mishra
# **************************************************

import torch

from .group_backward_kernel import group_bwd_W
from .group_kernel import group
from .scatter_kernel import scatter2scatter


class _ScatteredExperts(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates=None,
grouped_in=False,
grouped_out=False,
):
output = torch.empty(sorted_expert_idxs.size(0), expert_weights.size(-1), device=x.device, dtype=x.dtype)

scatter2scatter(
X=x,
W=expert_weights,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
out=output,
FAN_OUT=k,
x_grouped=grouped_in,
y_grouped=grouped_out,
)

if gates is None:
output_expanded = None
else:
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
output = torch.bmm(gates.unsqueeze(1), output_expanded).squeeze(1)

ctx.save_for_backward(
x,
expert_weights,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
)

ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k

return output

@staticmethod
def backward(ctx, grad_out):
(
x,
expert_weights,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded,
) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out

if gates is None:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
else:
# calculate gates gradient
d_gates = torch.bmm(output_expanded, grad_out.unsqueeze(2)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
# print("expanded and grouping")
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later

if grouped_out:
grouped_grad_out = grad_out
else:
if grouped_grad_out is None:
if gate_fan == 1:
grouped_grad_out = torch.empty_like(grad_out)
else:
raise RuntimeError("Need to infer size")
group(
A=grad_out,
sorted_expert_idxs=sorted_scattered_idxs,
out=grouped_grad_out,
coeff=gates_flat,
fan_out=gate_fan,
)

if grouped_in:
grouped_x = x
d_expanded_input = torch.empty(
sorted_expert_idxs.size(0), expert_weights.size(1), device=x.device, dtype=x.dtype
)
else:
grouped_x = torch.empty(sorted_scattered_idxs.size(0), x.size(1), dtype=x.dtype, device=x.device)
group(
A=x,
sorted_expert_idxs=sorted_scattered_idxs,
out=grouped_x,
fan_out=k,
)

d_expanded_input = grouped_x

d_weights = torch.zeros_like(expert_weights)

group_bwd_W(
DY=grouped_grad_out,
X=grouped_x,
expert_offsets=expert_offsets,
DW=d_weights,
E=expert_weights.size(0),
)

scatter2scatter(
X=grouped_grad_out,
W=expert_weights.permute(0, 2, 1),
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
out=d_expanded_input,
FAN_OUT=1,
x_grouped=True,
y_grouped=grouped_in,
)

if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)

return (
# x, expert_weights, k,
d_input,
d_weights,
None,
# sorted_expert_idxs, sorted_scattered_idxs,
None,
None,
# expert_offsets,
None,
# gates
d_gates,
None,
None,
)


def scattered_experts(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates=None,
grouped_in=False,
grouped_out=False,
):
return _ScatteredExperts.apply(
inputs,
expert_weights,
k,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
grouped_in,
grouped_out,
)
Loading