-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Add VLM support to RLOO trainer #4067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
- Add prepare_multimodal_messages import from data_utils - Part of implementing VLM support in RLOO trainer
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
- Apply ruff formatting to resolve code quality issues - No functional changes, only formatting fixes
0d57627
to
c5270ab
Compare
…tributes - Remove vllm_importance_sampling_correction and vllm_importance_sampling_cap assignments - Remove use_liger_loss, loss_type, and scale_rewards assignments - Remove associated importance sampling calculation and metrics logging - All removed attributes don't exist in RLOOConfig, only in GRPOConfig - Fixes 29 failing RLOO trainer tests with AttributeError
…ion issues - Add missing normalize_advantages assignment from RLOOConfig - Fix old_per_token_logps computation when beta != 0.0 to prevent some NoneType errors - This addresses some but not all of the 28 failing RLOO trainer tests - Additional investigation may be needed for remaining failures
…ner compatibility
…token logps The _compute_loss method expected inputs['old_logps'] but only inputs['old_per_token_logps'] was provided. Added computation to convert per-token to sequence-level log probabilities by summing over completion tokens with mask.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice, a few corrections
trl/trainer/rloo_trainer.py
Outdated
# If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of | ||
# a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the | ||
# samples may come from an earlier version of the model. In that case, we need to track old_per_token_logps | ||
# for importance sampling. If the steps are aligned, importance sampling isn't necessary and we set | ||
# old_per_token_logps to None. | ||
# When using vLLM, we always compute old_per_token_logps for importance sampling, it was shown that the | ||
# distribution mismatch between vLLM and the training model can be large and harm the training. | ||
generate_every = self.args.steps_per_generation * self.num_iterations # generation frequency | ||
if self.args.gradient_accumulation_steps % generate_every != 0 or self.beta != 0.0: | ||
old_per_token_logps, _ = self._get_per_token_logps_and_entropies( | ||
self.model, | ||
prompt_completion_ids, | ||
attention_mask, | ||
logits_to_keep, | ||
batch_size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be reverted I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is still there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's here now, but the issue is that (check the failed tests):
=========================== short test summary info ============================
FAILED tests/test_rloo_trainer.py::RLOOTrainerTester::test_training_beta_zero - KeyError: 'old_logps'
= 1 failed, 802 passed, 168 skipped, 2 xfailed, 104 warnings, 5 rerun in 424.98s (0:07:04) =
make: *** [Makefile:9: test] Error 1
Looking at the condition on line 1296
: if self.args.gradient_accumulation_steps % generate_every != 0 or self.beta != 0.0:
When beta = 0.0
, the condition becomes: if self.args.gradient_accumulation_steps % generate_every != 0 or False:
, which means old_per_token_logps
is computed only if gradient_accumulation_steps % generate_every != 0
.
However, the compute_loss
method always expects old_logps
to be present. I think we need to ensure old_logps
is always computed, even when old_per_token_logps
is None
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
correct, that's why you should revert this change. In the previous version old_per_token_logps
is always computed
trl/trainer/rloo_trainer.py
Outdated
if old_per_token_logps is not None: | ||
output["old_per_token_logps"] = old_per_token_logps | ||
output["old_logps"] = (old_per_token_logps * completion_mask).sum(1) | ||
if ref_per_token_logps is not None: | ||
output["ref_per_token_logps"] = ref_per_token_logps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if old_per_token_logps is not None: | |
output["old_per_token_logps"] = old_per_token_logps | |
output["old_logps"] = (old_per_token_logps * completion_mask).sum(1) | |
if ref_per_token_logps is not None: | |
output["ref_per_token_logps"] = ref_per_token_logps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and add "old_logps": old_logps,
back
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and add
"old_logps": old_logps,
back
did not catch this part :)
Compute sequence-level old_logps from old_per_token_logps to resolve test failures. This ensures the compute_loss method receives the expected old_logps key in the inputs dictionary.
Apply ruff formatter fixes for proper indentation and spacing
Add computation of sequence-level old_logps from old_per_token_logps in the step output to resolve KeyError in compute_loss method. This ensures all required data is available for loss computation.
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Ensure old_logps is available in step method output by computing it from old_per_token_logps. This fixes the KeyError in compute_loss method while preserving the importance sampling logic.
Add test coverage for Vision Language Models including: - Multi-model VLM training tests (Gemma3, LlavaNext, Qwen2.5-VL, Qwen2-VL) - VLM training with non-zero beta (reference model usage) - VLM training with PEFT (LoRA) support - VLM training with importance sampling enabled - Import AutoModelForImageTextToText for VLM support - Add require_vision decorator for vision-dependent tests These tests ensure VLM functionality works correctly with RLOO trainer.
Apply ruff formatting fixes: - Format transformers imports across multiple lines for readability - Sort imports alphabetically (require_vision, require_vllm) - Remove unnecessary blank lines - Ensure consistent code style compliance
- Fix code formatting in test_rloo_vlm_functionality.py - Ensure consistent style with project standards - All 10 VLM tests pass with proper formatting - Code quality checks pass without warnings
Add VLM (Vision Language Models) support to RLOO trainer by importing prepare_multimodal_messages function from data_utils. This enables multimodal data processing as the first step toward full VLM integration.