From e0110e1d5d8fc907a7ea955b243e489c0c192780 Mon Sep 17 00:00:00 2001 From: rubik Date: Wed, 20 May 2026 17:12:09 +0800 Subject: [PATCH] fix(basic_llm_processor): prevent duplicate BOS token in Llama-3/3.1 chat When using `model.chat()` with Llama-3/3.1 models, the framework inadvertently prepends two `<|begin_of_text|>` (BOS, token ID 128000) tokens to the prompt_token_ids. This shifts the RoPE positional encodings by 1, causing the greedy decoding output to diverge significantly from HuggingFace. Root cause: The Llama-3/3.1 chat template explicitly includes `<|begin_of_text|>` at the start of the rendered string. Later, `BasicLLMProcessor.__call__` passes this string to `self.tokenizer(prompt)`, which defaults to `add_special_tokens=True`. Since `LlamaTokenizerFast` initializes with `add_bos_token=True` by default, the tokenizer automatically prepends a second BOS token via its Rust backend PostProcessor. Fix: Explicitly pass `add_special_tokens=False` to the tokenizer calls in `BasicLLMProcessor.__call__`. Since the chat template is already responsible for adding necessary special tokens, the tokenizer should only perform pure text-to-ID mapping. --- python/infinilm/processors/basic_llm_processor.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/python/infinilm/processors/basic_llm_processor.py b/python/infinilm/processors/basic_llm_processor.py index e341f3ba..6fafb55a 100644 --- a/python/infinilm/processors/basic_llm_processor.py +++ b/python/infinilm/processors/basic_llm_processor.py @@ -12,18 +12,25 @@ def __init__(self, model_dir_path: str): ) def __call__(self, prompt: str, return_tensors: str = None, **kwargs) -> dict: + # add_special_tokens=False Prevent duplicate BOS token for Llama-3/3.1 models. + # The `prompt` string here is already rendered by `apply_chat_template(tokenize=False)`, + # which explicitly includes the `<|begin_of_text|>` (BOS) token at the start. + # Since `LlamaTokenizerFast` defaults to `add_bos_token=True`, calling the tokenizer + # with the default `add_special_tokens=True` would prepend a second BOS token. + # This shifts the RoPE positional encodings by 1 and causes greedy decoding outputs + # to diverge significantly from HuggingFace. We must explicitly disable it. if return_tensors is None: - return self.tokenizer(prompt) + return self.tokenizer(prompt, add_special_tokens=False) elif return_tensors == "infini": import infinicore result = {} - for key, tensor in self.tokenizer(prompt, return_tensors="pt").items(): + for key, tensor in self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False).items(): result[key] = tensor.from_torch(tensor) return result # "pt" or "np" or "tf". - return self.tokenizer(prompt, return_tensors="pt") + return self.tokenizer(prompt, return_tensors="pt", add_special_tokens=False) def apply_chat_template( self,