Pass through API base to ImageGenerationTool

This commit is contained in:
Weves 2024-07-02 23:21:39 -07:00 committed by Chris Weaver
parent 982b1b0c49
commit 7f1bb67e52
4 changed files with 27 additions and 4 deletions

View File

@ -18,6 +18,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.db.chat import attach_files_to_chat_message
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
@ -49,6 +50,7 @@ from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.retrieval.search_runner import inference_documents_from_ids
@ -435,13 +437,13 @@ def stream_chat_message_objects(
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
dalle_key = None
img_generation_llm_config: LLMConfig | None = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
dalle_key = llm.config.api_key
img_generation_llm_config = llm.config
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
@ -458,10 +460,19 @@ def stream_chat_message_objects(
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
dalle_key = openai_provider.api_key
img_generation_llm_config = LLMConfig(
model_provider=openai_provider.provider,
model_name=openai_provider.default_model_name,
temperature=GEN_AI_TEMPERATURE,
api_key=openai_provider.api_key,
api_base=openai_provider.api_base,
api_version=openai_provider.api_version,
)
tool_dict[db_tool_model.id] = [
ImageGenerationTool(
api_key=dalle_key,
api_key=cast(str, img_generation_llm_config.api_key),
api_base=img_generation_llm_config.api_base,
api_version=img_generation_llm_config.api_version,
additional_headers=litellm_additional_headers,
)
]

View File

@ -306,6 +306,8 @@ class DefaultMultiLLM(LLM):
model_name=self._model_version,
temperature=self._temperature,
api_key=self._api_key,
api_base=self._api_base,
api_version=self._api_version,
)
def invoke(

View File

@ -19,6 +19,8 @@ class LLMConfig(BaseModel):
model_name: str
temperature: float
api_key: str | None
api_base: str | None
api_version: str | None
class LLM(abc.ABC):

View File

@ -60,11 +60,16 @@ class ImageGenerationTool(Tool):
def __init__(
self,
api_key: str,
api_base: str | None,
api_version: str | None,
model: str = "dall-e-3",
num_imgs: int = 2,
additional_headers: dict[str, str] | None = None,
) -> None:
self.api_key = api_key
self.api_base = api_base
self.api_version = api_version
self.model = model
self.num_imgs = num_imgs
@ -148,6 +153,9 @@ class ImageGenerationTool(Tool):
prompt=prompt,
model=self.model,
api_key=self.api_key,
# need to pass in None rather than empty str
api_base=self.api_base or None,
api_version=self.api_version or None,
n=1,
extra_headers=build_llm_extra_headers(self.additional_headers),
)