mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-04 17:00:24 +02:00
214 lines
6.9 KiB
Python
214 lines
6.9 KiB
Python
import os
|
|
from typing import Type
|
|
from typing_extensions import TypedDict # noreorder
|
|
|
|
from sqlalchemy import not_
|
|
from sqlalchemy import or_
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.db.models import Persona
|
|
from onyx.db.models import Tool as ToolDBModel
|
|
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
|
ImageGenerationTool,
|
|
)
|
|
from onyx.tools.tool_implementations.internet_search.internet_search_tool import (
|
|
InternetSearchTool,
|
|
)
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.tools.tool import Tool
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
class InCodeToolInfo(TypedDict):
|
|
cls: Type[Tool]
|
|
description: str
|
|
in_code_tool_id: str
|
|
display_name: str
|
|
|
|
|
|
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
|
InCodeToolInfo(
|
|
cls=SearchTool,
|
|
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
|
|
in_code_tool_id=SearchTool.__name__,
|
|
display_name=SearchTool._DISPLAY_NAME,
|
|
),
|
|
InCodeToolInfo(
|
|
cls=ImageGenerationTool,
|
|
description=(
|
|
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
|
|
"The tool will be used when the user asks the assistant to generate an image."
|
|
),
|
|
in_code_tool_id=ImageGenerationTool.__name__,
|
|
display_name=ImageGenerationTool._DISPLAY_NAME,
|
|
),
|
|
# don't show the InternetSearchTool as an option if BING_API_KEY is not available
|
|
*(
|
|
[
|
|
InCodeToolInfo(
|
|
cls=InternetSearchTool,
|
|
description=(
|
|
"The Internet Search Tool allows the assistant "
|
|
"to perform internet searches for up-to-date information."
|
|
),
|
|
in_code_tool_id=InternetSearchTool.__name__,
|
|
display_name=InternetSearchTool._DISPLAY_NAME,
|
|
)
|
|
]
|
|
if os.environ.get("BING_API_KEY")
|
|
else []
|
|
),
|
|
]
|
|
|
|
|
|
def load_builtin_tools(db_session: Session) -> None:
|
|
existing_in_code_tools = db_session.scalars(
|
|
select(ToolDBModel).where(not_(ToolDBModel.in_code_tool_id.is_(None)))
|
|
).all()
|
|
in_code_tool_id_to_tool = {
|
|
tool.in_code_tool_id: tool for tool in existing_in_code_tools
|
|
}
|
|
|
|
# Add or update existing tools
|
|
for tool_info in BUILT_IN_TOOLS:
|
|
tool_name = tool_info["cls"].__name__
|
|
tool = in_code_tool_id_to_tool.get(tool_info["in_code_tool_id"])
|
|
if tool:
|
|
# Update existing tool
|
|
tool.name = tool_name
|
|
tool.description = tool_info["description"]
|
|
tool.display_name = tool_info["display_name"]
|
|
logger.notice(f"Updated tool: {tool_name}")
|
|
else:
|
|
# Add new tool
|
|
new_tool = ToolDBModel(
|
|
name=tool_name,
|
|
description=tool_info["description"],
|
|
display_name=tool_info["display_name"],
|
|
in_code_tool_id=tool_info["in_code_tool_id"],
|
|
)
|
|
db_session.add(new_tool)
|
|
logger.notice(f"Added new tool: {tool_name}")
|
|
|
|
# Remove tools that are no longer in BUILT_IN_TOOLS
|
|
built_in_ids = {tool_info["in_code_tool_id"] for tool_info in BUILT_IN_TOOLS}
|
|
for tool_id, tool in list(in_code_tool_id_to_tool.items()):
|
|
if tool_id not in built_in_ids:
|
|
db_session.delete(tool)
|
|
logger.notice(f"Removed tool no longer in built-in list: {tool.name}")
|
|
|
|
db_session.commit()
|
|
logger.notice("All built-in tools are loaded/verified.")
|
|
|
|
|
|
def get_search_tool(db_session: Session) -> ToolDBModel | None:
|
|
"""
|
|
Retrieves for the SearchTool from the BUILT_IN_TOOLS list.
|
|
"""
|
|
search_tool_id = next(
|
|
(
|
|
tool["in_code_tool_id"]
|
|
for tool in BUILT_IN_TOOLS
|
|
if tool["cls"].__name__ == SearchTool.__name__
|
|
),
|
|
None,
|
|
)
|
|
|
|
if not search_tool_id:
|
|
raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.")
|
|
|
|
search_tool = db_session.execute(
|
|
select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id)
|
|
).scalar_one_or_none()
|
|
|
|
return search_tool
|
|
|
|
|
|
def auto_add_search_tool_to_personas(db_session: Session) -> None:
|
|
"""
|
|
Automatically adds the SearchTool to all Persona objects in the database that have
|
|
`num_chunks` either unset or set to a value that isn't 0. This is done to migrate
|
|
Persona objects that were created before the concept of Tools were added.
|
|
"""
|
|
# Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS
|
|
search_tool = get_search_tool(db_session)
|
|
|
|
if not search_tool:
|
|
raise RuntimeError("SearchTool not found in the database.")
|
|
|
|
# Fetch all Personas that need the SearchTool added
|
|
personas_to_update = (
|
|
db_session.execute(
|
|
select(Persona).where(
|
|
or_(Persona.num_chunks.is_(None), Persona.num_chunks != 0)
|
|
)
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
|
|
# Add the SearchTool to each relevant Persona
|
|
for persona in personas_to_update:
|
|
if search_tool not in persona.tools:
|
|
persona.tools.append(search_tool)
|
|
logger.notice(f"Added SearchTool to Persona ID: {persona.id}")
|
|
|
|
# Commit changes to the database
|
|
db_session.commit()
|
|
logger.notice("Completed adding SearchTool to relevant Personas.")
|
|
|
|
|
|
_built_in_tools_cache: dict[str, Type[Tool]] | None = None
|
|
|
|
|
|
def refresh_built_in_tools_cache(db_session: Session) -> None:
|
|
global _built_in_tools_cache
|
|
_built_in_tools_cache = {}
|
|
all_tool_built_in_tools = (
|
|
db_session.execute(
|
|
select(ToolDBModel).where(not_(ToolDBModel.in_code_tool_id.is_(None)))
|
|
)
|
|
.scalars()
|
|
.all()
|
|
)
|
|
for tool in all_tool_built_in_tools:
|
|
tool_info = next(
|
|
(
|
|
item
|
|
for item in BUILT_IN_TOOLS
|
|
if item["in_code_tool_id"] == tool.in_code_tool_id
|
|
),
|
|
None,
|
|
)
|
|
if tool_info and tool.in_code_tool_id:
|
|
_built_in_tools_cache[tool.in_code_tool_id] = tool_info["cls"]
|
|
|
|
|
|
def get_built_in_tool_by_id(
|
|
in_code_tool_id: str, db_session: Session, force_refresh: bool = False
|
|
) -> Type[Tool]:
|
|
global _built_in_tools_cache
|
|
|
|
# If the tool is not in the cache, refresh it once
|
|
if (
|
|
_built_in_tools_cache is None
|
|
or force_refresh
|
|
or in_code_tool_id not in _built_in_tools_cache
|
|
):
|
|
refresh_built_in_tools_cache(db_session)
|
|
|
|
if _built_in_tools_cache is None:
|
|
raise RuntimeError(
|
|
"Built-in tools cache is None despite being refreshed. Should never happen."
|
|
)
|
|
|
|
if in_code_tool_id not in _built_in_tools_cache:
|
|
raise ValueError(
|
|
f"No built-in tool found in the cache with ID {in_code_tool_id}"
|
|
)
|
|
|
|
return _built_in_tools_cache[in_code_tool_id]
|