We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent e0e6446 commit 1c20cfbCopy full SHA for 1c20cfb
torchrec/sparse/jagged_tensor.py
@@ -1104,7 +1104,7 @@ def _maybe_compute_stride_kjt(
1104
stride_per_key_per_rank is not None and stride_per_key_per_rank.numel() > 0
1105
):
1106
# For VBE KJT, batch size should be based on inverse_indices when set.
1107
- if inverse_indices is not None:
+ if inverse_indices is not None and inverse_indices[1].numel() > 0:
1108
return inverse_indices[1].shape[-1]
1109
1110
s = stride_per_key_per_rank.sum(dim=1).max().item()
0 commit comments