mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 04:18:35 +02:00
Allow all LLMs for image generation assistants (#3730)
* Allow all LLMs for image generation assistants * ensure pushed * update color + assistant -> model * update prompt * fix silly conditional
This commit is contained in:
@@ -15,6 +15,7 @@ from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
@@ -90,6 +91,7 @@ class AnswerPromptBuilder:
|
||||
provider_type=llm_config.model_provider,
|
||||
model_name=llm_config.model_name,
|
||||
)
|
||||
self.llm_config = llm_config
|
||||
self.llm_tokenizer_encode_func = cast(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
@@ -98,12 +100,21 @@ class AnswerPromptBuilder:
|
||||
(
|
||||
self.message_history,
|
||||
self.history_token_cnts,
|
||||
) = translate_history_to_basemessages(message_history)
|
||||
) = translate_history_to_basemessages(
|
||||
message_history,
|
||||
exclude_images=not model_supports_image_input(
|
||||
self.llm_config.model_name,
|
||||
self.llm_config.model_provider,
|
||||
),
|
||||
)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(user_message, self.llm_tokenizer_encode_func),
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
@@ -11,6 +11,7 @@ from onyx.llm.utils import build_content_with_imgs
|
||||
|
||||
def translate_onyx_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
exclude_images: bool = False,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
@@ -18,7 +19,9 @@ def translate_onyx_msg_to_langchain(
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
content = build_content_with_imgs(
|
||||
msg.message, files, message_type=msg.message_type, exclude_images=exclude_images
|
||||
)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
@@ -32,9 +35,12 @@ def translate_onyx_msg_to_langchain(
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
exclude_images: bool = False,
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_onyx_msg_to_langchain(msg) for msg in history if msg.token_count != 0
|
||||
translate_onyx_msg_to_langchain(msg, exclude_images)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
|
@@ -142,6 +142,7 @@ def build_content_with_imgs(
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
exclude_images: bool = False,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
|
||||
@@ -157,7 +158,7 @@ def build_content_with_imgs(
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
if not img_files and not img_urls:
|
||||
if exclude_images or (not img_files and not img_urls):
|
||||
return message_main_content
|
||||
|
||||
return cast(
|
||||
@@ -382,9 +383,19 @@ def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(
|
||||
model_map: dict, provider: str, model_names: list[str | None]
|
||||
) -> dict | None:
|
||||
def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None:
|
||||
stripped_model_name = _strip_extra_provider_from_model_name(model_name)
|
||||
|
||||
model_names = [
|
||||
model_name,
|
||||
_strip_extra_provider_from_model_name(model_name),
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(stripped_model_name),
|
||||
]
|
||||
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
@@ -417,21 +428,10 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||
model_name
|
||||
)
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
[
|
||||
model_name,
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
extra_provider_stripped_model_name,
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||
],
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
@@ -523,3 +523,23 @@ def get_max_input_tokens(
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
return input_toks
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
model_map = get_model_map()
|
||||
try:
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
)
|
||||
return model_obj.get("supports_vision", False)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get model object for {model_provider}/{model_name}"
|
||||
)
|
||||
return False
|
||||
|
@@ -16,6 +16,7 @@ from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.prompts.constants import GENERAL_SEP_PAT
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -316,12 +317,22 @@ class ImageGenerationTool(Tool):
|
||||
for img in img_generation_response
|
||||
if img.image_data is not None
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
img_urls=img_urls,
|
||||
b64_imgs=b64_imgs,
|
||||
)
|
||||
|
||||
user_prompt = build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
supports_image_input=model_supports_image_input(
|
||||
prompt_builder.llm_config.model_name,
|
||||
prompt_builder.llm_config.model_provider,
|
||||
),
|
||||
prompts=[
|
||||
prompt
|
||||
for response in img_generation_response
|
||||
for prompt in response.revised_prompt
|
||||
],
|
||||
img_urls=img_urls,
|
||||
b64_imgs=b64_imgs,
|
||||
)
|
||||
|
||||
prompt_builder.update_user_prompt(user_prompt)
|
||||
|
||||
return prompt_builder
|
||||
|
@@ -9,16 +9,34 @@ You have just created the attached images in response to the following query: "{
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES = """
|
||||
You have generated images based on the following query: "{query}".
|
||||
The prompts used to create these images were: {prompts}
|
||||
|
||||
Describe the two images you generated, summarizing the key elements and content in a sentence or two.
|
||||
Be specific about what was generated and respond as if you have seen them,
|
||||
without including any disclaimers or speculations.
|
||||
"""
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str,
|
||||
supports_image_input: bool,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
prompts: list[str] | None = None,
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
b64_imgs=b64_imgs,
|
||||
img_urls=img_urls,
|
||||
if supports_image_input:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
b64_imgs=b64_imgs,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return HumanMessage(
|
||||
content=IMG_GENERATION_SUMMARY_PROMPT_NO_IMAGES.format(
|
||||
query=query, prompts=prompts
|
||||
).strip()
|
||||
)
|
||||
)
|
||||
|
Reference in New Issue
Block a user