Skip to content

Add tokenizer_kwargs argument to the text generation pipeline #40364

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from

Conversation

Joshua-Chin
Copy link

What does this PR do?

This PR adds a tokenizer_kwargs argument to the TextGenerationPipeline, allowing users to pass arbitrary arguments to the tokenizer during preprocessing. In particular, this lets users set chat template arguments, such as the enable_thinking flag for Qwen3 or SmolLM3.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline, Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

@Joshua-Chin
Copy link
Author

The test failure seems to be an unrelated flake:

self = <test_accelerate_examples.ExamplesTestsNoTrainer testMethod=test_run_swag_no_trainer>

    @mock.patch.dict(os.environ, {"WANDB_MODE": "offline", "DVCLIVE_TEST": "true"})
    def test_run_swag_no_trainer(self):
        tmp_dir = self.get_auto_remove_tmp_dir()
        testargs = f"""
            {self.examples_dir}/pytorch/multiple-choice/run_swag_no_trainer.py
            --model_name_or_path google-bert/bert-base-uncased
            --train_file tests/fixtures/tests_samples/swag/sample.json
            --validation_file tests/fixtures/tests_samples/swag/sample.json
            --output_dir {tmp_dir}
            --max_train_steps=20
            --num_warmup_steps=2
            --learning_rate=2e-4
            --per_device_train_batch_size=2
            --per_device_eval_batch_size=1
            --with_tracking
        """.split()
    
        run_command(self._launch_args + testargs)
        result = get_results(tmp_dir)
>       self.assertGreaterEqual(result["eval_accuracy"], 0.8)
E       AssertionError: 0.4 not greater than or equal to 0.8

examples/pytorch/test_accelerate_examples.py:225: AssertionError

Pushing an empty commit to re-run the CI.

@Joshua-Chin
Copy link
Author

A disjoint set of tests have failed in the re-run.

@Joshua-Chin
Copy link
Author

@Rocketknight1 Please review this PR when you have a chance. The CI failures seem to be caused by unrelated, flaky tests.

@Joshua-Chin Joshua-Chin force-pushed the text-generation-pipeline-tokenizer-kwargs branch from 92b4d49 to 2fe0979 Compare August 22, 2025 01:03
Copy link
Member

@Rocketknight1 Rocketknight1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This LGTM! cc @gante just in case you have opinions about the max_length generate kwarg clash.

@Rocketknight1
Copy link
Member

Also @Joshua-Chin you may need to rebase to fix some conflicts before we can merge the PR! That should also clear up the CI issues.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One question about variable names, otherwise lgtm :)

@@ -285,6 +289,9 @@ def __call__(self, text_inputs, **kwargs):
- `None` : default strategy where nothing in particular happens
- `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
tokenizer_kwargs (`dict`, *optional*):
Copy link
Member

@gante gante Aug 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps tokenizer_encode_kwargs? There are also kwargs used at decode time, and we don't want to mix the two

cc @Rocketknight1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gante I updated the argument to tokenizer_encode_kwargs. Please take another look when you have a chance.

@Joshua-Chin Joshua-Chin force-pushed the text-generation-pipeline-tokenizer-kwargs branch from 2fe0979 to e484dbb Compare August 22, 2025 17:26
@Joshua-Chin
Copy link
Author

The CI is currently failing because of the following test, added by a recently merged change (HunYuan opensource #39606):

FAILED tests/models/hunyuan_v1_moe/test_modeling_hunyuan_v1_moe.py::HunYuanMoEV1ModelTest::test_generate_compile_model_forward_fullgraph - torch._dynamo.exc.Unsupported: Dynamic shape operator

@Joshua-Chin Joshua-Chin force-pushed the text-generation-pipeline-tokenizer-kwargs branch from d80a814 to f1d1dc1 Compare August 22, 2025 20:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants