danswer/backend/onyx/tools/built_in_tools.py
2025-01-22 09:57:21 -08:00

214 lines
6.9 KiB
Python

import os
from typing import Type
from typing_extensions import TypedDict # noreorder
from sqlalchemy import not_
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import Persona
from onyx.db.models import Tool as ToolDBModel
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
)
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
InternetSearchTool,
)
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool import Tool
from onyx.utils.logger import setup_logger
logger = setup_logger()
class InCodeToolInfo(TypedDict):
cls: Type[Tool]
description: str
in_code_tool_id: str
display_name: str
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
InCodeToolInfo(
cls=SearchTool,
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
in_code_tool_id=SearchTool.__name__,
display_name=SearchTool._DISPLAY_NAME,
),
InCodeToolInfo(
cls=ImageGenerationTool,
description=(
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
"The tool will be used when the user asks the assistant to generate an image."
),
in_code_tool_id=ImageGenerationTool.__name__,
display_name=ImageGenerationTool._DISPLAY_NAME,
),
# don't show the InternetSearchTool as an option if BING_API_KEY is not available
*(
[
InCodeToolInfo(
cls=InternetSearchTool,
description=(
"The Internet Search Tool allows the assistant "
"to perform internet searches for up-to-date information."
),
in_code_tool_id=InternetSearchTool.__name__,
display_name=InternetSearchTool._DISPLAY_NAME,
)
]
if os.environ.get("BING_API_KEY")
else []
),
]
def load_builtin_tools(db_session: Session) -> None:
existing_in_code_tools = db_session.scalars(
select(ToolDBModel).where(not_(ToolDBModel.in_code_tool_id.is_(None)))
).all()
in_code_tool_id_to_tool = {
tool.in_code_tool_id: tool for tool in existing_in_code_tools
}
# Add or update existing tools
for tool_info in BUILT_IN_TOOLS:
tool_name = tool_info["cls"].__name__
tool = in_code_tool_id_to_tool.get(tool_info["in_code_tool_id"])
if tool:
# Update existing tool
tool.name = tool_name
tool.description = tool_info["description"]
tool.display_name = tool_info["display_name"]
logger.notice(f"Updated tool: {tool_name}")
else:
# Add new tool
new_tool = ToolDBModel(
name=tool_name,
description=tool_info["description"],
display_name=tool_info["display_name"],
in_code_tool_id=tool_info["in_code_tool_id"],
)
db_session.add(new_tool)
logger.notice(f"Added new tool: {tool_name}")
# Remove tools that are no longer in BUILT_IN_TOOLS
built_in_ids = {tool_info["in_code_tool_id"] for tool_info in BUILT_IN_TOOLS}
for tool_id, tool in list(in_code_tool_id_to_tool.items()):
if tool_id not in built_in_ids:
db_session.delete(tool)
logger.notice(f"Removed tool no longer in built-in list: {tool.name}")
db_session.commit()
logger.notice("All built-in tools are loaded/verified.")
def get_search_tool(db_session: Session) -> ToolDBModel | None:
"""
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
"""
search_tool_id = next(
(
tool["in_code_tool_id"]
for tool in BUILT_IN_TOOLS
if tool["cls"].__name__ == SearchTool.__name__
),
None,
)
if not search_tool_id:
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
search_tool = db_session.execute(
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
).scalar_one_or_none()
return search_tool
def auto_add_search_tool_to_personas(db_session: Session) -> None:
"""
Automatically adds the SearchTool to all Persona objects in the database that have
`num_chunks` either unset or set to a value that isn't 0. This is done to migrate
Persona objects that were created before the concept of Tools were added.
"""
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
search_tool = get_search_tool(db_session)
if not search_tool:
raise RuntimeError("SearchTool not found in the database.")
# Fetch all Personas that need the SearchTool added
personas_to_update = (
db_session.execute(
select(Persona).where(
or_(Persona.num_chunks.is_(None), Persona.num_chunks != 0)
)
)
.scalars()
.all()
)
# Add the SearchTool to each relevant Persona
for persona in personas_to_update:
if search_tool not in persona.tools:
persona.tools.append(search_tool)
logger.notice(f"Added SearchTool to Persona ID: {persona.id}")
# Commit changes to the database
db_session.commit()
logger.notice("Completed adding SearchTool to relevant Personas.")
_built_in_tools_cache: dict[str, Type[Tool]] | None = None
def refresh_built_in_tools_cache(db_session: Session) -> None:
global _built_in_tools_cache
_built_in_tools_cache = {}
all_tool_built_in_tools = (
db_session.execute(
select(ToolDBModel).where(not_(ToolDBModel.in_code_tool_id.is_(None)))
)
.scalars()
.all()
)
for tool in all_tool_built_in_tools:
tool_info = next(
(
item
for item in BUILT_IN_TOOLS
if item["in_code_tool_id"] == tool.in_code_tool_id
),
None,
)
if tool_info and tool.in_code_tool_id:
_built_in_tools_cache[tool.in_code_tool_id] = tool_info["cls"]
def get_built_in_tool_by_id(
in_code_tool_id: str, db_session: Session, force_refresh: bool = False
) -> Type[Tool]:
global _built_in_tools_cache
# If the tool is not in the cache, refresh it once
if (
_built_in_tools_cache is None
or force_refresh
or in_code_tool_id not in _built_in_tools_cache
):
refresh_built_in_tools_cache(db_session)
if _built_in_tools_cache is None:
raise RuntimeError(
"Built-in tools cache is None despite being refreshed. Should never happen."
)
if in_code_tool_id not in _built_in_tools_cache:
raise ValueError(
f"No built-in tool found in the cache with ID {in_code_tool_id}"
)
return _built_in_tools_cache[in_code_tool_id]