diff --git a/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py new file mode 100644 index 00000000000..514389ac216 --- /dev/null +++ b/backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py @@ -0,0 +1,61 @@ +"""Add support for custom tools + +Revision ID: 48d14957fe80 +Revises: b85f02ec1308 +Create Date: 2024-06-09 14:58:19.946509 + +""" +from alembic import op +import fastapi_users_db_sqlalchemy +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "48d14957fe80" +down_revision = "b85f02ec1308" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "tool", + sa.Column( + "openapi_schema", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "tool", + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + ) + op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"]) + + op.create_table( + "tool_call", + sa.Column("id", sa.Integer(), primary_key=True), + sa.Column("tool_id", sa.Integer(), nullable=False), + sa.Column("tool_name", sa.String(), nullable=False), + sa.Column( + "tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column( + "tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False + ), + sa.Column( + "message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False + ), + ) + + +def downgrade() -> None: + op.drop_table("tool_call") + + op.drop_constraint("tool_user_fk", "tool", type_="foreignkey") + op.drop_column("tool", "user_id") + op.drop_column("tool", "openapi_schema") diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index 8fa5eecaeeb..7fc526a5cbe 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -106,12 +106,18 @@ class ImageGenerationDisplay(BaseModel): file_ids: list[str] +class CustomToolResponse(BaseModel): + response: dict + tool_name: str + + AnswerQuestionPossibleReturn = ( DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | ImageGenerationDisplay + | CustomToolResponse | StreamingError ) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 72381542176..7733bf523d6 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -7,6 +7,7 @@ from sqlalchemy.orm import Session from danswer.chat.chat_utils import create_chat_chain from danswer.chat.models import CitationInfo +from danswer.chat.models import CustomToolResponse from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LlmDoc @@ -31,6 +32,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import SearchDoc as DbSearchDoc +from danswer.db.models import ToolCall from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.file_store.models import ChatFileType @@ -54,7 +56,10 @@ from danswer.search.utils import drop_llm_indices from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line -from danswer.tools.factory import get_tool_cls +from danswer.tools.built_in_tools import get_built_in_tool_by_id +from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema +from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID +from danswer.tools.custom.custom_tool import CustomToolCallSummary from danswer.tools.force import ForceUseTool from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID from danswer.tools.images.image_generation_tool import ImageGenerationResponse @@ -65,6 +70,7 @@ from danswer.tools.search.search_tool import SearchTool from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse +from danswer.tools.tool_runner import ToolCallFinalResult from danswer.tools.utils import compute_all_tool_tokens from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger @@ -162,7 +168,7 @@ def _check_should_force_search( args = {"query": new_msg_req.message} return ForceUseTool( - tool_name=SearchTool.name(), + tool_name=SearchTool.NAME, args=args, ) return None @@ -176,6 +182,7 @@ ChatPacket = ( | DanswerAnswerPiece | CitationInfo | ImageGenerationDisplay + | CustomToolResponse ) ChatPacketStream = Iterator[ChatPacket] @@ -389,61 +396,78 @@ def stream_chat_message_objects( ), ) - persona_tool_classes = [ - get_tool_cls(tool, db_session) for tool in persona.tools - ] + # find out what tools to use + search_tool: SearchTool | None = None + tool_dict: dict[int, list[Tool]] = {} # tool_id to tool + for db_tool_model in persona.tools: + # handle in-code tools specially + if db_tool_model.in_code_tool_id: + tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session) + if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files: + search_tool = SearchTool( + db_session=db_session, + user=user, + persona=persona, + retrieval_options=retrieval_options, + prompt_config=prompt_config, + llm=llm, + pruning_config=document_pruning_config, + selected_docs=selected_llm_docs, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + ) + tool_dict[db_tool_model.id] = [search_tool] + elif tool_cls.__name__ == ImageGenerationTool.__name__: + dalle_key = None + if ( + llm + and llm.config.api_key + and llm.config.model_provider == "openai" + ): + dalle_key = llm.config.api_key + else: + llm_providers = fetch_existing_llm_providers(db_session) + openai_provider = next( + iter( + [ + llm_provider + for llm_provider in llm_providers + if llm_provider.provider == "openai" + ] + ), + None, + ) + if not openai_provider or not openai_provider.api_key: + raise ValueError( + "Image generation tool requires an OpenAI API key" + ) + dalle_key = openai_provider.api_key + tool_dict[db_tool_model.id] = [ + ImageGenerationTool(api_key=dalle_key) + ] + + continue + + # handle all custom tools + if db_tool_model.openapi_schema: + tool_dict[db_tool_model.id] = cast( + list[Tool], + build_custom_tools_from_openapi_schema( + db_tool_model.openapi_schema + ), + ) + + tools: list[Tool] = [] + for tool_list in tool_dict.values(): + tools.extend(tool_list) # factor in tool definition size when pruning - document_pruning_config.tool_num_tokens = compute_all_tool_tokens( - persona_tool_classes - ) + document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools) document_pruning_config.using_tool_message = explicit_tool_calling_supported( llm.config.model_provider, llm.config.model_name ) - # NOTE: for now, only support SearchTool and ImageGenerationTool - # in the future, will support arbitrary user-defined tools - search_tool: SearchTool | None = None - tools: list[Tool] = [] - for tool_cls in persona_tool_classes: - if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files: - search_tool = SearchTool( - db_session=db_session, - user=user, - persona=persona, - retrieval_options=retrieval_options, - prompt_config=prompt_config, - llm=llm, - pruning_config=document_pruning_config, - selected_docs=selected_llm_docs, - chunks_above=new_msg_req.chunks_above, - chunks_below=new_msg_req.chunks_below, - full_doc=new_msg_req.full_doc, - ) - tools.append(search_tool) - elif tool_cls.__name__ == ImageGenerationTool.__name__: - dalle_key = None - if llm and llm.config.api_key and llm.config.model_provider == "openai": - dalle_key = llm.config.api_key - else: - llm_providers = fetch_existing_llm_providers(db_session) - openai_provider = next( - iter( - [ - llm_provider - for llm_provider in llm_providers - if llm_provider.provider == "openai" - ] - ), - None, - ) - if not openai_provider or not openai_provider.api_key: - raise ValueError( - "Image generation tool requires an OpenAI API key" - ) - dalle_key = openai_provider.api_key - tools.append(ImageGenerationTool(api_key=dalle_key)) - # LLM prompt building, response capturing, etc. answer = Answer( question=final_msg.message, @@ -468,7 +492,9 @@ def stream_chat_message_objects( ], tools=tools, force_use_tool=( - _check_should_force_search(new_msg_req) if search_tool else None + _check_should_force_search(new_msg_req) + if search_tool and len(tools) == 1 + else None ), ) @@ -476,6 +502,7 @@ def stream_chat_message_objects( qa_docs_response = None ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images dropped_indices = None + tool_result = None for packet in answer.processed_streamed_output: if isinstance(packet, ToolResponse): if packet.id == SEARCH_RESPONSE_SUMMARY_ID: @@ -521,8 +548,16 @@ def stream_chat_message_objects( yield ImageGenerationDisplay( file_ids=[str(file_id) for file_id in file_ids] ) + elif packet.id == CUSTOM_TOOL_RESPONSE_ID: + custom_tool_response = cast(CustomToolCallSummary, packet.response) + yield CustomToolResponse( + response=custom_tool_response.tool_result, + tool_name=custom_tool_response.tool_name, + ) else: + if isinstance(packet, ToolCallFinalResult): + tool_result = packet yield cast(ChatPacket, packet) except Exception as e: @@ -551,6 +586,11 @@ def stream_chat_message_objects( ) # Saving Gen AI answer and responding with message info + tool_name_to_tool_id: dict[str, int] = {} + for tool_id, tool_list in tool_dict.items(): + for tool in tool_list: + tool_name_to_tool_id[tool.name()] = tool_id + gen_ai_response_message = partial_response( message=answer.llm_answer, rephrased_query=( @@ -561,6 +601,16 @@ def stream_chat_message_objects( token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, error=None, + tool_calls=[ + ToolCall( + tool_id=tool_name_to_tool_id[tool_result.tool_name], + tool_name=tool_result.tool_name, + tool_arguments=tool_result.tool_args, + tool_result=tool_result.tool_result, + ) + ] + if tool_result + else [], ) db_session.commit() # actually save user / assistant message diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 58779288022..5ee9dfb3f51 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -5,6 +5,7 @@ from sqlalchemy import nullsfirst from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.exc import MultipleResultsFound +from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from danswer.auth.schemas import UserRole @@ -16,6 +17,7 @@ from danswer.db.models import ChatSessionSharedStatus from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc +from danswer.db.models import ToolCall from danswer.db.models import User from danswer.file_store.models import FileDescriptor from danswer.llm.override_models import LLMOverride @@ -24,6 +26,7 @@ from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc from danswer.search.models import SearchDoc as ServerSearchDoc from danswer.server.query_and_chat.models import ChatMessageDetail +from danswer.tools.tool_runner import ToolCallFinalResult from danswer.utils.logger import setup_logger logger = setup_logger() @@ -185,6 +188,7 @@ def get_chat_messages_by_session( user_id: UUID | None, db_session: Session, skip_permission_check: bool = False, + prefetch_tool_calls: bool = False, ) -> list[ChatMessage]: if not skip_permission_check: get_chat_session_by_id( @@ -192,12 +196,18 @@ def get_chat_messages_by_session( ) stmt = ( - select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id) - # Start with the root message which has no parent + select(ChatMessage) + .where(ChatMessage.chat_session_id == chat_session_id) .order_by(nullsfirst(ChatMessage.parent_message)) ) - result = db_session.execute(stmt).scalars().all() + if prefetch_tool_calls: + stmt = stmt.options(joinedload(ChatMessage.tool_calls)) + + if prefetch_tool_calls: + result = db_session.scalars(stmt).unique().all() + else: + result = db_session.scalars(stmt).all() return list(result) @@ -251,6 +261,7 @@ def create_new_chat_message( reference_docs: list[DBSearchDoc] | None = None, # Maps the citation number [n] to the DB SearchDoc citations: dict[int, int] | None = None, + tool_calls: list[ToolCall] | None = None, commit: bool = True, ) -> ChatMessage: new_chat_message = ChatMessage( @@ -264,6 +275,7 @@ def create_new_chat_message( message_type=message_type, citations=citations, files=files, + tool_calls=tool_calls if tool_calls else [], error=error, ) @@ -459,6 +471,14 @@ def translate_db_message_to_chat_message_detail( time_sent=chat_message.time_sent, citations=chat_message.citations, files=chat_message.files or [], + tool_calls=[ + ToolCallFinalResult( + tool_name=tool_call.tool_name, + tool_args=tool_call.tool_arguments, + tool_result=tool_call.tool_result, + ) + for tool_call in chat_message.tool_calls + ], ) return chat_msg_detail diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 3b4b67f0947..661eb04a8c8 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -133,6 +133,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base): prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user") # Personas owned by this user personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") + # Custom tools created by this user + custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user") class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): @@ -330,7 +332,6 @@ class Document(Base): primary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) - # Something like assignee or space owner secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) @@ -618,6 +619,26 @@ class SearchDoc(Base): ) +class ToolCall(Base): + """Represents a single tool call""" + + __tablename__ = "tool_call" + + id: Mapped[int] = mapped_column(primary_key=True) + # not a FK because we want to be able to delete the tool without deleting + # this entry + tool_id: Mapped[int] = mapped_column(Integer()) + tool_name: Mapped[str] = mapped_column(String()) + tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB()) + tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB()) + + message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id")) + + message: Mapped["ChatMessage"] = relationship( + "ChatMessage", back_populates="tool_calls" + ) + + class ChatSession(Base): __tablename__ = "chat_session" @@ -723,6 +744,10 @@ class ChatMessage(Base): secondary="chat_message__search_doc", back_populates="chat_messages", ) + tool_calls: Mapped[list["ToolCall"]] = relationship( + "ToolCall", + back_populates="message", + ) class ChatFolder(Base): @@ -901,9 +926,18 @@ class Tool(Base): name: Mapped[str] = mapped_column(String, nullable=False) description: Mapped[str] = mapped_column(Text, nullable=True) # ID of the tool in the codebase, only applies for in-code tools. - # tools defiend via the UI will have this as None + # tools defined via the UI will have this as None in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True) + # OpenAPI scheme for the tool. Only applies to tools defined via the UI. + openapi_schema: Mapped[dict[str, Any] | None] = mapped_column( + postgresql.JSONB(), nullable=True + ) + + # user who created / owns the tool. Will be None for built-in tools. + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + + user: Mapped[User | None] = relationship("User", back_populates="custom_tools") # Relationship to Persona through the association table personas: Mapped[list["Persona"]] = relationship( "Persona", diff --git a/backend/danswer/db/tools.py b/backend/danswer/db/tools.py new file mode 100644 index 00000000000..1e75b1c4901 --- /dev/null +++ b/backend/danswer/db/tools.py @@ -0,0 +1,74 @@ +from typing import Any +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import Tool +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_tools(db_session: Session) -> list[Tool]: + return list(db_session.scalars(select(Tool)).all()) + + +def get_tool_by_id(tool_id: int, db_session: Session) -> Tool: + tool = db_session.scalar(select(Tool).where(Tool.id == tool_id)) + if not tool: + raise ValueError("Tool by specified id does not exist") + return tool + + +def create_tool( + name: str, + description: str | None, + openapi_schema: dict[str, Any] | None, + user_id: UUID | None, + db_session: Session, +) -> Tool: + new_tool = Tool( + name=name, + description=description, + in_code_tool_id=None, + openapi_schema=openapi_schema, + user_id=user_id, + ) + db_session.add(new_tool) + db_session.commit() + return new_tool + + +def update_tool( + tool_id: int, + name: str | None, + description: str | None, + openapi_schema: dict[str, Any] | None, + user_id: UUID | None, + db_session: Session, +) -> Tool: + tool = get_tool_by_id(tool_id, db_session) + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + if name is not None: + tool.name = name + if description is not None: + tool.description = description + if openapi_schema is not None: + tool.openapi_schema = openapi_schema + if user_id is not None: + tool.user_id = user_id + db_session.commit() + + return tool + + +def delete_tool(tool_id: int, db_session: Session) -> None: + tool = get_tool_by_id(tool_id, db_session) + if tool is None: + raise ValueError(f"Tool with ID {tool_id} does not exist") + + db_session.delete(tool) + db_session.commit() diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index 82c027304ae..4b849f70d96 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -24,7 +24,7 @@ def load_chat_file( file_id=file_descriptor["id"], content=file_io.read(), file_type=file_descriptor["type"], - filename=file_descriptor["name"], + filename=file_descriptor.get("name"), ) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index 41f8e109033..328bb7251da 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -4,6 +4,7 @@ from uuid import uuid4 from langchain.schema.messages import BaseMessage from langchain_core.messages import AIMessageChunk +from langchain_core.messages import HumanMessage from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import AnswerQuestionPossibleReturn @@ -33,6 +34,9 @@ from danswer.llm.answering.stream_processing.quotes_processing import ( from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import message_generator_to_string_generator +from danswer.tools.custom.custom_tool_prompt_builder import ( + build_user_message_for_custom_tool_for_non_tool_calling_llm, +) from danswer.tools.force import filter_tools_for_force_tool_use from danswer.tools.force import ForceUseTool from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID @@ -50,7 +54,8 @@ from danswer.tools.tool import ToolResponse from danswer.tools.tool_runner import ( check_which_tools_should_run_for_non_tool_calling_llm, ) -from danswer.tools.tool_runner import ToolRunKickoff +from danswer.tools.tool_runner import ToolCallFinalResult +from danswer.tools.tool_runner import ToolCallKickoff from danswer.tools.tool_runner import ToolRunner from danswer.tools.utils import explicit_tool_calling_supported @@ -72,7 +77,7 @@ def _get_answer_stream_processor( raise RuntimeError("Not implemented yet") -AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolRunKickoff | ToolResponse] +AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse] class Answer: @@ -125,7 +130,7 @@ class Answer: self._streamed_output: list[str] | None = None self._processed_stream: list[ - AnswerQuestionPossibleReturn | ToolResponse | ToolRunKickoff + AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff ] | None = None def _update_prompt_builder_for_search_tool( @@ -160,7 +165,7 @@ class Answer: def _raw_output_for_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) tool_call_chunk: AIMessageChunk | None = None @@ -237,16 +242,18 @@ class Answer: ), ) - if tool.name() == SearchTool.name(): + if tool.name() == SearchTool.NAME: self._update_prompt_builder_for_search_tool(prompt_builder, []) - elif tool.name() == ImageGenerationTool.name(): + elif tool.name() == ImageGenerationTool.NAME: prompt_builder.update_user_prompt( build_image_generation_user_prompt( query=self.question, ) ) - prompt = prompt_builder.build(tool_call_summary=tool_call_summary) + yield tool_runner.tool_final_result() + + prompt = prompt_builder.build(tool_call_summary=tool_call_summary) yield from message_generator_to_string_generator( self.llm.stream( prompt=prompt, @@ -258,7 +265,7 @@ class Answer: def _raw_output_for_non_explicit_tool_calling_llms( self, - ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + ) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]: prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) chosen_tool_and_args: tuple[Tool, dict] | None = None @@ -324,7 +331,7 @@ class Answer: tool_runner = ToolRunner(tool, tool_args) yield tool_runner.kickoff() - if tool.name() == SearchTool.name(): + if tool.name() == SearchTool.NAME: final_context_documents = None for response in tool_runner.tool_responses(): if response.id == FINAL_CONTEXT_DOCUMENTS: @@ -337,7 +344,7 @@ class Answer: self._update_prompt_builder_for_search_tool( prompt_builder, final_context_documents ) - elif tool.name() == ImageGenerationTool.name(): + elif tool.name() == ImageGenerationTool.NAME: img_urls = [] for response in tool_runner.tool_responses(): if response.id == IMAGE_GENERATION_RESPONSE_ID: @@ -354,6 +361,18 @@ class Answer: img_urls=img_urls, ) ) + else: + prompt_builder.update_user_prompt( + HumanMessage( + content=build_user_message_for_custom_tool_for_non_tool_calling_llm( + self.question, + tool.name(), + *tool_runner.tool_responses(), + ) + ) + ) + + yield tool_runner.tool_final_result() prompt = prompt_builder.build() yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt)) @@ -374,7 +393,7 @@ class Answer: ) def _process_stream( - stream: Iterator[ToolRunKickoff | ToolResponse | str], + stream: Iterator[ToolCallKickoff | ToolResponse | str], ) -> AnswerStream: message = None @@ -387,7 +406,9 @@ class Answer: ] | None = None # processed docs to feed into the LLM for message in stream: - if isinstance(message, ToolRunKickoff): + if isinstance(message, ToolCallKickoff) or isinstance( + message, ToolCallFinalResult + ): yield message elif isinstance(message, ToolResponse): if message.id == SEARCH_RESPONSE_SUMMARY_ID: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 8757d6d5588..e8988606714 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -63,6 +63,7 @@ from danswer.server.features.folder.api import router as folder_router from danswer.server.features.persona.api import admin_router as admin_persona_router from danswer.server.features.persona.api import basic_router as persona_router from danswer.server.features.prompt.api import basic_router as prompt_router +from danswer.server.features.tool.api import admin_router as admin_tool_router from danswer.server.features.tool.api import router as tool_router from danswer.server.gpts.api import router as gpts_router from danswer.server.manage.administrative import router as admin_router @@ -277,6 +278,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, prompt_router) include_router_with_global_prefix_prepended(application, tool_router) + include_router_with_global_prefix_prepended(application, admin_tool_router) include_router_with_global_prefix_prepended(application, state_router) include_router_with_global_prefix_prepended(application, danswer_api_router) include_router_with_global_prefix_prepended(application, gpts_router) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 7c0c39544c2..8aff6676b88 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -51,7 +51,7 @@ from danswer.tools.search.search_tool import SearchResponseSummary from danswer.tools.search.search_tool import SearchTool from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID from danswer.tools.tool import ToolResponse -from danswer.tools.tool_runner import ToolRunKickoff +from danswer.tools.tool_runner import ToolCallKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -67,7 +67,7 @@ AnswerObjectIterator = Iterator[ | StreamingError | ChatMessageDetail | CitationInfo - | ToolRunKickoff + | ToolCallKickoff ] diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py index 0a9666646a4..b1f57a1a924 100644 --- a/backend/danswer/server/features/tool/api.py +++ b/backend/danswer/server/features/tool/api.py @@ -1,32 +1,132 @@ +from typing import Any + from fastapi import APIRouter from fastapi import Depends +from fastapi import HTTPException from pydantic import BaseModel -from sqlalchemy import select from sqlalchemy.orm import Session +from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.db.engine import get_session -from danswer.db.models import Tool from danswer.db.models import User - +from danswer.db.tools import create_tool +from danswer.db.tools import delete_tool +from danswer.db.tools import get_tool_by_id +from danswer.db.tools import get_tools +from danswer.db.tools import update_tool +from danswer.server.features.tool.models import ToolSnapshot +from danswer.tools.custom.openapi_parsing import MethodSpec +from danswer.tools.custom.openapi_parsing import openapi_to_method_specs +from danswer.tools.custom.openapi_parsing import validate_openapi_schema router = APIRouter(prefix="/tool") +admin_router = APIRouter(prefix="/admin/tool") -class ToolSnapshot(BaseModel): - id: int +class CustomToolCreate(BaseModel): name: str - description: str - in_code_tool_id: str | None + description: str | None + definition: dict[str, Any] - @classmethod - def from_model(cls, tool: Tool) -> "ToolSnapshot": - return cls( - id=tool.id, - name=tool.name, - description=tool.description, - in_code_tool_id=tool.in_code_tool_id, - ) + +class CustomToolUpdate(BaseModel): + name: str | None + description: str | None + definition: dict[str, Any] | None + + +def _validate_tool_definition(definition: dict[str, Any]) -> None: + try: + validate_openapi_schema(definition) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@admin_router.post("/custom") +def create_custom_tool( + tool_data: CustomToolCreate, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_admin_user), +) -> ToolSnapshot: + _validate_tool_definition(tool_data.definition) + tool = create_tool( + name=tool_data.name, + description=tool_data.description, + openapi_schema=tool_data.definition, + user_id=user.id if user else None, + db_session=db_session, + ) + return ToolSnapshot.from_model(tool) + + +@admin_router.put("/custom/{tool_id}") +def update_custom_tool( + tool_id: int, + tool_data: CustomToolUpdate, + db_session: Session = Depends(get_session), + user: User | None = Depends(current_admin_user), +) -> ToolSnapshot: + if tool_data.definition: + _validate_tool_definition(tool_data.definition) + updated_tool = update_tool( + tool_id=tool_id, + name=tool_data.name, + description=tool_data.description, + openapi_schema=tool_data.definition, + user_id=user.id if user else None, + db_session=db_session, + ) + return ToolSnapshot.from_model(updated_tool) + + +@admin_router.delete("/custom/{tool_id}") +def delete_custom_tool( + tool_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_admin_user), +) -> None: + try: + delete_tool(tool_id, db_session) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + except Exception as e: + # handles case where tool is still used by an Assistant + raise HTTPException(status_code=400, detail=str(e)) + + +class ValidateToolRequest(BaseModel): + definition: dict[str, Any] + + +class ValidateToolResponse(BaseModel): + methods: list[MethodSpec] + + +@admin_router.post("/custom/validate") +def validate_tool( + tool_data: ValidateToolRequest, + _: User | None = Depends(current_admin_user), +) -> ValidateToolResponse: + _validate_tool_definition(tool_data.definition) + method_specs = openapi_to_method_specs(tool_data.definition) + return ValidateToolResponse(methods=method_specs) + + +"""Endpoints for all""" + + +@router.get("/{tool_id}") +def get_custom_tool( + tool_id: int, + db_session: Session = Depends(get_session), + _: User | None = Depends(current_user), +) -> ToolSnapshot: + try: + tool = get_tool_by_id(tool_id, db_session) + except ValueError as e: + raise HTTPException(status_code=404, detail=str(e)) + return ToolSnapshot.from_model(tool) @router.get("") @@ -34,5 +134,5 @@ def list_tools( db_session: Session = Depends(get_session), _: User | None = Depends(current_user), ) -> list[ToolSnapshot]: - tools = db_session.execute(select(Tool)).scalars().all() + tools = get_tools(db_session) return [ToolSnapshot.from_model(tool) for tool in tools] diff --git a/backend/danswer/server/features/tool/models.py b/backend/danswer/server/features/tool/models.py new file mode 100644 index 00000000000..feb3ba68269 --- /dev/null +++ b/backend/danswer/server/features/tool/models.py @@ -0,0 +1,23 @@ +from typing import Any + +from pydantic import BaseModel + +from danswer.db.models import Tool + + +class ToolSnapshot(BaseModel): + id: int + name: str + description: str + definition: dict[str, Any] | None + in_code_tool_id: str | None + + @classmethod + def from_model(cls, tool: Tool) -> "ToolSnapshot": + return cls( + id=tool.id, + name=tool.name, + description=tool.description, + definition=tool.openapi_schema, + in_code_tool_id=tool.in_code_tool_id, + ) diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index d20a4b11101..834453e6e2d 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -129,6 +129,8 @@ def get_chat_session( # we already did a permission check above with the call to # `get_chat_session_by_id`, so we can skip it here skip_permission_check=True, + # we need the tool call objs anyways, so just fetch them in a single call + prefetch_tool_calls=True, ) return ChatSessionDetailResponse( diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 44e8ab84624..09561bf24f8 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -17,6 +17,7 @@ from danswer.search.models import ChunkContext from danswer.search.models import RetrievalDetails from danswer.search.models import SearchDoc from danswer.search.models import Tag +from danswer.tools.models import ToolCallFinalResult class SourceTag(Tag): @@ -176,6 +177,7 @@ class ChatMessageDetail(BaseModel): # Dict mapping citation number to db_doc_id citations: dict[int, int] | None files: list[FileDescriptor] + tool_calls: list[ToolCallFinalResult] def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore initial_dict = super().dict(*args, **kwargs) # type: ignore diff --git a/backend/danswer/tools/custom/custom_tool.py b/backend/danswer/tools/custom/custom_tool.py new file mode 100644 index 00000000000..ea232fc5a74 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool.py @@ -0,0 +1,233 @@ +import json +from collections.abc import Generator +from typing import Any +from typing import cast + +import requests +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage +from pydantic import BaseModel + +from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM +from danswer.tools.custom.custom_tool_prompts import ( + SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT, +) +from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT +from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT +from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT +from danswer.tools.custom.custom_tool_prompts import USE_TOOL +from danswer.tools.custom.openapi_parsing import MethodSpec +from danswer.tools.custom.openapi_parsing import openapi_to_method_specs +from danswer.tools.custom.openapi_parsing import openapi_to_url +from danswer.tools.custom.openapi_parsing import REQUEST_BODY +from danswer.tools.custom.openapi_parsing import validate_openapi_schema +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.utils.logger import setup_logger + +logger = setup_logger() + +CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response" + + +class CustomToolCallSummary(BaseModel): + tool_name: str + tool_result: dict + + +class CustomTool(Tool): + def __init__(self, method_spec: MethodSpec, base_url: str) -> None: + self._base_url = base_url + self._method_spec = method_spec + self._tool_definition = self._method_spec.to_tool_definition() + + self._name = self._method_spec.name + self.description = self._method_spec.summary + + def name(self) -> str: + return self._name + + """For LLMs which support explicit tool calling""" + + def tool_definition(self) -> dict: + return self._tool_definition + + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + response = cast(CustomToolCallSummary, args[0].response) + return json.dumps(response.tool_result) + + """For LLMs which do NOT support explicit tool calling""" + + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + if not force_run: + should_use_result = llm.invoke( + [ + SystemMessage(content=SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT), + HumanMessage( + content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format( + history=history, + query=query, + tool_name=self.name(), + tool_description=self.description, + ) + ), + ] + ) + if cast(str, should_use_result.content).strip() != USE_TOOL: + return None + + args_result = llm.invoke( + [ + SystemMessage(content=TOOL_ARG_SYSTEM_PROMPT), + HumanMessage( + content=TOOL_ARG_USER_PROMPT.format( + history=history, + query=query, + tool_name=self.name(), + tool_description=self.description, + tool_args=self.tool_definition()["function"]["parameters"], + ) + ), + ] + ) + args_result_str = cast(str, args_result.content) + + try: + return json.loads(args_result_str.strip()) + except json.JSONDecodeError: + pass + + # try removing ``` + try: + return json.loads(args_result_str.strip("```")) + except json.JSONDecodeError: + pass + + # try removing ```json + try: + return json.loads(args_result_str.strip("```").strip("json")) + except json.JSONDecodeError: + pass + + # pretend like nothing happened if not parse-able + logger.error( + f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}" + ) + return None + + """Actual execution of the tool""" + + def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + request_body = kwargs.get(REQUEST_BODY) + + path_params = {} + for path_param_schema in self._method_spec.get_path_param_schemas(): + path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]] + + query_params = {} + for query_param_schema in self._method_spec.get_query_param_schemas(): + if query_param_schema["name"] in kwargs: + query_params[query_param_schema["name"]] = kwargs[ + query_param_schema["name"] + ] + + url = self._method_spec.build_url(self._base_url, path_params, query_params) + method = self._method_spec.method + + response = requests.request(method, url, json=request_body) + + yield ToolResponse( + id=CUSTOM_TOOL_RESPONSE_ID, + response=CustomToolCallSummary( + tool_name=self._name, tool_result=response.json() + ), + ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + return cast(CustomToolCallSummary, args[0].response).tool_result + + +def build_custom_tools_from_openapi_schema( + openapi_schema: dict[str, Any] +) -> list[CustomTool]: + url = openapi_to_url(openapi_schema) + method_specs = openapi_to_method_specs(openapi_schema) + return [CustomTool(method_spec, url) for method_spec in method_specs] + + +if __name__ == "__main__": + import openai + + openapi_schema = { + "openapi": "3.0.0", + "info": { + "version": "1.0.0", + "title": "Assistants API", + "description": "An API for managing assistants", + }, + "servers": [ + {"url": "http://localhost:8080"}, + ], + "paths": { + "/assistant/{assistant_id}": { + "get": { + "summary": "Get a specific Assistant", + "operationId": "getAssistant", + "parameters": [ + { + "name": "assistant_id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + }, + "post": { + "summary": "Create a new Assistant", + "operationId": "createAssistant", + "parameters": [ + { + "name": "assistant_id", + "in": "path", + "required": True, + "schema": {"type": "string"}, + } + ], + "requestBody": { + "required": True, + "content": {"application/json": {"schema": {"type": "object"}}}, + }, + }, + } + }, + } + validate_openapi_schema(openapi_schema) + + tools = build_custom_tools_from_openapi_schema(openapi_schema) + + openai_client = openai.OpenAI() + response = openai_client.chat.completions.create( + model="gpt-4o", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Can you fetch assistant with ID 10"}, + ], + tools=[tool.tool_definition() for tool in tools], # type: ignore + ) + choice = response.choices[0] + if choice.message.tool_calls: + print(choice.message.tool_calls) + for tool_response in tools[0].run( + **json.loads(choice.message.tool_calls[0].function.arguments) + ): + print(tool_response) diff --git a/backend/danswer/tools/custom/custom_tool_prompt_builder.py b/backend/danswer/tools/custom/custom_tool_prompt_builder.py new file mode 100644 index 00000000000..8016363acc9 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool_prompt_builder.py @@ -0,0 +1,21 @@ +from typing import cast + +from danswer.tools.custom.custom_tool import CustomToolCallSummary +from danswer.tools.models import ToolResponse + + +def build_user_message_for_custom_tool_for_non_tool_calling_llm( + query: str, + tool_name: str, + *args: ToolResponse, +) -> str: + tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result + return f""" +Here's the result from the {tool_name} tool: + +{tool_run_summary} + +Now respond to the following: + +{query} +""".strip() diff --git a/backend/danswer/tools/custom/custom_tool_prompts.py b/backend/danswer/tools/custom/custom_tool_prompts.py new file mode 100644 index 00000000000..14e8b007ef0 --- /dev/null +++ b/backend/danswer/tools/custom/custom_tool_prompts.py @@ -0,0 +1,57 @@ +from danswer.prompts.constants import GENERAL_SEP_PAT + +DONT_USE_TOOL = "Don't use tool" +USE_TOOL = "Use tool" + + +"""Prompts to determine if we should use a custom tool or not.""" + + +SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT = ( + "You are a large language model whose only job is to determine if the system should call an " + "external tool to be able to answer the user's last message." +).strip() + +SHOULD_USE_CUSTOM_TOOL_USER_PROMPT = f""" +Given the conversation history and a follow up query, determine if the system should use the \ +'{{tool_name}}' tool to answer the user's query. The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. + +Respond with "{USE_TOOL}" if you think the tool would be helpful in respnding to the users query. +Respond with "{DONT_USE_TOOL}" otherwise. + +Conversation History: +{GENERAL_SEP_PAT} +{{history}} +{GENERAL_SEP_PAT} + +If you are at all unsure, respond with {DONT_USE_TOOL}. +Respond with EXACTLY and ONLY "{DONT_USE_TOOL}" or "{USE_TOOL}" + +Follow up input: +{{query}} +""".strip() + + +"""Prompts to figure out the arguments to pass to a custom tool.""" + + +TOOL_ARG_SYSTEM_PROMPT = ( + "You are a large language model whose only job is to determine the arguments to pass to an " + "external tool." +).strip() + + +TOOL_ARG_USER_PROMPT = f""" +Given the following conversation and a follow up input, generate a \ +dictionary of arguments to pass to the '{{tool_name}}' tool. \ +The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. \ +The expected arguments are: {{tool_args}}. + +Conversation: +{{history}} + +Follow up input: +{{query}} + +Respond with ONLY and EXACTLY a JSON object specifying the values of the arguments to pass to the tool. +""".strip() # noqa: F541 diff --git a/backend/danswer/tools/custom/openapi_parsing.py b/backend/danswer/tools/custom/openapi_parsing.py new file mode 100644 index 00000000000..40ed5544d8b --- /dev/null +++ b/backend/danswer/tools/custom/openapi_parsing.py @@ -0,0 +1,225 @@ +from typing import Any +from typing import cast + +from openai import BaseModel + +REQUEST_BODY = "requestBody" + + +class PathSpec(BaseModel): + path: str + methods: dict[str, Any] + + +class MethodSpec(BaseModel): + name: str + summary: str + path: str + method: str + spec: dict[str, Any] + + def get_request_body_schema(self) -> dict[str, Any]: + content = self.spec.get("requestBody", {}).get("content", {}) + if "application/json" in content: + return content["application/json"].get("schema") + + if content: + raise ValueError( + f"Unsupported content type: '{list(content.keys())[0]}'. " + f"Only 'application/json' is supported." + ) + + return {} + + def get_query_param_schemas(self) -> list[dict[str, Any]]: + return [ + param + for param in self.spec.get("parameters", []) + if "schema" in param and "in" in param and param["in"] == "query" + ] + + def get_path_param_schemas(self) -> list[dict[str, Any]]: + return [ + param + for param in self.spec.get("parameters", []) + if "schema" in param and "in" in param and param["in"] == "path" + ] + + def build_url( + self, base_url: str, path_params: dict[str, str], query_params: dict[str, str] + ) -> str: + url = f"{base_url}{self.path}" + try: + url = url.format(**path_params) + except KeyError as e: + raise ValueError(f"Missing path parameter: {e}") + if query_params: + url += "?" + for param, value in query_params.items(): + url += f"{param}={value}&" + url = url[:-1] + return url + + def to_tool_definition(self) -> dict[str, Any]: + tool_definition: Any = { + "type": "function", + "function": { + "name": self.name, + "description": self.summary, + "parameters": {"type": "object", "properties": {}}, + }, + } + + request_body_schema = self.get_request_body_schema() + if request_body_schema: + tool_definition["function"]["parameters"]["properties"][ + REQUEST_BODY + ] = request_body_schema + + query_param_schemas = self.get_query_param_schemas() + if query_param_schemas: + tool_definition["function"]["parameters"]["properties"].update( + {param["name"]: param["schema"] for param in query_param_schemas} + ) + + path_param_schemas = self.get_path_param_schemas() + if path_param_schemas: + tool_definition["function"]["parameters"]["properties"].update( + {param["name"]: param["schema"] for param in path_param_schemas} + ) + return tool_definition + + def validate_spec(self) -> None: + # Validate url construction + path_param_schemas = self.get_path_param_schemas() + dummy_path_dict = {param["name"]: "value" for param in path_param_schemas} + query_param_schemas = self.get_query_param_schemas() + dummy_query_dict = {param["name"]: "value" for param in query_param_schemas} + self.build_url("", dummy_path_dict, dummy_query_dict) + + # Make sure request body doesn't throw an exception + self.get_request_body_schema() + + # Ensure the method is valid + if not self.method: + raise ValueError("HTTP method is not specified.") + if self.method.upper() not in ["GET", "POST", "PUT", "DELETE", "PATCH"]: + raise ValueError(f"HTTP method '{self.method}' is not supported.") + + +"""Path-level utils""" + + +def openapi_to_path_specs(openapi_spec: dict[str, Any]) -> list[PathSpec]: + path_specs = [] + + for path, methods in openapi_spec.get("paths", {}).items(): + path_specs.append(PathSpec(path=path, methods=methods)) + + return path_specs + + +"""Method-level utils""" + + +def openapi_to_method_specs(openapi_spec: dict[str, Any]) -> list[MethodSpec]: + path_specs = openapi_to_path_specs(openapi_spec) + + method_specs = [] + for path_spec in path_specs: + for method_name, method in path_spec.methods.items(): + name = method.get("operationId") + if not name: + raise ValueError( + f"Operation ID is not specified for {method_name.upper()} {path_spec.path}" + ) + + summary = method.get("summary") or method.get("description") + if not summary: + raise ValueError( + f"Summary is not specified for {method_name.upper()} {path_spec.path}" + ) + + method_specs.append( + MethodSpec( + name=name, + summary=summary, + path=path_spec.path, + method=method_name, + spec=method, + ) + ) + + if not method_specs: + raise ValueError("No methods found in OpenAPI schema") + + return method_specs + + +def openapi_to_url(openapi_schema: dict[str, dict | str]) -> str: + """ + Extract URLs from the servers section of an OpenAPI schema. + + Args: + openapi_schema (Dict[str, Union[Dict, str, List]]): The OpenAPI schema in dictionary format. + + Returns: + List[str]: A list of base URLs. + """ + urls: list[str] = [] + + servers = cast(list[dict[str, Any]], openapi_schema.get("servers", [])) + for server in servers: + url = server.get("url") + if url: + urls.append(url) + + if len(urls) != 1: + raise ValueError( + f"Expected exactly one URL in OpenAPI schema, but found {urls}" + ) + + return urls[0] + + +def validate_openapi_schema(schema: dict[str, Any]) -> None: + """ + Validate the given JSON schema as an OpenAPI schema. + + Parameters: + - schema (dict): The JSON schema to validate. + + Returns: + - bool: True if the schema is valid, False otherwise. + """ + + # check basic structure + if "info" not in schema: + raise ValueError("`info` section is required in OpenAPI schema") + + info = schema["info"] + if "title" not in info: + raise ValueError("`title` is required in `info` section of OpenAPI schema") + if "description" not in info: + raise ValueError( + "`description` is required in `info` section of OpenAPI schema" + ) + + if "openapi" not in schema: + raise ValueError( + "`openapi` field which specifies OpenAPI schema version is required" + ) + openapi_version = schema["openapi"] + if not openapi_version.startswith("3."): + raise ValueError(f"OpenAPI version '{openapi_version}' is not supported") + + if "paths" not in schema: + raise ValueError("`paths` section is required in OpenAPI schema") + + url = openapi_to_url(schema) + if not url: + raise ValueError("OpenAPI schema does not contain a valid URL in `servers`") + + method_specs = openapi_to_method_specs(schema) + for method_spec in method_specs: + method_spec.validate_spec() diff --git a/backend/danswer/tools/factory.py b/backend/danswer/tools/factory.py deleted file mode 100644 index 197bdd6619a..00000000000 --- a/backend/danswer/tools/factory.py +++ /dev/null @@ -1,12 +0,0 @@ -from typing import Type - -from sqlalchemy.orm import Session - -from danswer.db.models import Tool as ToolDBModel -from danswer.tools.built_in_tools import get_built_in_tool_by_id -from danswer.tools.tool import Tool - - -def get_tool_cls(tool: ToolDBModel, db_session: Session) -> Type[Tool]: - # Currently only support built-in tools - return get_built_in_tool_by_id(tool.id, db_session) diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py index da66271322f..22aa40993b6 100644 --- a/backend/danswer/tools/images/image_generation_tool.py +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -8,6 +8,7 @@ from pydantic import BaseModel from danswer.chat.chat_utils import combine_message_chain from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.llm.utils import build_content_with_imgs @@ -53,6 +54,8 @@ class ImageGenerationResponse(BaseModel): class ImageGenerationTool(Tool): + NAME = "run_image_generation" + def __init__( self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2 ) -> None: @@ -60,16 +63,14 @@ class ImageGenerationTool(Tool): self.model = model self.num_imgs = num_imgs - @classmethod def name(self) -> str: - return "run_image_generation" + return self.NAME - @classmethod - def tool_definition(cls) -> dict: + def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": cls.name(), + "name": self.name(), "description": "Generate an image from a prompt", "parameters": { "type": "object", @@ -162,3 +163,12 @@ class ImageGenerationTool(Tool): id=IMAGE_GENERATION_RESPONSE_ID, response=results, ) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + image_generation_responses = cast( + list[ImageGenerationResponse], args[0].response + ) + return [ + image_generation_response.dict() + for image_generation_response in image_generation_responses + ] diff --git a/backend/danswer/tools/models.py b/backend/danswer/tools/models.py new file mode 100644 index 00000000000..53940dcea49 --- /dev/null +++ b/backend/danswer/tools/models.py @@ -0,0 +1,39 @@ +from typing import Any + +from pydantic import BaseModel +from pydantic import root_validator + + +class ToolResponse(BaseModel): + id: str | None = None + response: Any + + +class ToolCallKickoff(BaseModel): + tool_name: str + tool_args: dict[str, Any] + + +class ToolRunnerResponse(BaseModel): + tool_run_kickoff: ToolCallKickoff | None = None + tool_response: ToolResponse | None = None + tool_message_content: str | list[str | dict[str, Any]] | None = None + + @root_validator + def validate_tool_runner_response( + cls, values: dict[str, ToolResponse | str] + ) -> dict[str, ToolResponse | str]: + fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] + provided = sum(1 for field in fields if values.get(field) is not None) + + if provided != 1: + raise ValueError( + "Exactly one of 'tool_response', 'tool_message_content', " + "or 'tool_run_kickoff' must be provided" + ) + + return values + + +class ToolCallFinalResult(ToolCallKickoff): + tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 30ec47d1664..b0b45bd8f40 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -10,6 +10,7 @@ from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import LlmDoc from danswer.db.models import Persona from danswer.db.models import User +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage @@ -55,6 +56,8 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ class SearchTool(Tool): + NAME = "run_search" + def __init__( self, db_session: Session, @@ -87,18 +90,16 @@ class SearchTool(Tool): self.bypass_acl = bypass_acl self.db_session = db_session - @classmethod - def name(cls) -> str: - return "run_search" + def name(self) -> str: + return self.NAME """For explicit tool calling""" - @classmethod - def tool_definition(cls) -> dict: + def tool_definition(self) -> dict: return { "type": "function", "function": { - "name": cls.name(), + "name": self.name(), "description": search_tool_description, "parameters": { "type": "object", @@ -241,3 +242,13 @@ class SearchTool(Tool): document_pruning_config=self.pruning_config, ) yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=final_context_documents) + + def final_result(self, *args: ToolResponse) -> JSON_ro: + final_docs = cast( + list[LlmDoc], + next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS), + ) + # NOTE: need to do this json.loads(doc.json()) stuff because there are some + # subfields that are not serializable by default (datetime) + # this forces pydantic to make them JSON serializable for us + return [json.loads(doc.json()) for doc in final_docs] diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py index dd443757e67..e335a049838 100644 --- a/backend/danswer/tools/tool.py +++ b/backend/danswer/tools/tool.py @@ -2,26 +2,19 @@ import abc from collections.abc import Generator from typing import Any -from pydantic import BaseModel - +from danswer.dynamic_configs.interface import JSON_ro from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM - - -class ToolResponse(BaseModel): - id: str | None = None - response: Any +from danswer.tools.models import ToolResponse class Tool(abc.ABC): - @classmethod @abc.abstractmethod def name(self) -> str: raise NotImplementedError """For LLMs which support explicit tool calling""" - @classmethod @abc.abstractmethod def tool_definition(self) -> dict: raise NotImplementedError @@ -49,3 +42,11 @@ class Tool(abc.ABC): @abc.abstractmethod def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: raise NotImplementedError + + @abc.abstractmethod + def final_result(self, *args: ToolResponse) -> JSON_ro: + """ + This is the "final summary" result of the tool. + It is the result that will be stored in the database. + """ + raise NotImplementedError diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py index 46f247b06dc..a4367d865d5 100644 --- a/backend/danswer/tools/tool_runner.py +++ b/backend/danswer/tools/tool_runner.py @@ -1,42 +1,15 @@ from collections.abc import Generator from typing import Any -from pydantic import BaseModel -from pydantic import root_validator - from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM +from danswer.tools.models import ToolCallFinalResult +from danswer.tools.models import ToolCallKickoff from danswer.tools.tool import Tool from danswer.tools.tool import ToolResponse from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel -class ToolRunKickoff(BaseModel): - tool_name: str - tool_args: dict[str, Any] - - -class ToolRunnerResponse(BaseModel): - tool_run_kickoff: ToolRunKickoff | None = None - tool_response: ToolResponse | None = None - tool_message_content: str | list[str | dict[str, Any]] | None = None - - @root_validator - def validate_tool_runner_response( - cls, values: dict[str, ToolResponse | str] - ) -> dict[str, ToolResponse | str]: - fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] - provided = sum(1 for field in fields if values.get(field) is not None) - - if provided != 1: - raise ValueError( - "Exactly one of 'tool_response', 'tool_message_content', " - "or 'tool_run_kickoff' must be provided" - ) - - return values - - class ToolRunner: def __init__(self, tool: Tool, args: dict[str, Any]): self.tool = tool @@ -44,8 +17,8 @@ class ToolRunner: self._tool_responses: list[ToolResponse] | None = None - def kickoff(self) -> ToolRunKickoff: - return ToolRunKickoff(tool_name=self.tool.name(), tool_args=self.args) + def kickoff(self) -> ToolCallKickoff: + return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args) def tool_responses(self) -> Generator[ToolResponse, None, None]: if self._tool_responses is not None: @@ -62,6 +35,13 @@ class ToolRunner: tool_responses = list(self.tool_responses()) return self.tool.build_tool_message_content(*tool_responses) + def tool_final_result(self) -> ToolCallFinalResult: + return ToolCallFinalResult( + tool_name=self.tool.name(), + tool_args=self.args, + tool_result=self.tool.final_result(*self.tool_responses()), + ) + def check_which_tools_should_run_for_non_tool_calling_llm( tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM diff --git a/backend/danswer/tools/utils.py b/backend/danswer/tools/utils.py index 831021cdab3..7fb2156df59 100644 --- a/backend/danswer/tools/utils.py +++ b/backend/danswer/tools/utils.py @@ -1,5 +1,4 @@ import json -from typing import Type from tiktoken import Encoding @@ -17,15 +16,13 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo return False -def compute_tool_tokens( - tool: Tool | Type[Tool], llm_tokenizer: Encoding | None = None -) -> int: +def compute_tool_tokens(tool: Tool, llm_tokenizer: Encoding | None = None) -> int: if not llm_tokenizer: llm_tokenizer = get_default_llm_tokenizer() return len(llm_tokenizer.encode(json.dumps(tool.tool_definition()))) def compute_all_tool_tokens( - tools: list[Tool] | list[Type[Tool]], llm_tokenizer: Encoding | None = None + tools: list[Tool], llm_tokenizer: Encoding | None = None ) -> int: return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 14b89bc1365..2d8745f3349 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -24,6 +24,7 @@ httpx[http2]==0.23.3 httpx-oauth==0.11.2 huggingface-hub==0.20.1 jira==3.5.1 +jsonref==1.1.0 langchain==0.1.17 langchain-community==0.0.36 langchain-core==0.1.50 diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index d6427a67929..10b0ad586aa 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -147,40 +147,56 @@ export function AssistantEditor({ const imageGenerationTool = providerSupportingImageGenerationExists ? findImageGenerationTool(tools) : undefined; + const customTools = tools.filter( + (tool) => + tool.in_code_tool_id !== searchTool?.in_code_tool_id && + tool.in_code_tool_id !== imageGenerationTool?.in_code_tool_id + ); + + const availableTools = [ + ...customTools, + ...(searchTool ? [searchTool] : []), + ...(imageGenerationTool ? [imageGenerationTool] : []), + ]; + const enabledToolsMap: { [key: number]: boolean } = {}; + availableTools.forEach((tool) => { + enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id); + }); + + const initialValues = { + name: existingPersona?.name ?? "", + description: existingPersona?.description ?? "", + system_prompt: existingPrompt?.system_prompt ?? "", + task_prompt: existingPrompt?.task_prompt ?? "", + is_public: existingPersona?.is_public ?? defaultPublic, + document_set_ids: + existingPersona?.document_sets?.map((documentSet) => documentSet.id) ?? + ([] as number[]), + num_chunks: existingPersona?.num_chunks ?? null, + include_citations: existingPersona?.prompts[0]?.include_citations ?? true, + llm_relevance_filter: existingPersona?.llm_relevance_filter ?? false, + llm_model_provider_override: + existingPersona?.llm_model_provider_override ?? null, + llm_model_version_override: + existingPersona?.llm_model_version_override ?? null, + starter_messages: existingPersona?.starter_messages ?? [], + enabled_tools_map: enabledToolsMap, + // search_tool_enabled: existingPersona + // ? personaCurrentToolIds.includes(searchTool!.id) + // : ccPairs.length > 0, + // image_generation_tool_enabled: imageGenerationTool + // ? personaCurrentToolIds.includes(imageGenerationTool.id) + // : false, + // EE Only + groups: existingPersona?.groups ?? [], + }; return (
Name | +Summary | +Method | +Path | +
---|---|---|---|
{method.name} | +{method.summary} | ++ {method.method.toUpperCase()} + | +{method.path} | +