-
Notifications
You must be signed in to change notification settings - Fork 299
Add padding_side
and pad_token_id
in OrtBackend
#705
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- The shared functionality for the input preparation has been detached from both the `OrtBackend::embed` and `OrtBackend::predict`, into separate functions: - `prepare_inputs` to prepare the inputs based on what the ONNX model expects i.e., input_ids, attention_mask, etc. - `prepare_ort_inputs` to go from those inputs to `ort::inputs!` - Since the input processing in both `OrtBackend::embed` and `OrtBackend::predict` was default to right-padding, and both the pooling and post-processing in `OrtBackend::embed` too, the `PaddingSide` is now handled to ensure the proper methods are used taking into consideration the `padding_side`
padding_side
handling in OrtBackend
padding_side
and pad_token_id
in OrtBackend
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks alright, but I think we can simplify further.
backends/ort/src/lib.rs
Outdated
Pool::Cls => match self.padding_side { | ||
PaddingSide::Left => { | ||
if masking { | ||
let mut cls_embeddings = Vec::new(); | ||
for (batch_idx, &seq_length) in | ||
model_inputs.input_lengths.iter().enumerate() | ||
{ | ||
let padding = max_length as f32 - seq_length; | ||
let cls_pos = padding as usize; | ||
cls_embeddings | ||
.push(outputs.slice(s![batch_idx, cls_pos, ..]).to_owned()); | ||
} | ||
ndarray::stack( | ||
Axis(0), | ||
&cls_embeddings.iter().map(|x| x.view()).collect::<Vec<_>>(), | ||
) | ||
.unwrap() | ||
.into_dyn() | ||
} else { | ||
outputs.slice(s![.., 0, ..]).into_owned().into_dyn() | ||
} | ||
} | ||
PaddingSide::Right => outputs.slice(s![.., 0, ..]).into_owned().into_dyn(), | ||
}, | ||
Pool::LastToken => match self.padding_side { | ||
// NOTE: when using left-padding, the last-token is always in the last position |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feels like there's a lot of switching on padding side. I haven't carefully looked at each line, but it seems to me that the code could be made significantly simpler by simply allocating things and using a different offset of insertion (overwriting) using the padding side.
Something like
let offset = if padding_side == Side::Left {0} else {max_length - length};
for (i, item) in elements.iter().enumerate(){
input_ids[i + offset] = item;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm yes it could be the case for some, the issue is that given the padding_side, we can apply the pooling in a most performant approach as in e.g. last-token pooling when left-padding it's literally the last token in the sequence, but when padding is right we need to iterate over each sequence to identify where it ends and then capture what's the last-token accordingly; this is why I kept one implementation per padding_side, but happy to unify those into a single match even if either right or left might be slightly less performance as in requiring more ops to obtain, which I'm also fine with, as that'd simplify the code a bit.
What does this PR do?
This PR adds the
padding_side
from thetokenizer_config.json
if applicable, otherwise it defaults topadding_side: "right"
to handle the scenarios where thepadding_side
is other than "right", as e.g. https://huggingface.co/onnx-community/Qwen3-Embedding-0.6B-ONNX to ensure parity with the inputs and the outputs. And also reads thepad_token_id
from theconfig.json
, instead of setting it to 0 by default, which means thepad_token_id
is read, if not there it falls back toeos_token_id
, and finally to 0 if none are defined.This PR also updates the input preparation and pooling strategies accordingly, so that those are applied one way or another based on the padding side, given that the pooling should be padding-agnostic, but with the padding side information we can efficiently apply the pooling strategies instead.
Additionally, this PR fixes the last-token pooling for the
OrtBackend
which was leading to issues (unrelated to thepadding_side
) as e.g. #704As some other, smaller but still relevant changes, this PR:
OrtBackend::prepare_inputs
to prepare thendarray
s for theinput_ids
,attention_mask
, etc. within a function to be reused for bothOrtBackend::embed
andOrtBackend::predict
to prevent from duplicating the codeOrtBackend::prepare_ort_inputs
to go fromndarray
s toort::inputs!
, and the reason is the same as per the function above.ModelInputs
to capture all the inputs within the same struct so that it can be easily managedConfig
to read fromconfig.json
, required for both thepad_token_id
and also for thepast_key_values
required configuration valuesBefore submitting
insta
snapshots?Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@Narsil