mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-09 12:47:13 +02:00
Pass headers into image generation (#1739)
This commit is contained in:
@@ -444,7 +444,10 @@ def stream_chat_message_objects(
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(api_key=dalle_key)
|
||||
ImageGenerationTool(
|
||||
api_key=dalle_key,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
|
||||
continue
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
from danswer.db.llm import fetch_provider
|
||||
from danswer.db.models import Persona
|
||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.headers import build_llm_extra_headers
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
|
||||
@@ -84,12 +84,6 @@ def get_llm(
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> LLM:
|
||||
extra_headers = {}
|
||||
if additional_headers:
|
||||
extra_headers.update(additional_headers)
|
||||
if LITELLM_EXTRA_HEADERS:
|
||||
extra_headers.update(LITELLM_EXTRA_HEADERS)
|
||||
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
@@ -99,5 +93,5 @@ def get_llm(
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=extra_headers,
|
||||
extra_headers=build_llm_extra_headers(additional_headers),
|
||||
)
|
||||
|
@@ -1,5 +1,6 @@
|
||||
from fastapi.datastructures import Headers
|
||||
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
|
||||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
|
||||
|
||||
@@ -20,3 +21,14 @@ def get_litellm_additional_request_headers(
|
||||
pass_through_headers[lowercase_key] = headers[lowercase_key]
|
||||
|
||||
return pass_through_headers
|
||||
|
||||
|
||||
def build_llm_extra_headers(
|
||||
additional_headers: dict[str, str] | None = None
|
||||
) -> dict[str, str]:
|
||||
extra_headers: dict[str, str] = {}
|
||||
if additional_headers:
|
||||
extra_headers.update(additional_headers)
|
||||
if LITELLM_EXTRA_HEADERS:
|
||||
extra_headers.update(LITELLM_EXTRA_HEADERS)
|
||||
return extra_headers
|
||||
|
@@ -10,6 +10,7 @@ from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.headers import build_llm_extra_headers
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import message_to_string
|
||||
@@ -57,12 +58,18 @@ class ImageGenerationTool(Tool):
|
||||
NAME = "run_image_generation"
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2
|
||||
self,
|
||||
api_key: str,
|
||||
model: str = "dall-e-3",
|
||||
num_imgs: int = 2,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
self.additional_headers = additional_headers
|
||||
|
||||
def name(self) -> str:
|
||||
return self.NAME
|
||||
|
||||
@@ -142,6 +149,7 @@ class ImageGenerationTool(Tool):
|
||||
model=self.model,
|
||||
api_key=self.api_key,
|
||||
n=1,
|
||||
extra_headers=build_llm_extra_headers(self.additional_headers),
|
||||
)
|
||||
return ImageGenerationResponse(
|
||||
revised_prompt=response.data[0]["revised_prompt"],
|
||||
|
Reference in New Issue
Block a user