Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions llama_cpp/_internals.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,12 @@ def tokenize(self, text: bytes, add_bos: bool, special: bool):
return list(tokens[:n_tokens])

def token_to_piece(self, token: int, special: bool = False) -> bytes:
buf = ctypes.create_string_buffer(32)
llama_cpp.llama_token_to_piece(self.vocab, token, buf, 32, 0, special)
return bytes(buf)
size = 32
buffer = (ctypes.c_char * size)()
n = llama_cpp.llama_token_to_piece(
self.vocab, llama_cpp.llama_token(token), buffer, size, 0, special
)
return bytes(buffer[:n])

def detokenize(self, tokens: List[int], special: bool = False) -> bytes:
output = b""
Expand Down Expand Up @@ -503,13 +506,17 @@ def reset(self):
def set_batch(self, batch: Sequence[int], n_past: int, logits_all: bool):
n_tokens = len(batch)
self.batch.n_tokens = n_tokens
token_arr = np.ctypeslib.as_array(self.batch.token, shape=(n_tokens,))
token_arr[:] = batch
pos_arr = np.ctypeslib.as_array(self.batch.pos, shape=(n_tokens,))
pos_arr[:] = np.arange(n_past, n_past + n_tokens, dtype=pos_arr.dtype)
n_seq_id_arr = np.ctypeslib.as_array(self.batch.n_seq_id, shape=(n_tokens,))
n_seq_id_arr[:] = 1
logits_arr = np.ctypeslib.as_array(self.batch.logits, shape=(n_tokens,))
logits_arr[:] = logits_all
logits_arr[n_tokens - 1] = True
for i in range(n_tokens):
self.batch.token[i] = batch[i]
self.batch.pos[i] = n_past + i
self.batch.seq_id[i][0] = 0
self.batch.n_seq_id[i] = 1
self.batch.logits[i] = logits_all
self.batch.logits[n_tokens - 1] = True

def add_sequence(self, batch: Sequence[int], seq_id: int, logits_all: bool):
n_tokens = len(batch)
Expand Down
107 changes: 44 additions & 63 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,14 @@ def free_lora_adapter():

self._sampler = None

# Cache model architecture flags to avoid repeated FFI calls
self._is_recurrent_model = llama_cpp.llama_model_is_recurrent(
self._model.model
) or llama_cpp.llama_model_is_hybrid(self._model.model)
self._has_swa_model = llama_cpp.llama_model_n_swa(
self._model.model
) > 0

@property
def ctx(self) -> llama_cpp.llama_context_p:
return self._ctx.ctx
Expand Down Expand Up @@ -638,13 +646,12 @@ def reset(self):
"""Reset the model state."""
self.n_tokens = 0

def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens.
mem = llama_cpp.llama_get_memory(self._ctx.ctx)
if mem is not None:
llama_cpp.llama_memory_clear(mem, True)

Args:
tokens: The list of tokens to evaluate.
"""
self._ctx.kv_cache_seq_rm(-1, self.n_tokens, -1)
def eval(self, tokens: Sequence[int]):
"""Evaluate a list of tokens."""
for i in range(0, len(tokens), self.n_batch):
batch = tokens[i : min(len(tokens), i + self.n_batch)]
n_past = self.n_tokens
Expand All @@ -653,26 +660,12 @@ def eval(self, tokens: Sequence[int]):
batch=batch, n_past=n_past, logits_all=self._logits_all
)
self._ctx.decode(self._batch)
# Save tokens
self.input_ids[n_past : n_past + n_tokens] = batch
# Save logits
if self._logits_all:
rows = n_tokens
cols = self._n_vocab
logits = np.ctypeslib.as_array(
self._ctx.get_logits(), shape=(rows * cols,)
self._ctx.get_logits(), shape=(n_tokens, self._n_vocab)
)
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[::] = logits
else:
# rows = 1
# cols = self._n_vocab
# logits = np.ctypeslib.as_array(
# self._ctx.get_logits(), shape=(rows * cols,)
# )
# self.scores[n_past + n_tokens - 1, :].reshape(-1)[::] = logits
# NOTE: Now that sampling is done inside the sampler, logits are only needed for logprobs which requires logits_all
pass
# Update n_tokens
self.scores[n_past : n_past + n_tokens, :] = logits
self.n_tokens += n_tokens

def _init_sampler(
Expand Down Expand Up @@ -888,22 +881,20 @@ def generate(
# Check for kv cache prefix match
if reset and self.n_tokens > 0:
longest_prefix = 0
for a, b in zip(self._input_ids, tokens[:-1]):
for a, b in zip(self._input_ids, tokens):
if a == b:
longest_prefix += 1
else:
break

if (self._is_recurrent_model or self._has_swa_model) and longest_prefix < self.n_tokens:
longest_prefix = 0

if longest_prefix > 0:
if self._ctx.kv_cache_seq_rm(-1, longest_prefix, -1):
reset = False
tokens = tokens[longest_prefix:]
self.n_tokens = longest_prefix
if self.verbose:
print(
f"Llama.generate: {longest_prefix} prefix-match hit, "
f"remaining {len(tokens)} prompt tokens to eval",
file=sys.stderr,
)
elif self.verbose:
print(
f"Llama.generate: {longest_prefix} prefix-match found "
Expand Down Expand Up @@ -1267,12 +1258,9 @@ def logit_bias_processor(
input_ids: npt.NDArray[np.intc],
scores: npt.NDArray[np.single],
) -> npt.NDArray[np.single]:
new_scores = np.copy(
scores
) # Does it make sense to copy the whole array or can we just overwrite the original one?
for input_id, score in logit_bias_map.items():
new_scores[input_id] = score + scores[input_id]
return new_scores
scores[input_id] += score
return scores

_logit_bias_processor = LogitsProcessorList([logit_bias_processor])
if logits_processor is None:
Expand Down Expand Up @@ -1333,6 +1321,7 @@ def logit_bias_processor(

finish_reason = "length"
multibyte_fix = 0
accumulated_text = b""
for token in self.generate(
prompt_tokens,
top_k=top_k,
Expand All @@ -1352,16 +1341,17 @@ def logit_bias_processor(
grammar=grammar,
):
if llama_cpp.llama_vocab_is_eog(self._model.vocab, token):
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
text = accumulated_text
finish_reason = "stop"
break

completion_tokens.append(token)

all_text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
new_text = self._model.token_to_piece(token)
accumulated_text += new_text

# Contains multi-byte UTF8
for k, char in enumerate(all_text[-3:]):
for k, char in enumerate(accumulated_text[-3:]):
k = 3 - k
for num, pattern in [(2, 192), (3, 224), (4, 240)]:
# Bitwise AND check
Expand All @@ -1373,19 +1363,16 @@ def logit_bias_processor(
multibyte_fix -= 1
continue

any_stop = [s for s in stop_sequences if s in all_text]
any_stop = [s for s in stop_sequences if s in accumulated_text]
if len(any_stop) > 0:
first_stop = any_stop[0]
text = all_text[: all_text.index(first_stop)]
text = accumulated_text[: accumulated_text.index(first_stop)]
finish_reason = "stop"
break

if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(
remaining_tokens,
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
)
remaining_text = self._model.token_to_piece(token)
remaining_length = len(remaining_text)

# We want to avoid yielding any characters from
Expand Down Expand Up @@ -1522,24 +1509,23 @@ def logit_bias_processor(
}

if len(completion_tokens) >= max_tokens:
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
text = accumulated_text
finish_reason = "length"
break

if stopping_criteria is not None and stopping_criteria(
self._input_ids, self._scores[-1, :]
):
text = self.detokenize(completion_tokens, prev_tokens=prompt_tokens)
text = accumulated_text
finish_reason = "stop"

if self.verbose:
self._ctx.print_timings()

if stream:
remaining_tokens = completion_tokens[returned_tokens:]
remaining_text = self.detokenize(
remaining_tokens,
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
remaining_text = b"".join(
self._model.token_to_piece(t) for t in remaining_tokens
)
any_stop = [s for s in stop_sequences if s in remaining_text]
if len(any_stop) > 0:
Expand All @@ -1549,12 +1535,8 @@ def logit_bias_processor(

token_end_position = 0
for token in remaining_tokens:
token_end_position += len(
self.detokenize(
[token],
prev_tokens=prompt_tokens + completion_tokens[:returned_tokens],
)
)
token_piece = self._model.token_to_piece(token)
token_end_position += len(token_piece)

logprobs_or_none: Optional[CompletionLogprobs] = None
if logprobs is not None:
Expand Down Expand Up @@ -1594,7 +1576,7 @@ def logit_bias_processor(
}

if token_end_position >= end:
last_text = self.detokenize([token])
last_text = token_piece
if token_end_position == end - 1:
break
returned_tokens += 1
Expand Down Expand Up @@ -1707,17 +1689,16 @@ def logit_bias_processor(
)
)
tokens.append(token_str)
sorted_logprobs = list(
sorted(
zip(logprobs_token, range(len(logprobs_token))), reverse=True
)
)
top_k_indices = np.argpartition(logprobs_token, -logprobs)[-logprobs:]
top_k_indices = top_k_indices[
np.argsort(logprobs_token[top_k_indices])
][::-1]
token_logprobs.append(logprobs_token[int(token)])
top_logprob: Optional[Dict[str, float]] = {
self.detokenize([i], prev_tokens=all_tokens[:idx]).decode(
self.detokenize([int(i)], prev_tokens=all_tokens[:idx]).decode(
"utf-8", errors="ignore"
): logprob
for logprob, i in sorted_logprobs[:logprobs]
): logprobs_token[int(i)]
for i in top_k_indices
}
top_logprob.update({token_str: logprobs_token[int(token)]})
top_logprobs.append(top_logprob)
Expand Down