diff --git a/Middleware/llmapis/handlers/impl/litellm_api_handler.py b/Middleware/llmapis/handlers/impl/litellm_api_handler.py new file mode 100644 index 0000000..1c4b8d5 --- /dev/null +++ b/Middleware/llmapis/handlers/impl/litellm_api_handler.py @@ -0,0 +1,147 @@ +# Middleware/llmapis/handlers/impl/litellm_api_handler.py + +""" +LiteLLM API handler using the LiteLLM Python SDK. + +Routes requests through litellm.completion() which supports 100+ LLM providers +(Anthropic, Bedrock, Vertex, Gemini, Cohere, Mistral, etc.) natively without +requiring a separate proxy server. +""" + +import json +import logging +from typing import Any, Dict, Generator, List, Optional, Union + +from Middleware.llmapis.handlers.base.base_llm_api_handler import LlmApiHandler + +logger = logging.getLogger(__name__) + + +class LiteLLMApiHandler(LlmApiHandler): + """ + Handles LLM interactions via the LiteLLM Python SDK. + + Uses litellm.completion() directly, supporting any model string LiteLLM + understands (e.g. anthropic/claude-sonnet-4-6, bedrock/claude-3.5-sonnet, + gpt-4o, gemini/gemini-2.5-pro, etc.). + + The api_key and base_url from WilmerAI endpoint config are forwarded to + litellm. If no api_key is set, litellm reads provider-specific env vars + (ANTHROPIC_API_KEY, OPENAI_API_KEY, etc.). + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + try: + import litellm + self._litellm = litellm + except ImportError as exc: + raise RuntimeError( + "litellm package is required for the litellmChatCompletion handler. " + "Install it with: pip install litellm" + ) from exc + + def _get_api_endpoint_url(self) -> str: + return self.base_url + + def _prepare_payload(self, conversation: Optional[List[Dict[str, str]]], system_prompt: Optional[str], + prompt: Optional[str], *, tools: Optional[List[Dict]] = None, + tool_choice: Optional[Any] = None) -> Dict: + messages: List[Dict[str, str]] = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + if conversation: + messages.extend(conversation) + if prompt: + messages.append({"role": "user", "content": prompt}) + + payload: Dict[str, Any] = { + "model": self.model_name, + "messages": messages, + "stream": self.stream, + "drop_params": True, + } + if self.max_tokens: + payload["max_tokens"] = self.max_tokens + if self.gen_input.get("temperature") is not None: + payload["temperature"] = self.gen_input["temperature"] + if self.api_key: + payload["api_key"] = self.api_key + if self.base_url: + payload["base_url"] = self.base_url + if tools: + payload["tools"] = tools + if tool_choice: + payload["tool_choice"] = tool_choice + return payload + + def _process_stream_data(self, data_str: str) -> Optional[Dict[str, Any]]: + try: + data = json.loads(data_str) + except json.JSONDecodeError: + return None + + choices = data.get("choices", []) + if not choices: + return None + + choice = choices[0] + delta = choice.get("delta", {}) + content = delta.get("content", "") + finish_reason = choice.get("finish_reason") + + if content or finish_reason: + return {"token": content or "", "finish_reason": finish_reason} + return None + + def _parse_non_stream_response(self, response_json: Dict) -> Union[str, Dict[str, Any]]: + choices = response_json.get("choices", []) + if not choices: + return "" + + choice = choices[0] + message = choice.get("message", {}) + content = message.get("content", "") + tool_calls = message.get("tool_calls") + finish_reason = choice.get("finish_reason") + + if tool_calls: + return { + "content": content or "", + "tool_calls": tool_calls, + "finish_reason": finish_reason, + } + return content or "" + + def handle_non_streaming(self, conversation=None, system_prompt=None, prompt=None, + request_id=None, tools=None, tool_choice=None): + payload = self._prepare_payload(conversation, system_prompt, prompt, + tools=tools, tool_choice=tool_choice) + model = payload.pop("model") + messages = payload.pop("messages") + stream = payload.pop("stream", False) + + response = self._litellm.completion(model=model, messages=messages, stream=False, **payload) + response_json = response.model_dump() + return self._parse_non_stream_response(response_json) + + def handle_streaming(self, conversation=None, system_prompt=None, prompt=None, + request_id=None, tools=None, tool_choice=None): + payload = self._prepare_payload(conversation, system_prompt, prompt, + tools=tools, tool_choice=tool_choice) + model = payload.pop("model") + messages = payload.pop("messages") + payload.pop("stream", None) + + response = self._litellm.completion(model=model, messages=messages, stream=True, **payload) + for chunk in response: + data = chunk.model_dump() + choices = data.get("choices", []) + if not choices: + continue + choice = choices[0] + delta = choice.get("delta", {}) + content = delta.get("content", "") + finish_reason = choice.get("finish_reason") + if content or finish_reason: + yield {"token": content or "", "finish_reason": finish_reason} diff --git a/Middleware/llmapis/llm_api.py b/Middleware/llmapis/llm_api.py index 4b25b59..6aab9c6 100644 --- a/Middleware/llmapis/llm_api.py +++ b/Middleware/llmapis/llm_api.py @@ -14,6 +14,7 @@ from Middleware.llmapis.handlers.base.base_llm_api_handler import LlmApiHandler from Middleware.llmapis.handlers.impl.claude_api_handler import ClaudeApiHandler from Middleware.llmapis.handlers.impl.koboldcpp_api_handler import KoboldCppApiHandler +from Middleware.llmapis.handlers.impl.litellm_api_handler import LiteLLMApiHandler from Middleware.llmapis.handlers.impl.ollama_chat_api_handler import OllamaChatHandler from Middleware.llmapis.handlers.impl.ollama_generate_api_handler import OllamaGenerateApiHandler from Middleware.llmapis.handlers.impl.openai_api_handler import OpenAiApiHandler @@ -315,6 +316,8 @@ def create_api_handler(self) -> LlmApiHandler: return OllamaChatHandler(**common_args) elif self.llm_type == "ollamaApiGenerate": return OllamaGenerateApiHandler(**common_args) + elif self.llm_type == "litellmChatCompletion": + return LiteLLMApiHandler(**common_args, dont_include_model=self.dont_include_model) else: raise ValueError(f"Unsupported LLM type: {self.llm_type}") diff --git a/requirements.txt b/requirements.txt index 69cd148..7c1ba58 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,5 +7,6 @@ Pillow==12.2.0 eventlet==0.41.0 waitress==3.0.2 cryptography==48.0.1 +litellm>=1.80.0,<1.87.0 mcp==1.27.2 PySocks==1.7.1