diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 9dab81de4..8ed6e5c00 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -18,6 +18,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.constants import MessageType +from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.db.chat import attach_files_to_chat_message from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message @@ -49,6 +50,7 @@ from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_llms_for_persona from danswer.llm.factory import get_main_llm_from_tuple +from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import get_default_llm_tokenizer from danswer.search.enums import OptionalSearchSetting from danswer.search.retrieval.search_runner import inference_documents_from_ids @@ -435,13 +437,13 @@ def stream_chat_message_objects( ) tool_dict[db_tool_model.id] = [search_tool] elif tool_cls.__name__ == ImageGenerationTool.__name__: - dalle_key = None + img_generation_llm_config: LLMConfig | None = None if ( llm and llm.config.api_key and llm.config.model_provider == "openai" ): - dalle_key = llm.config.api_key + img_generation_llm_config = llm.config else: llm_providers = fetch_existing_llm_providers(db_session) openai_provider = next( @@ -458,10 +460,19 @@ def stream_chat_message_objects( raise ValueError( "Image generation tool requires an OpenAI API key" ) - dalle_key = openai_provider.api_key + img_generation_llm_config = LLMConfig( + model_provider=openai_provider.provider, + model_name=openai_provider.default_model_name, + temperature=GEN_AI_TEMPERATURE, + api_key=openai_provider.api_key, + api_base=openai_provider.api_base, + api_version=openai_provider.api_version, + ) tool_dict[db_tool_model.id] = [ ImageGenerationTool( - api_key=dalle_key, + api_key=cast(str, img_generation_llm_config.api_key), + api_base=img_generation_llm_config.api_base, + api_version=img_generation_llm_config.api_version, additional_headers=litellm_additional_headers, ) ] diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 9ed02c14f..2f60708bb 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -306,6 +306,8 @@ class DefaultMultiLLM(LLM): model_name=self._model_version, temperature=self._temperature, api_key=self._api_key, + api_base=self._api_base, + api_version=self._api_version, ) def invoke( diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 1f99383fa..e876403c4 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -19,6 +19,8 @@ class LLMConfig(BaseModel): model_name: str temperature: float api_key: str | None + api_base: str | None + api_version: str | None class LLM(abc.ABC): diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index 3785b549c..706f055ad 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -60,11 +60,16 @@ class ImageGenerationTool(Tool): def __init__( self, api_key: str, + api_base: str | None, + api_version: str | None, model: str = "dall-e-3", num_imgs: int = 2, additional_headers: dict[str, str] | None = None, ) -> None: self.api_key = api_key + self.api_base = api_base + self.api_version = api_version + self.model = model self.num_imgs = num_imgs @@ -148,6 +153,9 @@ class ImageGenerationTool(Tool): prompt=prompt, model=self.model, api_key=self.api_key, + # need to pass in None rather than empty str + api_base=self.api_base or None, + api_version=self.api_version or None, n=1, extra_headers=build_llm_extra_headers(self.additional_headers), )