Skip to content

issue/388 [BugFix](basic_llm_processor): prevent duplicate BOS token in Llama-3/3.1 chat#389

Open
rubik-hua wants to merge 1 commit into
InfiniTensor:mainfrom
rubik-hua:double_bos
Open

issue/388 [BugFix](basic_llm_processor): prevent duplicate BOS token in Llama-3/3.1 chat#389
rubik-hua wants to merge 1 commit into
InfiniTensor:mainfrom
rubik-hua:double_bos

Conversation

@rubik-hua
Copy link
Copy Markdown

@rubik-hua rubik-hua commented May 20, 2026

When using model.chat() with Llama-3/3.1 models, the framework inadvertently prepends two <|begin_of_text|> (BOS, token ID 128000) tokens to the prompt_token_ids. This shifts the RoPE positional encodings by 1, causing the greedy decoding output to diverge significantly from HuggingFace.

Root cause:
The Llama-3/3.1 chat template explicitly includes <|begin_of_text|> at the start of the rendered string. Later, BasicLLMProcessor.__call__ passes this string to self.tokenizer(prompt), which defaults to add_special_tokens=True. Since LlamaTokenizerFast initializes with add_bos_token=True by default, the tokenizer automatically prepends a second BOS token via its Rust backend PostProcessor.

Fix:
Explicitly pass add_special_tokens=False to the tokenizer calls in BasicLLMProcessor.__call__. Since the chat template is already responsible for adding necessary special tokens, the tokenizer should only perform pure text-to-ID mapping.

修复前,input_ids有两个BOS符,实际推理时看起来也没有什么影响。当想使用贪心解码来验证RotaryEmbedding的时候就会有问题,此时一般会去跟HF原生的输出对比输出token以及logprobs,双重BOS导致输出天然不一致,于是就想把这个隐藏的bug修复掉。
image

修复后,double bos消失
image

重新跑一遍已有的模型推理就能验证对原有推理逻辑没影响。

image image image image image image image image

@rubik-hua rubik-hua requested a review from a team May 20, 2026 09:56
…chat

When using `model.chat()` with Llama-3/3.1 models, the framework
inadvertently prepends two `<|begin_of_text|>` (BOS, token ID 128000)
tokens to the prompt_token_ids. This shifts the RoPE positional
encodings by 1, causing the greedy decoding output to diverge
significantly from HuggingFace.

Root cause:
The Llama-3/3.1 chat template explicitly includes `<|begin_of_text|>`
at the start of the rendered string. Later, `BasicLLMProcessor.__call__`
passes this string to `self.tokenizer(prompt)`, which defaults to
`add_special_tokens=True`. Since `LlamaTokenizerFast` initializes with
`add_bos_token=True` by default, the tokenizer automatically prepends
a second BOS token via its Rust backend PostProcessor.

Fix:
Explicitly pass `add_special_tokens=False` to the tokenizer calls in
`BasicLLMProcessor.__call__`. Since the chat template is already
responsible for adding necessary special tokens, the tokenizer should
only perform pure text-to-ID mapping.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant