-
Notifications
You must be signed in to change notification settings - Fork 45
Open
Labels
bugSomething isn't workingSomething isn't working
Description
Bug first encountered
While working on examples for outlines
after having upgraded the version of outlines-core
used to 0.2.11, I run into an error when calling the guide.advance
method.
Code used when first encountering the bug (current main
branch of outlines
or v1.2.0):
from transformers import AutoModelForCausalLM, AutoTokenizer
import outlines
model_name = "erwanf/gpt2-mini"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = outlines.from_transformers(
model,
tokenizer,
)
response = model("Hello there", list[str], max_new_tokens=100)
print(response)
Stacktrace:
File "/Users/robin/outlines/.idea/b.py", line 26, in <module>
gen()
File "/Users/robin/outlines/.idea/b.py", line 15, in gen
response = model("Hello there", list[str], max_new_tokens=100)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/outlines/models/base.py", line 122, in __call__
return Generator(self, output_type, backend)(model_input, **inference_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/outlines/generator.py", line 297, in __call__
return self.model.generate(
^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/outlines/models/transformers.py", line 307, in generate
generated_ids = self._generate_output_seq(
^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/outlines/models/transformers.py", line 356, in _generate_output_seq
output_ids = self.model.generate(
^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/.venv/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/.venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 2465, in generate
result = self._sample(
^^^^^^^^^^^^^
File "/Users/robin/outlines/.venv/lib/python3.11/site-packages/transformers/generation/utils.py", line 3450, in _sample
next_token_scores = logits_processor(input_ids, next_token_logits)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/.venv/lib/python3.11/site-packages/transformers/generation/logits_process.py", line 88, in __call__
scores = processor(input_ids, scores)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/outlines/processors/base_logits_processor.py", line 123, in __call__
processed_logits = self.process_logits(input_ids, logits)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/robin/outlines/outlines/backends/outlines_core.py", line 168, in process_logits
self._guides[i].advance(
ValueError: No next state found for the current state: 448 with token ID: 0
The input ids up to the crash were: [15496, 612, 58, 8973, 198, 198, 1, 11, 0]
Bug reproduction without using outlines
To try to understand the origin of the bug and make sure it's not caused by the implementation of outlines-core
in outlines
, here's another way of reproducing with some extra elements to understand the bug:
from typing import Dict
from transformers import AutoTokenizer
from outlines_core import Index, Vocabulary
tokenizer = AutoTokenizer.from_pretrained("erwanf/gpt2-mini")
vocabulary = tokenizer.get_vocab()
eos_token_id = tokenizer.eos_token_id
eos_token = tokenizer.eos_token
def create_outlines_core_vocabulary(
vocab: Dict[str, int], eos_token_id: int, eos_token: str
) -> Vocabulary:
formatted_vocab = {}
for token, token_id in vocab.items():
formatted_vocab[token] = [token_id]
formatted_vocab.pop(eos_token)
return Vocabulary(eos_token_id, formatted_vocab)
vocabulary = create_outlines_core_vocabulary(vocabulary, eos_token_id, eos_token)
# this is the regex corresponding to the output type above
index = Index(r'\[("[^"]*")(,\ ("[^"]*"))*\]', vocabulary)
print(index.get_initial_state()) # 416
print(index.get_final_states()) # {480}
transitions = index.get_transitions()
print(transitions[1792]) # {11: 448, 60: 480}
# 448 is the target state when generating token 11 at current state 1792 for instance
print(transitions[448])
# Traceback (most recent call last):
# File "/home/robinpicard/outlines/.idea/error_outlines_core.py", line 48, in <module>
# print(transitions[448])
# ~~~~~~~~~~~^^^^^
# KeyError: 448
# There should be an entry for 448 as it's a target state as shown above
# token 1 = "!"
# token 11 = '"'
# token 60 = ']'
Or could the issue be the creation of the Vocabulary?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working