mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 13:15:18 +02:00
Pass headers into image generation (#1739)
This commit is contained in:
@@ -444,7 +444,10 @@ def stream_chat_message_objects(
|
|||||||
)
|
)
|
||||||
dalle_key = openai_provider.api_key
|
dalle_key = openai_provider.api_key
|
||||||
tool_dict[db_tool_model.id] = [
|
tool_dict[db_tool_model.id] = [
|
||||||
ImageGenerationTool(api_key=dalle_key)
|
ImageGenerationTool(
|
||||||
|
api_key=dalle_key,
|
||||||
|
additional_headers=litellm_additional_headers,
|
||||||
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
continue
|
continue
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
|
|
||||||
from danswer.db.engine import get_session_context_manager
|
from danswer.db.engine import get_session_context_manager
|
||||||
from danswer.db.llm import fetch_default_provider
|
from danswer.db.llm import fetch_default_provider
|
||||||
from danswer.db.llm import fetch_provider
|
from danswer.db.llm import fetch_provider
|
||||||
from danswer.db.models import Persona
|
from danswer.db.models import Persona
|
||||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||||
from danswer.llm.exceptions import GenAIDisabledException
|
from danswer.llm.exceptions import GenAIDisabledException
|
||||||
|
from danswer.llm.headers import build_llm_extra_headers
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.override_models import LLMOverride
|
from danswer.llm.override_models import LLMOverride
|
||||||
|
|
||||||
@@ -84,12 +84,6 @@ def get_llm(
|
|||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
additional_headers: dict[str, str] | None = None,
|
additional_headers: dict[str, str] | None = None,
|
||||||
) -> LLM:
|
) -> LLM:
|
||||||
extra_headers = {}
|
|
||||||
if additional_headers:
|
|
||||||
extra_headers.update(additional_headers)
|
|
||||||
if LITELLM_EXTRA_HEADERS:
|
|
||||||
extra_headers.update(LITELLM_EXTRA_HEADERS)
|
|
||||||
|
|
||||||
return DefaultMultiLLM(
|
return DefaultMultiLLM(
|
||||||
model_provider=provider,
|
model_provider=provider,
|
||||||
model_name=model,
|
model_name=model,
|
||||||
@@ -99,5 +93,5 @@ def get_llm(
|
|||||||
timeout=timeout,
|
timeout=timeout,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
custom_config=custom_config,
|
custom_config=custom_config,
|
||||||
extra_headers=extra_headers,
|
extra_headers=build_llm_extra_headers(additional_headers),
|
||||||
)
|
)
|
||||||
|
@@ -1,5 +1,6 @@
|
|||||||
from fastapi.datastructures import Headers
|
from fastapi.datastructures import Headers
|
||||||
|
|
||||||
|
from danswer.configs.model_configs import LITELLM_EXTRA_HEADERS
|
||||||
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
from danswer.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||||
|
|
||||||
|
|
||||||
@@ -20,3 +21,14 @@ def get_litellm_additional_request_headers(
|
|||||||
pass_through_headers[lowercase_key] = headers[lowercase_key]
|
pass_through_headers[lowercase_key] = headers[lowercase_key]
|
||||||
|
|
||||||
return pass_through_headers
|
return pass_through_headers
|
||||||
|
|
||||||
|
|
||||||
|
def build_llm_extra_headers(
|
||||||
|
additional_headers: dict[str, str] | None = None
|
||||||
|
) -> dict[str, str]:
|
||||||
|
extra_headers: dict[str, str] = {}
|
||||||
|
if additional_headers:
|
||||||
|
extra_headers.update(additional_headers)
|
||||||
|
if LITELLM_EXTRA_HEADERS:
|
||||||
|
extra_headers.update(LITELLM_EXTRA_HEADERS)
|
||||||
|
return extra_headers
|
||||||
|
@@ -10,6 +10,7 @@ from danswer.chat.chat_utils import combine_message_chain
|
|||||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||||
from danswer.dynamic_configs.interface import JSON_ro
|
from danswer.dynamic_configs.interface import JSON_ro
|
||||||
from danswer.llm.answering.models import PreviousMessage
|
from danswer.llm.answering.models import PreviousMessage
|
||||||
|
from danswer.llm.headers import build_llm_extra_headers
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import build_content_with_imgs
|
from danswer.llm.utils import build_content_with_imgs
|
||||||
from danswer.llm.utils import message_to_string
|
from danswer.llm.utils import message_to_string
|
||||||
@@ -57,12 +58,18 @@ class ImageGenerationTool(Tool):
|
|||||||
NAME = "run_image_generation"
|
NAME = "run_image_generation"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2
|
self,
|
||||||
|
api_key: str,
|
||||||
|
model: str = "dall-e-3",
|
||||||
|
num_imgs: int = 2,
|
||||||
|
additional_headers: dict[str, str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self.num_imgs = num_imgs
|
self.num_imgs = num_imgs
|
||||||
|
|
||||||
|
self.additional_headers = additional_headers
|
||||||
|
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.NAME
|
return self.NAME
|
||||||
|
|
||||||
@@ -142,6 +149,7 @@ class ImageGenerationTool(Tool):
|
|||||||
model=self.model,
|
model=self.model,
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
n=1,
|
n=1,
|
||||||
|
extra_headers=build_llm_extra_headers(self.additional_headers),
|
||||||
)
|
)
|
||||||
return ImageGenerationResponse(
|
return ImageGenerationResponse(
|
||||||
revised_prompt=response.data[0]["revised_prompt"],
|
revised_prompt=response.data[0]["revised_prompt"],
|
||||||
|
Reference in New Issue
Block a user