mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Pass through API base to ImageGenerationTool
This commit is contained in:
parent
982b1b0c49
commit
7f1bb67e52
@ -18,6 +18,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.chat import attach_files_to_chat_message
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
@ -49,6 +50,7 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
@ -435,13 +437,13 @@ def stream_chat_message_objects(
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
dalle_key = None
|
||||
img_generation_llm_config: LLMConfig | None = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
dalle_key = llm.config.api_key
|
||||
img_generation_llm_config = llm.config
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
@ -458,10 +460,19 @@ def stream_chat_message_objects(
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
img_generation_llm_config = LLMConfig(
|
||||
model_provider=openai_provider.provider,
|
||||
model_name=openai_provider.default_model_name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=openai_provider.api_key,
|
||||
api_base=openai_provider.api_base,
|
||||
api_version=openai_provider.api_version,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(
|
||||
api_key=dalle_key,
|
||||
api_key=cast(str, img_generation_llm_config.api_key),
|
||||
api_base=img_generation_llm_config.api_base,
|
||||
api_version=img_generation_llm_config.api_version,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
]
|
||||
|
@ -306,6 +306,8 @@ class DefaultMultiLLM(LLM):
|
||||
model_name=self._model_version,
|
||||
temperature=self._temperature,
|
||||
api_key=self._api_key,
|
||||
api_base=self._api_base,
|
||||
api_version=self._api_version,
|
||||
)
|
||||
|
||||
def invoke(
|
||||
|
@ -19,6 +19,8 @@ class LLMConfig(BaseModel):
|
||||
model_name: str
|
||||
temperature: float
|
||||
api_key: str | None
|
||||
api_base: str | None
|
||||
api_version: str | None
|
||||
|
||||
|
||||
class LLM(abc.ABC):
|
||||
|
@ -60,11 +60,16 @@ class ImageGenerationTool(Tool):
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
api_base: str | None,
|
||||
api_version: str | None,
|
||||
model: str = "dall-e-3",
|
||||
num_imgs: int = 2,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
|
||||
self.model = model
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
@ -148,6 +153,9 @@ class ImageGenerationTool(Tool):
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
api_key=self.api_key,
|
||||
# need to pass in None rather than empty str
|
||||
api_base=self.api_base or None,
|
||||
api_version=self.api_version or None,
|
||||
n=1,
|
||||
extra_headers=build_llm_extra_headers(self.additional_headers),
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user