Skip to content

Commit e484dbb

Browse files
committed
Rename tokenizer_kwargs to tokenizer_encode_kwargs for text generation pipeline
1 parent c4f6407 commit e484dbb

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/transformers/pipelines/text_generation.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ def _sanitize_parameters(
158158
max_length=None,
159159
continue_final_message=None,
160160
skip_special_tokens=None,
161-
tokenizer_kwargs=None,
161+
tokenizer_encode_kwargs=None,
162162
**generate_kwargs,
163163
):
164164
# preprocess kwargs
@@ -196,8 +196,8 @@ def _sanitize_parameters(
196196
if continue_final_message is not None:
197197
preprocess_params["continue_final_message"] = continue_final_message
198198

199-
if tokenizer_kwargs is not None:
200-
preprocess_params["tokenizer_kwargs"] = tokenizer_kwargs
199+
if tokenizer_encode_kwargs is not None:
200+
preprocess_params["tokenizer_encode_kwargs"] = tokenizer_encode_kwargs
201201

202202
preprocess_params.update(generate_kwargs)
203203

@@ -293,9 +293,9 @@ def __call__(self, text_inputs, **kwargs):
293293
- `None` : default strategy where nothing in particular happens
294294
- `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
295295
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
296-
tokenizer_kwargs (`dict`, *optional*):
297-
Additional keyword arguments to pass along to the tokenizer. If the text input is a chat, it is passed
298-
to `apply_chat_template`. Otherwise, it is passed to `__call__`.
296+
tokenizer_encode_kwargs (`dict`, *optional*):
297+
Additional keyword arguments to pass along to encoding step of the tokenizer. If the text input is a
298+
chat, it is passed to `apply_chat_template`. Otherwise, it is passed to `__call__`.
299299
generate_kwargs (`dict`, *optional*):
300300
Additional keyword arguments to pass along to the generate method of the model (see the generate method
301301
corresponding to your framework [here](./text_generation)).
@@ -341,18 +341,18 @@ def preprocess(
341341
padding=None,
342342
max_length=None,
343343
continue_final_message=None,
344-
tokenizer_kwargs=None,
344+
tokenizer_encode_kwargs=None,
345345
**generate_kwargs,
346346
):
347347
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
348-
base_tokenizer_kwargs = {
348+
tokenizer_kwargs = {
349349
"add_special_tokens": add_special_tokens,
350350
"truncation": truncation,
351351
"padding": padding,
352352
"max_length": max_length, # NOTE: `max_length` is also a `generate` arg. Use `tokenizer_kwargs` to avoid a name clash
353353
}
354-
base_tokenizer_kwargs = {key: value for key, value in base_tokenizer_kwargs.items() if value is not None}
355-
tokenizer_kwargs = {**base_tokenizer_kwargs, **(tokenizer_kwargs or {})}
354+
tokenizer_kwargs = {key: value for key, value in tokenizer_kwargs.items() if value is not None}
355+
tokenizer_kwargs.update(tokenizer_encode_kwargs or {})
356356

357357
if isinstance(prompt_text, Chat):
358358
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats

tests/pipelines/test_pipelines_text_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ def test_forward_tokenizer_kwargs(self):
568568
tokenizer = text_generator.tokenizer
569569

570570
with patch.object(tokenizer, "apply_chat_template", wraps=tokenizer.apply_chat_template) as mock:
571-
text_generator(chat, tokenizer_kwargs={"enable_thinking": True})
571+
text_generator(chat, tokenizer_encode_kwargs={"enable_thinking": True})
572572
self.assertGreater(mock.call_count, 0)
573573
kw_call_args = mock.call_args[1]
574574
self.assertIn("enable_thinking", kw_call_args)

0 commit comments

Comments
 (0)