2024-12-13 09:56:10 -08:00

328 lines
10 KiB
Python

import json
from collections.abc import Generator
from enum import Enum
from typing import Any
from typing import cast
import requests
from litellm import image_generation # type: ignore
from pydantic import BaseModel
from onyx.chat.chat_utils import combine_message_chain
from onyx.chat.prompt_builder.build import AnswerPromptBuilder
from onyx.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from onyx.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import message_to_string
from onyx.prompts.constants import GENERAL_SEP_PAT
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.images.prompt import (
build_image_generation_user_prompt,
)
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
IMAGE_GENERATION_RESPONSE_ID = "image_generation_response"
YES_IMAGE_GENERATION = "Yes Image Generation"
SKIP_IMAGE_GENERATION = "Skip Image Generation"
IMAGE_GENERATION_TEMPLATE = f"""
Given the conversation history and a follow up query, determine if the system should call \
an external image generation tool to better answer the latest user input.
Your default response is {SKIP_IMAGE_GENERATION}.
Respond "{YES_IMAGE_GENERATION}" if:
- The user is asking for an image to be generated.
Conversation History:
{GENERAL_SEP_PAT}
{{chat_history}}
{GENERAL_SEP_PAT}
If you are at all unsure, respond with {SKIP_IMAGE_GENERATION}.
Respond with EXACTLY and ONLY "{YES_IMAGE_GENERATION}" or "{SKIP_IMAGE_GENERATION}"
Follow Up Input:
{{final_query}}
""".strip()
class ImageFormat(str, Enum):
URL = "url"
BASE64 = "b64_json"
_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT)
class ImageGenerationResponse(BaseModel):
revised_prompt: str
url: str | None
image_data: str | None
class ImageShape(str, Enum):
SQUARE = "square"
PORTRAIT = "portrait"
LANDSCAPE = "landscape"
class ImageGenerationTool(Tool):
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
_DISPLAY_NAME = "Image Generation Tool"
def __init__(
self,
api_key: str,
api_base: str | None,
api_version: str | None,
model: str = "dall-e-3",
num_imgs: int = 2,
additional_headers: dict[str, str] | None = None,
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
) -> None:
self.api_key = api_key
self.api_base = api_base
self.api_version = api_version
self.model = model
self.num_imgs = num_imgs
self.additional_headers = additional_headers
self.output_format = output_format
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Prompt used to generate the image",
},
"shape": {
"type": "string",
"description": (
"Optional - only specify if you want a specific shape."
" Image shape: 'square', 'portrait', or 'landscape'."
),
"enum": [shape.value for shape in ImageShape],
},
},
"required": ["prompt"],
},
},
}
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
args = {"prompt": query}
if force_run:
return args
history_str = combine_message_chain(
messages=history, token_limit=GEN_AI_HISTORY_CUTOFF
)
prompt = IMAGE_GENERATION_TEMPLATE.format(
chat_history=history_str,
final_query=query,
)
use_image_generation_tool_output = message_to_string(llm.invoke(prompt))
logger.debug(
f"Evaluated if should use ImageGenerationTool: {use_image_generation_tool_output}"
)
if (
YES_IMAGE_GENERATION.split()[0]
).lower() in use_image_generation_tool_output.lower():
return args
return None
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
generation_response = args[0]
image_generations = cast(
list[ImageGenerationResponse], generation_response.response
)
return build_content_with_imgs(
message=json.dumps(
[
{
"revised_prompt": image_generation.revised_prompt,
"url": image_generation.url,
}
for image_generation in image_generations
]
),
)
def _generate_image(
self, prompt: str, shape: ImageShape, format: ImageFormat
) -> ImageGenerationResponse:
if shape == ImageShape.LANDSCAPE:
size = "1792x1024"
elif shape == ImageShape.PORTRAIT:
size = "1024x1792"
else:
size = "1024x1024"
try:
response = image_generation(
prompt=prompt,
model=self.model,
api_key=self.api_key,
api_base=self.api_base or None,
api_version=self.api_version or None,
size=size,
n=1,
response_format=format,
extra_headers=build_llm_extra_headers(self.additional_headers),
)
if format == ImageFormat.URL:
url = response.data[0]["url"]
image_data = None
else:
url = None
image_data = response.data[0]["b64_json"]
return ImageGenerationResponse(
revised_prompt=response.data[0]["revised_prompt"],
url=url,
image_data=image_data,
)
except requests.RequestException as e:
logger.error(f"Error fetching or converting image: {e}")
raise ValueError("Failed to fetch or convert the generated image")
except Exception as e:
logger.debug(f"Error occurred during image generation: {e}")
error_message = str(e)
if "OpenAIException" in str(type(e)):
if (
"Your request was rejected as a result of our safety system"
in error_message
):
raise ValueError(
"The image generation request was rejected due to OpenAI's content policy. Please try a different prompt."
)
elif "Invalid image URL" in error_message:
raise ValueError("Invalid image URL provided for image generation.")
elif "invalid_request_error" in error_message:
raise ValueError(
"Invalid request for image generation. Please check your input."
)
raise ValueError(
"An error occurred during image generation. Please try again later."
)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
format = self.output_format
results = cast(
list[ImageGenerationResponse],
run_functions_tuples_in_parallel(
[
(
self._generate_image,
(
prompt,
shape,
format,
),
)
for _ in range(self.num_imgs)
]
),
)
yield ToolResponse(
id=IMAGE_GENERATION_RESPONSE_ID,
response=results,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
image_generation_responses = cast(
list[ImageGenerationResponse], args[0].response
)
return [
image_generation_response.model_dump()
for image_generation_response in image_generation_responses
]
def build_next_prompt(
self,
prompt_builder: AnswerPromptBuilder,
tool_call_summary: ToolCallSummary,
tool_responses: list[ToolResponse],
using_tool_calling_llm: bool,
) -> AnswerPromptBuilder:
img_generation_response = cast(
list[ImageGenerationResponse] | None,
next(
(
response.response
for response in tool_responses
if response.id == IMAGE_GENERATION_RESPONSE_ID
),
None,
),
)
if img_generation_response is None:
raise ValueError("No image generation response found")
img_urls = [img.url for img in img_generation_response if img.url is not None]
b64_imgs = [
img.image_data
for img in img_generation_response
if img.image_data is not None
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=prompt_builder.get_user_message_content(),
img_urls=img_urls,
b64_imgs=b64_imgs,
)
)
return prompt_builder