diff --git a/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py new file mode 100644 index 00000000000..514389ac216 --- /dev/null +++ b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py @@ -0,0 +1,61 @@ +"""Add support for custom tools + +Revision ID: 48d14957fe80 +Revises: b85f02ec1308 +Create Date: 2024-06-09 14:58:19.946509 + +""" +from alembic import op +import fastapi_users_db_sqlalchemy +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "48d14957fe80" +down_revision = "b85f02ec1308" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "tool", + sa.Column( + "openapi_schema", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "tool", + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + ) + op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"]) + + op.create_table( + "tool_call", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("tool_id", sa.Integer(), nullable=False), + sa.Column("tool_name", sa.String(), nullable=False), + sa.Column( + "tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column( + "tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column( + "message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False + ), + ) + + +def downgrade() -> None: + op.drop_table("tool_call") + + op.drop_constraint("tool_user_fk", "tool", type_="foreignkey") + op.drop_column("tool", "user_id") + op.drop_column("tool", "openapi_schema") diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 8fa5eecaeeb..7fc526a5cbe 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -106,12 +106,18 @@ class ImageGenerationDisplay(BaseModel): file_ids: list[str] +class CustomToolResponse(BaseModel): + response: dict + tool_name: str + + AnswerQuestionPossibleReturn = ( DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | ImageGenerationDisplay + | CustomToolResponse | StreamingError ) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 72381542176..7733bf523d6 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import create_chat_chain from danswer.chat.models import CitationInfo +from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LlmDoc @@ -31,6 +32,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import SearchDoc as DbSearchDoc +from danswer.db.models import ToolCall from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.file_store.models import ChatFileType @@ -54,7 +56,10 @@ from danswer.search.utils import drop_llm_indices from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line -from danswer.tools.factory import get_tool_cls +from danswer.tools.built_in_tools import get_built_in_tool_by_id +from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema +from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID +from danswer.tools.custom.custom_tool import CustomToolCallSummary from danswer.tools.force import ForceUseTool from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID from danswer.tools.images.image_generation_tool import ImageGenerationResponse @@ -65,6 +70,7 @@ from danswer.tools.search.search_tool import SearchTool from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse +from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.utils import compute_all_tool_tokens from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger @@ -162,7 +168,7 @@ def _check_should_force_search( args = {"query": new_msg_req.message} return ForceUseTool( - tool_name=SearchTool.name(), + tool_name=SearchTool.NAME, args=args, ) return None @@ -176,6 +182,7 @@ ChatPacket = ( | DanswerAnswerPiece | CitationInfo | ImageGenerationDisplay + | CustomToolResponse ) ChatPacketStream = Iterator[ChatPacket] @@ -389,61 +396,78 @@ def stream_chat_message_objects( ), ) - persona_tool_classes = [ - get_tool_cls(tool, db_session) for tool in persona.tools - ] + # find out what tools to use + search_tool: SearchTool | None = None + tool_dict: dict[int, list[Tool]] = {} # tool_id to tool + for db_tool_model in persona.tools: + # handle in-code tools specially + if db_tool_model.in_code_tool_id: + tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session) + if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files: + search_tool = SearchTool( + db_session=db_session, + user=user, + persona=persona, + retrieval_options=retrieval_options, + prompt_config=prompt_config, + llm=llm, + pruning_config=document_pruning_config, + selected_docs=selected_llm_docs, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + ) + tool_dict[db_tool_model.id] = [search_tool] + elif tool_cls.__name__ == ImageGenerationTool.__name__: + dalle_key = None + if ( + llm + and llm.config.api_key + and llm.config.model_provider == "openai" + ): + dalle_key = llm.config.api_key + else: + llm_providers = fetch_existing_llm_providers(db_session) + openai_provider = next( + iter( + [ + llm_provider + for llm_provider in llm_providers + if llm_provider.provider == "openai" + ] + ), + None, + ) + if not openai_provider or not openai_provider.api_key: + raise ValueError( + "Image generation tool requires an OpenAI API key" + ) + dalle_key = openai_provider.api_key + tool_dict[db_tool_model.id] = [ + ImageGenerationTool(api_key=dalle_key) + ] + + continue + + # handle all custom tools + if db_tool_model.openapi_schema: + tool_dict[db_tool_model.id] = cast( + list[Tool], + build_custom_tools_from_openapi_schema( + db_tool_model.openapi_schema + ), + ) + + tools: list[Tool] = [] + for tool_list in tool_dict.values(): + tools.extend(tool_list) # factor in tool definition size when pruning - document_pruning_config.tool_num_tokens = compute_all_tool_tokens( - persona_tool_classes - ) + document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools) document_pruning_config.using_tool_message = explicit_tool_calling_supported( llm.config.model_provider, llm.config.model_name ) - # NOTE: for now, only support SearchTool and ImageGenerationTool - # in the future, will support arbitrary user-defined tools - search_tool: SearchTool | None = None - tools: list[Tool] = [] - for tool_cls in persona_tool_classes: - if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files: - search_tool = SearchTool( - db_session=db_session, - user=user, - persona=persona, - retrieval_options=retrieval_options, - prompt_config=prompt_config, - llm=llm, - pruning_config=document_pruning_config, - selected_docs=selected_llm_docs, - chunks_above=new_msg_req.chunks_above, - chunks_below=new_msg_req.chunks_below, - full_doc=new_msg_req.full_doc, - ) - tools.append(search_tool) - elif tool_cls.__name__ == ImageGenerationTool.__name__: - dalle_key = None - if llm and llm.config.api_key and llm.config.model_provider == "openai": - dalle_key = llm.config.api_key - else: - llm_providers = fetch_existing_llm_providers(db_session) - openai_provider = next( - iter( - [ - llm_provider - for llm_provider in llm_providers - if llm_provider.provider == "openai" - ] - ), - None, - ) - if not openai_provider or not openai_provider.api_key: - raise ValueError( - "Image generation tool requires an OpenAI API key" - ) - dalle_key = openai_provider.api_key - tools.append(ImageGenerationTool(api_key=dalle_key)) - # LLM prompt building, response capturing, etc. answer = Answer( question=final_msg.message, @@ -468,7 +492,9 @@ def stream_chat_message_objects( ], tools=tools, force_use_tool=( - _check_should_force_search(new_msg_req) if search_tool else None + _check_should_force_search(new_msg_req) + if search_tool and len(tools) == 1 + else None ), ) @@ -476,6 +502,7 @@ def stream_chat_message_objects( qa_docs_response = None ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images dropped_indices = None + tool_result = None for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): if packet.id == SEARCH_RESPONSE_SUMMARY_ID: @@ -521,8 +548,16 @@ def stream_chat_message_objects( yield ImageGenerationDisplay( file_ids=[str(file_id) for file_id in file_ids] ) + elif packet.id == CUSTOM_TOOL_RESPONSE_ID: + custom_tool_response = cast(CustomToolCallSummary, packet.response) + yield CustomToolResponse( + response=custom_tool_response.tool_result, + tool_name=custom_tool_response.tool_name, + ) else: + if isinstance(packet, ToolCallFinalResult): + tool_result = packet yield cast(ChatPacket, packet) except Exception as e: @@ -551,6 +586,11 @@ def stream_chat_message_objects( ) # Saving Gen AI answer and responding with message info + tool_name_to_tool_id: dict[str, int] = {} + for tool_id, tool_list in tool_dict.items(): + for tool in tool_list: + tool_name_to_tool_id[tool.name()] = tool_id + gen_ai_response_message = partial_response( message=answer.llm_answer, rephrased_query=( @@ -561,6 +601,16 @@ def stream_chat_message_objects( token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, error=None, + tool_calls=[ + ToolCall( + tool_id=tool_name_to_tool_id[tool_result.tool_name], + tool_name=tool_result.tool_name, + tool_arguments=tool_result.tool_args, + tool_result=tool_result.tool_result, + ) + ] + if tool_result + else [], ) db_session.commit() # actually save user / assistant message diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 58779288022..5ee9dfb3f51 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -5,6 +5,7 @@ from sqlalchemy import nullsfirst from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.exc import MultipleResultsFound +from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from danswer.auth.schemas import UserRole @@ -16,6 +17,7 @@ from danswer.db.models import ChatSessionSharedStatus from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc +from danswer.db.models import ToolCall from danswer.db.models import User from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride @@ -24,6 +26,7 @@ from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchDoc as ServerSearchDoc from danswer.server.query_and_chat.models import ChatMessageDetail +from danswer.tools.tool_runner import ToolCallFinalResult from danswer.utils.logger import setup_logger logger = setup_logger() @@ -185,6 +188,7 @@ def get_chat_messages_by_session( user_id: UUID | None, db_session: Session, skip_permission_check: bool = False, + prefetch_tool_calls: bool = False, ) -> list[ChatMessage]: if not skip_permission_check: get_chat_session_by_id( @@ -192,12 +196,18 @@ def get_chat_messages_by_session( ) stmt = ( - select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id) - # Start with the root message which has no parent + select(ChatMessage) + .where(ChatMessage.chat_session_id == chat_session_id) .order_by(nullsfirst(ChatMessage.parent_message)) ) - result = db_session.execute(stmt).scalars().all() + if prefetch_tool_calls: + stmt = stmt.options(joinedload(ChatMessage.tool_calls)) + + if prefetch_tool_calls: + result = db_session.scalars(stmt).unique().all() + else: + result = db_session.scalars(stmt).all() return list(result) @@ -251,6 +261,7 @@ def create_new_chat_message( reference_docs: list[DBSearchDoc] | None = None, # Maps the citation number [n] to the DB SearchDoc citations: dict[int, int] | None = None, + tool_calls: list[ToolCall] | None = None, commit: bool = True, ) -> ChatMessage: new_chat_message = ChatMessage( @@ -264,6 +275,7 @@ def create_new_chat_message( message_type=message_type, citations=citations, files=files, + tool_calls=tool_calls if tool_calls else [], error=error, ) @@ -459,6 +471,14 @@ def translate_db_message_to_chat_message_detail( time_sent=chat_message.time_sent, citations=chat_message.citations, files=chat_message.files or [], + tool_calls=[ + ToolCallFinalResult( + tool_name=tool_call.tool_name, + tool_args=tool_call.tool_arguments, + tool_result=tool_call.tool_result, + ) + for tool_call in chat_message.tool_calls + ], ) return chat_msg_detail diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3b4b67f0947..661eb04a8c8 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -133,6 +133,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base): prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user") # Personas owned by this user personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") + # Custom tools created by this user + custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user") class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): @@ -330,7 +332,6 @@ class Document(Base): primary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) - # Something like assignee or space owner secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) @@ -618,6 +619,26 @@ class SearchDoc(Base): ) +class ToolCall(Base): + """Represents a single tool call""" + + __tablename__ = "tool_call" + + id: Mapped[int] = mapped_column(primary_key=True) + # not a FK because we want to be able to delete the tool without deleting + # this entry + tool_id: Mapped[int] = mapped_column(Integer()) + tool_name: Mapped[str] = mapped_column(String()) + tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB()) + tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB()) + + message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + + message: Mapped["ChatMessage"] = relationship( + "ChatMessage", back_populates="tool_calls" + ) + + class ChatSession(Base): __tablename__ = "chat_session" @@ -723,6 +744,10 @@ class ChatMessage(Base): secondary="chat_message__search_doc", back_populates="chat_messages", ) + tool_calls: Mapped[list["ToolCall"]] = relationship( + "ToolCall", + back_populates="message", + ) class ChatFolder(Base): @@ -901,9 +926,18 @@ class Tool(Base): name: Mapped[str] = mapped_column(String, nullable=False) description: Mapped[str] = mapped_column(Text, nullable=True) # ID of the tool in the codebase, only applies for in-code tools. - # tools defiend via the UI will have this as None + # tools defined via the UI will have this as None in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True) + # OpenAPI scheme for the tool. Only applies to tools defined via the UI. + openapi_schema: Mapped[dict[str, Any] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) + + # user who created / owns the tool. Will be None for built-in tools. + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + + user: Mapped[User | None] = relationship("User", back_populates="custom_tools") # Relationship to Persona through the association table personas: Mapped[list["Persona"]] = relationship( "Persona", diff --git a/backend/danswer/db/tools.py b/backend/danswer/db/tools.py new file mode 100644 index 00000000000..1e75b1c4901 --- /dev/null +++ b/backend/danswer/db/tools.py @@ -0,0 +1,74 @@ +from typing import Any +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import Tool +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_tools(db_session: Session) -> list[Tool]: + return list(db_session.scalars(select(Tool)).all()) + + +def get_tool_by_id(tool_id: int, db_session: Session) -> Tool: + tool = db_session.scalar(select(Tool).where(Tool.id == tool_id)) + if not tool: + raise ValueError("Tool by specified id does not exist") + return tool + + +def create_tool( + name: str, + description: str | None, + openapi_schema: dict[str, Any] | None, + user_id: UUID | None, + db_session: Session, +) -> Tool: + new_tool = Tool( + name=name, + description=description, + in_code_tool_id=None, + openapi_schema=openapi_schema, + user_id=user_id, + ) + db_session.add(new_tool) + db_session.commit() + return new_tool + + +def update_tool( + tool_id: int, + name: str | None, + description: str | None, + openapi_schema: dict[str, Any] | None, + user_id: UUID | None, + db_session: Session, +) -> Tool: + tool = get_tool_by_id(tool_id, db_session) + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + if name is not None: + tool.name = name + if description is not None: + tool.description = description + if openapi_schema is not None: + tool.openapi_schema = openapi_schema + if user_id is not None: + tool.user_id = user_id + db_session.commit() + + return tool + + +def delete_tool(tool_id: int, db_session: Session) -> None: + tool = get_tool_by_id(tool_id, db_session) + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + db_session.delete(tool) + db_session.commit() diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index 82c027304ae..4b849f70d96 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -24,7 +24,7 @@ def load_chat_file( file_id=file_descriptor["id"], content=file_io.read(), file_type=file_descriptor["type"], - filename=file_descriptor["name"], + filename=file_descriptor.get("name"), ) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 41f8e109033..328bb7251da 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -4,6 +4,7 @@ from uuid import uuid4 from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk +from langchain_core.messages import HumanMessage from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import AnswerQuestionPossibleReturn @@ -33,6 +34,9 @@ from danswer.llm.answering.stream_processing.quotes_processing import ( from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import message_generator_to_string_generator +from danswer.tools.custom.custom_tool_prompt_builder import ( + build_user_message_for_custom_tool_for_non_tool_calling_llm, +) from danswer.tools.force import filter_tools_for_force_tool_use from danswer.tools.force import ForceUseTool from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID @@ -50,7 +54,8 @@ from danswer.tools.tool import ToolResponse from danswer.tools.tool_runner import ( check_which_tools_should_run_for_non_tool_calling_llm, ) -from danswer.tools.tool_runner import ToolRunKickoff +from danswer.tools.tool_runner import ToolCallFinalResult +from danswer.tools.tool_runner import ToolCallKickoff from danswer.tools.tool_runner import ToolRunner from danswer.tools.utils import explicit_tool_calling_supported @@ -72,7 +77,7 @@ def _get_answer_stream_processor( raise RuntimeError("Not implemented yet") -AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolRunKickoff | ToolResponse] +AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse] class Answer: @@ -125,7 +130,7 @@ class Answer: self._streamed_output: list[str] | None = None self._processed_stream: list[ - AnswerQuestionPossibleReturn | ToolResponse | ToolRunKickoff + AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff ] | None = None def _update_prompt_builder_for_search_tool( @@ -160,7 +165,7 @@ class Answer: def _raw_output_for_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) tool_call_chunk: AIMessageChunk | None = None @@ -237,16 +242,18 @@ class Answer: ), ) - if tool.name() == SearchTool.name(): + if tool.name() == SearchTool.NAME: self._update_prompt_builder_for_search_tool(prompt_builder, []) - elif tool.name() == ImageGenerationTool.name(): + elif tool.name() == ImageGenerationTool.NAME: prompt_builder.update_user_prompt( build_image_generation_user_prompt( query=self.question, ) ) - prompt = prompt_builder.build(tool_call_summary=tool_call_summary) + yield tool_runner.tool_final_result() + + prompt = prompt_builder.build(tool_call_summary=tool_call_summary) yield from message_generator_to_string_generator( self.llm.stream( prompt=prompt, @@ -258,7 +265,7 @@ class Answer: def _raw_output_for_non_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) chosen_tool_and_args: tuple[Tool, dict] | None = None @@ -324,7 +331,7 @@ class Answer: tool_runner = ToolRunner(tool, tool_args) yield tool_runner.kickoff() - if tool.name() == SearchTool.name(): + if tool.name() == SearchTool.NAME: final_context_documents = None for response in tool_runner.tool_responses(): if response.id == FINAL_CONTEXT_DOCUMENTS: @@ -337,7 +344,7 @@ class Answer: self._update_prompt_builder_for_search_tool( prompt_builder, final_context_documents ) - elif tool.name() == ImageGenerationTool.name(): + elif tool.name() == ImageGenerationTool.NAME: img_urls = [] for response in tool_runner.tool_responses(): if response.id == IMAGE_GENERATION_RESPONSE_ID: @@ -354,6 +361,18 @@ class Answer: img_urls=img_urls, ) ) + else: + prompt_builder.update_user_prompt( + HumanMessage( + content=build_user_message_for_custom_tool_for_non_tool_calling_llm( + self.question, + tool.name(), + *tool_runner.tool_responses(), + ) + ) + ) + + yield tool_runner.tool_final_result() prompt = prompt_builder.build() yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt)) @@ -374,7 +393,7 @@ class Answer: ) def _process_stream( - stream: Iterator[ToolRunKickoff | ToolResponse | str], + stream: Iterator[ToolCallKickoff | ToolResponse | str], ) -> AnswerStream: message = None @@ -387,7 +406,9 @@ class Answer: ] | None = None # processed docs to feed into the LLM for message in stream: - if isinstance(message, ToolRunKickoff): + if isinstance(message, ToolCallKickoff) or isinstance( + message, ToolCallFinalResult + ): yield message elif isinstance(message, ToolResponse): if message.id == SEARCH_RESPONSE_SUMMARY_ID: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 8757d6d5588..e8988606714 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -63,6 +63,7 @@ from danswer.server.features.folder.api import router as folder_router from danswer.server.features.persona.api import admin_router as admin_persona_router from danswer.server.features.persona.api import basic_router as persona_router from danswer.server.features.prompt.api import basic_router as prompt_router +from danswer.server.features.tool.api import admin_router as admin_tool_router from danswer.server.features.tool.api import router as tool_router from danswer.server.gpts.api import router as gpts_router from danswer.server.manage.administrative import router as admin_router @@ -277,6 +278,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, prompt_router) include_router_with_global_prefix_prepended(application, tool_router) + include_router_with_global_prefix_prepended(application, admin_tool_router) include_router_with_global_prefix_prepended(application, state_router) include_router_with_global_prefix_prepended(application, danswer_api_router) include_router_with_global_prefix_prepended(application, gpts_router) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 7c0c39544c2..8aff6676b88 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -51,7 +51,7 @@ from danswer.tools.search.search_tool import SearchResponseSummary from danswer.tools.search.search_tool import SearchTool from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID from danswer.tools.tool import ToolResponse -from danswer.tools.tool_runner import ToolRunKickoff +from danswer.tools.tool_runner import ToolCallKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -67,7 +67,7 @@ AnswerObjectIterator = Iterator[ | StreamingError | ChatMessageDetail | CitationInfo - | ToolRunKickoff + | ToolCallKickoff ] diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index 0a9666646a4..b1f57a1a924 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -1,32 +1,132 @@ +from typing import Any + from fastapi import APIRouter from fastapi import Depends +from fastapi import HTTPException from pydantic import BaseModel -from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.engine import get_session -from danswer.db.models import Tool from danswer.db.models import User - +from danswer.db.tools import create_tool +from danswer.db.tools import delete_tool +from danswer.db.tools import get_tool_by_id +from danswer.db.tools import get_tools +from danswer.db.tools import update_tool +from danswer.server.features.tool.models import ToolSnapshot +from danswer.tools.custom.openapi_parsing import MethodSpec +from danswer.tools.custom.openapi_parsing import openapi_to_method_specs +from danswer.tools.custom.openapi_parsing import validate_openapi_schema router = APIRouter(prefix="/tool") +admin_router = APIRouter(prefix="/admin/tool") -class ToolSnapshot(BaseModel): - id: int +class CustomToolCreate(BaseModel): name: str - description: str - in_code_tool_id: str | None + description: str | None + definition: dict[str, Any] - @classmethod - def from_model(cls, tool: Tool) -> "ToolSnapshot": - return cls( - id=tool.id, - name=tool.name, - description=tool.description, - in_code_tool_id=tool.in_code_tool_id, - ) + +class CustomToolUpdate(BaseModel): + name: str | None + description: str | None + definition: dict[str, Any] | None + + +def _validate_tool_definition(definition: dict[str, Any]) -> None: + try: + validate_openapi_schema(definition) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@admin_router.post("/custom") +def create_custom_tool( + tool_data: CustomToolCreate, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_admin_user), +) -> ToolSnapshot: + _validate_tool_definition(tool_data.definition) + tool = create_tool( + name=tool_data.name, + description=tool_data.description, + openapi_schema=tool_data.definition, + user_id=user.id if user else None, + db_session=db_session, + ) + return ToolSnapshot.from_model(tool) + + +@admin_router.put("/custom/{tool_id}") +def update_custom_tool( + tool_id: int, + tool_data: CustomToolUpdate, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_admin_user), +) -> ToolSnapshot: + if tool_data.definition: + _validate_tool_definition(tool_data.definition) + updated_tool = update_tool( + tool_id=tool_id, + name=tool_data.name, + description=tool_data.description, + openapi_schema=tool_data.definition, + user_id=user.id if user else None, + db_session=db_session, + ) + return ToolSnapshot.from_model(updated_tool) + + +@admin_router.delete("/custom/{tool_id}") +def delete_custom_tool( + tool_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> None: + try: + delete_tool(tool_id, db_session) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + # handles case where tool is still used by an Assistant + raise HTTPException(status_code=400, detail=str(e)) + + +class ValidateToolRequest(BaseModel): + definition: dict[str, Any] + + +class ValidateToolResponse(BaseModel): + methods: list[MethodSpec] + + +@admin_router.post("/custom/validate") +def validate_tool( + tool_data: ValidateToolRequest, + _: User | None = Depends(current_admin_user), +) -> ValidateToolResponse: + _validate_tool_definition(tool_data.definition) + method_specs = openapi_to_method_specs(tool_data.definition) + return ValidateToolResponse(methods=method_specs) + + +"""Endpoints for all""" + + +@router.get("/{tool_id}") +def get_custom_tool( + tool_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_user), +) -> ToolSnapshot: + try: + tool = get_tool_by_id(tool_id, db_session) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + return ToolSnapshot.from_model(tool) @router.get("") @@ -34,5 +134,5 @@ def list_tools( db_session: Session = Depends(get_session), _: User | None = Depends(current_user), ) -> list[ToolSnapshot]: - tools = db_session.execute(select(Tool)).scalars().all() + tools = get_tools(db_session) return [ToolSnapshot.from_model(tool) for tool in tools] diff --git a/backend/danswer/server/features/tool/models.py b/backend/danswer/server/features/tool/models.py new file mode 100644 index 00000000000..feb3ba68269 --- /dev/null +++ b/backend/danswer/server/features/tool/models.py @@ -0,0 +1,23 @@ +from typing import Any + +from pydantic import BaseModel + +from danswer.db.models import Tool + + +class ToolSnapshot(BaseModel): + id: int + name: str + description: str + definition: dict[str, Any] | None + in_code_tool_id: str | None + + @classmethod + def from_model(cls, tool: Tool) -> "ToolSnapshot": + return cls( + id=tool.id, + name=tool.name, + description=tool.description, + definition=tool.openapi_schema, + in_code_tool_id=tool.in_code_tool_id, + ) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index d20a4b11101..834453e6e2d 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -129,6 +129,8 @@ def get_chat_session( # we already did a permission check above with the call to # `get_chat_session_by_id`, so we can skip it here skip_permission_check=True, + # we need the tool call objs anyways, so just fetch them in a single call + prefetch_tool_calls=True, ) return ChatSessionDetailResponse( diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 44e8ab84624..09561bf24f8 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -17,6 +17,7 @@ from danswer.search.models import ChunkContext from danswer.search.models import RetrievalDetails from danswer.search.models import SearchDoc from danswer.search.models import Tag +from danswer.tools.models import ToolCallFinalResult class SourceTag(Tag): @@ -176,6 +177,7 @@ class ChatMessageDetail(BaseModel): # Dict mapping citation number to db_doc_id citations: dict[int, int] | None files: list[FileDescriptor] + tool_calls: list[ToolCallFinalResult] def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore initial_dict = super().dict(*args, **kwargs) # type: ignore diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py new file mode 100644 index 00000000000..ea232fc5a74 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool.py @@ -0,0 +1,233 @@ +import json +from collections.abc import Generator +from typing import Any +from typing import cast + +import requests +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage +from pydantic import BaseModel + +from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM +from danswer.tools.custom.custom_tool_prompts import ( + SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, +) +from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT +from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT +from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT +from danswer.tools.custom.custom_tool_prompts import USE_TOOL +from danswer.tools.custom.openapi_parsing import MethodSpec +from danswer.tools.custom.openapi_parsing import openapi_to_method_specs +from danswer.tools.custom.openapi_parsing import openapi_to_url +from danswer.tools.custom.openapi_parsing import REQUEST_BODY +from danswer.tools.custom.openapi_parsing import validate_openapi_schema +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response" + + +class CustomToolCallSummary(BaseModel): + tool_name: str + tool_result: dict + + +class CustomTool(Tool): + def __init__(self, method_spec: MethodSpec, base_url: str) -> None: + self._base_url = base_url + self._method_spec = method_spec + self._tool_definition = self._method_spec.to_tool_definition() + + self._name = self._method_spec.name + self.description = self._method_spec.summary + + def name(self) -> str: + return self._name + + """For LLMs which support explicit tool calling""" + + def tool_definition(self) -> dict: + return self._tool_definition + + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + response = cast(CustomToolCallSummary, args[0].response) + return json.dumps(response.tool_result) + + """For LLMs which do NOT support explicit tool calling""" + + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + if not force_run: + should_use_result = llm.invoke( + [ + SystemMessage(content=SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT), + HumanMessage( + content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format( + history=history, + query=query, + tool_name=self.name(), + tool_description=self.description, + ) + ), + ] + ) + if cast(str, should_use_result.content).strip() != USE_TOOL: + return None + + args_result = llm.invoke( + [ + SystemMessage(content=TOOL_ARG_SYSTEM_PROMPT), + HumanMessage( + content=TOOL_ARG_USER_PROMPT.format( + history=history, + query=query, + tool_name=self.name(), + tool_description=self.description, + tool_args=self.tool_definition()["function"]["parameters"], + ) + ), + ] + ) + args_result_str = cast(str, args_result.content) + + try: + return json.loads(args_result_str.strip()) + except json.JSONDecodeError: + pass + + # try removing ``` + try: + return json.loads(args_result_str.strip("```")) + except json.JSONDecodeError: + pass + + # try removing ```json + try: + return json.loads(args_result_str.strip("```").strip("json")) + except json.JSONDecodeError: + pass + + # pretend like nothing happened if not parse-able + logger.error( + f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}" + ) + return None + + """Actual execution of the tool""" + + def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + request_body = kwargs.get(REQUEST_BODY) + + path_params = {} + for path_param_schema in self._method_spec.get_path_param_schemas(): + path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]] + + query_params = {} + for query_param_schema in self._method_spec.get_query_param_schemas(): + if query_param_schema["name"] in kwargs: + query_params[query_param_schema["name"]] = kwargs[ + query_param_schema["name"] + ] + + url = self._method_spec.build_url(self._base_url, path_params, query_params) + method = self._method_spec.method + + response = requests.request(method, url, json=request_body) + + yield ToolResponse( + id=CUSTOM_TOOL_RESPONSE_ID, + response=CustomToolCallSummary( + tool_name=self._name, tool_result=response.json() + ), + ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + return cast(CustomToolCallSummary, args[0].response).tool_result + + +def build_custom_tools_from_openapi_schema( + openapi_schema: dict[str, Any] +) -> list[CustomTool]: + url = openapi_to_url(openapi_schema) + method_specs = openapi_to_method_specs(openapi_schema) + return [CustomTool(method_spec, url) for method_spec in method_specs] + + +if __name__ == "__main__": + import openai + + openapi_schema = { + "openapi": "3.0.0", + "info": { + "version": "1.0.0", + "title": "Assistants API", + "description": "An API for managing assistants", + }, + "servers": [ + {"url": "http://localhost:8080"}, + ], + "paths": { + "/assistant/{assistant_id}": { + "get": { + "summary": "Get a specific Assistant", + "operationId": "getAssistant", + "parameters": [ + { + "name": "assistant_id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + "post": { + "summary": "Create a new Assistant", + "operationId": "createAssistant", + "parameters": [ + { + "name": "assistant_id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "requestBody": { + "required": True, + "content": {"application/json": {"schema": {"type": "object"}}}, + }, + }, + } + }, + } + validate_openapi_schema(openapi_schema) + + tools = build_custom_tools_from_openapi_schema(openapi_schema) + + openai_client = openai.OpenAI() + response = openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Can you fetch assistant with ID 10"}, + ], + tools=[tool.tool_definition() for tool in tools], # type: ignore + ) + choice = response.choices[0] + if choice.message.tool_calls: + print(choice.message.tool_calls) + for tool_response in tools[0].run( + **json.loads(choice.message.tool_calls[0].function.arguments) + ): + print(tool_response) diff --git a/backend/danswer/tools/custom/custom_tool_prompt_builder.py b/backend/danswer/tools/custom/custom_tool_prompt_builder.py new file mode 100644 index 00000000000..8016363acc9 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool_prompt_builder.py @@ -0,0 +1,21 @@ +from typing import cast + +from danswer.tools.custom.custom_tool import CustomToolCallSummary +from danswer.tools.models import ToolResponse + + +def build_user_message_for_custom_tool_for_non_tool_calling_llm( + query: str, + tool_name: str, + *args: ToolResponse, +) -> str: + tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result + return f""" +Here's the result from the {tool_name} tool: + +{tool_run_summary} + +Now respond to the following: + +{query} +""".strip() diff --git a/backend/danswer/tools/custom/custom_tool_prompts.py b/backend/danswer/tools/custom/custom_tool_prompts.py new file mode 100644 index 00000000000..14e8b007ef0 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool_prompts.py @@ -0,0 +1,57 @@ +from danswer.prompts.constants import GENERAL_SEP_PAT + +DONT_USE_TOOL = "Don't use tool" +USE_TOOL = "Use tool" + + +"""Prompts to determine if we should use a custom tool or not.""" + + +SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT = ( + "You are a large language model whose only job is to determine if the system should call an " + "external tool to be able to answer the user's last message." +).strip() + +SHOULD_USE_CUSTOM_TOOL_USER_PROMPT = f""" +Given the conversation history and a follow up query, determine if the system should use the \ +'{{tool_name}}' tool to answer the user's query. The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. + +Respond with "{USE_TOOL}" if you think the tool would be helpful in respnding to the users query. +Respond with "{DONT_USE_TOOL}" otherwise. + +Conversation History: +{GENERAL_SEP_PAT} +{{history}} +{GENERAL_SEP_PAT} + +If you are at all unsure, respond with {DONT_USE_TOOL}. +Respond with EXACTLY and ONLY "{DONT_USE_TOOL}" or "{USE_TOOL}" + +Follow up input: +{{query}} +""".strip() + + +"""Prompts to figure out the arguments to pass to a custom tool.""" + + +TOOL_ARG_SYSTEM_PROMPT = ( + "You are a large language model whose only job is to determine the arguments to pass to an " + "external tool." +).strip() + + +TOOL_ARG_USER_PROMPT = f""" +Given the following conversation and a follow up input, generate a \ +dictionary of arguments to pass to the '{{tool_name}}' tool. \ +The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. \ +The expected arguments are: {{tool_args}}. + +Conversation: +{{history}} + +Follow up input: +{{query}} + +Respond with ONLY and EXACTLY a JSON object specifying the values of the arguments to pass to the tool. +""".strip() # noqa: F541 diff --git a/backend/danswer/tools/custom/openapi_parsing.py b/backend/danswer/tools/custom/openapi_parsing.py new file mode 100644 index 00000000000..40ed5544d8b --- /dev/null +++ b/backend/danswer/tools/custom/openapi_parsing.py @@ -0,0 +1,225 @@ +from typing import Any +from typing import cast + +from openai import BaseModel + +REQUEST_BODY = "requestBody" + + +class PathSpec(BaseModel): + path: str + methods: dict[str, Any] + + +class MethodSpec(BaseModel): + name: str + summary: str + path: str + method: str + spec: dict[str, Any] + + def get_request_body_schema(self) -> dict[str, Any]: + content = self.spec.get("requestBody", {}).get("content", {}) + if "application/json" in content: + return content["application/json"].get("schema") + + if content: + raise ValueError( + f"Unsupported content type: '{list(content.keys())[0]}'. " + f"Only 'application/json' is supported." + ) + + return {} + + def get_query_param_schemas(self) -> list[dict[str, Any]]: + return [ + param + for param in self.spec.get("parameters", []) + if "schema" in param and "in" in param and param["in"] == "query" + ] + + def get_path_param_schemas(self) -> list[dict[str, Any]]: + return [ + param + for param in self.spec.get("parameters", []) + if "schema" in param and "in" in param and param["in"] == "path" + ] + + def build_url( + self, base_url: str, path_params: dict[str, str], query_params: dict[str, str] + ) -> str: + url = f"{base_url}{self.path}" + try: + url = url.format(**path_params) + except KeyError as e: + raise ValueError(f"Missing path parameter: {e}") + if query_params: + url += "?" + for param, value in query_params.items(): + url += f"{param}={value}&" + url = url[:-1] + return url + + def to_tool_definition(self) -> dict[str, Any]: + tool_definition: Any = { + "type": "function", + "function": { + "name": self.name, + "description": self.summary, + "parameters": {"type": "object", "properties": {}}, + }, + } + + request_body_schema = self.get_request_body_schema() + if request_body_schema: + tool_definition["function"]["parameters"]["properties"][ + REQUEST_BODY + ] = request_body_schema + + query_param_schemas = self.get_query_param_schemas() + if query_param_schemas: + tool_definition["function"]["parameters"]["properties"].update( + {param["name"]: param["schema"] for param in query_param_schemas} + ) + + path_param_schemas = self.get_path_param_schemas() + if path_param_schemas: + tool_definition["function"]["parameters"]["properties"].update( + {param["name"]: param["schema"] for param in path_param_schemas} + ) + return tool_definition + + def validate_spec(self) -> None: + # Validate url construction + path_param_schemas = self.get_path_param_schemas() + dummy_path_dict = {param["name"]: "value" for param in path_param_schemas} + query_param_schemas = self.get_query_param_schemas() + dummy_query_dict = {param["name"]: "value" for param in query_param_schemas} + self.build_url("", dummy_path_dict, dummy_query_dict) + + # Make sure request body doesn't throw an exception + self.get_request_body_schema() + + # Ensure the method is valid + if not self.method: + raise ValueError("HTTP method is not specified.") + if self.method.upper() not in ["GET", "POST", "PUT", "DELETE", "PATCH"]: + raise ValueError(f"HTTP method '{self.method}' is not supported.") + + +"""Path-level utils""" + + +def openapi_to_path_specs(openapi_spec: dict[str, Any]) -> list[PathSpec]: + path_specs = [] + + for path, methods in openapi_spec.get("paths", {}).items(): + path_specs.append(PathSpec(path=path, methods=methods)) + + return path_specs + + +"""Method-level utils""" + + +def openapi_to_method_specs(openapi_spec: dict[str, Any]) -> list[MethodSpec]: + path_specs = openapi_to_path_specs(openapi_spec) + + method_specs = [] + for path_spec in path_specs: + for method_name, method in path_spec.methods.items(): + name = method.get("operationId") + if not name: + raise ValueError( + f"Operation ID is not specified for {method_name.upper()} {path_spec.path}" + ) + + summary = method.get("summary") or method.get("description") + if not summary: + raise ValueError( + f"Summary is not specified for {method_name.upper()} {path_spec.path}" + ) + + method_specs.append( + MethodSpec( + name=name, + summary=summary, + path=path_spec.path, + method=method_name, + spec=method, + ) + ) + + if not method_specs: + raise ValueError("No methods found in OpenAPI schema") + + return method_specs + + +def openapi_to_url(openapi_schema: dict[str, dict | str]) -> str: + """ + Extract URLs from the servers section of an OpenAPI schema. + + Args: + openapi_schema (Dict[str, Union[Dict, str, List]]): The OpenAPI schema in dictionary format. + + Returns: + List[str]: A list of base URLs. + """ + urls: list[str] = [] + + servers = cast(list[dict[str, Any]], openapi_schema.get("servers", [])) + for server in servers: + url = server.get("url") + if url: + urls.append(url) + + if len(urls) != 1: + raise ValueError( + f"Expected exactly one URL in OpenAPI schema, but found {urls}" + ) + + return urls[0] + + +def validate_openapi_schema(schema: dict[str, Any]) -> None: + """ + Validate the given JSON schema as an OpenAPI schema. + + Parameters: + - schema (dict): The JSON schema to validate. + + Returns: + - bool: True if the schema is valid, False otherwise. + """ + + # check basic structure + if "info" not in schema: + raise ValueError("`info` section is required in OpenAPI schema") + + info = schema["info"] + if "title" not in info: + raise ValueError("`title` is required in `info` section of OpenAPI schema") + if "description" not in info: + raise ValueError( + "`description` is required in `info` section of OpenAPI schema" + ) + + if "openapi" not in schema: + raise ValueError( + "`openapi` field which specifies OpenAPI schema version is required" + ) + openapi_version = schema["openapi"] + if not openapi_version.startswith("3."): + raise ValueError(f"OpenAPI version '{openapi_version}' is not supported") + + if "paths" not in schema: + raise ValueError("`paths` section is required in OpenAPI schema") + + url = openapi_to_url(schema) + if not url: + raise ValueError("OpenAPI schema does not contain a valid URL in `servers`") + + method_specs = openapi_to_method_specs(schema) + for method_spec in method_specs: + method_spec.validate_spec() diff --git a/backend/danswer/tools/factory.py b/backend/danswer/tools/factory.py deleted file mode 100644 index 197bdd6619a..00000000000 --- a/backend/danswer/tools/factory.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Type - -from sqlalchemy.orm import Session - -from danswer.db.models import Tool as ToolDBModel -from danswer.tools.built_in_tools import get_built_in_tool_by_id -from danswer.tools.tool import Tool - - -def get_tool_cls(tool: ToolDBModel, db_session: Session) -> Type[Tool]: - # Currently only support built-in tools - return get_built_in_tool_by_id(tool.id, db_session) diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index da66271322f..22aa40993b6 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from danswer.chat.chat_utils import combine_message_chain from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs @@ -53,6 +54,8 @@ class ImageGenerationResponse(BaseModel): class ImageGenerationTool(Tool): + NAME = "run_image_generation" + def __init__( self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2 ) -> None: @@ -60,16 +63,14 @@ class ImageGenerationTool(Tool): self.model = model self.num_imgs = num_imgs - @classmethod def name(self) -> str: - return "run_image_generation" + return self.NAME - @classmethod - def tool_definition(cls) -> dict: + def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": cls.name(), + "name": self.name(), "description": "Generate an image from a prompt", "parameters": { "type": "object", @@ -162,3 +163,12 @@ class ImageGenerationTool(Tool): id=IMAGE_GENERATION_RESPONSE_ID, response=results, ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + image_generation_responses = cast( + list[ImageGenerationResponse], args[0].response + ) + return [ + image_generation_response.dict() + for image_generation_response in image_generation_responses + ] diff --git a/backend/danswer/tools/models.py b/backend/danswer/tools/models.py new file mode 100644 index 00000000000..53940dcea49 --- /dev/null +++ b/backend/danswer/tools/models.py @@ -0,0 +1,39 @@ +from typing import Any + +from pydantic import BaseModel +from pydantic import root_validator + + +class ToolResponse(BaseModel): + id: str | None = None + response: Any + + +class ToolCallKickoff(BaseModel): + tool_name: str + tool_args: dict[str, Any] + + +class ToolRunnerResponse(BaseModel): + tool_run_kickoff: ToolCallKickoff | None = None + tool_response: ToolResponse | None = None + tool_message_content: str | list[str | dict[str, Any]] | None = None + + @root_validator + def validate_tool_runner_response( + cls, values: dict[str, ToolResponse | str] + ) -> dict[str, ToolResponse | str]: + fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] + provided = sum(1 for field in fields if values.get(field) is not None) + + if provided != 1: + raise ValueError( + "Exactly one of 'tool_response', 'tool_message_content', " + "or 'tool_run_kickoff' must be provided" + ) + + return values + + +class ToolCallFinalResult(ToolCallKickoff): + tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 30ec47d1664..b0b45bd8f40 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -10,6 +10,7 @@ from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import LlmDoc from danswer.db.models import Persona from danswer.db.models import User +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage @@ -55,6 +56,8 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ class SearchTool(Tool): + NAME = "run_search" + def __init__( self, db_session: Session, @@ -87,18 +90,16 @@ class SearchTool(Tool): self.bypass_acl = bypass_acl self.db_session = db_session - @classmethod - def name(cls) -> str: - return "run_search" + def name(self) -> str: + return self.NAME """For explicit tool calling""" - @classmethod - def tool_definition(cls) -> dict: + def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": cls.name(), + "name": self.name(), "description": search_tool_description, "parameters": { "type": "object", @@ -241,3 +242,13 @@ class SearchTool(Tool): document_pruning_config=self.pruning_config, ) yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=final_context_documents) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + final_docs = cast( + list[LlmDoc], + next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS), + ) + # NOTE: need to do this json.loads(doc.json()) stuff because there are some + # subfields that are not serializable by default (datetime) + # this forces pydantic to make them JSON serializable for us + return [json.loads(doc.json()) for doc in final_docs] diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index dd443757e67..e335a049838 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -2,26 +2,19 @@ import abc from collections.abc import Generator from typing import Any -from pydantic import BaseModel - +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM - - -class ToolResponse(BaseModel): - id: str | None = None - response: Any +from danswer.tools.models import ToolResponse class Tool(abc.ABC): - @classmethod @abc.abstractmethod def name(self) -> str: raise NotImplementedError """For LLMs which support explicit tool calling""" - @classmethod @abc.abstractmethod def tool_definition(self) -> dict: raise NotImplementedError @@ -49,3 +42,11 @@ class Tool(abc.ABC): @abc.abstractmethod def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: raise NotImplementedError + + @abc.abstractmethod + def final_result(self, *args: ToolResponse) -> JSON_ro: + """ + This is the "final summary" result of the tool. + It is the result that will be stored in the database. + """ + raise NotImplementedError diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index 46f247b06dc..a4367d865d5 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -1,42 +1,15 @@ from collections.abc import Generator from typing import Any -from pydantic import BaseModel -from pydantic import root_validator - from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -class ToolRunKickoff(BaseModel): - tool_name: str - tool_args: dict[str, Any] - - -class ToolRunnerResponse(BaseModel): - tool_run_kickoff: ToolRunKickoff | None = None - tool_response: ToolResponse | None = None - tool_message_content: str | list[str | dict[str, Any]] | None = None - - @root_validator - def validate_tool_runner_response( - cls, values: dict[str, ToolResponse | str] - ) -> dict[str, ToolResponse | str]: - fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] - provided = sum(1 for field in fields if values.get(field) is not None) - - if provided != 1: - raise ValueError( - "Exactly one of 'tool_response', 'tool_message_content', " - "or 'tool_run_kickoff' must be provided" - ) - - return values - - class ToolRunner: def __init__(self, tool: Tool, args: dict[str, Any]): self.tool = tool @@ -44,8 +17,8 @@ class ToolRunner: self._tool_responses: list[ToolResponse] | None = None - def kickoff(self) -> ToolRunKickoff: - return ToolRunKickoff(tool_name=self.tool.name(), tool_args=self.args) + def kickoff(self) -> ToolCallKickoff: + return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args) def tool_responses(self) -> Generator[ToolResponse, None, None]: if self._tool_responses is not None: @@ -62,6 +35,13 @@ class ToolRunner: tool_responses = list(self.tool_responses()) return self.tool.build_tool_message_content(*tool_responses) + def tool_final_result(self) -> ToolCallFinalResult: + return ToolCallFinalResult( + tool_name=self.tool.name(), + tool_args=self.args, + tool_result=self.tool.final_result(*self.tool_responses()), + ) + def check_which_tools_should_run_for_non_tool_calling_llm( tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM diff --git a/backend/danswer/tools/utils.py b/backend/danswer/tools/utils.py index 831021cdab3..7fb2156df59 100644 --- a/backend/danswer/tools/utils.py +++ b/backend/danswer/tools/utils.py @@ -1,5 +1,4 @@ import json -from typing import Type from tiktoken import Encoding @@ -17,15 +16,13 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo return False -def compute_tool_tokens( - tool: Tool | Type[Tool], llm_tokenizer: Encoding | None = None -) -> int: +def compute_tool_tokens(tool: Tool, llm_tokenizer: Encoding | None = None) -> int: if not llm_tokenizer: llm_tokenizer = get_default_llm_tokenizer() return len(llm_tokenizer.encode(json.dumps(tool.tool_definition()))) def compute_all_tool_tokens( - tools: list[Tool] | list[Type[Tool]], llm_tokenizer: Encoding | None = None + tools: list[Tool], llm_tokenizer: Encoding | None = None ) -> int: return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 14b89bc1365..2d8745f3349 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -24,6 +24,7 @@ httpx[http2]==0.23.3 httpx-oauth==0.11.2 huggingface-hub==0.20.1 jira==3.5.1 +jsonref==1.1.0 langchain==0.1.17 langchain-community==0.0.36 langchain-core==0.1.50 diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index d6427a67929..10b0ad586aa 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -147,40 +147,56 @@ export function AssistantEditor({ const imageGenerationTool = providerSupportingImageGenerationExists ? findImageGenerationTool(tools) : undefined; + const customTools = tools.filter( + (tool) => + tool.in_code_tool_id !== searchTool?.in_code_tool_id && + tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id + ); + + const availableTools = [ + ...customTools, + ...(searchTool ? [searchTool] : []), + ...(imageGenerationTool ? [imageGenerationTool] : []), + ]; + const enabledToolsMap: { [key: number]: boolean } = {}; + availableTools.forEach((tool) => { + enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id); + }); + + const initialValues = { + name: existingPersona?.name ?? "", + description: existingPersona?.description ?? "", + system_prompt: existingPrompt?.system_prompt ?? "", + task_prompt: existingPrompt?.task_prompt ?? "", + is_public: existingPersona?.is_public ?? defaultPublic, + document_set_ids: + existingPersona?.document_sets?.map((documentSet) => documentSet.id) ?? + ([] as number[]), + num_chunks: existingPersona?.num_chunks ?? null, + include_citations: existingPersona?.prompts[0]?.include_citations ?? true, + llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, + llm_model_provider_override: + existingPersona?.llm_model_provider_override ?? null, + llm_model_version_override: + existingPersona?.llm_model_version_override ?? null, + starter_messages: existingPersona?.starter_messages ?? [], + enabled_tools_map: enabledToolsMap, + // search_tool_enabled: existingPersona + // ? personaCurrentToolIds.includes(searchTool!.id) + // : ccPairs.length > 0, + // image_generation_tool_enabled: imageGenerationTool + // ? personaCurrentToolIds.includes(imageGenerationTool.id) + // : false, + // EE Only + groups: existingPersona?.groups ?? [], + }; return (
{popup} documentSet.id - ) ?? ([] as number[]), - num_chunks: existingPersona?.num_chunks ?? null, - include_citations: - existingPersona?.prompts[0]?.include_citations ?? true, - llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, - llm_model_provider_override: - existingPersona?.llm_model_provider_override ?? null, - llm_model_version_override: - existingPersona?.llm_model_version_override ?? null, - starter_messages: existingPersona?.starter_messages ?? [], - // EE Only - groups: existingPersona?.groups ?? [], - search_tool_enabled: existingPersona - ? personaCurrentToolIds.includes(searchTool!.id) - : ccPairs.length > 0, - image_generation_tool_enabled: imageGenerationTool - ? personaCurrentToolIds.includes(imageGenerationTool.id) - : false, - }} + initialValues={initialValues} validationSchema={Yup.object() .shape({ name: Yup.string().required("Must give the Assistant a name!"), @@ -205,8 +221,6 @@ export function AssistantEditor({ ), // EE Only groups: Yup.array().of(Yup.number()), - search_tool_enabled: Yup.boolean().required(), - image_generation_tool_enabled: Yup.boolean().required(), }) .test( "system-prompt-or-task-prompt", @@ -251,30 +265,36 @@ export function AssistantEditor({ formikHelpers.setSubmitting(true); - const tools = []; - if (values.search_tool_enabled && ccPairs.length > 0) { - tools.push(searchTool!.id); - } - if ( - values.image_generation_tool_enabled && - imageGenerationTool && - checkLLMSupportsImageInput( - providerDisplayNameToProviderName.get( - values.llm_model_provider_override || "" - ) || - defaultProviderName || - "", - values.llm_model_version_override || defaultModelName || "" - ) - ) { - tools.push(imageGenerationTool.id); + let enabledTools = Object.keys(values.enabled_tools_map) + .map((toolId) => Number(toolId)) + .filter((toolId) => values.enabled_tools_map[toolId]); + const searchToolEnabled = searchTool + ? enabledTools.includes(searchTool.id) + : false; + const imageGenerationToolEnabled = imageGenerationTool + ? enabledTools.includes(imageGenerationTool.id) + : false; + + if (imageGenerationToolEnabled) { + if ( + !checkLLMSupportsImageInput( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || defaultModelName || "" + ) + ) { + enabledTools = enabledTools.filter( + (toolId) => toolId !== imageGenerationTool!.id + ); + } } // if disable_retrieval is set, set num_chunks to 0 // to tell the backend to not fetch any documents - const numChunks = values.search_tool_enabled - ? values.num_chunks || 10 - : 0; + const numChunks = searchToolEnabled ? values.num_chunks || 10 : 0; // don't set groups if marked as public const groups = values.is_public ? [] : values.groups; @@ -290,7 +310,7 @@ export function AssistantEditor({ users: user && !checkUserIsNoAuthUser(user.id) ? [user.id] : undefined, groups, - tool_ids: tools, + tool_ids: enabledTools, }); } else { [promptResponse, personaResponse] = await createPersona({ @@ -299,7 +319,7 @@ export function AssistantEditor({ users: user && !checkUserIsNoAuthUser(user.id) ? [user.id] : undefined, groups, - tool_ids: tools, + tool_ids: enabledTools, }); } @@ -351,546 +371,581 @@ export function AssistantEditor({ } }} > - {({ isSubmitting, values, setFieldValue }) => ( -
-
- - <> - + {({ isSubmitting, values, setFieldValue }) => { + function toggleToolInValues(toolId: number) { + const updatedEnabledToolsMap = { + ...values.enabled_tools_map, + [toolId]: !values.enabled_tools_map[toolId], + }; + setFieldValue("enabled_tools_map", updatedEnabledToolsMap); + } - + function searchToolEnabled() { + return searchTool && values.enabled_tools_map[searchTool.id] + ? true + : false; + } - { - setFieldValue("system_prompt", e.target.value); - triggerFinalPromptUpdate( - e.target.value, - values.task_prompt, - values.search_tool_enabled - ); - }} - error={finalPromptError} - /> + return ( + +
+ + <> + - + + { + setFieldValue("system_prompt", e.target.value); + triggerFinalPromptUpdate( + e.target.value, + values.task_prompt, + searchToolEnabled() + ); + }} + error={finalPromptError} + /> + + { - setFieldValue("task_prompt", e.target.value); - triggerFinalPromptUpdate( - values.system_prompt, - e.target.value, - values.search_tool_enabled - ); - }} - error={finalPromptError} - /> + onChange={(e) => { + setFieldValue("task_prompt", e.target.value); + triggerFinalPromptUpdate( + values.system_prompt, + e.target.value, + searchToolEnabled() + ); + }} + error={finalPromptError} + /> - + - {finalPrompt ? ( -
-                      {finalPrompt}
-                    
- ) : ( - "-" - )} - -
+ {finalPrompt ? ( +
+                        {finalPrompt}
+                      
+ ) : ( + "-" + )} + + - + - - <> - {ccPairs.length > 0 && ( - <> - { - setFieldValue("num_chunks", null); - setFieldValue( - "search_tool_enabled", - e.target.checked - ); - }} - /> - - {values.search_tool_enabled && ( -
- {ccPairs.length > 0 && ( - <> - - -
- - <> - Select which{" "} - {!user || user.role === "admin" ? ( - - Document Sets - - ) : ( - "Document Sets" - )}{" "} - that this Assistant should search through. - If none are specified, the Assistant will - search through all available documents in - order to try and respond to queries. - - -
- - {documentSets.length > 0 ? ( - ( -
-
- {documentSets.map((documentSet) => { - const ind = - values.document_set_ids.indexOf( - documentSet.id - ); - let isSelected = ind !== -1; - return ( - { - if (isSelected) { - arrayHelpers.remove(ind); - } else { - arrayHelpers.push( - documentSet.id - ); - } - }} - /> - ); - })} -
-
- )} - /> - ) : ( - - No Document Sets available.{" "} - {user?.role !== "admin" && ( - <> - If this functionality would be useful, - reach out to the administrators of Danswer - for assistance. - - )} - - )} + + <> + {ccPairs.length > 0 && searchTool && ( + <> + { + setFieldValue("num_chunks", null); + toggleToolInValues(searchTool.id); + }} + /> + {searchToolEnabled() && ( +
+ {ccPairs.length > 0 && ( <> - - How many chunks should we feed into the - LLM when generating the final response? - Each chunk is ~400 words long. -
- } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if ( - value === "" || - /^[0-9]+$/.test(value) - ) { - setFieldValue("num_chunks", value); + + +
+ + <> + Select which{" "} + {!user || user.role === "admin" ? ( + + Document Sets + + ) : ( + "Document Sets" + )}{" "} + that this Assistant should search through. + If none are specified, the Assistant will + search through all available documents in + order to try and respond to queries. + + +
+ + {documentSets.length > 0 ? ( + ( +
+
+ {documentSets.map((documentSet) => { + const ind = + values.document_set_ids.indexOf( + documentSet.id + ); + let isSelected = ind !== -1; + return ( + { + if (isSelected) { + arrayHelpers.remove(ind); + } else { + arrayHelpers.push( + documentSet.id + ); + } + }} + /> + ); + })} +
+
+ )} + /> + ) : ( + + No Document Sets available.{" "} + {user?.role !== "admin" && ( + <> + If this functionality would be useful, + reach out to the administrators of + Danswer for assistance. + + )} + + )} + + <> + + How many chunks should we feed into the + LLM when generating the final response? + Each chunk is ~400 words long. +
} - }} - /> + onChange={(e) => { + const value = e.target.value; + // Allow only integer values + if ( + value === "" || + /^[0-9]+$/.test(value) + ) { + setFieldValue("num_chunks", value); + } + }} + /> - + - + - + /> + - + )} +
+ )} + + )} + + {imageGenerationTool && + checkLLMSupportsImageInput( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || + defaultModelName || + "" + ) && ( + { + toggleToolInValues(imageGenerationTool.id); + }} + /> + )} + + {customTools.length > 0 && ( + <> + {customTools.map((tool) => ( + { + toggleToolInValues(tool.id); + }} + /> + ))} + + )} + +
+ + + + {llmProviders.length > 0 && ( + <> + + <> + + Pick which LLM to use for this Assistant. If left as + Default, will use{" "} + {defaultModelName} + . +
+
+ For more information on the different LLMs, checkout + the{" "} + + OpenAI docs + + . +
+ +
+
+ LLM Provider + ({ + name: llmProvider.name, + value: llmProvider.name, + }))} + includeDefault={true} + onSelect={(selected) => { + if ( + selected !== + values.llm_model_provider_override + ) { + setFieldValue( + "llm_model_version_override", + null + ); + } + setFieldValue( + "llm_model_provider_override", + selected + ); + }} + /> +
+ + {values.llm_model_provider_override && ( +
+ Model + +
)}
+ +
+ + + + )} + + + <> +
+ + Starter Messages help guide users to use this Assistant. + They are shown to the user as clickable options when + they select this Assistant. When selected, the specified + message is sent to the LLM as the initial user message. + +
+ + + ) => ( +
+ {values.starter_messages && + values.starter_messages.length > 0 && + values.starter_messages.map((_, index) => { + return ( +
+
+
+
+ + + Shows up as the "title" for + this Starter Message. For example, + "Write an email". + + + +
+ +
+ + + A description which tells the user + what they might want to use this + Starter Message for. For example + "to a client about a new + feature" + + + +
+ +
+ + + The actual message to be sent as the + initial user message if a user selects + this starter prompt. For example, + "Write me an email to a client + about a new billing feature we just + released." + + + +
+
+
+ + arrayHelpers.remove(index) + } + /> +
+
+
+ ); + })} + + +
)} + /> + +
+ + + + {EE_ENABLED && + userGroups && + (!user || user.role === "admin") && ( + <> + + <> + + + {userGroups && + userGroups.length > 0 && + !values.is_public && ( +
+ + Select which User Groups should have access to + this Assistant. + +
+ {userGroups.map((userGroup) => { + const isSelected = values.groups.includes( + userGroup.id + ); + return ( + { + if (isSelected) { + setFieldValue( + "groups", + values.groups.filter( + (id) => id !== userGroup.id + ) + ); + } else { + setFieldValue("groups", [ + ...values.groups, + userGroup.id, + ]); + } + }} + > +
+ +
+ {userGroup.name} +
+
+
+ ); + })} +
+
+ )} + +
+ )} - {imageGenerationTool && - checkLLMSupportsImageInput( - providerDisplayNameToProviderName.get( - values.llm_model_provider_override || "" - ) || - defaultProviderName || - "", - values.llm_model_version_override || - defaultModelName || - "" - ) && ( - { - setFieldValue( - "image_generation_tool_enabled", - e.target.checked - ); - }} - /> - )} - - - - - - {llmProviders.length > 0 && ( - <> - + -
- )} - /> - - - - - - {EE_ENABLED && userGroups && (!user || user.role === "admin") && ( - <> - - <> - - - {userGroups && - userGroups.length > 0 && - !values.is_public && ( -
- - Select which User Groups should have access to - this Assistant. - -
- {userGroups.map((userGroup) => { - const isSelected = values.groups.includes( - userGroup.id - ); - return ( - { - if (isSelected) { - setFieldValue( - "groups", - values.groups.filter( - (id) => id !== userGroup.id - ) - ); - } else { - setFieldValue("groups", [ - ...values.groups, - userGroup.id, - ]); - } - }} - > -
- -
- {userGroup.name} -
-
-
- ); - })} -
-
- )} - -
- - - )} - -
- + {isUpdate ? "Update!" : "Create!"} + +
- - - )} + + ); + }} ); diff --git a/web/src/app/admin/tools/ToolEditor.tsx b/web/src/app/admin/tools/ToolEditor.tsx new file mode 100644 index 00000000000..89046d21f1f --- /dev/null +++ b/web/src/app/admin/tools/ToolEditor.tsx @@ -0,0 +1,261 @@ +"use client"; + +import { useState, useEffect, useCallback } from "react"; +import { useRouter } from "next/navigation"; +import { Formik, Form, Field, ErrorMessage } from "formik"; +import * as Yup from "yup"; +import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces"; +import { TextFormField } from "@/components/admin/connectors/Field"; +import { Button, Divider } from "@tremor/react"; +import { + createCustomTool, + updateCustomTool, + validateToolDefinition, +} from "@/lib/tools/edit"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import debounce from "lodash/debounce"; + +function parseJsonWithTrailingCommas(jsonString: string) { + // Regular expression to remove trailing commas before } or ] + let cleanedJsonString = jsonString.replace(/,\s*([}\]])/g, "$1"); + // Replace True with true, False with false, and None with null + cleanedJsonString = cleanedJsonString + .replace(/\bTrue\b/g, "true") + .replace(/\bFalse\b/g, "false") + .replace(/\bNone\b/g, "null"); + // Now parse the cleaned JSON string + return JSON.parse(cleanedJsonString); +} + +function prettifyDefinition(definition: any) { + return JSON.stringify(definition, null, 2); +} + +function ToolForm({ + existingTool, + values, + setFieldValue, + isSubmitting, + definitionErrorState, + methodSpecsState, +}: { + existingTool?: ToolSnapshot; + values: ToolFormValues; + setFieldValue: (field: string, value: string) => void; + isSubmitting: boolean; + definitionErrorState: [ + string | null, + React.Dispatch>, + ]; + methodSpecsState: [ + MethodSpec[] | null, + React.Dispatch>, + ]; +}) { + const [definitionError, setDefinitionError] = definitionErrorState; + const [methodSpecs, setMethodSpecs] = methodSpecsState; + + const debouncedValidateDefinition = useCallback( + debounce(async (definition: string) => { + try { + const parsedDefinition = parseJsonWithTrailingCommas(definition); + const response = await validateToolDefinition({ + definition: parsedDefinition, + }); + if (response.error) { + setMethodSpecs(null); + setDefinitionError(response.error); + } else { + setMethodSpecs(response.data); + setDefinitionError(null); + } + } catch (error) { + console.log(error); + setMethodSpecs(null); + setDefinitionError("Invalid JSON format"); + } + }, 300), + [] + ); + + useEffect(() => { + if (values.definition) { + debouncedValidateDefinition(values.definition); + } + }, [values.definition, debouncedValidateDefinition]); + + return ( +
+
+ + +
+ {definitionError && ( +
{definitionError}
+ )} + + + {methodSpecs && methodSpecs.length > 0 && ( +
+

Available methods

+
+ + + + + + + + + + + {methodSpecs?.map((method: MethodSpec, index: number) => ( + + + + + + + ))} + +
NameSummaryMethodPath
{method.name}{method.summary} + {method.method.toUpperCase()} + {method.path}
+
+
+ )} + + +
+ +
+ + ); +} + +interface ToolFormValues { + definition: string; +} + +const ToolSchema = Yup.object().shape({ + definition: Yup.string().required("Tool definition is required"), +}); + +export function ToolEditor({ tool }: { tool?: ToolSnapshot }) { + const router = useRouter(); + const { popup, setPopup } = usePopup(); + const [definitionError, setDefinitionError] = useState(null); + const [methodSpecs, setMethodSpecs] = useState(null); + + const prettifiedDefinition = tool?.definition + ? prettifyDefinition(tool.definition) + : ""; + + return ( +
+ {popup} + { + let definition: any; + try { + definition = parseJsonWithTrailingCommas(values.definition); + } catch (error) { + setDefinitionError("Invalid JSON in tool definition"); + return; + } + + const name = definition?.info?.title; + const description = definition?.info?.description; + const toolData = { + name: name, + description: description || "", + definition: definition, + }; + let response; + if (tool) { + response = await updateCustomTool(tool.id, toolData); + } else { + response = await createCustomTool(toolData); + } + if (response.error) { + setPopup({ + message: "Failed to create tool - " + response.error, + type: "error", + }); + return; + } + router.push(`/admin/tools?u=${Date.now()}`); + }} + > + {({ isSubmitting, values, setFieldValue }) => { + return ( + + ); + }} + +
+ ); +} diff --git a/web/src/app/admin/tools/ToolsTable.tsx b/web/src/app/admin/tools/ToolsTable.tsx new file mode 100644 index 00000000000..8017a2431a0 --- /dev/null +++ b/web/src/app/admin/tools/ToolsTable.tsx @@ -0,0 +1,107 @@ +"use client"; + +import { + Text, + Table, + TableHead, + TableRow, + TableHeaderCell, + TableBody, + TableCell, +} from "@tremor/react"; +import { ToolSnapshot } from "@/lib/tools/interfaces"; +import { useRouter } from "next/navigation"; +import { usePopup } from "@/components/admin/connectors/Popup"; +import { FiCheckCircle, FiEdit, FiXCircle } from "react-icons/fi"; +import { TrashIcon } from "@/components/icons/icons"; +import { deleteCustomTool } from "@/lib/tools/edit"; + +export function ToolsTable({ tools }: { tools: ToolSnapshot[] }) { + const router = useRouter(); + const { popup, setPopup } = usePopup(); + + const sortedTools = [...tools]; + sortedTools.sort((a, b) => a.id - b.id); + + return ( +
+ {popup} + + + + + Name + Description + Built In? + Delete + + + + {sortedTools.map((tool) => ( + + +
+ {tool.in_code_tool_id === null && ( + + router.push( + `/admin/tools/edit/${tool.id}?u=${Date.now()}` + ) + } + /> + )} +

+ {tool.name} +

+
+
+ + {tool.description} + + + {tool.in_code_tool_id === null ? ( + + + No + + ) : ( + + + Yes + + )} + + +
+ {tool.in_code_tool_id === null ? ( +
+
{ + const response = await deleteCustomTool(tool.id); + if (response.data) { + router.refresh(); + } else { + setPopup({ + message: `Failed to delete tool - ${response.error}`, + type: "error", + }); + } + }} + > + +
+
+ ) : ( + "-" + )} +
+
+
+ ))} +
+
+
+ ); +} diff --git a/web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx b/web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx new file mode 100644 index 00000000000..c02e141b54a --- /dev/null +++ b/web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx @@ -0,0 +1,28 @@ +"use client"; + +import { Button } from "@tremor/react"; +import { FiTrash } from "react-icons/fi"; +import { deleteCustomTool } from "@/lib/tools/edit"; +import { useRouter } from "next/navigation"; + +export function DeleteToolButton({ toolId }: { toolId: number }) { + const router = useRouter(); + + return ( + + ); +} diff --git a/web/src/app/admin/tools/edit/[toolId]/page.tsx b/web/src/app/admin/tools/edit/[toolId]/page.tsx new file mode 100644 index 00000000000..8dd54be46b3 --- /dev/null +++ b/web/src/app/admin/tools/edit/[toolId]/page.tsx @@ -0,0 +1,55 @@ +import { ErrorCallout } from "@/components/ErrorCallout"; +import { Card, Text, Title } from "@tremor/react"; +import { ToolEditor } from "@/app/admin/tools/ToolEditor"; +import { fetchToolByIdSS } from "@/lib/tools/fetchTools"; +import { DeleteToolButton } from "./DeleteToolButton"; +import { FiTool } from "react-icons/fi"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { BackButton } from "@/components/BackButton"; + +export default async function Page({ params }: { params: { toolId: string } }) { + const tool = await fetchToolByIdSS(params.toolId); + + let body; + if (!tool) { + body = ( +
+ +
+ ); + } else { + body = ( +
+
+
+ + + + + Delete Tool + Click the button below to permanently delete this tool. +
+ +
+
+
+
+ ); + } + + return ( +
+ + + } + /> + + {body} +
+ ); +} diff --git a/web/src/app/admin/tools/new/page.tsx b/web/src/app/admin/tools/new/page.tsx new file mode 100644 index 00000000000..5d1723f96ac --- /dev/null +++ b/web/src/app/admin/tools/new/page.tsx @@ -0,0 +1,24 @@ +"use client"; + +import { ToolEditor } from "@/app/admin/tools/ToolEditor"; +import { BackButton } from "@/components/BackButton"; +import { AdminPageTitle } from "@/components/admin/Title"; +import { Card } from "@tremor/react"; +import { FiTool } from "react-icons/fi"; + +export default function NewToolPage() { + return ( +
+ + + } + /> + + + + +
+ ); +} diff --git a/web/src/app/admin/tools/page.tsx b/web/src/app/admin/tools/page.tsx new file mode 100644 index 00000000000..7b9edf7abe0 --- /dev/null +++ b/web/src/app/admin/tools/page.tsx @@ -0,0 +1,68 @@ +import { ToolsTable } from "./ToolsTable"; +import { ToolSnapshot } from "@/lib/tools/interfaces"; +import { FiPlusSquare, FiTool } from "react-icons/fi"; +import Link from "next/link"; +import { Divider, Text, Title } from "@tremor/react"; +import { fetchSS } from "@/lib/utilsSS"; +import { ErrorCallout } from "@/components/ErrorCallout"; +import { AdminPageTitle } from "@/components/admin/Title"; + +export default async function Page() { + const toolResponse = await fetchSS("/tool"); + + if (!toolResponse.ok) { + return ( + + ); + } + + const tools = (await toolResponse.json()) as ToolSnapshot[]; + + return ( +
+ } + title="Tools" + /> + + + Tools allow assistants to retrieve information or take actions. + + +
+ + + Create a Tool + +
+ + New Tool +
+ + + + + Existing Tools + +
+
+ ); +} diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 047b115c2d3..92e4384f912 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -12,7 +12,8 @@ import { Message, RetrievalType, StreamingError, - ToolRunKickoff, + ToolCallFinalResult, + ToolCallMetadata, } from "./interfaces"; import { ChatSidebar } from "./sessionSidebar/ChatSidebar"; import { Persona } from "../admin/assistants/interfaces"; @@ -265,6 +266,7 @@ export function ChatPage({ message: "", type: "system", files: [], + toolCalls: [], parentMessageId: null, childrenMessageIds: [firstMessageId], latestChildMessageId: firstMessageId, @@ -307,7 +309,6 @@ export function ChatPage({ return newCompleteMessageMap; }; const messageHistory = buildLatestMessageChain(completeMessageMap); - const [currentTool, setCurrentTool] = useState(null); const [isStreaming, setIsStreaming] = useState(false); // uploaded files @@ -535,6 +536,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, + toolCalls: [], parentMessageId: parentMessage?.messageId || null, }, ]; @@ -569,6 +571,7 @@ export function ChatPage({ let aiMessageImages: FileDescriptor[] | null = null; let error: string | null = null; let finalMessage: BackendMessage | null = null; + let toolCalls: ToolCallMetadata[] = []; try { const lastSuccessfulMessageId = getLastSuccessfulMessageId(currMessageHistory); @@ -627,7 +630,13 @@ export function ChatPage({ } ); } else if (Object.hasOwn(packet, "tool_name")) { - setCurrentTool((packet as ToolRunKickoff).tool_name); + toolCalls = [ + { + tool_name: (packet as ToolCallMetadata).tool_name, + tool_args: (packet as ToolCallMetadata).tool_args, + tool_result: (packet as ToolCallMetadata).tool_result, + }, + ]; } else if (Object.hasOwn(packet, "error")) { error = (packet as StreamingError).error; } else if (Object.hasOwn(packet, "message_id")) { @@ -657,6 +666,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, + toolCalls: [], parentMessageId: parentMessage?.messageId || null, childrenMessageIds: [newAssistantMessageId], latestChildMessageId: newAssistantMessageId, @@ -670,6 +680,7 @@ export function ChatPage({ documents: finalMessage?.context_docs?.top_documents || documents, citations: finalMessage?.citations || {}, files: finalMessage?.files || aiMessageImages || [], + toolCalls: finalMessage?.tool_calls || toolCalls, parentMessageId: newUserMessageId, }, ]); @@ -687,6 +698,7 @@ export function ChatPage({ message: currMessage, type: "user", files: currentMessageFiles, + toolCalls: [], parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID, }, { @@ -694,6 +706,7 @@ export function ChatPage({ message: errorMsg, type: "error", files: aiMessageImages || [], + toolCalls: [], parentMessageId: TEMP_USER_MESSAGE_ID, }, ], @@ -1031,7 +1044,7 @@ export function ChatPage({ citedDocuments={getCitedDocumentsFromMessage( message )} - currentTool={currentTool} + toolCall={message.toolCalls[0]} isComplete={ i !== messageHistory.length - 1 || !isStreaming diff --git a/web/src/app/chat/interfaces.ts b/web/src/app/chat/interfaces.ts index 9f3d647340e..52121ed264b 100644 --- a/web/src/app/chat/interfaces.ts +++ b/web/src/app/chat/interfaces.ts @@ -34,6 +34,18 @@ export interface FileDescriptor { isUploading?: boolean; } +export interface ToolCallMetadata { + tool_name: string; + tool_args: Record; + tool_result?: Record; +} + +export interface ToolCallFinalResult { + tool_name: string; + tool_args: Record; + tool_result: Record; +} + export interface ChatSession { id: number; name: string; @@ -52,6 +64,7 @@ export interface Message { documents?: DanswerDocument[] | null; citations?: CitationMap; files: FileDescriptor[]; + toolCalls: ToolCallMetadata[]; // for rebuilding the message tree parentMessageId: number | null; childrenMessageIds?: number[]; @@ -79,6 +92,7 @@ export interface BackendMessage { time_sent: string; citations: CitationMap; files: FileDescriptor[]; + tool_calls: ToolCallFinalResult[]; } export interface DocumentsResponse { @@ -90,11 +104,6 @@ export interface ImageGenerationDisplay { file_ids: string[]; } -export interface ToolRunKickoff { - tool_name: string; - tool_args: Record; -} - export interface StreamingError { error: string; } diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 9d65e64ce41..b17413d9304 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -15,7 +15,7 @@ import { Message, RetrievalType, StreamingError, - ToolRunKickoff, + ToolCallMetadata, } from "./interfaces"; import { Persona } from "../admin/assistants/interfaces"; import { ReadonlyURLSearchParams } from "next/navigation"; @@ -138,7 +138,7 @@ export async function* sendMessage({ | DocumentsResponse | BackendMessage | ImageGenerationDisplay - | ToolRunKickoff + | ToolCallMetadata | StreamingError >(sendMessageResponse); } @@ -384,6 +384,7 @@ export function processRawChatHistory( citations: messageInfo?.citations || {}, } : {}), + toolCalls: messageInfo.tool_calls, parentMessageId: messageInfo.parent_message, childrenMessageIds: [], latestChildMessageId: messageInfo.latest_child_message, diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index b188c128ea7..3a7b9d6d8b4 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -9,6 +9,7 @@ import { FiEdit2, FiChevronRight, FiChevronLeft, + FiTool, } from "react-icons/fi"; import { FeedbackType } from "../types"; import { useEffect, useRef, useState } from "react"; @@ -20,9 +21,12 @@ import { ThreeDots } from "react-loader-spinner"; import { SkippedSearch } from "./SkippedSearch"; import remarkGfm from "remark-gfm"; import { CopyButton } from "@/components/CopyButton"; -import { ChatFileType, FileDescriptor } from "../interfaces"; -import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants"; -import { ToolRunningAnimation } from "../tools/ToolRunningAnimation"; +import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces"; +import { + IMAGE_GENERATION_TOOL_NAME, + SEARCH_TOOL_NAME, +} from "../tools/constants"; +import { ToolRunDisplay } from "../tools/ToolRunningAnimation"; import { Hoverable } from "@/components/Hoverable"; import { DocumentPreview } from "../files/documents/DocumentPreview"; import { InMessageImage } from "../files/images/InMessageImage"; @@ -35,6 +39,11 @@ import Prism from "prismjs"; import "prismjs/themes/prism-tomorrow.css"; import "./custom-code-styles.css"; +const TOOLS_WITH_CUSTOM_HANDLING = [ + SEARCH_TOOL_NAME, + IMAGE_GENERATION_TOOL_NAME, +]; + function FileDisplay({ files }: { files: FileDescriptor[] }) { const imageFiles = files.filter((file) => file.type === ChatFileType.IMAGE); const nonImgFiles = files.filter((file) => file.type !== ChatFileType.IMAGE); @@ -77,7 +86,7 @@ export const AIMessage = ({ query, personaName, citedDocuments, - currentTool, + toolCall, isComplete, hasDocs, handleFeedback, @@ -93,7 +102,7 @@ export const AIMessage = ({ query?: string; personaName?: string; citedDocuments?: [string, DanswerDocument][] | null; - currentTool?: string | null; + toolCall?: ToolCallMetadata; isComplete?: boolean; hasDocs?: boolean; handleFeedback?: (feedbackType: FeedbackType) => void; @@ -133,28 +142,23 @@ export const AIMessage = ({ content = trimIncompleteCodeSection(content); } - const loader = - currentTool === IMAGE_GENERATION_TOOL_NAME ? ( -
- } - /> -
- ) : ( -
- -
- ); + const shouldShowLoader = + !toolCall || + (toolCall.tool_name === SEARCH_TOOL_NAME && query === undefined); + const defaultLoader = shouldShowLoader ? ( +
+ +
+ ) : undefined; return (
@@ -189,28 +193,61 @@ export const AIMessage = ({
- {query !== undefined && - handleShowRetrieved !== undefined && - isCurrentlyShowingRetrieved !== undefined && - !retrievalDisabled && ( -
- + {query !== undefined && + handleShowRetrieved !== undefined && + isCurrentlyShowingRetrieved !== undefined && + !retrievalDisabled && ( +
+ +
+ )} + {handleForceSearch && + content && + query === undefined && + !hasDocs && + !retrievalDisabled && ( +
+ +
+ )} + + )} + + {toolCall && + !TOOLS_WITH_CUSTOM_HANDLING.includes(toolCall.tool_name) && ( +
+ } + isRunning={!toolCall.tool_result || !content} />
)} - {handleForceSearch && - content && - query === undefined && - !hasDocs && - !retrievalDisabled && ( -
- + + {toolCall && + toolCall.tool_name === IMAGE_GENERATION_TOOL_NAME && + !toolCall.tool_result && ( +
+ } + isRunning={!toolCall.tool_result} + />
)} @@ -260,7 +297,7 @@ export const AIMessage = ({ )} ) : isComplete ? null : ( - loader + defaultLoader )} {citedDocuments && citedDocuments.length > 0 && (
diff --git a/web/src/app/chat/tools/ToolRunningAnimation.tsx b/web/src/app/chat/tools/ToolRunningAnimation.tsx index bd0414295fe..139c9e92151 100644 --- a/web/src/app/chat/tools/ToolRunningAnimation.tsx +++ b/web/src/app/chat/tools/ToolRunningAnimation.tsx @@ -1,16 +1,18 @@ import { LoadingAnimation } from "@/components/Loading"; -export function ToolRunningAnimation({ +export function ToolRunDisplay({ toolName, toolLogo, + isRunning, }: { toolName: string; - toolLogo: JSX.Element; + toolLogo?: JSX.Element; + isRunning: boolean; }) { return ( -
+
{toolLogo} - + {isRunning ? : toolName}
); } diff --git a/web/src/components/admin/Layout.tsx b/web/src/components/admin/Layout.tsx index e403b00a168..1411d6c14b6 100644 --- a/web/src/components/admin/Layout.tsx +++ b/web/src/components/admin/Layout.tsx @@ -2,15 +2,12 @@ import { Header } from "@/components/header/Header"; import { AdminSidebar } from "@/components/admin/connectors/AdminSidebar"; import { NotebookIcon, - KeyIcon, UsersIcon, ThumbsUpIcon, BookmarkIcon, - CPUIcon, ZoomInIcon, RobotIcon, ConnectorIcon, - SlackIcon, } from "@/components/icons/icons"; import { User } from "@/lib/types"; import { @@ -19,13 +16,7 @@ import { getCurrentUserSS, } from "@/lib/userSS"; import { redirect } from "next/navigation"; -import { - FiCpu, - FiLayers, - FiPackage, - FiSettings, - FiSlack, -} from "react-icons/fi"; +import { FiCpu, FiPackage, FiSettings, FiSlack, FiTool } from "react-icons/fi"; export async function Layout({ children }: { children: React.ReactNode }) { const tasks = [getAuthTypeMetadataSS(), getCurrentUserSS()]; @@ -142,6 +133,15 @@ export async function Layout({ children }: { children: React.ReactNode }) { ), link: "/admin/bot", }, + { + name: ( +
+ +
Tools
+
+ ), + link: "/admin/tools", + }, ], }, { diff --git a/web/src/components/admin/connectors/Field.tsx b/web/src/components/admin/connectors/Field.tsx index 88e482d77e8..563a5263396 100644 --- a/web/src/components/admin/connectors/Field.tsx +++ b/web/src/components/admin/connectors/Field.tsx @@ -43,6 +43,10 @@ export function TextFormField({ disabled = false, autoCompleteDisabled = true, error, + defaultHeight, + isCode = false, + fontSize, + hideError, }: { name: string; label: string; @@ -54,7 +58,16 @@ export function TextFormField({ disabled?: boolean; autoCompleteDisabled?: boolean; error?: string; + defaultHeight?: string; + isCode?: boolean; + fontSize?: "text-sm" | "text-base" | "text-lg"; + hideError?: boolean; }) { + let heightString = defaultHeight || ""; + if (isTextArea && !heightString) { + heightString = "h-28"; + } + return (
@@ -64,18 +77,19 @@ export function TextFormField({ type={type} name={name} id={name} - className={ - ` - border - border-border - rounded - w-full - py-2 - px-3 - mt-1 - ${isTextArea ? " h-28" : ""} - ` + (disabled ? " bg-background-strong" : " bg-background-emphasis") - } + className={` + border + border-border + rounded + w-full + py-2 + px-3 + mt-1 + ${heightString} + ${fontSize} + ${disabled ? " bg-background-strong" : " bg-background-emphasis"} + ${isCode ? " font-mono" : ""} + `} disabled={disabled} placeholder={placeholder} autoComplete={autoCompleteDisabled ? "off" : undefined} @@ -84,11 +98,13 @@ export function TextFormField({ {error ? ( {error} ) : ( - + !hideError && ( + + ) )}
); diff --git a/web/src/lib/tools/edit.ts b/web/src/lib/tools/edit.ts new file mode 100644 index 00000000000..841870a93e7 --- /dev/null +++ b/web/src/lib/tools/edit.ts @@ -0,0 +1,111 @@ +import { MethodSpec, ToolSnapshot } from "./interfaces"; + +interface ApiResponse { + data: T | null; + error: string | null; +} + +export async function createCustomTool(toolData: { + name: string; + description?: string; + definition: Record; +}): Promise> { + try { + const response = await fetch("/api/admin/tool/custom", { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(toolData), + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: null, error: `Failed to create tool: ${errorDetail}` }; + } + + const tool: ToolSnapshot = await response.json(); + return { data: tool, error: null }; + } catch (error) { + console.error("Error creating tool:", error); + return { data: null, error: "Error creating tool" }; + } +} + +export async function updateCustomTool( + toolId: number, + toolData: { + name?: string; + description?: string; + definition?: Record; + } +): Promise> { + try { + const response = await fetch(`/api/admin/tool/custom/${toolId}`, { + method: "PUT", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(toolData), + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: null, error: `Failed to update tool: ${errorDetail}` }; + } + + const updatedTool: ToolSnapshot = await response.json(); + return { data: updatedTool, error: null }; + } catch (error) { + console.error("Error updating tool:", error); + return { data: null, error: "Error updating tool" }; + } +} + +export async function deleteCustomTool( + toolId: number +): Promise> { + try { + const response = await fetch(`/api/admin/tool/custom/${toolId}`, { + method: "DELETE", + headers: { + "Content-Type": "application/json", + }, + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: false, error: `Failed to delete tool: ${errorDetail}` }; + } + + return { data: true, error: null }; + } catch (error) { + console.error("Error deleting tool:", error); + return { data: false, error: "Error deleting tool" }; + } +} + +export async function validateToolDefinition(toolData: { + definition: Record; +}): Promise> { + try { + const response = await fetch(`/api/admin/tool/custom/validate`, { + method: "POST", + headers: { + "Content-Type": "application/json", + }, + body: JSON.stringify(toolData), + }); + + if (!response.ok) { + const errorDetail = (await response.json()).detail; + return { data: null, error: errorDetail }; + } + + const responseJson = await response.json(); + return { data: responseJson.methods, error: null }; + } catch (error) { + console.error("Error validating tool:", error); + return { data: null, error: "Unexpected error validating tool definition" }; + } +} diff --git a/web/src/lib/tools/fetchTools.ts b/web/src/lib/tools/fetchTools.ts index 51969c6db7f..3ea6cd31f73 100644 --- a/web/src/lib/tools/fetchTools.ts +++ b/web/src/lib/tools/fetchTools.ts @@ -14,3 +14,21 @@ export async function fetchToolsSS(): Promise { return null; } } + +export async function fetchToolByIdSS( + toolId: string +): Promise { + try { + const response = await fetchSS(`/tool/${toolId}`); + if (!response.ok) { + throw new Error( + `Failed to fetch tool with ID ${toolId}: ${await response.text()}` + ); + } + const tool: ToolSnapshot = await response.json(); + return tool; + } catch (error) { + console.error(`Error fetching tool with ID ${toolId}:`, error); + return null; + } +} diff --git a/web/src/lib/tools/interfaces.ts b/web/src/lib/tools/interfaces.ts index f8882e6bfdb..bcb5df50a2a 100644 --- a/web/src/lib/tools/interfaces.ts +++ b/web/src/lib/tools/interfaces.ts @@ -2,5 +2,21 @@ export interface ToolSnapshot { id: number; name: string; description: string; + + // only specified for Custom Tools. OpenAPI schema which represents + // the tool's API. + definition: Record | null; + + // only specified for Custom Tools. ID of the tool in the codebase. in_code_tool_id: string | null; } + +export interface MethodSpec { + /* Defines a single method that is part of a custom tool. Each method maps to a single + action that the LLM can choose to take. */ + name: string; + summary: string; + path: string; + method: string; + spec: Record; +}