From 38da3128d84518213570e43eb0f7e2f096d5559f Mon Sep 17 00:00:00 2001 From: Chris Weaver <25087905+Weves@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:33:53 -0700 Subject: [PATCH] Pass headers into image generation (#1739) --- backend/danswer/chat/process_message.py | 5 ++++- backend/danswer/llm/factory.py | 10 ++-------- backend/danswer/llm/headers.py | 12 ++++++++++++ .../danswer/tools/images/image_generation_tool.py | 10 +++++++++- 4 files changed, 27 insertions(+), 10 deletions(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 7733bf523d60..5d0a33695bf8 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -444,7 +444,10 @@ def stream_chat_message_objects( ) dalle_key = openai_provider.api_key tool_dict[db_tool_model.id] = [ - ImageGenerationTool(api_key=dalle_key) + ImageGenerationTool( + api_key=dalle_key, + additional_headers=litellm_additional_headers, + ) ] continue diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index d85dbdc9f616..edad6a295ad6 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,13 +1,13 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_default_provider from danswer.db.llm import fetch_provider from danswer.db.models import Persona from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.exceptions import GenAIDisabledException +from danswer.llm.headers import build_llm_extra_headers from danswer.llm.interfaces import LLM from danswer.llm.override_models import LLMOverride @@ -84,12 +84,6 @@ def get_llm( timeout: int = QA_TIMEOUT, additional_headers: dict[str, str] | None = None, ) -> LLM: - extra_headers = {} - if additional_headers: - extra_headers.update(additional_headers) - if LITELLM_EXTRA_HEADERS: - extra_headers.update(LITELLM_EXTRA_HEADERS) - return DefaultMultiLLM( model_provider=provider, model_name=model, @@ -99,5 +93,5 @@ def get_llm( timeout=timeout, temperature=temperature, custom_config=custom_config, - extra_headers=extra_headers, + extra_headers=build_llm_extra_headers(additional_headers), ) diff --git a/backend/danswer/llm/headers.py b/backend/danswer/llm/headers.py index f7ae7436fb46..b43c83e141e8 100644 --- a/backend/danswer/llm/headers.py +++ b/backend/danswer/llm/headers.py @@ -1,5 +1,6 @@ from fastapi.datastructures import Headers +from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS @@ -20,3 +21,14 @@ def get_litellm_additional_request_headers( pass_through_headers[lowercase_key] = headers[lowercase_key] return pass_through_headers + + +def build_llm_extra_headers( + additional_headers: dict[str, str] | None = None +) -> dict[str, str]: + extra_headers: dict[str, str] = {} + if additional_headers: + extra_headers.update(additional_headers) + if LITELLM_EXTRA_HEADERS: + extra_headers.update(LITELLM_EXTRA_HEADERS) + return extra_headers diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index 22aa40993b6d..3785b549c8ed 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -10,6 +10,7 @@ from danswer.chat.chat_utils import combine_message_chain from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage +from danswer.llm.headers import build_llm_extra_headers from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import message_to_string @@ -57,12 +58,18 @@ class ImageGenerationTool(Tool): NAME = "run_image_generation" def __init__( - self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2 + self, + api_key: str, + model: str = "dall-e-3", + num_imgs: int = 2, + additional_headers: dict[str, str] | None = None, ) -> None: self.api_key = api_key self.model = model self.num_imgs = num_imgs + self.additional_headers = additional_headers + def name(self) -> str: return self.NAME @@ -142,6 +149,7 @@ class ImageGenerationTool(Tool): model=self.model, api_key=self.api_key, n=1, + extra_headers=build_llm_extra_headers(self.additional_headers), ) return ImageGenerationResponse( revised_prompt=response.data[0]["revised_prompt"],