Pass headers into image generation (#1739)

This commit is contained in:
Chris Weaver
2024-06-28 12:33:53 -07:00
committed by GitHub
parent e47da0d688
commit 38da3128d8
4 changed files with 27 additions and 10 deletions

View File

@@ -444,7 +444,10 @@ def stream_chat_message_objects(
)
dalle_key = openai_provider.api_key
tool_dict[db_tool_model.id] = [
ImageGenerationTool(api_key=dalle_key)
ImageGenerationTool(
api_key=dalle_key,
additional_headers=litellm_additional_headers,
)
]
continue

View File

@@ -1,13 +1,13 @@
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_default_provider
from danswer.db.llm import fetch_provider
from danswer.db.models import Persona
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
@@ -84,12 +84,6 @@ def get_llm(
timeout: int = QA_TIMEOUT,
additional_headers: dict[str, str] | None = None,
) -> LLM:
extra_headers = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return DefaultMultiLLM(
model_provider=provider,
model_name=model,
@@ -99,5 +93,5 @@ def get_llm(
timeout=timeout,
temperature=temperature,
custom_config=custom_config,
extra_headers=extra_headers,
extra_headers=build_llm_extra_headers(additional_headers),
)

View File

@@ -1,5 +1,6 @@
from fastapi.datastructures import Headers
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
@@ -20,3 +21,14 @@ def get_litellm_additional_request_headers(
pass_through_headers[lowercase_key] = headers[lowercase_key]
return pass_through_headers
def build_llm_extra_headers(
additional_headers: dict[str, str] | None = None
) -> dict[str, str]:
extra_headers: dict[str, str] = {}
if additional_headers:
extra_headers.update(additional_headers)
if LITELLM_EXTRA_HEADERS:
extra_headers.update(LITELLM_EXTRA_HEADERS)
return extra_headers

View File

@@ -10,6 +10,7 @@ from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.headers import build_llm_extra_headers
from danswer.llm.interfaces import LLM
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import message_to_string
@@ -57,12 +58,18 @@ class ImageGenerationTool(Tool):
NAME = "run_image_generation"
def __init__(
self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2
self,
api_key: str,
model: str = "dall-e-3",
num_imgs: int = 2,
additional_headers: dict[str, str] | None = None,
) -> None:
self.api_key = api_key
self.model = model
self.num_imgs = num_imgs
self.additional_headers = additional_headers
def name(self) -> str:
return self.NAME
@@ -142,6 +149,7 @@ class ImageGenerationTool(Tool):
model=self.model,
api_key=self.api_key,
n=1,
extra_headers=build_llm_extra_headers(self.additional_headers),
)
return ImageGenerationResponse(
revised_prompt=response.data[0]["revised_prompt"],