Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions Middleware/llmapis/handlers/impl/litellm_api_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Middleware/llmapis/handlers/impl/litellm_api_handler.py

"""
LiteLLM API handler using the LiteLLM Python SDK.

Routes requests through litellm.completion() which supports 100+ LLM providers
(Anthropic, Bedrock, Vertex, Gemini, Cohere, Mistral, etc.) natively without
requiring a separate proxy server.
"""

import json
import logging
from typing import Any, Dict, Generator, List, Optional, Union

from Middleware.llmapis.handlers.base.base_llm_api_handler import LlmApiHandler

logger = logging.getLogger(__name__)


class LiteLLMApiHandler(LlmApiHandler):
"""
Handles LLM interactions via the LiteLLM Python SDK.

Uses litellm.completion() directly, supporting any model string LiteLLM
understands (e.g. anthropic/claude-sonnet-4-6, bedrock/claude-3.5-sonnet,
gpt-4o, gemini/gemini-2.5-pro, etc.).

The api_key and base_url from WilmerAI endpoint config are forwarded to
litellm. If no api_key is set, litellm reads provider-specific env vars
(ANTHROPIC_API_KEY, OPENAI_API_KEY, etc.).
"""

def __init__(self, **kwargs):
super().__init__(**kwargs)
try:
import litellm
self._litellm = litellm
except ImportError as exc:
raise RuntimeError(
"litellm package is required for the litellmChatCompletion handler. "
"Install it with: pip install litellm"
) from exc

def _get_api_endpoint_url(self) -> str:
return self.base_url

def _prepare_payload(self, conversation: Optional[List[Dict[str, str]]], system_prompt: Optional[str],
prompt: Optional[str], *, tools: Optional[List[Dict]] = None,
tool_choice: Optional[Any] = None) -> Dict:
messages: List[Dict[str, str]] = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if conversation:
messages.extend(conversation)
if prompt:
messages.append({"role": "user", "content": prompt})

payload: Dict[str, Any] = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
"drop_params": True,
}
if self.max_tokens:
payload["max_tokens"] = self.max_tokens
if self.gen_input.get("temperature") is not None:
payload["temperature"] = self.gen_input["temperature"]
if self.api_key:
payload["api_key"] = self.api_key
if self.base_url:
payload["base_url"] = self.base_url
if tools:
payload["tools"] = tools
if tool_choice:
payload["tool_choice"] = tool_choice
return payload

def _process_stream_data(self, data_str: str) -> Optional[Dict[str, Any]]:
try:
data = json.loads(data_str)
except json.JSONDecodeError:
return None

choices = data.get("choices", [])
if not choices:
return None

choice = choices[0]
delta = choice.get("delta", {})
content = delta.get("content", "")
finish_reason = choice.get("finish_reason")

if content or finish_reason:
return {"token": content or "", "finish_reason": finish_reason}
return None

def _parse_non_stream_response(self, response_json: Dict) -> Union[str, Dict[str, Any]]:
choices = response_json.get("choices", [])
if not choices:
return ""

choice = choices[0]
message = choice.get("message", {})
content = message.get("content", "")
tool_calls = message.get("tool_calls")
finish_reason = choice.get("finish_reason")

if tool_calls:
return {
"content": content or "",
"tool_calls": tool_calls,
"finish_reason": finish_reason,
}
return content or ""

def handle_non_streaming(self, conversation=None, system_prompt=None, prompt=None,
request_id=None, tools=None, tool_choice=None):
payload = self._prepare_payload(conversation, system_prompt, prompt,
tools=tools, tool_choice=tool_choice)
model = payload.pop("model")
messages = payload.pop("messages")
stream = payload.pop("stream", False)

response = self._litellm.completion(model=model, messages=messages, stream=False, **payload)
response_json = response.model_dump()
return self._parse_non_stream_response(response_json)

def handle_streaming(self, conversation=None, system_prompt=None, prompt=None,
request_id=None, tools=None, tool_choice=None):
payload = self._prepare_payload(conversation, system_prompt, prompt,
tools=tools, tool_choice=tool_choice)
model = payload.pop("model")
messages = payload.pop("messages")
payload.pop("stream", None)

response = self._litellm.completion(model=model, messages=messages, stream=True, **payload)
for chunk in response:
data = chunk.model_dump()
choices = data.get("choices", [])
if not choices:
continue
choice = choices[0]
delta = choice.get("delta", {})
content = delta.get("content", "")
finish_reason = choice.get("finish_reason")
if content or finish_reason:
yield {"token": content or "", "finish_reason": finish_reason}
3 changes: 3 additions & 0 deletions Middleware/llmapis/llm_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from Middleware.llmapis.handlers.base.base_llm_api_handler import LlmApiHandler
from Middleware.llmapis.handlers.impl.claude_api_handler import ClaudeApiHandler
from Middleware.llmapis.handlers.impl.koboldcpp_api_handler import KoboldCppApiHandler
from Middleware.llmapis.handlers.impl.litellm_api_handler import LiteLLMApiHandler
from Middleware.llmapis.handlers.impl.ollama_chat_api_handler import OllamaChatHandler
from Middleware.llmapis.handlers.impl.ollama_generate_api_handler import OllamaGenerateApiHandler
from Middleware.llmapis.handlers.impl.openai_api_handler import OpenAiApiHandler
Expand Down Expand Up @@ -315,6 +316,8 @@ def create_api_handler(self) -> LlmApiHandler:
return OllamaChatHandler(**common_args)
elif self.llm_type == "ollamaApiGenerate":
return OllamaGenerateApiHandler(**common_args)
elif self.llm_type == "litellmChatCompletion":
return LiteLLMApiHandler(**common_args, dont_include_model=self.dont_include_model)
else:
raise ValueError(f"Unsupported LLM type: {self.llm_type}")

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ Pillow==12.2.0
eventlet==0.41.0
waitress==3.0.2
cryptography==48.0.1
litellm>=1.80.0,<1.87.0
mcp==1.27.2
PySocks==1.7.1