issue/388 [BugFix](basic_llm_processor): prevent duplicate BOS token in Llama-3/3.1 chat#389
Open
rubik-hua wants to merge 1 commit into
Open
issue/388 [BugFix](basic_llm_processor): prevent duplicate BOS token in Llama-3/3.1 chat#389rubik-hua wants to merge 1 commit into
rubik-hua wants to merge 1 commit into
Conversation
…chat When using `model.chat()` with Llama-3/3.1 models, the framework inadvertently prepends two `<|begin_of_text|>` (BOS, token ID 128000) tokens to the prompt_token_ids. This shifts the RoPE positional encodings by 1, causing the greedy decoding output to diverge significantly from HuggingFace. Root cause: The Llama-3/3.1 chat template explicitly includes `<|begin_of_text|>` at the start of the rendered string. Later, `BasicLLMProcessor.__call__` passes this string to `self.tokenizer(prompt)`, which defaults to `add_special_tokens=True`. Since `LlamaTokenizerFast` initializes with `add_bos_token=True` by default, the tokenizer automatically prepends a second BOS token via its Rust backend PostProcessor. Fix: Explicitly pass `add_special_tokens=False` to the tokenizer calls in `BasicLLMProcessor.__call__`. Since the chat template is already responsible for adding necessary special tokens, the tokenizer should only perform pure text-to-ID mapping.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
When using
model.chat()with Llama-3/3.1 models, the framework inadvertently prepends two<|begin_of_text|>(BOS, token ID 128000) tokens to the prompt_token_ids. This shifts the RoPE positional encodings by 1, causing the greedy decoding output to diverge significantly from HuggingFace.Root cause:
The Llama-3/3.1 chat template explicitly includes
<|begin_of_text|>at the start of the rendered string. Later,BasicLLMProcessor.__call__passes this string toself.tokenizer(prompt), which defaults toadd_special_tokens=True. SinceLlamaTokenizerFastinitializes withadd_bos_token=Trueby default, the tokenizer automatically prepends a second BOS token via its Rust backend PostProcessor.Fix:
Explicitly pass
add_special_tokens=Falseto the tokenizer calls inBasicLLMProcessor.__call__. Since the chat template is already responsible for adding necessary special tokens, the tokenizer should only perform pure text-to-ID mapping.修复前,input_ids有两个BOS符,实际推理时看起来也没有什么影响。当想使用贪心解码来验证RotaryEmbedding的时候就会有问题,此时一般会去跟HF原生的输出对比输出token以及logprobs,双重BOS导致输出天然不一致,于是就想把这个隐藏的bug修复掉。

修复后,double bos消失

重新跑一遍已有的模型推理就能验证对原有推理逻辑没影响。