mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-21 09:10:09 +02:00
328 lines
10 KiB
Python
328 lines
10 KiB
Python
import json
|
|
from collections.abc import Generator
|
|
from enum import Enum
|
|
from typing import Any
|
|
from typing import cast
|
|
|
|
import requests
|
|
from litellm import image_generation # type: ignore
|
|
from pydantic import BaseModel
|
|
|
|
from onyx.chat.chat_utils import combine_message_chain
|
|
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
|
|
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
|
from onyx.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
|
|
from onyx.llm.interfaces import LLM
|
|
from onyx.llm.models import PreviousMessage
|
|
from onyx.llm.utils import build_content_with_imgs
|
|
from onyx.llm.utils import message_to_string
|
|
from onyx.prompts.constants import GENERAL_SEP_PAT
|
|
from onyx.tools.message import ToolCallSummary
|
|
from onyx.tools.models import ToolResponse
|
|
from onyx.tools.tool import Tool
|
|
from onyx.tools.tool_implementations.images.prompt import (
|
|
build_image_generation_user_prompt,
|
|
)
|
|
from onyx.utils.headers import build_llm_extra_headers
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.special_types import JSON_ro
|
|
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
IMAGE_GENERATION_RESPONSE_ID = "image_generation_response"
|
|
|
|
YES_IMAGE_GENERATION = "Yes Image Generation"
|
|
SKIP_IMAGE_GENERATION = "Skip Image Generation"
|
|
|
|
IMAGE_GENERATION_TEMPLATE = f"""
|
|
Given the conversation history and a follow up query, determine if the system should call \
|
|
an external image generation tool to better answer the latest user input.
|
|
Your default response is {SKIP_IMAGE_GENERATION}.
|
|
|
|
Respond "{YES_IMAGE_GENERATION}" if:
|
|
- The user is asking for an image to be generated.
|
|
|
|
Conversation History:
|
|
{GENERAL_SEP_PAT}
|
|
{{chat_history}}
|
|
{GENERAL_SEP_PAT}
|
|
|
|
If you are at all unsure, respond with {SKIP_IMAGE_GENERATION}.
|
|
Respond with EXACTLY and ONLY "{YES_IMAGE_GENERATION}" or "{SKIP_IMAGE_GENERATION}"
|
|
|
|
Follow Up Input:
|
|
{{final_query}}
|
|
""".strip()
|
|
|
|
|
|
class ImageFormat(str, Enum):
|
|
URL = "url"
|
|
BASE64 = "b64_json"
|
|
|
|
|
|
_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT)
|
|
|
|
|
|
class ImageGenerationResponse(BaseModel):
|
|
revised_prompt: str
|
|
url: str | None
|
|
image_data: str | None
|
|
|
|
|
|
class ImageShape(str, Enum):
|
|
SQUARE = "square"
|
|
PORTRAIT = "portrait"
|
|
LANDSCAPE = "landscape"
|
|
|
|
|
|
class ImageGenerationTool(Tool):
|
|
_NAME = "run_image_generation"
|
|
_DESCRIPTION = "Generate an image from a prompt."
|
|
_DISPLAY_NAME = "Image Generation Tool"
|
|
|
|
def __init__(
|
|
self,
|
|
api_key: str,
|
|
api_base: str | None,
|
|
api_version: str | None,
|
|
model: str = "dall-e-3",
|
|
num_imgs: int = 2,
|
|
additional_headers: dict[str, str] | None = None,
|
|
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.api_base = api_base
|
|
self.api_version = api_version
|
|
|
|
self.model = model
|
|
self.num_imgs = num_imgs
|
|
|
|
self.additional_headers = additional_headers
|
|
self.output_format = output_format
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._NAME
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return self._DESCRIPTION
|
|
|
|
@property
|
|
def display_name(self) -> str:
|
|
return self._DISPLAY_NAME
|
|
|
|
def tool_definition(self) -> dict:
|
|
return {
|
|
"type": "function",
|
|
"function": {
|
|
"name": self.name,
|
|
"description": self.description,
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "Prompt used to generate the image",
|
|
},
|
|
"shape": {
|
|
"type": "string",
|
|
"description": (
|
|
"Optional - only specify if you want a specific shape."
|
|
" Image shape: 'square', 'portrait', or 'landscape'."
|
|
),
|
|
"enum": [shape.value for shape in ImageShape],
|
|
},
|
|
},
|
|
"required": ["prompt"],
|
|
},
|
|
},
|
|
}
|
|
|
|
def get_args_for_non_tool_calling_llm(
|
|
self,
|
|
query: str,
|
|
history: list[PreviousMessage],
|
|
llm: LLM,
|
|
force_run: bool = False,
|
|
) -> dict[str, Any] | None:
|
|
args = {"prompt": query}
|
|
if force_run:
|
|
return args
|
|
|
|
history_str = combine_message_chain(
|
|
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
|
|
)
|
|
prompt = IMAGE_GENERATION_TEMPLATE.format(
|
|
chat_history=history_str,
|
|
final_query=query,
|
|
)
|
|
use_image_generation_tool_output = message_to_string(llm.invoke(prompt))
|
|
|
|
logger.debug(
|
|
f"Evaluated if should use ImageGenerationTool: {use_image_generation_tool_output}"
|
|
)
|
|
if (
|
|
YES_IMAGE_GENERATION.split()[0]
|
|
).lower() in use_image_generation_tool_output.lower():
|
|
return args
|
|
|
|
return None
|
|
|
|
def build_tool_message_content(
|
|
self, *args: ToolResponse
|
|
) -> str | list[str | dict[str, Any]]:
|
|
generation_response = args[0]
|
|
image_generations = cast(
|
|
list[ImageGenerationResponse], generation_response.response
|
|
)
|
|
|
|
return build_content_with_imgs(
|
|
message=json.dumps(
|
|
[
|
|
{
|
|
"revised_prompt": image_generation.revised_prompt,
|
|
"url": image_generation.url,
|
|
}
|
|
for image_generation in image_generations
|
|
]
|
|
),
|
|
)
|
|
|
|
def _generate_image(
|
|
self, prompt: str, shape: ImageShape, format: ImageFormat
|
|
) -> ImageGenerationResponse:
|
|
if shape == ImageShape.LANDSCAPE:
|
|
size = "1792x1024"
|
|
elif shape == ImageShape.PORTRAIT:
|
|
size = "1024x1792"
|
|
else:
|
|
size = "1024x1024"
|
|
|
|
try:
|
|
response = image_generation(
|
|
prompt=prompt,
|
|
model=self.model,
|
|
api_key=self.api_key,
|
|
api_base=self.api_base or None,
|
|
api_version=self.api_version or None,
|
|
size=size,
|
|
n=1,
|
|
response_format=format,
|
|
extra_headers=build_llm_extra_headers(self.additional_headers),
|
|
)
|
|
|
|
if format == ImageFormat.URL:
|
|
url = response.data[0]["url"]
|
|
image_data = None
|
|
else:
|
|
url = None
|
|
image_data = response.data[0]["b64_json"]
|
|
|
|
return ImageGenerationResponse(
|
|
revised_prompt=response.data[0]["revised_prompt"],
|
|
url=url,
|
|
image_data=image_data,
|
|
)
|
|
|
|
except requests.RequestException as e:
|
|
logger.error(f"Error fetching or converting image: {e}")
|
|
raise ValueError("Failed to fetch or convert the generated image")
|
|
except Exception as e:
|
|
logger.debug(f"Error occurred during image generation: {e}")
|
|
|
|
error_message = str(e)
|
|
if "OpenAIException" in str(type(e)):
|
|
if (
|
|
"Your request was rejected as a result of our safety system"
|
|
in error_message
|
|
):
|
|
raise ValueError(
|
|
"The image generation request was rejected due to OpenAI's content policy. Please try a different prompt."
|
|
)
|
|
elif "Invalid image URL" in error_message:
|
|
raise ValueError("Invalid image URL provided for image generation.")
|
|
elif "invalid_request_error" in error_message:
|
|
raise ValueError(
|
|
"Invalid request for image generation. Please check your input."
|
|
)
|
|
|
|
raise ValueError(
|
|
"An error occurred during image generation. Please try again later."
|
|
)
|
|
|
|
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
|
prompt = cast(str, kwargs["prompt"])
|
|
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
|
format = self.output_format
|
|
|
|
results = cast(
|
|
list[ImageGenerationResponse],
|
|
run_functions_tuples_in_parallel(
|
|
[
|
|
(
|
|
self._generate_image,
|
|
(
|
|
prompt,
|
|
shape,
|
|
format,
|
|
),
|
|
)
|
|
for _ in range(self.num_imgs)
|
|
]
|
|
),
|
|
)
|
|
yield ToolResponse(
|
|
id=IMAGE_GENERATION_RESPONSE_ID,
|
|
response=results,
|
|
)
|
|
|
|
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
|
image_generation_responses = cast(
|
|
list[ImageGenerationResponse], args[0].response
|
|
)
|
|
return [
|
|
image_generation_response.model_dump()
|
|
for image_generation_response in image_generation_responses
|
|
]
|
|
|
|
def build_next_prompt(
|
|
self,
|
|
prompt_builder: AnswerPromptBuilder,
|
|
tool_call_summary: ToolCallSummary,
|
|
tool_responses: list[ToolResponse],
|
|
using_tool_calling_llm: bool,
|
|
) -> AnswerPromptBuilder:
|
|
img_generation_response = cast(
|
|
list[ImageGenerationResponse] | None,
|
|
next(
|
|
(
|
|
response.response
|
|
for response in tool_responses
|
|
if response.id == IMAGE_GENERATION_RESPONSE_ID
|
|
),
|
|
None,
|
|
),
|
|
)
|
|
if img_generation_response is None:
|
|
raise ValueError("No image generation response found")
|
|
|
|
img_urls = [img.url for img in img_generation_response if img.url is not None]
|
|
b64_imgs = [
|
|
img.image_data
|
|
for img in img_generation_response
|
|
if img.image_data is not None
|
|
]
|
|
prompt_builder.update_user_prompt(
|
|
build_image_generation_user_prompt(
|
|
query=prompt_builder.get_user_message_content(),
|
|
img_urls=img_urls,
|
|
b64_imgs=b64_imgs,
|
|
)
|
|
)
|
|
|
|
return prompt_builder
|