Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions patchwork/common/client/llm/aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,13 @@ def create_aio_client(inputs) -> "AioLlmClient" | None:
client = AnthropicLlmClient(anthropic_key)
clients.append(client)

litellm_key = inputs.get("litellm_api_key")
if litellm_key is not None:
from patchwork.common.client.llm.litellm_ import LiteLLMClient

client = LiteLLMClient(litellm_key)
clients.append(client)

if len(clients) == 0:
return None

Expand Down
145 changes: 145 additions & 0 deletions patchwork/common/client/llm/litellm_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

from pathlib import Path

import litellm
from openai.types.chat import (
ChatCompletion,
ChatCompletionMessageParam,
ChatCompletionToolChoiceOptionParam,
ChatCompletionToolParam,
completion_create_params,
)
from pydantic_ai.messages import ModelMessage, ModelResponse
from pydantic_ai.models import Model, ModelRequestParameters, StreamedResponse
from pydantic_ai.models.openai import OpenAIModel
from pydantic_ai.settings import ModelSettings
from pydantic_ai.usage import Usage
from typing_extensions import AsyncIterator, Dict, Iterable, List, Optional, Union

from patchwork.common.client.llm.protocol import NOT_GIVEN, LlmClient, NotGiven
from patchwork.logger import logger


class LiteLLMClient(LlmClient):
"""LLM client backed by the LiteLLM AI gateway.

Supports 100+ providers (OpenAI, Anthropic, Google, Azure,
AWS Bedrock, Ollama, Groq, Mistral, and more) through a unified
interface. Model strings follow the LiteLLM format, e.g.
``anthropic/claude-sonnet-4-6`` or ``openai/gpt-4o``.
"""

def __init__(self, api_key: str, **kwargs):
self.__api_key = api_key
self.__kwargs = kwargs

def __get_pydantic_model(self, model_settings: ModelSettings | None) -> Model:
if model_settings is None:
raise ValueError("Model settings cannot be None")
model_name = model_settings.get("model")
if model_name is None:
raise ValueError("Model must be set cannot be None")
return OpenAIModel(model_name, api_key=self.__api_key)

async def request(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> tuple[ModelResponse, Usage]:
model = self.__get_pydantic_model(model_settings)
return await model.request(messages, model_settings, model_request_parameters)

async def request_stream(
self,
messages: list[ModelMessage],
model_settings: ModelSettings | None,
model_request_parameters: ModelRequestParameters,
) -> AsyncIterator[StreamedResponse]:
model = self.__get_pydantic_model(model_settings)
yield model.request_stream(messages, model_settings, model_request_parameters)

@property
def model_name(self) -> str:
return "Undetermined"

@property
def system(self) -> str | None:
return "litellm"

def test(self):
return

def is_model_supported(self, model: str) -> bool:
return True

def is_prompt_supported(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
file: Path | NotGiven = NOT_GIVEN,
) -> int:
return 1

def truncate_messages(
self, messages: Iterable[ChatCompletionMessageParam], model: str
) -> Iterable[ChatCompletionMessageParam]:
return self._truncate_messages(self, messages, model)

def chat_completion(
self,
messages: Iterable[ChatCompletionMessageParam],
model: str,
frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN,
logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN,
logprobs: Optional[bool] | NotGiven = NOT_GIVEN,
max_tokens: Optional[int] | NotGiven = NOT_GIVEN,
n: Optional[int] | NotGiven = NOT_GIVEN,
presence_penalty: Optional[float] | NotGiven = NOT_GIVEN,
response_format: dict | completion_create_params.ResponseFormat | NotGiven = NOT_GIVEN,
stop: Union[Optional[str], List[str]] | NotGiven = NOT_GIVEN,
temperature: Optional[float] | NotGiven = NOT_GIVEN,
tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
tool_choice: ChatCompletionToolChoiceOptionParam | NotGiven = NOT_GIVEN,
top_logprobs: Optional[int] | NotGiven = NOT_GIVEN,
top_p: Optional[float] | NotGiven = NOT_GIVEN,
file: Path | NotGiven = NOT_GIVEN,
) -> ChatCompletion:
input_kwargs = dict(
messages=messages,
model=model,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=response_format,
stop=stop,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
top_logprobs=top_logprobs,
top_p=top_p,
drop_params=True,
)

cleaned = NotGiven.remove_not_given(input_kwargs)
if self.__api_key:
cleaned["api_key"] = self.__api_key

return litellm.completion(**cleaned)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ requests = "~2.32.3"
chardet = "~5.2.0"
attrs = "~23.2.0"
anthropic = "^0.49.0"
litellm = ">=1.35.0"
rich = "~13.7.1"
chevron = "~0.14.0"
giturlparse = "~0.12.0"
Expand Down
44 changes: 44 additions & 0 deletions tests/common/client/llm/test_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from unittest.mock import MagicMock, patch

from patchwork.common.client.llm.litellm_ import LiteLLMClient


def test_litellm_client_is_model_supported():
client = LiteLLMClient(api_key="test-key")
assert client.is_model_supported("anthropic/claude-sonnet-4-6") is True
assert client.is_model_supported("openai/gpt-4o") is True
assert client.is_model_supported("any-model") is True


def test_litellm_client_chat_completion():
client = LiteLLMClient(api_key="test-key")

mock_response = MagicMock()
mock_response.choices = [MagicMock(message=MagicMock(content="4"))]

with patch("litellm.completion", return_value=mock_response) as mock_completion:
response = client.chat_completion(
messages=[{"role": "user", "content": "What is 2+2?"}],
model="anthropic/claude-sonnet-4-6",
)

mock_completion.assert_called_once()
call_kwargs = mock_completion.call_args
assert call_kwargs.kwargs["model"] == "anthropic/claude-sonnet-4-6"
assert call_kwargs.kwargs["drop_params"] is True
assert call_kwargs.kwargs["api_key"] == "test-key"
assert response.choices[0].message.content == "4"


def test_litellm_client_drop_params_default():
client = LiteLLMClient(api_key="test-key")

mock_response = MagicMock()
with patch("litellm.completion", return_value=mock_response) as mock_completion:
client.chat_completion(
messages=[{"role": "user", "content": "test"}],
model="openai/gpt-4o",
)

call_kwargs = mock_completion.call_args.kwargs
assert call_kwargs["drop_params"] is True