From e54fe5942ddb6433e06853a454e8dfeefc35dd3a Mon Sep 17 00:00:00 2001 From: Ralf Waldukat Date: Mon, 13 Apr 2026 12:53:34 +0700 Subject: [PATCH] perf: vectorize hot-path ops, reduce Python overhead, fix SWA/ISWA KV cache corruption MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - set_batch(): numpy bulk writes replace per-token Python loop - _create_completion: incremental token_to_piece() accumulation replaces O(n²) re-detokenization per generated token - _create_completion: in-place logit_bias instead of full vocab copy - _create_completion: np.argpartition for top-k logprobs (O(V) vs O(V log V)) - reset(): call llama_memory_clear() for proper KV cache state reset - generate(): bypass prefix-match for recurrent/SWA models - generate(): fix tokens[:-1] off-by-one in prefix matching - eval(): remove unconditional kv_cache_seq_rm, simplify logits assignment - token_to_piece(): return correct byte length via actual write count --- llama_cpp/_internals.py | 23 ++++++--- llama_cpp/llama.py | 107 +++++++++++++++++----------------------- 2 files changed, 59 insertions(+), 71 deletions(-) diff --git a/llama_cpp/_internals.py b/llama_cpp/_internals.py index cde52c8c8..87d43eafc 100644 --- a/llama_cpp/_internals.py +++ b/llama_cpp/_internals.py @@ -182,9 +182,12 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool): return list(tokens[:n_tokens]) def token_to_piece(self, token: int, special: bool = False) -> bytes: - buf = ctypes.create_string_buffer(32) - llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special) - return bytes(buf) + size = 32 + buffer = (ctypes.c_char * size)() + n = llama_cpp.llama_token_to_piece( + self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special + ) + return bytes(buffer[:n]) def detokenize(self, tokens: List[int], special: bool = False) -> bytes: output = b"" @@ -503,13 +506,17 @@ def reset(self): def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool): n_tokens = len(batch) self.batch.n_tokens = n_tokens + token_arr = np.ctypeslib.as_array(self.batch.token, shape=(n_tokens,)) + token_arr[:] = batch + pos_arr = np.ctypeslib.as_array(self.batch.pos, shape=(n_tokens,)) + pos_arr[:] = np.arange(n_past, n_past + n_tokens, dtype=pos_arr.dtype) + n_seq_id_arr = np.ctypeslib.as_array(self.batch.n_seq_id, shape=(n_tokens,)) + n_seq_id_arr[:] = 1 + logits_arr = np.ctypeslib.as_array(self.batch.logits, shape=(n_tokens,)) + logits_arr[:] = logits_all + logits_arr[n_tokens - 1] = True for i in range(n_tokens): - self.batch.token[i] = batch[i] - self.batch.pos[i] = n_past + i self.batch.seq_id[i][0] = 0 - self.batch.n_seq_id[i] = 1 - self.batch.logits[i] = logits_all - self.batch.logits[n_tokens - 1] = True def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool): n_tokens = len(batch) diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 11fe169cf..a7b914821 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -553,6 +553,14 @@ def free_lora_adapter(): self._sampler = None + # Cache model architecture flags to avoid repeated FFI calls + self._is_recurrent_model = llama_cpp.llama_model_is_recurrent( + self._model.model + ) or llama_cpp.llama_model_is_hybrid(self._model.model) + self._has_swa_model = llama_cpp.llama_model_n_swa( + self._model.model + ) > 0 + @property def ctx(self) -> llama_cpp.llama_context_p: return self._ctx.ctx @@ -638,13 +646,12 @@ def reset(self): """Reset the model state.""" self.n_tokens = 0 - def eval(self, tokens: Sequence[int]): - """Evaluate a list of tokens. + mem = llama_cpp.llama_get_memory(self._ctx.ctx) + if mem is not None: + llama_cpp.llama_memory_clear(mem, True) - Args: - tokens: The list of tokens to evaluate. - """ - self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1) + def eval(self, tokens: Sequence[int]): + """Evaluate a list of tokens.""" for i in range(0, len(tokens), self.n_batch): batch = tokens[i : min(len(tokens), i + self.n_batch)] n_past = self.n_tokens @@ -653,26 +660,12 @@ def eval(self, tokens: Sequence[int]): batch=batch, n_past=n_past, logits_all=self._logits_all ) self._ctx.decode(self._batch) - # Save tokens self.input_ids[n_past : n_past + n_tokens] = batch - # Save logits if self._logits_all: - rows = n_tokens - cols = self._n_vocab logits = np.ctypeslib.as_array( - self._ctx.get_logits(), shape=(rows * cols,) + self._ctx.get_logits(), shape=(n_tokens, self._n_vocab) ) - self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits - else: - # rows = 1 - # cols = self._n_vocab - # logits = np.ctypeslib.as_array( - # self._ctx.get_logits(), shape=(rows * cols,) - # ) - # self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits - # NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all - pass - # Update n_tokens + self.scores[n_past : n_past + n_tokens, :] = logits self.n_tokens += n_tokens def _init_sampler( @@ -888,22 +881,20 @@ def generate( # Check for kv cache prefix match if reset and self.n_tokens > 0: longest_prefix = 0 - for a, b in zip(self._input_ids, tokens[:-1]): + for a, b in zip(self._input_ids, tokens): if a == b: longest_prefix += 1 else: break + + if (self._is_recurrent_model or self._has_swa_model) and longest_prefix < self.n_tokens: + longest_prefix = 0 + if longest_prefix > 0: if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1): reset = False tokens = tokens[longest_prefix:] self.n_tokens = longest_prefix - if self.verbose: - print( - f"Llama.generate: {longest_prefix} prefix-match hit, " - f"remaining {len(tokens)} prompt tokens to eval", - file=sys.stderr, - ) elif self.verbose: print( f"Llama.generate: {longest_prefix} prefix-match found " @@ -1267,12 +1258,9 @@ def logit_bias_processor( input_ids: npt.NDArray[np.intc], scores: npt.NDArray[np.single], ) -> npt.NDArray[np.single]: - new_scores = np.copy( - scores - ) # Does it make sense to copy the whole array or can we just overwrite the original one? for input_id, score in logit_bias_map.items(): - new_scores[input_id] = score + scores[input_id] - return new_scores + scores[input_id] += score + return scores _logit_bias_processor = LogitsProcessorList([logit_bias_processor]) if logits_processor is None: @@ -1333,6 +1321,7 @@ def logit_bias_processor( finish_reason = "length" multibyte_fix = 0 + accumulated_text = b"" for token in self.generate( prompt_tokens, top_k=top_k, @@ -1352,16 +1341,17 @@ def logit_bias_processor( grammar=grammar, ): if llama_cpp.llama_vocab_is_eog(self._model.vocab, token): - text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) + text = accumulated_text finish_reason = "stop" break completion_tokens.append(token) - all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) + new_text = self._model.token_to_piece(token) + accumulated_text += new_text # Contains multi-byte UTF8 - for k, char in enumerate(all_text[-3:]): + for k, char in enumerate(accumulated_text[-3:]): k = 3 - k for num, pattern in [(2, 192), (3, 224), (4, 240)]: # Bitwise AND check @@ -1373,19 +1363,16 @@ def logit_bias_processor( multibyte_fix -= 1 continue - any_stop = [s for s in stop_sequences if s in all_text] + any_stop = [s for s in stop_sequences if s in accumulated_text] if len(any_stop) > 0: first_stop = any_stop[0] - text = all_text[: all_text.index(first_stop)] + text = accumulated_text[: accumulated_text.index(first_stop)] finish_reason = "stop" break if stream: remaining_tokens = completion_tokens[returned_tokens:] - remaining_text = self.detokenize( - remaining_tokens, - prev_tokens=prompt_tokens + completion_tokens[:returned_tokens], - ) + remaining_text = self._model.token_to_piece(token) remaining_length = len(remaining_text) # We want to avoid yielding any characters from @@ -1522,14 +1509,14 @@ def logit_bias_processor( } if len(completion_tokens) >= max_tokens: - text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) + text = accumulated_text finish_reason = "length" break if stopping_criteria is not None and stopping_criteria( self._input_ids, self._scores[-1, :] ): - text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens) + text = accumulated_text finish_reason = "stop" if self.verbose: @@ -1537,9 +1524,8 @@ def logit_bias_processor( if stream: remaining_tokens = completion_tokens[returned_tokens:] - remaining_text = self.detokenize( - remaining_tokens, - prev_tokens=prompt_tokens + completion_tokens[:returned_tokens], + remaining_text = b"".join( + self._model.token_to_piece(t) for t in remaining_tokens ) any_stop = [s for s in stop_sequences if s in remaining_text] if len(any_stop) > 0: @@ -1549,12 +1535,8 @@ def logit_bias_processor( token_end_position = 0 for token in remaining_tokens: - token_end_position += len( - self.detokenize( - [token], - prev_tokens=prompt_tokens + completion_tokens[:returned_tokens], - ) - ) + token_piece = self._model.token_to_piece(token) + token_end_position += len(token_piece) logprobs_or_none: Optional[CompletionLogprobs] = None if logprobs is not None: @@ -1594,7 +1576,7 @@ def logit_bias_processor( } if token_end_position >= end: - last_text = self.detokenize([token]) + last_text = token_piece if token_end_position == end - 1: break returned_tokens += 1 @@ -1707,17 +1689,16 @@ def logit_bias_processor( ) ) tokens.append(token_str) - sorted_logprobs = list( - sorted( - zip(logprobs_token, range(len(logprobs_token))), reverse=True - ) - ) + top_k_indices = np.argpartition(logprobs_token, -logprobs)[-logprobs:] + top_k_indices = top_k_indices[ + np.argsort(logprobs_token[top_k_indices]) + ][::-1] token_logprobs.append(logprobs_token[int(token)]) top_logprob: Optional[Dict[str, float]] = { - self.detokenize([i], prev_tokens=all_tokens[:idx]).decode( + self.detokenize([int(i)], prev_tokens=all_tokens[:idx]).decode( "utf-8", errors="ignore" - ): logprob - for logprob, i in sorted_logprobs[:logprobs] + ): logprobs_token[int(i)] + for i in top_k_indices } top_logprob.update({token_str: logprobs_token[int(token)]}) top_logprobs.append(top_logprob)