Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions llama_cpp/llama_chat_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -3229,6 +3229,62 @@ def from_pretrained(
)


class GemmaChatHandler(Llava15ChatHandler):
"""Chat handler for Gemma-based multimodal models (e.g., PaliGemma, MedGemma).

Gemma models use <start_of_turn>/<end_of_turn> control tokens instead of
the LLaVA-style USER:/ASSISTANT: format. The text-only 'gemma' chat format
is already registered (see format_gemma), but multimodal Gemma models that
require a Llava-style vision pipeline need a dedicated handler so the
correct chat template is applied when chat_handler takes precedence over
chat_format in the resolution order.

See: https://ai.google.dev/gemma/docs/formatting
"""

DEFAULT_SYSTEM_MESSAGE = None # Gemma models do not natively support a system role

CHAT_FORMAT = (
"{% for message in messages %}"
# System messages are folded into a user turn (Gemma has no system role)
"{% if message.role == 'system' %}"
"<start_of_turn>user\n{{ message.content }}<end_of_turn>\n"
"{% endif %}"
# User message (handles both plain string and multimodal content list)
"{% if message.role == 'user' %}"
"<start_of_turn>user\n"
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.content is iterable and message.content is not string %}"
# Emit image tokens first
"{% for content in message.content %}"
"{% if content.type == 'image_url' and content.image_url is string %}"
"{{ content.image_url }}"
"{% endif %}"
"{% if content.type == 'image_url' and content.image_url is mapping %}"
"{{ content.image_url.url }}"
"{% endif %}"
"{% endfor %}"
# Then emit text tokens
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"<end_of_turn>\n"
"{% endif %}"
# Assistant message
"{% if message.role == 'assistant' and message.content is not none %}"
"<start_of_turn>model\n{{ message.content }}<end_of_turn>\n"
"{% endif %}"
"{% endfor %}"
# Generation prompt
"{% if add_generation_prompt %}"
"<start_of_turn>model\n"
"{% endif %}"
)
class ObsidianChatHandler(Llava15ChatHandler):
# Prompt Format
# The model followed ChatML format. However, with ### as the seperator
Expand Down Expand Up @@ -3581,6 +3637,44 @@ def __call__(self, **kwargs):
return super().__call__(**kwargs)


class MultimodalGemmaChatHandler(Llava15ChatHandler):
DEFAULT_SYSTEM_MESSAGE: Optional[str] = None

CHAT_FORMAT = (
"{% for message in messages %}"
"{% if message.role == 'user' %}"
"<start_of_turn>user\n"
"{% if message.content is string %}"
"{{ message.content }}"
"{% endif %}"
"{% if message.content is iterable %}"
"{% for content in message.content %}"
"{% if content.type == 'image_url' and content.image_url is string %}"
"{{ content.image_url }}"
"{% endif %}"
"{% if content.type == 'image_url' and content.image_url is mapping %}"
"{{ content.image_url.url }}"
"{% endif %}"
"{% endfor %}"
"{% for content in message.content %}"
"{% if content.type == 'text' %}"
"{{ content.text }}"
"{% endif %}"
"{% endfor %}"
"{% endif %}"
"<end_of_turn>\n"
"{% endif %}"
"{% if message.role == 'assistant' and message.content is not none %}"
"<start_of_turn>model\n"
"{{ message.content }}<end_of_turn>\n"
"{% endif %}"
"{% endfor %}"
"{% if add_generation_prompt %}"
"<start_of_turn>model\n"
"{% endif %}"
)


@register_chat_completion_handler("chatml-function-calling")
def chatml_function_calling(
llama: llama.Llama,
Expand Down