Custom tools

This commit is contained in:
Weves 2024-06-09 14:57:39 -07:00 committed by Chris Weaver
parent c6d094b2ee
commit 7746375bfd
43 changed files with 2588 additions and 809 deletions

View File

@ -0,0 +1,61 @@
"""Add support for custom tools
Revision ID: 48d14957fe80
Revises: b85f02ec1308
Create Date: 2024-06-09 14:58:19.946509
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "48d14957fe80"
down_revision = "b85f02ec1308"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"tool",
sa.Column(
"openapi_schema",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"tool",
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
)
op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"])
op.create_table(
"tool_call",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("tool_id", sa.Integer(), nullable=False),
sa.Column("tool_name", sa.String(), nullable=False),
sa.Column(
"tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False
),
sa.Column(
"message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False
),
)
def downgrade() -> None:
op.drop_table("tool_call")
op.drop_constraint("tool_user_fk", "tool", type_="foreignkey")
op.drop_column("tool", "user_id")
op.drop_column("tool", "openapi_schema")

View File

@ -106,12 +106,18 @@ class ImageGenerationDisplay(BaseModel):
file_ids: list[str]
class CustomToolResponse(BaseModel):
response: dict
tool_name: str
AnswerQuestionPossibleReturn = (
DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
)

View File

@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LlmDoc
@ -31,6 +32,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
@ -54,7 +56,10 @@ from danswer.search.utils import drop_llm_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.factory import get_tool_cls
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
@ -65,6 +70,7 @@ from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@ -162,7 +168,7 @@ def _check_should_force_search(
args = {"query": new_msg_req.message}
return ForceUseTool(
tool_name=SearchTool.name(),
tool_name=SearchTool.NAME,
args=args,
)
return None
@ -176,6 +182,7 @@ ChatPacket = (
| DanswerAnswerPiece
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
)
ChatPacketStream = Iterator[ChatPacket]
@ -389,61 +396,78 @@ def stream_chat_message_objects(
),
)
persona_tool_classes = [
get_tool_cls(tool, db_session) for tool in persona.tools
]
# find out what tools to use
search_tool: SearchTool | None = None
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tool_dict[db_tool_model.id] = [search_tool]
elif tool_cls.__name__ == ImageGenerationTool.__name__:
dalle_key = None
if (
llm
and llm.config.api_key
and llm.config.model_provider == "openai"
):
dalle_key = llm.config.api_key
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
dalle_key = openai_provider.api_key
tool_dict[db_tool_model.id] = [
ImageGenerationTool(api_key=dalle_key)
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
list[Tool],
build_custom_tools_from_openapi_schema(
db_tool_model.openapi_schema
),
)
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
persona_tool_classes
)
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
# NOTE: for now, only support SearchTool and ImageGenerationTool
# in the future, will support arbitrary user-defined tools
search_tool: SearchTool | None = None
tools: list[Tool] = []
for tool_cls in persona_tool_classes:
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
chunks_above=new_msg_req.chunks_above,
chunks_below=new_msg_req.chunks_below,
full_doc=new_msg_req.full_doc,
)
tools.append(search_tool)
elif tool_cls.__name__ == ImageGenerationTool.__name__:
dalle_key = None
if llm and llm.config.api_key and llm.config.model_provider == "openai":
dalle_key = llm.config.api_key
else:
llm_providers = fetch_existing_llm_providers(db_session)
openai_provider = next(
iter(
[
llm_provider
for llm_provider in llm_providers
if llm_provider.provider == "openai"
]
),
None,
)
if not openai_provider or not openai_provider.api_key:
raise ValueError(
"Image generation tool requires an OpenAI API key"
)
dalle_key = openai_provider.api_key
tools.append(ImageGenerationTool(api_key=dalle_key))
# LLM prompt building, response capturing, etc.
answer = Answer(
question=final_msg.message,
@ -468,7 +492,9 @@ def stream_chat_message_objects(
],
tools=tools,
force_use_tool=(
_check_should_force_search(new_msg_req) if search_tool else None
_check_should_force_search(new_msg_req)
if search_tool and len(tools) == 1
else None
),
)
@ -476,6 +502,7 @@ def stream_chat_message_objects(
qa_docs_response = None
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
@ -521,8 +548,16 @@ def stream_chat_message_objects(
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
except Exception as e:
@ -551,6 +586,11 @@ def stream_chat_message_objects(
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name()] = tool_id
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
@ -561,6 +601,16 @@ def stream_chat_message_objects(
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
)
db_session.commit() # actually save user / assistant message

View File

@ -5,6 +5,7 @@ from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
@ -16,6 +17,7 @@ from danswer.db.models import ChatSessionSharedStatus
from danswer.db.models import Prompt
from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.file_store.models import FileDescriptor
from danswer.llm.override_models import LLMOverride
@ -24,6 +26,7 @@ from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -185,6 +188,7 @@ def get_chat_messages_by_session(
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
prefetch_tool_calls: bool = False,
) -> list[ChatMessage]:
if not skip_permission_check:
get_chat_session_by_id(
@ -192,12 +196,18 @@ def get_chat_messages_by_session(
)
stmt = (
select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
# Start with the root message which has no parent
select(ChatMessage)
.where(ChatMessage.chat_session_id == chat_session_id)
.order_by(nullsfirst(ChatMessage.parent_message))
)
result = db_session.execute(stmt).scalars().all()
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
if prefetch_tool_calls:
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
return list(result)
@ -251,6 +261,7 @@ def create_new_chat_message(
reference_docs: list[DBSearchDoc] | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
commit: bool = True,
) -> ChatMessage:
new_chat_message = ChatMessage(
@ -264,6 +275,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
)
@ -459,6 +471,14 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
)
return chat_msg_detail

View File

@ -133,6 +133,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
# Personas owned by this user
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
# Custom tools created by this user
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
@ -330,7 +332,6 @@ class Document(Base):
primary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
# Something like assignee or space owner
secondary_owners: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
@ -618,6 +619,26 @@ class SearchDoc(Base):
)
class ToolCall(Base):
"""Represents a single tool call"""
__tablename__ = "tool_call"
id: Mapped[int] = mapped_column(primary_key=True)
# not a FK because we want to be able to delete the tool without deleting
# this entry
tool_id: Mapped[int] = mapped_column(Integer())
tool_name: Mapped[str] = mapped_column(String())
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
message: Mapped["ChatMessage"] = relationship(
"ChatMessage", back_populates="tool_calls"
)
class ChatSession(Base):
__tablename__ = "chat_session"
@ -723,6 +744,10 @@ class ChatMessage(Base):
secondary="chat_message__search_doc",
back_populates="chat_messages",
)
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
)
class ChatFolder(Base):
@ -901,9 +926,18 @@ class Tool(Base):
name: Mapped[str] = mapped_column(String, nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=True)
# ID of the tool in the codebase, only applies for in-code tools.
# tools defiend via the UI will have this as None
# tools defined via the UI will have this as None
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
# user who created / owns the tool. Will be None for built-in tools.
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
# Relationship to Persona through the association table
personas: Mapped[list["Persona"]] = relationship(
"Persona",

View File

@ -0,0 +1,74 @@
from typing import Any
from uuid import UUID
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.db.models import Tool
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_tools(db_session: Session) -> list[Tool]:
return list(db_session.scalars(select(Tool)).all())
def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
tool = db_session.scalar(select(Tool).where(Tool.id == tool_id))
if not tool:
raise ValueError("Tool by specified id does not exist")
return tool
def create_tool(
name: str,
description: str | None,
openapi_schema: dict[str, Any] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
new_tool = Tool(
name=name,
description=description,
in_code_tool_id=None,
openapi_schema=openapi_schema,
user_id=user_id,
)
db_session.add(new_tool)
db_session.commit()
return new_tool
def update_tool(
tool_id: int,
name: str | None,
description: str | None,
openapi_schema: dict[str, Any] | None,
user_id: UUID | None,
db_session: Session,
) -> Tool:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
raise ValueError(f"Tool with ID {tool_id} does not exist")
if name is not None:
tool.name = name
if description is not None:
tool.description = description
if openapi_schema is not None:
tool.openapi_schema = openapi_schema
if user_id is not None:
tool.user_id = user_id
db_session.commit()
return tool
def delete_tool(tool_id: int, db_session: Session) -> None:
tool = get_tool_by_id(tool_id, db_session)
if tool is None:
raise ValueError(f"Tool with ID {tool_id} does not exist")
db_session.delete(tool)
db_session.commit()

View File

@ -24,7 +24,7 @@ def load_chat_file(
file_id=file_descriptor["id"],
content=file_io.read(),
file_type=file_descriptor["type"],
filename=file_descriptor["name"],
filename=file_descriptor.get("name"),
)

View File

@ -4,6 +4,7 @@ from uuid import uuid4
from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import HumanMessage
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import AnswerQuestionPossibleReturn
@ -33,6 +34,9 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
from danswer.llm.interfaces import LLM
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import message_generator_to_string_generator
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
from danswer.tools.force import filter_tools_for_force_tool_use
from danswer.tools.force import ForceUseTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
@ -50,7 +54,8 @@ from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolRunKickoff
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.utils import explicit_tool_calling_supported
@ -72,7 +77,7 @@ def _get_answer_stream_processor(
raise RuntimeError("Not implemented yet")
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolRunKickoff | ToolResponse]
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
class Answer:
@ -125,7 +130,7 @@ class Answer:
self._streamed_output: list[str] | None = None
self._processed_stream: list[
AnswerQuestionPossibleReturn | ToolResponse | ToolRunKickoff
AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff
] | None = None
def _update_prompt_builder_for_search_tool(
@ -160,7 +165,7 @@ class Answer:
def _raw_output_for_explicit_tool_calling_llms(
self,
) -> Iterator[str | ToolRunKickoff | ToolResponse]:
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
@ -237,16 +242,18 @@ class Answer:
),
)
if tool.name() == SearchTool.name():
if tool.name() == SearchTool.NAME:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name() == ImageGenerationTool.name():
elif tool.name() == ImageGenerationTool.NAME:
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
)
)
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield tool_runner.tool_final_result()
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from message_generator_to_string_generator(
self.llm.stream(
prompt=prompt,
@ -258,7 +265,7 @@ class Answer:
def _raw_output_for_non_explicit_tool_calling_llms(
self,
) -> Iterator[str | ToolRunKickoff | ToolResponse]:
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
@ -324,7 +331,7 @@ class Answer:
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
if tool.name() == SearchTool.name():
if tool.name() == SearchTool.NAME:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS:
@ -337,7 +344,7 @@ class Answer:
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name() == ImageGenerationTool.name():
elif tool.name() == ImageGenerationTool.NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
@ -354,6 +361,18 @@ class Answer:
img_urls=img_urls,
)
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name(),
*tool_runner.tool_responses(),
)
)
)
yield tool_runner.tool_final_result()
prompt = prompt_builder.build()
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
@ -374,7 +393,7 @@ class Answer:
)
def _process_stream(
stream: Iterator[ToolRunKickoff | ToolResponse | str],
stream: Iterator[ToolCallKickoff | ToolResponse | str],
) -> AnswerStream:
message = None
@ -387,7 +406,9 @@ class Answer:
] | None = None # processed docs to feed into the LLM
for message in stream:
if isinstance(message, ToolRunKickoff):
if isinstance(message, ToolCallKickoff) or isinstance(
message, ToolCallFinalResult
):
yield message
elif isinstance(message, ToolResponse):
if message.id == SEARCH_RESPONSE_SUMMARY_ID:

View File

@ -63,6 +63,7 @@ from danswer.server.features.folder.api import router as folder_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
from danswer.server.features.tool.api import admin_router as admin_tool_router
from danswer.server.features.tool.api import router as tool_router
from danswer.server.gpts.api import router as gpts_router
from danswer.server.manage.administrative import router as admin_router
@ -277,6 +278,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_persona_router)
include_router_with_global_prefix_prepended(application, prompt_router)
include_router_with_global_prefix_prepended(application, tool_router)
include_router_with_global_prefix_prepended(application, admin_tool_router)
include_router_with_global_prefix_prepended(application, state_router)
include_router_with_global_prefix_prepended(application, danswer_api_router)
include_router_with_global_prefix_prepended(application, gpts_router)

View File

@ -51,7 +51,7 @@ from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolRunKickoff
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@ -67,7 +67,7 @@ AnswerObjectIterator = Iterator[
| StreamingError
| ChatMessageDetail
| CitationInfo
| ToolRunKickoff
| ToolCallKickoff
]

View File

@ -1,32 +1,132 @@
from typing import Any
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.models import Tool
from danswer.db.models import User
from danswer.db.tools import create_tool
from danswer.db.tools import delete_tool
from danswer.db.tools import get_tool_by_id
from danswer.db.tools import get_tools
from danswer.db.tools import update_tool
from danswer.server.features.tool.models import ToolSnapshot
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
router = APIRouter(prefix="/tool")
admin_router = APIRouter(prefix="/admin/tool")
class ToolSnapshot(BaseModel):
id: int
class CustomToolCreate(BaseModel):
name: str
description: str
in_code_tool_id: str | None
description: str | None
definition: dict[str, Any]
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
return cls(
id=tool.id,
name=tool.name,
description=tool.description,
in_code_tool_id=tool.in_code_tool_id,
)
class CustomToolUpdate(BaseModel):
name: str | None
description: str | None
definition: dict[str, Any] | None
def _validate_tool_definition(definition: dict[str, Any]) -> None:
try:
validate_openapi_schema(definition)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@admin_router.post("/custom")
def create_custom_tool(
tool_data: CustomToolCreate,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_admin_user),
) -> ToolSnapshot:
_validate_tool_definition(tool_data.definition)
tool = create_tool(
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
user_id=user.id if user else None,
db_session=db_session,
)
return ToolSnapshot.from_model(tool)
@admin_router.put("/custom/{tool_id}")
def update_custom_tool(
tool_id: int,
tool_data: CustomToolUpdate,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_admin_user),
) -> ToolSnapshot:
if tool_data.definition:
_validate_tool_definition(tool_data.definition)
updated_tool = update_tool(
tool_id=tool_id,
name=tool_data.name,
description=tool_data.description,
openapi_schema=tool_data.definition,
user_id=user.id if user else None,
db_session=db_session,
)
return ToolSnapshot.from_model(updated_tool)
@admin_router.delete("/custom/{tool_id}")
def delete_custom_tool(
tool_id: int,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_admin_user),
) -> None:
try:
delete_tool(tool_id, db_session)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
except Exception as e:
# handles case where tool is still used by an Assistant
raise HTTPException(status_code=400, detail=str(e))
class ValidateToolRequest(BaseModel):
definition: dict[str, Any]
class ValidateToolResponse(BaseModel):
methods: list[MethodSpec]
@admin_router.post("/custom/validate")
def validate_tool(
tool_data: ValidateToolRequest,
_: User | None = Depends(current_admin_user),
) -> ValidateToolResponse:
_validate_tool_definition(tool_data.definition)
method_specs = openapi_to_method_specs(tool_data.definition)
return ValidateToolResponse(methods=method_specs)
"""Endpoints for all"""
@router.get("/{tool_id}")
def get_custom_tool(
tool_id: int,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> ToolSnapshot:
try:
tool = get_tool_by_id(tool_id, db_session)
except ValueError as e:
raise HTTPException(status_code=404, detail=str(e))
return ToolSnapshot.from_model(tool)
@router.get("")
@ -34,5 +134,5 @@ def list_tools(
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> list[ToolSnapshot]:
tools = db_session.execute(select(Tool)).scalars().all()
tools = get_tools(db_session)
return [ToolSnapshot.from_model(tool) for tool in tools]

View File

@ -0,0 +1,23 @@
from typing import Any
from pydantic import BaseModel
from danswer.db.models import Tool
class ToolSnapshot(BaseModel):
id: int
name: str
description: str
definition: dict[str, Any] | None
in_code_tool_id: str | None
@classmethod
def from_model(cls, tool: Tool) -> "ToolSnapshot":
return cls(
id=tool.id,
name=tool.name,
description=tool.description,
definition=tool.openapi_schema,
in_code_tool_id=tool.in_code_tool_id,
)

View File

@ -129,6 +129,8 @@ def get_chat_session(
# we already did a permission check above with the call to
# `get_chat_session_by_id`, so we can skip it here
skip_permission_check=True,
# we need the tool call objs anyways, so just fetch them in a single call
prefetch_tool_calls=True,
)
return ChatSessionDetailResponse(

View File

@ -17,6 +17,7 @@ from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails
from danswer.search.models import SearchDoc
from danswer.search.models import Tag
from danswer.tools.models import ToolCallFinalResult
class SourceTag(Tag):
@ -176,6 +177,7 @@ class ChatMessageDetail(BaseModel):
# Dict mapping citation number to db_doc_id
citations: dict[int, int] | None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore

View File

@ -0,0 +1,233 @@
import json
from collections.abc import Generator
from typing import Any
from typing import cast
import requests
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from pydantic import BaseModel
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.custom.custom_tool_prompts import (
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
)
from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT
from danswer.tools.custom.custom_tool_prompts import USE_TOOL
from danswer.tools.custom.openapi_parsing import MethodSpec
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
from danswer.tools.custom.openapi_parsing import openapi_to_url
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
logger = setup_logger()
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"
class CustomToolCallSummary(BaseModel):
tool_name: str
tool_result: dict
class CustomTool(Tool):
def __init__(self, method_spec: MethodSpec, base_url: str) -> None:
self._base_url = base_url
self._method_spec = method_spec
self._tool_definition = self._method_spec.to_tool_definition()
self._name = self._method_spec.name
self.description = self._method_spec.summary
def name(self) -> str:
return self._name
"""For LLMs which support explicit tool calling"""
def tool_definition(self) -> dict:
return self._tool_definition
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
response = cast(CustomToolCallSummary, args[0].response)
return json.dumps(response.tool_result)
"""For LLMs which do NOT support explicit tool calling"""
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
if not force_run:
should_use_result = llm.invoke(
[
SystemMessage(content=SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT),
HumanMessage(
content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format(
history=history,
query=query,
tool_name=self.name(),
tool_description=self.description,
)
),
]
)
if cast(str, should_use_result.content).strip() != USE_TOOL:
return None
args_result = llm.invoke(
[
SystemMessage(content=TOOL_ARG_SYSTEM_PROMPT),
HumanMessage(
content=TOOL_ARG_USER_PROMPT.format(
history=history,
query=query,
tool_name=self.name(),
tool_description=self.description,
tool_args=self.tool_definition()["function"]["parameters"],
)
),
]
)
args_result_str = cast(str, args_result.content)
try:
return json.loads(args_result_str.strip())
except json.JSONDecodeError:
pass
# try removing ```
try:
return json.loads(args_result_str.strip("```"))
except json.JSONDecodeError:
pass
# try removing ```json
try:
return json.loads(args_result_str.strip("```").strip("json"))
except json.JSONDecodeError:
pass
# pretend like nothing happened if not parse-able
logger.error(
f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}"
)
return None
"""Actual execution of the tool"""
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY)
path_params = {}
for path_param_schema in self._method_spec.get_path_param_schemas():
path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]]
query_params = {}
for query_param_schema in self._method_spec.get_query_param_schemas():
if query_param_schema["name"] in kwargs:
query_params[query_param_schema["name"]] = kwargs[
query_param_schema["name"]
]
url = self._method_spec.build_url(self._base_url, path_params, query_params)
method = self._method_spec.method
response = requests.request(method, url, json=request_body)
yield ToolResponse(
id=CUSTOM_TOOL_RESPONSE_ID,
response=CustomToolCallSummary(
tool_name=self._name, tool_result=response.json()
),
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
return cast(CustomToolCallSummary, args[0].response).tool_result
def build_custom_tools_from_openapi_schema(
openapi_schema: dict[str, Any]
) -> list[CustomTool]:
url = openapi_to_url(openapi_schema)
method_specs = openapi_to_method_specs(openapi_schema)
return [CustomTool(method_spec, url) for method_spec in method_specs]
if __name__ == "__main__":
import openai
openapi_schema = {
"openapi": "3.0.0",
"info": {
"version": "1.0.0",
"title": "Assistants API",
"description": "An API for managing assistants",
},
"servers": [
{"url": "http://localhost:8080"},
],
"paths": {
"/assistant/{assistant_id}": {
"get": {
"summary": "Get a specific Assistant",
"operationId": "getAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
},
"post": {
"summary": "Create a new Assistant",
"operationId": "createAssistant",
"parameters": [
{
"name": "assistant_id",
"in": "path",
"required": True,
"schema": {"type": "string"},
}
],
"requestBody": {
"required": True,
"content": {"application/json": {"schema": {"type": "object"}}},
},
},
}
},
}
validate_openapi_schema(openapi_schema)
tools = build_custom_tools_from_openapi_schema(openapi_schema)
openai_client = openai.OpenAI()
response = openai_client.chat.completions.create(
model="gpt-4o",
messages=[
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Can you fetch assistant with ID 10"},
],
tools=[tool.tool_definition() for tool in tools], # type: ignore
)
choice = response.choices[0]
if choice.message.tool_calls:
print(choice.message.tool_calls)
for tool_response in tools[0].run(
**json.loads(choice.message.tool_calls[0].function.arguments)
):
print(tool_response)

View File

@ -0,0 +1,21 @@
from typing import cast
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.models import ToolResponse
def build_user_message_for_custom_tool_for_non_tool_calling_llm(
query: str,
tool_name: str,
*args: ToolResponse,
) -> str:
tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result
return f"""
Here's the result from the {tool_name} tool:
{tool_run_summary}
Now respond to the following:
{query}
""".strip()

View File

@ -0,0 +1,57 @@
from danswer.prompts.constants import GENERAL_SEP_PAT
DONT_USE_TOOL = "Don't use tool"
USE_TOOL = "Use tool"
"""Prompts to determine if we should use a custom tool or not."""
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT = (
"You are a large language model whose only job is to determine if the system should call an "
"external tool to be able to answer the user's last message."
).strip()
SHOULD_USE_CUSTOM_TOOL_USER_PROMPT = f"""
Given the conversation history and a follow up query, determine if the system should use the \
'{{tool_name}}' tool to answer the user's query. The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'.
Respond with "{USE_TOOL}" if you think the tool would be helpful in respnding to the users query.
Respond with "{DONT_USE_TOOL}" otherwise.
Conversation History:
{GENERAL_SEP_PAT}
{{history}}
{GENERAL_SEP_PAT}
If you are at all unsure, respond with {DONT_USE_TOOL}.
Respond with EXACTLY and ONLY "{DONT_USE_TOOL}" or "{USE_TOOL}"
Follow up input:
{{query}}
""".strip()
"""Prompts to figure out the arguments to pass to a custom tool."""
TOOL_ARG_SYSTEM_PROMPT = (
"You are a large language model whose only job is to determine the arguments to pass to an "
"external tool."
).strip()
TOOL_ARG_USER_PROMPT = f"""
Given the following conversation and a follow up input, generate a \
dictionary of arguments to pass to the '{{tool_name}}' tool. \
The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. \
The expected arguments are: {{tool_args}}.
Conversation:
{{history}}
Follow up input:
{{query}}
Respond with ONLY and EXACTLY a JSON object specifying the values of the arguments to pass to the tool.
""".strip() # noqa: F541

View File

@ -0,0 +1,225 @@
from typing import Any
from typing import cast
from openai import BaseModel
REQUEST_BODY = "requestBody"
class PathSpec(BaseModel):
path: str
methods: dict[str, Any]
class MethodSpec(BaseModel):
name: str
summary: str
path: str
method: str
spec: dict[str, Any]
def get_request_body_schema(self) -> dict[str, Any]:
content = self.spec.get("requestBody", {}).get("content", {})
if "application/json" in content:
return content["application/json"].get("schema")
if content:
raise ValueError(
f"Unsupported content type: '{list(content.keys())[0]}'. "
f"Only 'application/json' is supported."
)
return {}
def get_query_param_schemas(self) -> list[dict[str, Any]]:
return [
param
for param in self.spec.get("parameters", [])
if "schema" in param and "in" in param and param["in"] == "query"
]
def get_path_param_schemas(self) -> list[dict[str, Any]]:
return [
param
for param in self.spec.get("parameters", [])
if "schema" in param and "in" in param and param["in"] == "path"
]
def build_url(
self, base_url: str, path_params: dict[str, str], query_params: dict[str, str]
) -> str:
url = f"{base_url}{self.path}"
try:
url = url.format(**path_params)
except KeyError as e:
raise ValueError(f"Missing path parameter: {e}")
if query_params:
url += "?"
for param, value in query_params.items():
url += f"{param}={value}&"
url = url[:-1]
return url
def to_tool_definition(self) -> dict[str, Any]:
tool_definition: Any = {
"type": "function",
"function": {
"name": self.name,
"description": self.summary,
"parameters": {"type": "object", "properties": {}},
},
}
request_body_schema = self.get_request_body_schema()
if request_body_schema:
tool_definition["function"]["parameters"]["properties"][
REQUEST_BODY
] = request_body_schema
query_param_schemas = self.get_query_param_schemas()
if query_param_schemas:
tool_definition["function"]["parameters"]["properties"].update(
{param["name"]: param["schema"] for param in query_param_schemas}
)
path_param_schemas = self.get_path_param_schemas()
if path_param_schemas:
tool_definition["function"]["parameters"]["properties"].update(
{param["name"]: param["schema"] for param in path_param_schemas}
)
return tool_definition
def validate_spec(self) -> None:
# Validate url construction
path_param_schemas = self.get_path_param_schemas()
dummy_path_dict = {param["name"]: "value" for param in path_param_schemas}
query_param_schemas = self.get_query_param_schemas()
dummy_query_dict = {param["name"]: "value" for param in query_param_schemas}
self.build_url("", dummy_path_dict, dummy_query_dict)
# Make sure request body doesn't throw an exception
self.get_request_body_schema()
# Ensure the method is valid
if not self.method:
raise ValueError("HTTP method is not specified.")
if self.method.upper() not in ["GET", "POST", "PUT", "DELETE", "PATCH"]:
raise ValueError(f"HTTP method '{self.method}' is not supported.")
"""Path-level utils"""
def openapi_to_path_specs(openapi_spec: dict[str, Any]) -> list[PathSpec]:
path_specs = []
for path, methods in openapi_spec.get("paths", {}).items():
path_specs.append(PathSpec(path=path, methods=methods))
return path_specs
"""Method-level utils"""
def openapi_to_method_specs(openapi_spec: dict[str, Any]) -> list[MethodSpec]:
path_specs = openapi_to_path_specs(openapi_spec)
method_specs = []
for path_spec in path_specs:
for method_name, method in path_spec.methods.items():
name = method.get("operationId")
if not name:
raise ValueError(
f"Operation ID is not specified for {method_name.upper()} {path_spec.path}"
)
summary = method.get("summary") or method.get("description")
if not summary:
raise ValueError(
f"Summary is not specified for {method_name.upper()} {path_spec.path}"
)
method_specs.append(
MethodSpec(
name=name,
summary=summary,
path=path_spec.path,
method=method_name,
spec=method,
)
)
if not method_specs:
raise ValueError("No methods found in OpenAPI schema")
return method_specs
def openapi_to_url(openapi_schema: dict[str, dict | str]) -> str:
"""
Extract URLs from the servers section of an OpenAPI schema.
Args:
openapi_schema (Dict[str, Union[Dict, str, List]]): The OpenAPI schema in dictionary format.
Returns:
List[str]: A list of base URLs.
"""
urls: list[str] = []
servers = cast(list[dict[str, Any]], openapi_schema.get("servers", []))
for server in servers:
url = server.get("url")
if url:
urls.append(url)
if len(urls) != 1:
raise ValueError(
f"Expected exactly one URL in OpenAPI schema, but found {urls}"
)
return urls[0]
def validate_openapi_schema(schema: dict[str, Any]) -> None:
"""
Validate the given JSON schema as an OpenAPI schema.
Parameters:
- schema (dict): The JSON schema to validate.
Returns:
- bool: True if the schema is valid, False otherwise.
"""
# check basic structure
if "info" not in schema:
raise ValueError("`info` section is required in OpenAPI schema")
info = schema["info"]
if "title" not in info:
raise ValueError("`title` is required in `info` section of OpenAPI schema")
if "description" not in info:
raise ValueError(
"`description` is required in `info` section of OpenAPI schema"
)
if "openapi" not in schema:
raise ValueError(
"`openapi` field which specifies OpenAPI schema version is required"
)
openapi_version = schema["openapi"]
if not openapi_version.startswith("3."):
raise ValueError(f"OpenAPI version '{openapi_version}' is not supported")
if "paths" not in schema:
raise ValueError("`paths` section is required in OpenAPI schema")
url = openapi_to_url(schema)
if not url:
raise ValueError("OpenAPI schema does not contain a valid URL in `servers`")
method_specs = openapi_to_method_specs(schema)
for method_spec in method_specs:
method_spec.validate_spec()

View File

@ -1,12 +0,0 @@
from typing import Type
from sqlalchemy.orm import Session
from danswer.db.models import Tool as ToolDBModel
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.tool import Tool
def get_tool_cls(tool: ToolDBModel, db_session: Session) -> Type[Tool]:
# Currently only support built-in tools
return get_built_in_tool_by_id(tool.id, db_session)

View File

@ -8,6 +8,7 @@ from pydantic import BaseModel
from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.utils import build_content_with_imgs
@ -53,6 +54,8 @@ class ImageGenerationResponse(BaseModel):
class ImageGenerationTool(Tool):
NAME = "run_image_generation"
def __init__(
self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2
) -> None:
@ -60,16 +63,14 @@ class ImageGenerationTool(Tool):
self.model = model
self.num_imgs = num_imgs
@classmethod
def name(self) -> str:
return "run_image_generation"
return self.NAME
@classmethod
def tool_definition(cls) -> dict:
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": cls.name(),
"name": self.name(),
"description": "Generate an image from a prompt",
"parameters": {
"type": "object",
@ -162,3 +163,12 @@ class ImageGenerationTool(Tool):
id=IMAGE_GENERATION_RESPONSE_ID,
response=results,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
image_generation_responses = cast(
list[ImageGenerationResponse], args[0].response
)
return [
image_generation_response.dict()
for image_generation_response in image_generation_responses
]

View File

@ -0,0 +1,39 @@
from typing import Any
from pydantic import BaseModel
from pydantic import root_validator
class ToolResponse(BaseModel):
id: str | None = None
response: Any
class ToolCallKickoff(BaseModel):
tool_name: str
tool_args: dict[str, Any]
class ToolRunnerResponse(BaseModel):
tool_run_kickoff: ToolCallKickoff | None = None
tool_response: ToolResponse | None = None
tool_message_content: str | list[str | dict[str, Any]] | None = None
@root_validator
def validate_tool_runner_response(
cls, values: dict[str, ToolResponse | str]
) -> dict[str, ToolResponse | str]:
fields = ["tool_response", "tool_message_content", "tool_run_kickoff"]
provided = sum(1 for field in fields if values.get(field) is not None)
if provided != 1:
raise ValueError(
"Exactly one of 'tool_response', 'tool_message_content', "
"or 'tool_run_kickoff' must be provided"
)
return values
class ToolCallFinalResult(ToolCallKickoff):
tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature

View File

@ -10,6 +10,7 @@ from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import LlmDoc
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.doc_pruning import prune_documents
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
@ -55,6 +56,8 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
class SearchTool(Tool):
NAME = "run_search"
def __init__(
self,
db_session: Session,
@ -87,18 +90,16 @@ class SearchTool(Tool):
self.bypass_acl = bypass_acl
self.db_session = db_session
@classmethod
def name(cls) -> str:
return "run_search"
def name(self) -> str:
return self.NAME
"""For explicit tool calling"""
@classmethod
def tool_definition(cls) -> dict:
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": cls.name(),
"name": self.name(),
"description": search_tool_description,
"parameters": {
"type": "object",
@ -241,3 +242,13 @@ class SearchTool(Tool):
document_pruning_config=self.pruning_config,
)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=final_context_documents)
def final_result(self, *args: ToolResponse) -> JSON_ro:
final_docs = cast(
list[LlmDoc],
next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS),
)
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
# subfields that are not serializable by default (datetime)
# this forces pydantic to make them JSON serializable for us
return [json.loads(doc.json()) for doc in final_docs]

View File

@ -2,26 +2,19 @@ import abc
from collections.abc import Generator
from typing import Any
from pydantic import BaseModel
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
class ToolResponse(BaseModel):
id: str | None = None
response: Any
from danswer.tools.models import ToolResponse
class Tool(abc.ABC):
@classmethod
@abc.abstractmethod
def name(self) -> str:
raise NotImplementedError
"""For LLMs which support explicit tool calling"""
@classmethod
@abc.abstractmethod
def tool_definition(self) -> dict:
raise NotImplementedError
@ -49,3 +42,11 @@ class Tool(abc.ABC):
@abc.abstractmethod
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
raise NotImplementedError
@abc.abstractmethod
def final_result(self, *args: ToolResponse) -> JSON_ro:
"""
This is the "final summary" result of the tool.
It is the result that will be stored in the database.
"""
raise NotImplementedError

View File

@ -1,42 +1,15 @@
from collections.abc import Generator
from typing import Any
from pydantic import BaseModel
from pydantic import root_validator
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
class ToolRunKickoff(BaseModel):
tool_name: str
tool_args: dict[str, Any]
class ToolRunnerResponse(BaseModel):
tool_run_kickoff: ToolRunKickoff | None = None
tool_response: ToolResponse | None = None
tool_message_content: str | list[str | dict[str, Any]] | None = None
@root_validator
def validate_tool_runner_response(
cls, values: dict[str, ToolResponse | str]
) -> dict[str, ToolResponse | str]:
fields = ["tool_response", "tool_message_content", "tool_run_kickoff"]
provided = sum(1 for field in fields if values.get(field) is not None)
if provided != 1:
raise ValueError(
"Exactly one of 'tool_response', 'tool_message_content', "
"or 'tool_run_kickoff' must be provided"
)
return values
class ToolRunner:
def __init__(self, tool: Tool, args: dict[str, Any]):
self.tool = tool
@ -44,8 +17,8 @@ class ToolRunner:
self._tool_responses: list[ToolResponse] | None = None
def kickoff(self) -> ToolRunKickoff:
return ToolRunKickoff(tool_name=self.tool.name(), tool_args=self.args)
def kickoff(self) -> ToolCallKickoff:
return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args)
def tool_responses(self) -> Generator[ToolResponse, None, None]:
if self._tool_responses is not None:
@ -62,6 +35,13 @@ class ToolRunner:
tool_responses = list(self.tool_responses())
return self.tool.build_tool_message_content(*tool_responses)
def tool_final_result(self) -> ToolCallFinalResult:
return ToolCallFinalResult(
tool_name=self.tool.name(),
tool_args=self.args,
tool_result=self.tool.final_result(*self.tool_responses()),
)
def check_which_tools_should_run_for_non_tool_calling_llm(
tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM

View File

@ -1,5 +1,4 @@
import json
from typing import Type
from tiktoken import Encoding
@ -17,15 +16,13 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo
return False
def compute_tool_tokens(
tool: Tool | Type[Tool], llm_tokenizer: Encoding | None = None
) -> int:
def compute_tool_tokens(tool: Tool, llm_tokenizer: Encoding | None = None) -> int:
if not llm_tokenizer:
llm_tokenizer = get_default_llm_tokenizer()
return len(llm_tokenizer.encode(json.dumps(tool.tool_definition())))
def compute_all_tool_tokens(
tools: list[Tool] | list[Type[Tool]], llm_tokenizer: Encoding | None = None
tools: list[Tool], llm_tokenizer: Encoding | None = None
) -> int:
return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools)

View File

@ -24,6 +24,7 @@ httpx[http2]==0.23.3
httpx-oauth==0.11.2
huggingface-hub==0.20.1
jira==3.5.1
jsonref==1.1.0
langchain==0.1.17
langchain-community==0.0.36
langchain-core==0.1.50

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,261 @@
"use client";
import { useState, useEffect, useCallback } from "react";
import { useRouter } from "next/navigation";
import { Formik, Form, Field, ErrorMessage } from "formik";
import * as Yup from "yup";
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
import { TextFormField } from "@/components/admin/connectors/Field";
import { Button, Divider } from "@tremor/react";
import {
createCustomTool,
updateCustomTool,
validateToolDefinition,
} from "@/lib/tools/edit";
import { usePopup } from "@/components/admin/connectors/Popup";
import debounce from "lodash/debounce";
function parseJsonWithTrailingCommas(jsonString: string) {
// Regular expression to remove trailing commas before } or ]
let cleanedJsonString = jsonString.replace(/,\s*([}\]])/g, "$1");
// Replace True with true, False with false, and None with null
cleanedJsonString = cleanedJsonString
.replace(/\bTrue\b/g, "true")
.replace(/\bFalse\b/g, "false")
.replace(/\bNone\b/g, "null");
// Now parse the cleaned JSON string
return JSON.parse(cleanedJsonString);
}
function prettifyDefinition(definition: any) {
return JSON.stringify(definition, null, 2);
}
function ToolForm({
existingTool,
values,
setFieldValue,
isSubmitting,
definitionErrorState,
methodSpecsState,
}: {
existingTool?: ToolSnapshot;
values: ToolFormValues;
setFieldValue: (field: string, value: string) => void;
isSubmitting: boolean;
definitionErrorState: [
string | null,
React.Dispatch<React.SetStateAction<string | null>>,
];
methodSpecsState: [
MethodSpec[] | null,
React.Dispatch<React.SetStateAction<MethodSpec[] | null>>,
];
}) {
const [definitionError, setDefinitionError] = definitionErrorState;
const [methodSpecs, setMethodSpecs] = methodSpecsState;
const debouncedValidateDefinition = useCallback(
debounce(async (definition: string) => {
try {
const parsedDefinition = parseJsonWithTrailingCommas(definition);
const response = await validateToolDefinition({
definition: parsedDefinition,
});
if (response.error) {
setMethodSpecs(null);
setDefinitionError(response.error);
} else {
setMethodSpecs(response.data);
setDefinitionError(null);
}
} catch (error) {
console.log(error);
setMethodSpecs(null);
setDefinitionError("Invalid JSON format");
}
}, 300),
[]
);
useEffect(() => {
if (values.definition) {
debouncedValidateDefinition(values.definition);
}
}, [values.definition, debouncedValidateDefinition]);
return (
<Form>
<div className="relative">
<TextFormField
name="definition"
label="Definition"
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this tool."
placeholder="Enter your OpenAPI schema here"
isTextArea={true}
defaultHeight="h-96"
fontSize="text-sm"
isCode
hideError
/>
<button
type="button"
className="
absolute
bottom-4
right-4
border-border
border
bg-background
rounded
py-1
px-3
text-sm
hover:bg-hover-light
"
onClick={() => {
const definition = values.definition;
if (definition) {
try {
const formatted = prettifyDefinition(
parseJsonWithTrailingCommas(definition)
);
setFieldValue("definition", formatted);
} catch (error) {
alert("Invalid JSON format");
}
}
}}
>
Format
</button>
</div>
{definitionError && (
<div className="text-error text-sm">{definitionError}</div>
)}
<ErrorMessage
name="definition"
component="div"
className="text-error text-sm"
/>
{methodSpecs && methodSpecs.length > 0 && (
<div className="mt-4">
<h3 className="text-base font-semibold mb-2">Available methods</h3>
<div className="overflow-x-auto">
<table className="min-w-full bg-white border border-gray-200">
<thead>
<tr>
<th className="px-4 py-2 border-b">Name</th>
<th className="px-4 py-2 border-b">Summary</th>
<th className="px-4 py-2 border-b">Method</th>
<th className="px-4 py-2 border-b">Path</th>
</tr>
</thead>
<tbody>
{methodSpecs?.map((method: MethodSpec, index: number) => (
<tr key={index} className="text-sm">
<td className="px-4 py-2 border-b">{method.name}</td>
<td className="px-4 py-2 border-b">{method.summary}</td>
<td className="px-4 py-2 border-b">
{method.method.toUpperCase()}
</td>
<td className="px-4 py-2 border-b">{method.path}</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
)}
<Divider />
<div className="flex">
<Button
className="mx-auto"
color="green"
size="md"
type="submit"
disabled={isSubmitting || !!definitionError}
>
{existingTool ? "Update Tool" : "Create Tool"}
</Button>
</div>
</Form>
);
}
interface ToolFormValues {
definition: string;
}
const ToolSchema = Yup.object().shape({
definition: Yup.string().required("Tool definition is required"),
});
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
const router = useRouter();
const { popup, setPopup } = usePopup();
const [definitionError, setDefinitionError] = useState<string | null>(null);
const [methodSpecs, setMethodSpecs] = useState<MethodSpec[] | null>(null);
const prettifiedDefinition = tool?.definition
? prettifyDefinition(tool.definition)
: "";
return (
<div>
{popup}
<Formik
initialValues={{
definition: prettifiedDefinition,
}}
validationSchema={ToolSchema}
onSubmit={async (values: ToolFormValues) => {
let definition: any;
try {
definition = parseJsonWithTrailingCommas(values.definition);
} catch (error) {
setDefinitionError("Invalid JSON in tool definition");
return;
}
const name = definition?.info?.title;
const description = definition?.info?.description;
const toolData = {
name: name,
description: description || "",
definition: definition,
};
let response;
if (tool) {
response = await updateCustomTool(tool.id, toolData);
} else {
response = await createCustomTool(toolData);
}
if (response.error) {
setPopup({
message: "Failed to create tool - " + response.error,
type: "error",
});
return;
}
router.push(`/admin/tools?u=${Date.now()}`);
}}
>
{({ isSubmitting, values, setFieldValue }) => {
return (
<ToolForm
existingTool={tool}
values={values}
setFieldValue={setFieldValue}
isSubmitting={isSubmitting}
definitionErrorState={[definitionError, setDefinitionError]}
methodSpecsState={[methodSpecs, setMethodSpecs]}
/>
);
}}
</Formik>
</div>
);
}

View File

@ -0,0 +1,107 @@
"use client";
import {
Text,
Table,
TableHead,
TableRow,
TableHeaderCell,
TableBody,
TableCell,
} from "@tremor/react";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { useRouter } from "next/navigation";
import { usePopup } from "@/components/admin/connectors/Popup";
import { FiCheckCircle, FiEdit, FiXCircle } from "react-icons/fi";
import { TrashIcon } from "@/components/icons/icons";
import { deleteCustomTool } from "@/lib/tools/edit";
export function ToolsTable({ tools }: { tools: ToolSnapshot[] }) {
const router = useRouter();
const { popup, setPopup } = usePopup();
const sortedTools = [...tools];
sortedTools.sort((a, b) => a.id - b.id);
return (
<div>
{popup}
<Table>
<TableHead>
<TableRow>
<TableHeaderCell>Name</TableHeaderCell>
<TableHeaderCell>Description</TableHeaderCell>
<TableHeaderCell>Built In?</TableHeaderCell>
<TableHeaderCell>Delete</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{sortedTools.map((tool) => (
<TableRow key={tool.id.toString()}>
<TableCell>
<div className="flex">
{tool.in_code_tool_id === null && (
<FiEdit
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
)
}
/>
)}
<p className="text font-medium whitespace-normal break-none">
{tool.name}
</p>
</div>
</TableCell>
<TableCell className="whitespace-normal break-all max-w-2xl">
{tool.description}
</TableCell>
<TableCell className="whitespace-nowrap">
{tool.in_code_tool_id === null ? (
<span>
<FiXCircle className="inline-block mr-1 my-auto" />
No
</span>
) : (
<span>
<FiCheckCircle className="inline-block mr-1 my-auto" />
Yes
</span>
)}
</TableCell>
<TableCell className="whitespace-nowrap">
<div className="flex">
{tool.in_code_tool_id === null ? (
<div className="my-auto">
<div
className="hover:bg-hover rounded p-1 cursor-pointer"
onClick={async () => {
const response = await deleteCustomTool(tool.id);
if (response.data) {
router.refresh();
} else {
setPopup({
message: `Failed to delete tool - ${response.error}`,
type: "error",
});
}
}}
>
<TrashIcon />
</div>
</div>
) : (
"-"
)}
</div>
</TableCell>
</TableRow>
))}
</TableBody>
</Table>
</div>
);
}

View File

@ -0,0 +1,28 @@
"use client";
import { Button } from "@tremor/react";
import { FiTrash } from "react-icons/fi";
import { deleteCustomTool } from "@/lib/tools/edit";
import { useRouter } from "next/navigation";
export function DeleteToolButton({ toolId }: { toolId: number }) {
const router = useRouter();
return (
<Button
size="xs"
color="red"
onClick={async () => {
const response = await deleteCustomTool(toolId);
if (response.data) {
router.push(`/admin/tools?u=${Date.now()}`);
} else {
alert(`Failed to delete tool - ${response.error}`);
}
}}
icon={FiTrash}
>
Delete
</Button>
);
}

View File

@ -0,0 +1,55 @@
import { ErrorCallout } from "@/components/ErrorCallout";
import { Card, Text, Title } from "@tremor/react";
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
import { fetchToolByIdSS } from "@/lib/tools/fetchTools";
import { DeleteToolButton } from "./DeleteToolButton";
import { FiTool } from "react-icons/fi";
import { AdminPageTitle } from "@/components/admin/Title";
import { BackButton } from "@/components/BackButton";
export default async function Page({ params }: { params: { toolId: string } }) {
const tool = await fetchToolByIdSS(params.toolId);
let body;
if (!tool) {
body = (
<div>
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg="Tool not found"
/>
</div>
);
} else {
body = (
<div className="w-full my-8">
<div>
<div>
<Card>
<ToolEditor tool={tool} />
</Card>
<Title className="mt-12">Delete Tool</Title>
<Text>Click the button below to permanently delete this tool.</Text>
<div className="flex mt-6">
<DeleteToolButton toolId={tool.id} />
</div>
</div>
</div>
</div>
);
}
return (
<div className="mx-auto container">
<BackButton />
<AdminPageTitle
title="Edit Tool"
icon={<FiTool size={32} className="my-auto" />}
/>
{body}
</div>
);
}

View File

@ -0,0 +1,24 @@
"use client";
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
import { BackButton } from "@/components/BackButton";
import { AdminPageTitle } from "@/components/admin/Title";
import { Card } from "@tremor/react";
import { FiTool } from "react-icons/fi";
export default function NewToolPage() {
return (
<div className="mx-auto container">
<BackButton />
<AdminPageTitle
title="Create Tool"
icon={<FiTool size={32} className="my-auto" />}
/>
<Card>
<ToolEditor />
</Card>
</div>
);
}

View File

@ -0,0 +1,68 @@
import { ToolsTable } from "./ToolsTable";
import { ToolSnapshot } from "@/lib/tools/interfaces";
import { FiPlusSquare, FiTool } from "react-icons/fi";
import Link from "next/link";
import { Divider, Text, Title } from "@tremor/react";
import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import { AdminPageTitle } from "@/components/admin/Title";
export default async function Page() {
const toolResponse = await fetchSS("/tool");
if (!toolResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch tools - ${await toolResponse.text()}`}
/>
);
}
const tools = (await toolResponse.json()) as ToolSnapshot[];
return (
<div className="mx-auto container">
<AdminPageTitle
icon={<FiTool size={32} className="my-auto" />}
title="Tools"
/>
<Text className="mb-2">
Tools allow assistants to retrieve information or take actions.
</Text>
<div>
<Divider />
<Title>Create a Tool</Title>
<Link
href="/admin/tools/new"
className="
flex
py-2
px-4
mt-2
border
border-border
h-fit
cursor-pointer
hover:bg-hover
text-sm
w-40
"
>
<div className="mx-auto flex">
<FiPlusSquare className="my-auto mr-2" />
New Tool
</div>
</Link>
<Divider />
<Title>Existing Tools</Title>
<ToolsTable tools={tools} />
</div>
</div>
);
}

View File

@ -12,7 +12,8 @@ import {
Message,
RetrievalType,
StreamingError,
ToolRunKickoff,
ToolCallFinalResult,
ToolCallMetadata,
} from "./interfaces";
import { ChatSidebar } from "./sessionSidebar/ChatSidebar";
import { Persona } from "../admin/assistants/interfaces";
@ -265,6 +266,7 @@ export function ChatPage({
message: "",
type: "system",
files: [],
toolCalls: [],
parentMessageId: null,
childrenMessageIds: [firstMessageId],
latestChildMessageId: firstMessageId,
@ -307,7 +309,6 @@ export function ChatPage({
return newCompleteMessageMap;
};
const messageHistory = buildLatestMessageChain(completeMessageMap);
const [currentTool, setCurrentTool] = useState<string | null>(null);
const [isStreaming, setIsStreaming] = useState(false);
// uploaded files
@ -535,6 +536,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: parentMessage?.messageId || null,
},
];
@ -569,6 +571,7 @@ export function ChatPage({
let aiMessageImages: FileDescriptor[] | null = null;
let error: string | null = null;
let finalMessage: BackendMessage | null = null;
let toolCalls: ToolCallMetadata[] = [];
try {
const lastSuccessfulMessageId =
getLastSuccessfulMessageId(currMessageHistory);
@ -627,7 +630,13 @@ export function ChatPage({
}
);
} else if (Object.hasOwn(packet, "tool_name")) {
setCurrentTool((packet as ToolRunKickoff).tool_name);
toolCalls = [
{
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
},
];
} else if (Object.hasOwn(packet, "error")) {
error = (packet as StreamingError).error;
} else if (Object.hasOwn(packet, "message_id")) {
@ -657,6 +666,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: parentMessage?.messageId || null,
childrenMessageIds: [newAssistantMessageId],
latestChildMessageId: newAssistantMessageId,
@ -670,6 +680,7 @@ export function ChatPage({
documents: finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: newUserMessageId,
},
]);
@ -687,6 +698,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
{
@ -694,6 +706,7 @@ export function ChatPage({
message: errorMsg,
type: "error",
files: aiMessageImages || [],
toolCalls: [],
parentMessageId: TEMP_USER_MESSAGE_ID,
},
],
@ -1031,7 +1044,7 @@ export function ChatPage({
citedDocuments={getCitedDocumentsFromMessage(
message
)}
currentTool={currentTool}
toolCall={message.toolCalls[0]}
isComplete={
i !== messageHistory.length - 1 ||
!isStreaming

View File

@ -34,6 +34,18 @@ export interface FileDescriptor {
isUploading?: boolean;
}
export interface ToolCallMetadata {
tool_name: string;
tool_args: Record<string, any>;
tool_result?: Record<string, any>;
}
export interface ToolCallFinalResult {
tool_name: string;
tool_args: Record<string, any>;
tool_result: Record<string, any>;
}
export interface ChatSession {
id: number;
name: string;
@ -52,6 +64,7 @@ export interface Message {
documents?: DanswerDocument[] | null;
citations?: CitationMap;
files: FileDescriptor[];
toolCalls: ToolCallMetadata[];
// for rebuilding the message tree
parentMessageId: number | null;
childrenMessageIds?: number[];
@ -79,6 +92,7 @@ export interface BackendMessage {
time_sent: string;
citations: CitationMap;
files: FileDescriptor[];
tool_calls: ToolCallFinalResult[];
}
export interface DocumentsResponse {
@ -90,11 +104,6 @@ export interface ImageGenerationDisplay {
file_ids: string[];
}
export interface ToolRunKickoff {
tool_name: string;
tool_args: Record<string, any>;
}
export interface StreamingError {
error: string;
}

View File

@ -15,7 +15,7 @@ import {
Message,
RetrievalType,
StreamingError,
ToolRunKickoff,
ToolCallMetadata,
} from "./interfaces";
import { Persona } from "../admin/assistants/interfaces";
import { ReadonlyURLSearchParams } from "next/navigation";
@ -138,7 +138,7 @@ export async function* sendMessage({
| DocumentsResponse
| BackendMessage
| ImageGenerationDisplay
| ToolRunKickoff
| ToolCallMetadata
| StreamingError
>(sendMessageResponse);
}
@ -384,6 +384,7 @@ export function processRawChatHistory(
citations: messageInfo?.citations || {},
}
: {}),
toolCalls: messageInfo.tool_calls,
parentMessageId: messageInfo.parent_message,
childrenMessageIds: [],
latestChildMessageId: messageInfo.latest_child_message,

View File

@ -9,6 +9,7 @@ import {
FiEdit2,
FiChevronRight,
FiChevronLeft,
FiTool,
} from "react-icons/fi";
import { FeedbackType } from "../types";
import { useEffect, useRef, useState } from "react";
@ -20,9 +21,12 @@ import { ThreeDots } from "react-loader-spinner";
import { SkippedSearch } from "./SkippedSearch";
import remarkGfm from "remark-gfm";
import { CopyButton } from "@/components/CopyButton";
import { ChatFileType, FileDescriptor } from "../interfaces";
import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants";
import { ToolRunningAnimation } from "../tools/ToolRunningAnimation";
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
import {
IMAGE_GENERATION_TOOL_NAME,
SEARCH_TOOL_NAME,
} from "../tools/constants";
import { ToolRunDisplay } from "../tools/ToolRunningAnimation";
import { Hoverable } from "@/components/Hoverable";
import { DocumentPreview } from "../files/documents/DocumentPreview";
import { InMessageImage } from "../files/images/InMessageImage";
@ -35,6 +39,11 @@ import Prism from "prismjs";
import "prismjs/themes/prism-tomorrow.css";
import "./custom-code-styles.css";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
IMAGE_GENERATION_TOOL_NAME,
];
function FileDisplay({ files }: { files: FileDescriptor[] }) {
const imageFiles = files.filter((file) => file.type === ChatFileType.IMAGE);
const nonImgFiles = files.filter((file) => file.type !== ChatFileType.IMAGE);
@ -77,7 +86,7 @@ export const AIMessage = ({
query,
personaName,
citedDocuments,
currentTool,
toolCall,
isComplete,
hasDocs,
handleFeedback,
@ -93,7 +102,7 @@ export const AIMessage = ({
query?: string;
personaName?: string;
citedDocuments?: [string, DanswerDocument][] | null;
currentTool?: string | null;
toolCall?: ToolCallMetadata;
isComplete?: boolean;
hasDocs?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void;
@ -133,28 +142,23 @@ export const AIMessage = ({
content = trimIncompleteCodeSection(content);
}
const loader =
currentTool === IMAGE_GENERATION_TOOL_NAME ? (
<div className="text-sm my-auto">
<ToolRunningAnimation
toolName="Generating images"
toolLogo={<FiImage size={16} className="my-auto mr-1" />}
/>
</div>
) : (
<div className="text-sm my-auto">
<ThreeDots
height="30"
width="50"
color="#3b82f6"
ariaLabel="grid-loading"
radius="12.5"
wrapperStyle={{}}
wrapperClass=""
visible={true}
/>
</div>
);
const shouldShowLoader =
!toolCall ||
(toolCall.tool_name === SEARCH_TOOL_NAME && query === undefined);
const defaultLoader = shouldShowLoader ? (
<div className="text-sm my-auto">
<ThreeDots
height="30"
width="50"
color="#3b82f6"
ariaLabel="grid-loading"
radius="12.5"
wrapperStyle={{}}
wrapperClass=""
visible={true}
/>
</div>
) : undefined;
return (
<div className={"py-5 px-5 flex -mr-6 w-full"}>
@ -189,28 +193,61 @@ export const AIMessage = ({
</div>
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message-default break-words mt-1 ml-8">
{query !== undefined &&
handleShowRetrieved !== undefined &&
isCurrentlyShowingRetrieved !== undefined &&
!retrievalDisabled && (
<div className="my-1">
<SearchSummary
query={query}
hasDocs={hasDocs || false}
messageId={messageId}
isCurrentlyShowingRetrieved={isCurrentlyShowingRetrieved}
handleShowRetrieved={handleShowRetrieved}
handleSearchQueryEdit={handleSearchQueryEdit}
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
<>
{query !== undefined &&
handleShowRetrieved !== undefined &&
isCurrentlyShowingRetrieved !== undefined &&
!retrievalDisabled && (
<div className="my-1">
<SearchSummary
query={query}
hasDocs={hasDocs || false}
messageId={messageId}
isCurrentlyShowingRetrieved={
isCurrentlyShowingRetrieved
}
handleShowRetrieved={handleShowRetrieved}
handleSearchQueryEdit={handleSearchQueryEdit}
/>
</div>
)}
{handleForceSearch &&
content &&
query === undefined &&
!hasDocs &&
!retrievalDisabled && (
<div className="my-1">
<SkippedSearch handleForceSearch={handleForceSearch} />
</div>
)}
</>
)}
{toolCall &&
!TOOLS_WITH_CUSTOM_HANDLING.includes(toolCall.tool_name) && (
<div className="my-2">
<ToolRunDisplay
toolName={
toolCall.tool_result && content
? `Used "${toolCall.tool_name}"`
: `Using "${toolCall.tool_name}"`
}
toolLogo={<FiTool size={15} className="my-auto mr-1" />}
isRunning={!toolCall.tool_result || !content}
/>
</div>
)}
{handleForceSearch &&
content &&
query === undefined &&
!hasDocs &&
!retrievalDisabled && (
<div className="my-1">
<SkippedSearch handleForceSearch={handleForceSearch} />
{toolCall &&
toolCall.tool_name === IMAGE_GENERATION_TOOL_NAME &&
!toolCall.tool_result && (
<div className="my-2">
<ToolRunDisplay
toolName={`Generating images`}
toolLogo={<FiImage size={15} className="my-auto mr-1" />}
isRunning={!toolCall.tool_result}
/>
</div>
)}
@ -260,7 +297,7 @@ export const AIMessage = ({
)}
</>
) : isComplete ? null : (
loader
defaultLoader
)}
{citedDocuments && citedDocuments.length > 0 && (
<div className="mt-2">

View File

@ -1,16 +1,18 @@
import { LoadingAnimation } from "@/components/Loading";
export function ToolRunningAnimation({
export function ToolRunDisplay({
toolName,
toolLogo,
isRunning,
}: {
toolName: string;
toolLogo: JSX.Element;
toolLogo?: JSX.Element;
isRunning: boolean;
}) {
return (
<div className="text-sm my-auto flex">
<div className="text-sm text-subtle my-auto flex">
{toolLogo}
<LoadingAnimation text={toolName} />
{isRunning ? <LoadingAnimation text={toolName} /> : toolName}
</div>
);
}

View File

@ -2,15 +2,12 @@ import { Header } from "@/components/header/Header";
import { AdminSidebar } from "@/components/admin/connectors/AdminSidebar";
import {
NotebookIcon,
KeyIcon,
UsersIcon,
ThumbsUpIcon,
BookmarkIcon,
CPUIcon,
ZoomInIcon,
RobotIcon,
ConnectorIcon,
SlackIcon,
} from "@/components/icons/icons";
import { User } from "@/lib/types";
import {
@ -19,13 +16,7 @@ import {
getCurrentUserSS,
} from "@/lib/userSS";
import { redirect } from "next/navigation";
import {
FiCpu,
FiLayers,
FiPackage,
FiSettings,
FiSlack,
} from "react-icons/fi";
import { FiCpu, FiPackage, FiSettings, FiSlack, FiTool } from "react-icons/fi";
export async function Layout({ children }: { children: React.ReactNode }) {
const tasks = [getAuthTypeMetadataSS(), getCurrentUserSS()];
@ -142,6 +133,15 @@ export async function Layout({ children }: { children: React.ReactNode }) {
),
link: "/admin/bot",
},
{
name: (
<div className="flex">
<FiTool size={18} className="my-auto" />
<div className="ml-1">Tools</div>
</div>
),
link: "/admin/tools",
},
],
},
{

View File

@ -43,6 +43,10 @@ export function TextFormField({
disabled = false,
autoCompleteDisabled = true,
error,
defaultHeight,
isCode = false,
fontSize,
hideError,
}: {
name: string;
label: string;
@ -54,7 +58,16 @@ export function TextFormField({
disabled?: boolean;
autoCompleteDisabled?: boolean;
error?: string;
defaultHeight?: string;
isCode?: boolean;
fontSize?: "text-sm" | "text-base" | "text-lg";
hideError?: boolean;
}) {
let heightString = defaultHeight || "";
if (isTextArea && !heightString) {
heightString = "h-28";
}
return (
<div className="mb-4">
<Label>{label}</Label>
@ -64,18 +77,19 @@ export function TextFormField({
type={type}
name={name}
id={name}
className={
`
border
border-border
rounded
w-full
py-2
px-3
mt-1
${isTextArea ? " h-28" : ""}
` + (disabled ? " bg-background-strong" : " bg-background-emphasis")
}
className={`
border
border-border
rounded
w-full
py-2
px-3
mt-1
${heightString}
${fontSize}
${disabled ? " bg-background-strong" : " bg-background-emphasis"}
${isCode ? " font-mono" : ""}
`}
disabled={disabled}
placeholder={placeholder}
autoComplete={autoCompleteDisabled ? "off" : undefined}
@ -84,11 +98,13 @@ export function TextFormField({
{error ? (
<ManualErrorMessage>{error}</ManualErrorMessage>
) : (
<ErrorMessage
name={name}
component="div"
className="text-red-500 text-sm mt-1"
/>
!hideError && (
<ErrorMessage
name={name}
component="div"
className="text-red-500 text-sm mt-1"
/>
)
)}
</div>
);

111
web/src/lib/tools/edit.ts Normal file
View File

@ -0,0 +1,111 @@
import { MethodSpec, ToolSnapshot } from "./interfaces";
interface ApiResponse<T> {
data: T | null;
error: string | null;
}
export async function createCustomTool(toolData: {
name: string;
description?: string;
definition: Record<string, any>;
}): Promise<ApiResponse<ToolSnapshot>> {
try {
const response = await fetch("/api/admin/tool/custom", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(toolData),
});
if (!response.ok) {
const errorDetail = (await response.json()).detail;
return { data: null, error: `Failed to create tool: ${errorDetail}` };
}
const tool: ToolSnapshot = await response.json();
return { data: tool, error: null };
} catch (error) {
console.error("Error creating tool:", error);
return { data: null, error: "Error creating tool" };
}
}
export async function updateCustomTool(
toolId: number,
toolData: {
name?: string;
description?: string;
definition?: Record<string, any>;
}
): Promise<ApiResponse<ToolSnapshot>> {
try {
const response = await fetch(`/api/admin/tool/custom/${toolId}`, {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(toolData),
});
if (!response.ok) {
const errorDetail = (await response.json()).detail;
return { data: null, error: `Failed to update tool: ${errorDetail}` };
}
const updatedTool: ToolSnapshot = await response.json();
return { data: updatedTool, error: null };
} catch (error) {
console.error("Error updating tool:", error);
return { data: null, error: "Error updating tool" };
}
}
export async function deleteCustomTool(
toolId: number
): Promise<ApiResponse<boolean>> {
try {
const response = await fetch(`/api/admin/tool/custom/${toolId}`, {
method: "DELETE",
headers: {
"Content-Type": "application/json",
},
});
if (!response.ok) {
const errorDetail = (await response.json()).detail;
return { data: false, error: `Failed to delete tool: ${errorDetail}` };
}
return { data: true, error: null };
} catch (error) {
console.error("Error deleting tool:", error);
return { data: false, error: "Error deleting tool" };
}
}
export async function validateToolDefinition(toolData: {
definition: Record<string, any>;
}): Promise<ApiResponse<MethodSpec[]>> {
try {
const response = await fetch(`/api/admin/tool/custom/validate`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(toolData),
});
if (!response.ok) {
const errorDetail = (await response.json()).detail;
return { data: null, error: errorDetail };
}
const responseJson = await response.json();
return { data: responseJson.methods, error: null };
} catch (error) {
console.error("Error validating tool:", error);
return { data: null, error: "Unexpected error validating tool definition" };
}
}

View File

@ -14,3 +14,21 @@ export async function fetchToolsSS(): Promise<ToolSnapshot[] | null> {
return null;
}
}
export async function fetchToolByIdSS(
toolId: string
): Promise<ToolSnapshot | null> {
try {
const response = await fetchSS(`/tool/${toolId}`);
if (!response.ok) {
throw new Error(
`Failed to fetch tool with ID ${toolId}: ${await response.text()}`
);
}
const tool: ToolSnapshot = await response.json();
return tool;
} catch (error) {
console.error(`Error fetching tool with ID ${toolId}:`, error);
return null;
}
}

View File

@ -2,5 +2,21 @@ export interface ToolSnapshot {
id: number;
name: string;
description: string;
// only specified for Custom Tools. OpenAPI schema which represents
// the tool's API.
definition: Record<string, any> | null;
// only specified for Custom Tools. ID of the tool in the codebase.
in_code_tool_id: string | null;
}
export interface MethodSpec {
/* Defines a single method that is part of a custom tool. Each method maps to a single
action that the LLM can choose to take. */
name: string;
summary: string;
path: string;
method: string;
spec: Record<string, any>;
}