mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-31 10:10:21 +02:00
Custom tools
This commit is contained in:
parent
c6d094b2ee
commit
7746375bfd
@ -0,0 +1,61 @@
|
||||
"""Add support for custom tools
|
||||
|
||||
Revision ID: 48d14957fe80
|
||||
Revises: b85f02ec1308
|
||||
Create Date: 2024-06-09 14:58:19.946509
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "48d14957fe80"
|
||||
down_revision = "b85f02ec1308"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"openapi_schema",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"tool",
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"])
|
||||
|
||||
op.create_table(
|
||||
"tool_call",
|
||||
sa.Column("id", sa.Integer(), primary_key=True),
|
||||
sa.Column("tool_id", sa.Integer(), nullable=False),
|
||||
sa.Column("tool_name", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("tool_call")
|
||||
|
||||
op.drop_constraint("tool_user_fk", "tool", type_="foreignkey")
|
||||
op.drop_column("tool", "user_id")
|
||||
op.drop_column("tool", "openapi_schema")
|
@ -106,12 +106,18 @@ class ImageGenerationDisplay(BaseModel):
|
||||
file_ids: list[str]
|
||||
|
||||
|
||||
class CustomToolResponse(BaseModel):
|
||||
response: dict
|
||||
tool_name: str
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
)
|
||||
|
||||
|
@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LlmDoc
|
||||
@ -31,6 +32,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
@ -54,7 +56,10 @@ from danswer.search.utils import drop_llm_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.factory import get_tool_cls
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
@ -65,6 +70,7 @@ from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -162,7 +168,7 @@ def _check_should_force_search(
|
||||
args = {"query": new_msg_req.message}
|
||||
|
||||
return ForceUseTool(
|
||||
tool_name=SearchTool.name(),
|
||||
tool_name=SearchTool.NAME,
|
||||
args=args,
|
||||
)
|
||||
return None
|
||||
@ -176,6 +182,7 @@ ChatPacket = (
|
||||
| DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@ -389,61 +396,78 @@ def stream_chat_message_objects(
|
||||
),
|
||||
)
|
||||
|
||||
persona_tool_classes = [
|
||||
get_tool_cls(tool, db_session) for tool in persona.tools
|
||||
]
|
||||
# find out what tools to use
|
||||
search_tool: SearchTool | None = None
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
dalle_key = None
|
||||
if (
|
||||
llm
|
||||
and llm.config.api_key
|
||||
and llm.config.model_provider == "openai"
|
||||
):
|
||||
dalle_key = llm.config.api_key
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
tool_dict[db_tool_model.id] = [
|
||||
ImageGenerationTool(api_key=dalle_key)
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema(
|
||||
db_tool_model.openapi_schema
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
persona_tool_classes
|
||||
)
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(tools)
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
|
||||
# NOTE: for now, only support SearchTool and ImageGenerationTool
|
||||
# in the future, will support arbitrary user-defined tools
|
||||
search_tool: SearchTool | None = None
|
||||
tools: list[Tool] = []
|
||||
for tool_cls in persona_tool_classes:
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
)
|
||||
tools.append(search_tool)
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
dalle_key = None
|
||||
if llm and llm.config.api_key and llm.config.model_provider == "openai":
|
||||
dalle_key = llm.config.api_key
|
||||
else:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
openai_provider = next(
|
||||
iter(
|
||||
[
|
||||
llm_provider
|
||||
for llm_provider in llm_providers
|
||||
if llm_provider.provider == "openai"
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not openai_provider or not openai_provider.api_key:
|
||||
raise ValueError(
|
||||
"Image generation tool requires an OpenAI API key"
|
||||
)
|
||||
dalle_key = openai_provider.api_key
|
||||
tools.append(ImageGenerationTool(api_key=dalle_key))
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
question=final_msg.message,
|
||||
@ -468,7 +492,9 @@ def stream_chat_message_objects(
|
||||
],
|
||||
tools=tools,
|
||||
force_use_tool=(
|
||||
_check_should_force_search(new_msg_req) if search_tool else None
|
||||
_check_should_force_search(new_msg_req)
|
||||
if search_tool and len(tools) == 1
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
@ -476,6 +502,7 @@ def stream_chat_message_objects(
|
||||
qa_docs_response = None
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
@ -521,8 +548,16 @@ def stream_chat_message_objects(
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
except Exception as e:
|
||||
@ -551,6 +586,11 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name()] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
@ -561,6 +601,16 @@ def stream_chat_message_objects(
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else [],
|
||||
)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
|
@ -5,6 +5,7 @@ from sqlalchemy import nullsfirst
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
@ -16,6 +17,7 @@ from danswer.db.models import ChatSessionSharedStatus
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
@ -24,6 +26,7 @@ from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SearchDoc as ServerSearchDoc
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -185,6 +188,7 @@ def get_chat_messages_by_session(
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
skip_permission_check: bool = False,
|
||||
prefetch_tool_calls: bool = False,
|
||||
) -> list[ChatMessage]:
|
||||
if not skip_permission_check:
|
||||
get_chat_session_by_id(
|
||||
@ -192,12 +196,18 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
|
||||
# Start with the root message which has no parent
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.chat_session_id == chat_session_id)
|
||||
.order_by(nullsfirst(ChatMessage.parent_message))
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt).scalars().all()
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
|
||||
if prefetch_tool_calls:
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
|
||||
return list(result)
|
||||
|
||||
@ -251,6 +261,7 @@ def create_new_chat_message(
|
||||
reference_docs: list[DBSearchDoc] | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
commit: bool = True,
|
||||
) -> ChatMessage:
|
||||
new_chat_message = ChatMessage(
|
||||
@ -264,6 +275,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
)
|
||||
|
||||
@ -459,6 +471,14 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
@ -133,6 +133,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
# Custom tools created by this user
|
||||
custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user")
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
@ -330,7 +332,6 @@ class Document(Base):
|
||||
primary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
# Something like assignee or space owner
|
||||
secondary_owners: Mapped[list[str] | None] = mapped_column(
|
||||
postgresql.ARRAY(String), nullable=True
|
||||
)
|
||||
@ -618,6 +619,26 @@ class SearchDoc(Base):
|
||||
)
|
||||
|
||||
|
||||
class ToolCall(Base):
|
||||
"""Represents a single tool call"""
|
||||
|
||||
__tablename__ = "tool_call"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# not a FK because we want to be able to delete the tool without deleting
|
||||
# this entry
|
||||
tool_id: Mapped[int] = mapped_column(Integer())
|
||||
tool_name: Mapped[str] = mapped_column(String())
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
)
|
||||
|
||||
|
||||
class ChatSession(Base):
|
||||
__tablename__ = "chat_session"
|
||||
|
||||
@ -723,6 +744,10 @@ class ChatMessage(Base):
|
||||
secondary="chat_message__search_doc",
|
||||
back_populates="chat_messages",
|
||||
)
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
)
|
||||
|
||||
|
||||
class ChatFolder(Base):
|
||||
@ -901,9 +926,18 @@ class Tool(Base):
|
||||
name: Mapped[str] = mapped_column(String, nullable=False)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# ID of the tool in the codebase, only applies for in-code tools.
|
||||
# tools defiend via the UI will have this as None
|
||||
# tools defined via the UI will have this as None
|
||||
in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
# OpenAPI scheme for the tool. Only applies to tools defined via the UI.
|
||||
openapi_schema: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
|
||||
# user who created / owns the tool. Will be None for built-in tools.
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="custom_tools")
|
||||
# Relationship to Persona through the association table
|
||||
personas: Mapped[list["Persona"]] = relationship(
|
||||
"Persona",
|
||||
|
74
backend/danswer/db/tools.py
Normal file
74
backend/danswer/db/tools.py
Normal file
@ -0,0 +1,74 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Tool
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_tools(db_session: Session) -> list[Tool]:
|
||||
return list(db_session.scalars(select(Tool)).all())
|
||||
|
||||
|
||||
def get_tool_by_id(tool_id: int, db_session: Session) -> Tool:
|
||||
tool = db_session.scalar(select(Tool).where(Tool.id == tool_id))
|
||||
if not tool:
|
||||
raise ValueError("Tool by specified id does not exist")
|
||||
return tool
|
||||
|
||||
|
||||
def create_tool(
|
||||
name: str,
|
||||
description: str | None,
|
||||
openapi_schema: dict[str, Any] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Tool:
|
||||
new_tool = Tool(
|
||||
name=name,
|
||||
description=description,
|
||||
in_code_tool_id=None,
|
||||
openapi_schema=openapi_schema,
|
||||
user_id=user_id,
|
||||
)
|
||||
db_session.add(new_tool)
|
||||
db_session.commit()
|
||||
return new_tool
|
||||
|
||||
|
||||
def update_tool(
|
||||
tool_id: int,
|
||||
name: str | None,
|
||||
description: str | None,
|
||||
openapi_schema: dict[str, Any] | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> Tool:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
||||
|
||||
if name is not None:
|
||||
tool.name = name
|
||||
if description is not None:
|
||||
tool.description = description
|
||||
if openapi_schema is not None:
|
||||
tool.openapi_schema = openapi_schema
|
||||
if user_id is not None:
|
||||
tool.user_id = user_id
|
||||
db_session.commit()
|
||||
|
||||
return tool
|
||||
|
||||
|
||||
def delete_tool(tool_id: int, db_session: Session) -> None:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
if tool is None:
|
||||
raise ValueError(f"Tool with ID {tool_id} does not exist")
|
||||
|
||||
db_session.delete(tool)
|
||||
db_session.commit()
|
@ -24,7 +24,7 @@ def load_chat_file(
|
||||
file_id=file_descriptor["id"],
|
||||
content=file_io.read(),
|
||||
file_type=file_descriptor["type"],
|
||||
filename=file_descriptor["name"],
|
||||
filename=file_descriptor.get("name"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
@ -33,6 +34,9 @@ from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.force import filter_tools_for_force_tool_use
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
@ -50,7 +54,8 @@ from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolRunKickoff
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
|
||||
@ -72,7 +77,7 @@ def _get_answer_stream_processor(
|
||||
raise RuntimeError("Not implemented yet")
|
||||
|
||||
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolRunKickoff | ToolResponse]
|
||||
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
|
||||
|
||||
|
||||
class Answer:
|
||||
@ -125,7 +130,7 @@ class Answer:
|
||||
|
||||
self._streamed_output: list[str] | None = None
|
||||
self._processed_stream: list[
|
||||
AnswerQuestionPossibleReturn | ToolResponse | ToolRunKickoff
|
||||
AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff
|
||||
] | None = None
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
@ -160,7 +165,7 @@ class Answer:
|
||||
|
||||
def _raw_output_for_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[str | ToolRunKickoff | ToolResponse]:
|
||||
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
@ -237,16 +242,18 @@ class Answer:
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name() == SearchTool.name():
|
||||
if tool.name() == SearchTool.NAME:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name() == ImageGenerationTool.name():
|
||||
elif tool.name() == ImageGenerationTool.NAME:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
yield from message_generator_to_string_generator(
|
||||
self.llm.stream(
|
||||
prompt=prompt,
|
||||
@ -258,7 +265,7 @@ class Answer:
|
||||
|
||||
def _raw_output_for_non_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[str | ToolRunKickoff | ToolResponse]:
|
||||
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
@ -324,7 +331,7 @@ class Answer:
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
if tool.name() == SearchTool.name():
|
||||
if tool.name() == SearchTool.NAME:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS:
|
||||
@ -337,7 +344,7 @@ class Answer:
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name() == ImageGenerationTool.name():
|
||||
elif tool.name() == ImageGenerationTool.NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
@ -354,6 +361,18 @@ class Answer:
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name(),
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
prompt = prompt_builder.build()
|
||||
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
|
||||
@ -374,7 +393,7 @@ class Answer:
|
||||
)
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolRunKickoff | ToolResponse | str],
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str],
|
||||
) -> AnswerStream:
|
||||
message = None
|
||||
|
||||
@ -387,7 +406,9 @@ class Answer:
|
||||
] | None = None # processed docs to feed into the LLM
|
||||
|
||||
for message in stream:
|
||||
if isinstance(message, ToolRunKickoff):
|
||||
if isinstance(message, ToolCallKickoff) or isinstance(
|
||||
message, ToolCallFinalResult
|
||||
):
|
||||
yield message
|
||||
elif isinstance(message, ToolResponse):
|
||||
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
|
@ -63,6 +63,7 @@ from danswer.server.features.folder.api import router as folder_router
|
||||
from danswer.server.features.persona.api import admin_router as admin_persona_router
|
||||
from danswer.server.features.persona.api import basic_router as persona_router
|
||||
from danswer.server.features.prompt.api import basic_router as prompt_router
|
||||
from danswer.server.features.tool.api import admin_router as admin_tool_router
|
||||
from danswer.server.features.tool.api import router as tool_router
|
||||
from danswer.server.gpts.api import router as gpts_router
|
||||
from danswer.server.manage.administrative import router as admin_router
|
||||
@ -277,6 +278,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_tool_router)
|
||||
include_router_with_global_prefix_prepended(application, state_router)
|
||||
include_router_with_global_prefix_prepended(application, danswer_api_router)
|
||||
include_router_with_global_prefix_prepended(application, gpts_router)
|
||||
|
@ -51,7 +51,7 @@ from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolRunKickoff
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@ -67,7 +67,7 @@ AnswerObjectIterator = Iterator[
|
||||
| StreamingError
|
||||
| ChatMessageDetail
|
||||
| CitationInfo
|
||||
| ToolRunKickoff
|
||||
| ToolCallKickoff
|
||||
]
|
||||
|
||||
|
||||
|
@ -1,32 +1,132 @@
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
|
||||
from danswer.db.tools import create_tool
|
||||
from danswer.db.tools import delete_tool
|
||||
from danswer.db.tools import get_tool_by_id
|
||||
from danswer.db.tools import get_tools
|
||||
from danswer.db.tools import update_tool
|
||||
from danswer.server.features.tool.models import ToolSnapshot
|
||||
from danswer.tools.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||
|
||||
router = APIRouter(prefix="/tool")
|
||||
admin_router = APIRouter(prefix="/admin/tool")
|
||||
|
||||
|
||||
class ToolSnapshot(BaseModel):
|
||||
id: int
|
||||
class CustomToolCreate(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
in_code_tool_id: str | None
|
||||
description: str | None
|
||||
definition: dict[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, tool: Tool) -> "ToolSnapshot":
|
||||
return cls(
|
||||
id=tool.id,
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
)
|
||||
|
||||
class CustomToolUpdate(BaseModel):
|
||||
name: str | None
|
||||
description: str | None
|
||||
definition: dict[str, Any] | None
|
||||
|
||||
|
||||
def _validate_tool_definition(definition: dict[str, Any]) -> None:
|
||||
try:
|
||||
validate_openapi_schema(definition)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.post("/custom")
|
||||
def create_custom_tool(
|
||||
tool_data: CustomToolCreate,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> ToolSnapshot:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
tool = create_tool(
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
openapi_schema=tool_data.definition,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ToolSnapshot.from_model(tool)
|
||||
|
||||
|
||||
@admin_router.put("/custom/{tool_id}")
|
||||
def update_custom_tool(
|
||||
tool_id: int,
|
||||
tool_data: CustomToolUpdate,
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> ToolSnapshot:
|
||||
if tool_data.definition:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
updated_tool = update_tool(
|
||||
tool_id=tool_id,
|
||||
name=tool_data.name,
|
||||
description=tool_data.description,
|
||||
openapi_schema=tool_data.definition,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
return ToolSnapshot.from_model(updated_tool)
|
||||
|
||||
|
||||
@admin_router.delete("/custom/{tool_id}")
|
||||
def delete_custom_tool(
|
||||
tool_id: int,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
delete_tool(tool_id, db_session)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
except Exception as e:
|
||||
# handles case where tool is still used by an Assistant
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
class ValidateToolRequest(BaseModel):
|
||||
definition: dict[str, Any]
|
||||
|
||||
|
||||
class ValidateToolResponse(BaseModel):
|
||||
methods: list[MethodSpec]
|
||||
|
||||
|
||||
@admin_router.post("/custom/validate")
|
||||
def validate_tool(
|
||||
tool_data: ValidateToolRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> ValidateToolResponse:
|
||||
_validate_tool_definition(tool_data.definition)
|
||||
method_specs = openapi_to_method_specs(tool_data.definition)
|
||||
return ValidateToolResponse(methods=method_specs)
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
@router.get("/{tool_id}")
|
||||
def get_custom_tool(
|
||||
tool_id: int,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> ToolSnapshot:
|
||||
try:
|
||||
tool = get_tool_by_id(tool_id, db_session)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
return ToolSnapshot.from_model(tool)
|
||||
|
||||
|
||||
@router.get("")
|
||||
@ -34,5 +134,5 @@ def list_tools(
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> list[ToolSnapshot]:
|
||||
tools = db_session.execute(select(Tool)).scalars().all()
|
||||
tools = get_tools(db_session)
|
||||
return [ToolSnapshot.from_model(tool) for tool in tools]
|
||||
|
23
backend/danswer/server/features/tool/models.py
Normal file
23
backend/danswer/server/features/tool/models.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import Tool
|
||||
|
||||
|
||||
class ToolSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
definition: dict[str, Any] | None
|
||||
in_code_tool_id: str | None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, tool: Tool) -> "ToolSnapshot":
|
||||
return cls(
|
||||
id=tool.id,
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
definition=tool.openapi_schema,
|
||||
in_code_tool_id=tool.in_code_tool_id,
|
||||
)
|
@ -129,6 +129,8 @@ def get_chat_session(
|
||||
# we already did a permission check above with the call to
|
||||
# `get_chat_session_by_id`, so we can skip it here
|
||||
skip_permission_check=True,
|
||||
# we need the tool call objs anyways, so just fetch them in a single call
|
||||
prefetch_tool_calls=True,
|
||||
)
|
||||
|
||||
return ChatSessionDetailResponse(
|
||||
|
@ -17,6 +17,7 @@ from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import SearchDoc
|
||||
from danswer.search.models import Tag
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
@ -176,6 +177,7 @@ class ChatMessageDetail(BaseModel):
|
||||
# Dict mapping citation number to db_doc_id
|
||||
citations: dict[int, int] | None
|
||||
files: list[FileDescriptor]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
|
233
backend/danswer/tools/custom/custom_tool.py
Normal file
233
backend/danswer/tools/custom/custom_tool.py
Normal file
@ -0,0 +1,233 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.custom.custom_tool_prompts import (
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT,
|
||||
)
|
||||
from danswer.tools.custom.custom_tool_prompts import SHOULD_USE_CUSTOM_TOOL_USER_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_SYSTEM_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import TOOL_ARG_USER_PROMPT
|
||||
from danswer.tools.custom.custom_tool_prompts import USE_TOOL
|
||||
from danswer.tools.custom.openapi_parsing import MethodSpec
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_method_specs
|
||||
from danswer.tools.custom.openapi_parsing import openapi_to_url
|
||||
from danswer.tools.custom.openapi_parsing import REQUEST_BODY
|
||||
from danswer.tools.custom.openapi_parsing import validate_openapi_schema
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
CUSTOM_TOOL_RESPONSE_ID = "custom_tool_response"
|
||||
|
||||
|
||||
class CustomToolCallSummary(BaseModel):
|
||||
tool_name: str
|
||||
tool_result: dict
|
||||
|
||||
|
||||
class CustomTool(Tool):
|
||||
def __init__(self, method_spec: MethodSpec, base_url: str) -> None:
|
||||
self._base_url = base_url
|
||||
self._method_spec = method_spec
|
||||
self._tool_definition = self._method_spec.to_tool_definition()
|
||||
|
||||
self._name = self._method_spec.name
|
||||
self.description = self._method_spec.summary
|
||||
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
"""For LLMs which support explicit tool calling"""
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return self._tool_definition
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
response = cast(CustomToolCallSummary, args[0].response)
|
||||
return json.dumps(response.tool_result)
|
||||
|
||||
"""For LLMs which do NOT support explicit tool calling"""
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
if not force_run:
|
||||
should_use_result = llm.invoke(
|
||||
[
|
||||
SystemMessage(content=SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=SHOULD_USE_CUSTOM_TOOL_USER_PROMPT.format(
|
||||
history=history,
|
||||
query=query,
|
||||
tool_name=self.name(),
|
||||
tool_description=self.description,
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
if cast(str, should_use_result.content).strip() != USE_TOOL:
|
||||
return None
|
||||
|
||||
args_result = llm.invoke(
|
||||
[
|
||||
SystemMessage(content=TOOL_ARG_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=TOOL_ARG_USER_PROMPT.format(
|
||||
history=history,
|
||||
query=query,
|
||||
tool_name=self.name(),
|
||||
tool_description=self.description,
|
||||
tool_args=self.tool_definition()["function"]["parameters"],
|
||||
)
|
||||
),
|
||||
]
|
||||
)
|
||||
args_result_str = cast(str, args_result.content)
|
||||
|
||||
try:
|
||||
return json.loads(args_result_str.strip())
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# try removing ```
|
||||
try:
|
||||
return json.loads(args_result_str.strip("```"))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# try removing ```json
|
||||
try:
|
||||
return json.loads(args_result_str.strip("```").strip("json"))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# pretend like nothing happened if not parse-able
|
||||
logger.error(
|
||||
f"Failed to parse args for '{self.name()}' tool. Recieved: {args_result_str}"
|
||||
)
|
||||
return None
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
request_body = kwargs.get(REQUEST_BODY)
|
||||
|
||||
path_params = {}
|
||||
for path_param_schema in self._method_spec.get_path_param_schemas():
|
||||
path_params[path_param_schema["name"]] = kwargs[path_param_schema["name"]]
|
||||
|
||||
query_params = {}
|
||||
for query_param_schema in self._method_spec.get_query_param_schemas():
|
||||
if query_param_schema["name"] in kwargs:
|
||||
query_params[query_param_schema["name"]] = kwargs[
|
||||
query_param_schema["name"]
|
||||
]
|
||||
|
||||
url = self._method_spec.build_url(self._base_url, path_params, query_params)
|
||||
method = self._method_spec.method
|
||||
|
||||
response = requests.request(method, url, json=request_body)
|
||||
|
||||
yield ToolResponse(
|
||||
id=CUSTOM_TOOL_RESPONSE_ID,
|
||||
response=CustomToolCallSummary(
|
||||
tool_name=self._name, tool_result=response.json()
|
||||
),
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
return cast(CustomToolCallSummary, args[0].response).tool_result
|
||||
|
||||
|
||||
def build_custom_tools_from_openapi_schema(
|
||||
openapi_schema: dict[str, Any]
|
||||
) -> list[CustomTool]:
|
||||
url = openapi_to_url(openapi_schema)
|
||||
method_specs = openapi_to_method_specs(openapi_schema)
|
||||
return [CustomTool(method_spec, url) for method_spec in method_specs]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import openai
|
||||
|
||||
openapi_schema = {
|
||||
"openapi": "3.0.0",
|
||||
"info": {
|
||||
"version": "1.0.0",
|
||||
"title": "Assistants API",
|
||||
"description": "An API for managing assistants",
|
||||
},
|
||||
"servers": [
|
||||
{"url": "http://localhost:8080"},
|
||||
],
|
||||
"paths": {
|
||||
"/assistant/{assistant_id}": {
|
||||
"get": {
|
||||
"summary": "Get a specific Assistant",
|
||||
"operationId": "getAssistant",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "assistant_id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
}
|
||||
],
|
||||
},
|
||||
"post": {
|
||||
"summary": "Create a new Assistant",
|
||||
"operationId": "createAssistant",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "assistant_id",
|
||||
"in": "path",
|
||||
"required": True,
|
||||
"schema": {"type": "string"},
|
||||
}
|
||||
],
|
||||
"requestBody": {
|
||||
"required": True,
|
||||
"content": {"application/json": {"schema": {"type": "object"}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
},
|
||||
}
|
||||
validate_openapi_schema(openapi_schema)
|
||||
|
||||
tools = build_custom_tools_from_openapi_schema(openapi_schema)
|
||||
|
||||
openai_client = openai.OpenAI()
|
||||
response = openai_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Can you fetch assistant with ID 10"},
|
||||
],
|
||||
tools=[tool.tool_definition() for tool in tools], # type: ignore
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if choice.message.tool_calls:
|
||||
print(choice.message.tool_calls)
|
||||
for tool_response in tools[0].run(
|
||||
**json.loads(choice.message.tool_calls[0].function.arguments)
|
||||
):
|
||||
print(tool_response)
|
21
backend/danswer/tools/custom/custom_tool_prompt_builder.py
Normal file
21
backend/danswer/tools/custom/custom_tool_prompt_builder.py
Normal file
@ -0,0 +1,21 @@
|
||||
from typing import cast
|
||||
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
def build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
query: str,
|
||||
tool_name: str,
|
||||
*args: ToolResponse,
|
||||
) -> str:
|
||||
tool_run_summary = cast(CustomToolCallSummary, args[0].response).tool_result
|
||||
return f"""
|
||||
Here's the result from the {tool_name} tool:
|
||||
|
||||
{tool_run_summary}
|
||||
|
||||
Now respond to the following:
|
||||
|
||||
{query}
|
||||
""".strip()
|
57
backend/danswer/tools/custom/custom_tool_prompts.py
Normal file
57
backend/danswer/tools/custom/custom_tool_prompts.py
Normal file
@ -0,0 +1,57 @@
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
|
||||
DONT_USE_TOOL = "Don't use tool"
|
||||
USE_TOOL = "Use tool"
|
||||
|
||||
|
||||
"""Prompts to determine if we should use a custom tool or not."""
|
||||
|
||||
|
||||
SHOULD_USE_CUSTOM_TOOL_SYSTEM_PROMPT = (
|
||||
"You are a large language model whose only job is to determine if the system should call an "
|
||||
"external tool to be able to answer the user's last message."
|
||||
).strip()
|
||||
|
||||
SHOULD_USE_CUSTOM_TOOL_USER_PROMPT = f"""
|
||||
Given the conversation history and a follow up query, determine if the system should use the \
|
||||
'{{tool_name}}' tool to answer the user's query. The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'.
|
||||
|
||||
Respond with "{USE_TOOL}" if you think the tool would be helpful in respnding to the users query.
|
||||
Respond with "{DONT_USE_TOOL}" otherwise.
|
||||
|
||||
Conversation History:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
If you are at all unsure, respond with {DONT_USE_TOOL}.
|
||||
Respond with EXACTLY and ONLY "{DONT_USE_TOOL}" or "{USE_TOOL}"
|
||||
|
||||
Follow up input:
|
||||
{{query}}
|
||||
""".strip()
|
||||
|
||||
|
||||
"""Prompts to figure out the arguments to pass to a custom tool."""
|
||||
|
||||
|
||||
TOOL_ARG_SYSTEM_PROMPT = (
|
||||
"You are a large language model whose only job is to determine the arguments to pass to an "
|
||||
"external tool."
|
||||
).strip()
|
||||
|
||||
|
||||
TOOL_ARG_USER_PROMPT = f"""
|
||||
Given the following conversation and a follow up input, generate a \
|
||||
dictionary of arguments to pass to the '{{tool_name}}' tool. \
|
||||
The '{{tool_name}}' tool is a tool defined as: '{{tool_description}}'. \
|
||||
The expected arguments are: {{tool_args}}.
|
||||
|
||||
Conversation:
|
||||
{{history}}
|
||||
|
||||
Follow up input:
|
||||
{{query}}
|
||||
|
||||
Respond with ONLY and EXACTLY a JSON object specifying the values of the arguments to pass to the tool.
|
||||
""".strip() # noqa: F541
|
225
backend/danswer/tools/custom/openapi_parsing.py
Normal file
225
backend/danswer/tools/custom/openapi_parsing.py
Normal file
@ -0,0 +1,225 @@
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from openai import BaseModel
|
||||
|
||||
REQUEST_BODY = "requestBody"
|
||||
|
||||
|
||||
class PathSpec(BaseModel):
|
||||
path: str
|
||||
methods: dict[str, Any]
|
||||
|
||||
|
||||
class MethodSpec(BaseModel):
|
||||
name: str
|
||||
summary: str
|
||||
path: str
|
||||
method: str
|
||||
spec: dict[str, Any]
|
||||
|
||||
def get_request_body_schema(self) -> dict[str, Any]:
|
||||
content = self.spec.get("requestBody", {}).get("content", {})
|
||||
if "application/json" in content:
|
||||
return content["application/json"].get("schema")
|
||||
|
||||
if content:
|
||||
raise ValueError(
|
||||
f"Unsupported content type: '{list(content.keys())[0]}'. "
|
||||
f"Only 'application/json' is supported."
|
||||
)
|
||||
|
||||
return {}
|
||||
|
||||
def get_query_param_schemas(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
param
|
||||
for param in self.spec.get("parameters", [])
|
||||
if "schema" in param and "in" in param and param["in"] == "query"
|
||||
]
|
||||
|
||||
def get_path_param_schemas(self) -> list[dict[str, Any]]:
|
||||
return [
|
||||
param
|
||||
for param in self.spec.get("parameters", [])
|
||||
if "schema" in param and "in" in param and param["in"] == "path"
|
||||
]
|
||||
|
||||
def build_url(
|
||||
self, base_url: str, path_params: dict[str, str], query_params: dict[str, str]
|
||||
) -> str:
|
||||
url = f"{base_url}{self.path}"
|
||||
try:
|
||||
url = url.format(**path_params)
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Missing path parameter: {e}")
|
||||
if query_params:
|
||||
url += "?"
|
||||
for param, value in query_params.items():
|
||||
url += f"{param}={value}&"
|
||||
url = url[:-1]
|
||||
return url
|
||||
|
||||
def to_tool_definition(self) -> dict[str, Any]:
|
||||
tool_definition: Any = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.summary,
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
|
||||
request_body_schema = self.get_request_body_schema()
|
||||
if request_body_schema:
|
||||
tool_definition["function"]["parameters"]["properties"][
|
||||
REQUEST_BODY
|
||||
] = request_body_schema
|
||||
|
||||
query_param_schemas = self.get_query_param_schemas()
|
||||
if query_param_schemas:
|
||||
tool_definition["function"]["parameters"]["properties"].update(
|
||||
{param["name"]: param["schema"] for param in query_param_schemas}
|
||||
)
|
||||
|
||||
path_param_schemas = self.get_path_param_schemas()
|
||||
if path_param_schemas:
|
||||
tool_definition["function"]["parameters"]["properties"].update(
|
||||
{param["name"]: param["schema"] for param in path_param_schemas}
|
||||
)
|
||||
return tool_definition
|
||||
|
||||
def validate_spec(self) -> None:
|
||||
# Validate url construction
|
||||
path_param_schemas = self.get_path_param_schemas()
|
||||
dummy_path_dict = {param["name"]: "value" for param in path_param_schemas}
|
||||
query_param_schemas = self.get_query_param_schemas()
|
||||
dummy_query_dict = {param["name"]: "value" for param in query_param_schemas}
|
||||
self.build_url("", dummy_path_dict, dummy_query_dict)
|
||||
|
||||
# Make sure request body doesn't throw an exception
|
||||
self.get_request_body_schema()
|
||||
|
||||
# Ensure the method is valid
|
||||
if not self.method:
|
||||
raise ValueError("HTTP method is not specified.")
|
||||
if self.method.upper() not in ["GET", "POST", "PUT", "DELETE", "PATCH"]:
|
||||
raise ValueError(f"HTTP method '{self.method}' is not supported.")
|
||||
|
||||
|
||||
"""Path-level utils"""
|
||||
|
||||
|
||||
def openapi_to_path_specs(openapi_spec: dict[str, Any]) -> list[PathSpec]:
|
||||
path_specs = []
|
||||
|
||||
for path, methods in openapi_spec.get("paths", {}).items():
|
||||
path_specs.append(PathSpec(path=path, methods=methods))
|
||||
|
||||
return path_specs
|
||||
|
||||
|
||||
"""Method-level utils"""
|
||||
|
||||
|
||||
def openapi_to_method_specs(openapi_spec: dict[str, Any]) -> list[MethodSpec]:
|
||||
path_specs = openapi_to_path_specs(openapi_spec)
|
||||
|
||||
method_specs = []
|
||||
for path_spec in path_specs:
|
||||
for method_name, method in path_spec.methods.items():
|
||||
name = method.get("operationId")
|
||||
if not name:
|
||||
raise ValueError(
|
||||
f"Operation ID is not specified for {method_name.upper()} {path_spec.path}"
|
||||
)
|
||||
|
||||
summary = method.get("summary") or method.get("description")
|
||||
if not summary:
|
||||
raise ValueError(
|
||||
f"Summary is not specified for {method_name.upper()} {path_spec.path}"
|
||||
)
|
||||
|
||||
method_specs.append(
|
||||
MethodSpec(
|
||||
name=name,
|
||||
summary=summary,
|
||||
path=path_spec.path,
|
||||
method=method_name,
|
||||
spec=method,
|
||||
)
|
||||
)
|
||||
|
||||
if not method_specs:
|
||||
raise ValueError("No methods found in OpenAPI schema")
|
||||
|
||||
return method_specs
|
||||
|
||||
|
||||
def openapi_to_url(openapi_schema: dict[str, dict | str]) -> str:
|
||||
"""
|
||||
Extract URLs from the servers section of an OpenAPI schema.
|
||||
|
||||
Args:
|
||||
openapi_schema (Dict[str, Union[Dict, str, List]]): The OpenAPI schema in dictionary format.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of base URLs.
|
||||
"""
|
||||
urls: list[str] = []
|
||||
|
||||
servers = cast(list[dict[str, Any]], openapi_schema.get("servers", []))
|
||||
for server in servers:
|
||||
url = server.get("url")
|
||||
if url:
|
||||
urls.append(url)
|
||||
|
||||
if len(urls) != 1:
|
||||
raise ValueError(
|
||||
f"Expected exactly one URL in OpenAPI schema, but found {urls}"
|
||||
)
|
||||
|
||||
return urls[0]
|
||||
|
||||
|
||||
def validate_openapi_schema(schema: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the given JSON schema as an OpenAPI schema.
|
||||
|
||||
Parameters:
|
||||
- schema (dict): The JSON schema to validate.
|
||||
|
||||
Returns:
|
||||
- bool: True if the schema is valid, False otherwise.
|
||||
"""
|
||||
|
||||
# check basic structure
|
||||
if "info" not in schema:
|
||||
raise ValueError("`info` section is required in OpenAPI schema")
|
||||
|
||||
info = schema["info"]
|
||||
if "title" not in info:
|
||||
raise ValueError("`title` is required in `info` section of OpenAPI schema")
|
||||
if "description" not in info:
|
||||
raise ValueError(
|
||||
"`description` is required in `info` section of OpenAPI schema"
|
||||
)
|
||||
|
||||
if "openapi" not in schema:
|
||||
raise ValueError(
|
||||
"`openapi` field which specifies OpenAPI schema version is required"
|
||||
)
|
||||
openapi_version = schema["openapi"]
|
||||
if not openapi_version.startswith("3."):
|
||||
raise ValueError(f"OpenAPI version '{openapi_version}' is not supported")
|
||||
|
||||
if "paths" not in schema:
|
||||
raise ValueError("`paths` section is required in OpenAPI schema")
|
||||
|
||||
url = openapi_to_url(schema)
|
||||
if not url:
|
||||
raise ValueError("OpenAPI schema does not contain a valid URL in `servers`")
|
||||
|
||||
method_specs = openapi_to_method_specs(schema)
|
||||
for method_spec in method_specs:
|
||||
method_spec.validate_spec()
|
@ -1,12 +0,0 @@
|
||||
from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def get_tool_cls(tool: ToolDBModel, db_session: Session) -> Type[Tool]:
|
||||
# Currently only support built-in tools
|
||||
return get_built_in_tool_by_id(tool.id, db_session)
|
@ -8,6 +8,7 @@ from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
@ -53,6 +54,8 @@ class ImageGenerationResponse(BaseModel):
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool):
|
||||
NAME = "run_image_generation"
|
||||
|
||||
def __init__(
|
||||
self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2
|
||||
) -> None:
|
||||
@ -60,16 +63,14 @@ class ImageGenerationTool(Tool):
|
||||
self.model = model
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
@classmethod
|
||||
def name(self) -> str:
|
||||
return "run_image_generation"
|
||||
return self.NAME
|
||||
|
||||
@classmethod
|
||||
def tool_definition(cls) -> dict:
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": cls.name(),
|
||||
"name": self.name(),
|
||||
"description": "Generate an image from a prompt",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
@ -162,3 +163,12 @@ class ImageGenerationTool(Tool):
|
||||
id=IMAGE_GENERATION_RESPONSE_ID,
|
||||
response=results,
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
image_generation_responses = cast(
|
||||
list[ImageGenerationResponse], args[0].response
|
||||
)
|
||||
return [
|
||||
image_generation_response.dict()
|
||||
for image_generation_response in image_generation_responses
|
||||
]
|
||||
|
39
backend/danswer/tools/models.py
Normal file
39
backend/danswer/tools/models.py
Normal file
@ -0,0 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import root_validator
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str | None = None
|
||||
response: Any
|
||||
|
||||
|
||||
class ToolCallKickoff(BaseModel):
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class ToolRunnerResponse(BaseModel):
|
||||
tool_run_kickoff: ToolCallKickoff | None = None
|
||||
tool_response: ToolResponse | None = None
|
||||
tool_message_content: str | list[str | dict[str, Any]] | None = None
|
||||
|
||||
@root_validator
|
||||
def validate_tool_runner_response(
|
||||
cls, values: dict[str, ToolResponse | str]
|
||||
) -> dict[str, ToolResponse | str]:
|
||||
fields = ["tool_response", "tool_message_content", "tool_run_kickoff"]
|
||||
provided = sum(1 for field in fields if values.get(field) is not None)
|
||||
|
||||
if provided != 1:
|
||||
raise ValueError(
|
||||
"Exactly one of 'tool_response', 'tool_message_content', "
|
||||
"or 'tool_run_kickoff' must be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ToolCallFinalResult(ToolCallKickoff):
|
||||
tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature
|
@ -10,6 +10,7 @@ from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.doc_pruning import prune_documents
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
@ -55,6 +56,8 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
|
||||
|
||||
|
||||
class SearchTool(Tool):
|
||||
NAME = "run_search"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_session: Session,
|
||||
@ -87,18 +90,16 @@ class SearchTool(Tool):
|
||||
self.bypass_acl = bypass_acl
|
||||
self.db_session = db_session
|
||||
|
||||
@classmethod
|
||||
def name(cls) -> str:
|
||||
return "run_search"
|
||||
def name(self) -> str:
|
||||
return self.NAME
|
||||
|
||||
"""For explicit tool calling"""
|
||||
|
||||
@classmethod
|
||||
def tool_definition(cls) -> dict:
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": cls.name(),
|
||||
"name": self.name(),
|
||||
"description": search_tool_description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
@ -241,3 +242,13 @@ class SearchTool(Tool):
|
||||
document_pruning_config=self.pruning_config,
|
||||
)
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=final_context_documents)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
final_docs = cast(
|
||||
list[LlmDoc],
|
||||
next(arg.response for arg in args if arg.id == FINAL_CONTEXT_DOCUMENTS),
|
||||
)
|
||||
# NOTE: need to do this json.loads(doc.json()) stuff because there are some
|
||||
# subfields that are not serializable by default (datetime)
|
||||
# this forces pydantic to make them JSON serializable for us
|
||||
return [json.loads(doc.json()) for doc in final_docs]
|
||||
|
@ -2,26 +2,19 @@ import abc
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
id: str | None = None
|
||||
response: Any
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
class Tool(abc.ABC):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
"""For LLMs which support explicit tool calling"""
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def tool_definition(self) -> dict:
|
||||
raise NotImplementedError
|
||||
@ -49,3 +42,11 @@ class Tool(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
"""
|
||||
This is the "final summary" result of the tool.
|
||||
It is the result that will be stored in the database.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
@ -1,42 +1,15 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import root_validator
|
||||
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
class ToolRunKickoff(BaseModel):
|
||||
tool_name: str
|
||||
tool_args: dict[str, Any]
|
||||
|
||||
|
||||
class ToolRunnerResponse(BaseModel):
|
||||
tool_run_kickoff: ToolRunKickoff | None = None
|
||||
tool_response: ToolResponse | None = None
|
||||
tool_message_content: str | list[str | dict[str, Any]] | None = None
|
||||
|
||||
@root_validator
|
||||
def validate_tool_runner_response(
|
||||
cls, values: dict[str, ToolResponse | str]
|
||||
) -> dict[str, ToolResponse | str]:
|
||||
fields = ["tool_response", "tool_message_content", "tool_run_kickoff"]
|
||||
provided = sum(1 for field in fields if values.get(field) is not None)
|
||||
|
||||
if provided != 1:
|
||||
raise ValueError(
|
||||
"Exactly one of 'tool_response', 'tool_message_content', "
|
||||
"or 'tool_run_kickoff' must be provided"
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class ToolRunner:
|
||||
def __init__(self, tool: Tool, args: dict[str, Any]):
|
||||
self.tool = tool
|
||||
@ -44,8 +17,8 @@ class ToolRunner:
|
||||
|
||||
self._tool_responses: list[ToolResponse] | None = None
|
||||
|
||||
def kickoff(self) -> ToolRunKickoff:
|
||||
return ToolRunKickoff(tool_name=self.tool.name(), tool_args=self.args)
|
||||
def kickoff(self) -> ToolCallKickoff:
|
||||
return ToolCallKickoff(tool_name=self.tool.name(), tool_args=self.args)
|
||||
|
||||
def tool_responses(self) -> Generator[ToolResponse, None, None]:
|
||||
if self._tool_responses is not None:
|
||||
@ -62,6 +35,13 @@ class ToolRunner:
|
||||
tool_responses = list(self.tool_responses())
|
||||
return self.tool.build_tool_message_content(*tool_responses)
|
||||
|
||||
def tool_final_result(self) -> ToolCallFinalResult:
|
||||
return ToolCallFinalResult(
|
||||
tool_name=self.tool.name(),
|
||||
tool_args=self.args,
|
||||
tool_result=self.tool.final_result(*self.tool_responses()),
|
||||
)
|
||||
|
||||
|
||||
def check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM
|
||||
|
@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import Type
|
||||
|
||||
from tiktoken import Encoding
|
||||
|
||||
@ -17,15 +16,13 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo
|
||||
return False
|
||||
|
||||
|
||||
def compute_tool_tokens(
|
||||
tool: Tool | Type[Tool], llm_tokenizer: Encoding | None = None
|
||||
) -> int:
|
||||
def compute_tool_tokens(tool: Tool, llm_tokenizer: Encoding | None = None) -> int:
|
||||
if not llm_tokenizer:
|
||||
llm_tokenizer = get_default_llm_tokenizer()
|
||||
return len(llm_tokenizer.encode(json.dumps(tool.tool_definition())))
|
||||
|
||||
|
||||
def compute_all_tool_tokens(
|
||||
tools: list[Tool] | list[Type[Tool]], llm_tokenizer: Encoding | None = None
|
||||
tools: list[Tool], llm_tokenizer: Encoding | None = None
|
||||
) -> int:
|
||||
return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools)
|
||||
|
@ -24,6 +24,7 @@ httpx[http2]==0.23.3
|
||||
httpx-oauth==0.11.2
|
||||
huggingface-hub==0.20.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
langchain==0.1.17
|
||||
langchain-community==0.0.36
|
||||
langchain-core==0.1.50
|
||||
|
File diff suppressed because it is too large
Load Diff
261
web/src/app/admin/tools/ToolEditor.tsx
Normal file
261
web/src/app/admin/tools/ToolEditor.tsx
Normal file
@ -0,0 +1,261 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { Formik, Form, Field, ErrorMessage } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { MethodSpec, ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { Button, Divider } from "@tremor/react";
|
||||
import {
|
||||
createCustomTool,
|
||||
updateCustomTool,
|
||||
validateToolDefinition,
|
||||
} from "@/lib/tools/edit";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import debounce from "lodash/debounce";
|
||||
|
||||
function parseJsonWithTrailingCommas(jsonString: string) {
|
||||
// Regular expression to remove trailing commas before } or ]
|
||||
let cleanedJsonString = jsonString.replace(/,\s*([}\]])/g, "$1");
|
||||
// Replace True with true, False with false, and None with null
|
||||
cleanedJsonString = cleanedJsonString
|
||||
.replace(/\bTrue\b/g, "true")
|
||||
.replace(/\bFalse\b/g, "false")
|
||||
.replace(/\bNone\b/g, "null");
|
||||
// Now parse the cleaned JSON string
|
||||
return JSON.parse(cleanedJsonString);
|
||||
}
|
||||
|
||||
function prettifyDefinition(definition: any) {
|
||||
return JSON.stringify(definition, null, 2);
|
||||
}
|
||||
|
||||
function ToolForm({
|
||||
existingTool,
|
||||
values,
|
||||
setFieldValue,
|
||||
isSubmitting,
|
||||
definitionErrorState,
|
||||
methodSpecsState,
|
||||
}: {
|
||||
existingTool?: ToolSnapshot;
|
||||
values: ToolFormValues;
|
||||
setFieldValue: (field: string, value: string) => void;
|
||||
isSubmitting: boolean;
|
||||
definitionErrorState: [
|
||||
string | null,
|
||||
React.Dispatch<React.SetStateAction<string | null>>,
|
||||
];
|
||||
methodSpecsState: [
|
||||
MethodSpec[] | null,
|
||||
React.Dispatch<React.SetStateAction<MethodSpec[] | null>>,
|
||||
];
|
||||
}) {
|
||||
const [definitionError, setDefinitionError] = definitionErrorState;
|
||||
const [methodSpecs, setMethodSpecs] = methodSpecsState;
|
||||
|
||||
const debouncedValidateDefinition = useCallback(
|
||||
debounce(async (definition: string) => {
|
||||
try {
|
||||
const parsedDefinition = parseJsonWithTrailingCommas(definition);
|
||||
const response = await validateToolDefinition({
|
||||
definition: parsedDefinition,
|
||||
});
|
||||
if (response.error) {
|
||||
setMethodSpecs(null);
|
||||
setDefinitionError(response.error);
|
||||
} else {
|
||||
setMethodSpecs(response.data);
|
||||
setDefinitionError(null);
|
||||
}
|
||||
} catch (error) {
|
||||
console.log(error);
|
||||
setMethodSpecs(null);
|
||||
setDefinitionError("Invalid JSON format");
|
||||
}
|
||||
}, 300),
|
||||
[]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (values.definition) {
|
||||
debouncedValidateDefinition(values.definition);
|
||||
}
|
||||
}, [values.definition, debouncedValidateDefinition]);
|
||||
|
||||
return (
|
||||
<Form>
|
||||
<div className="relative">
|
||||
<TextFormField
|
||||
name="definition"
|
||||
label="Definition"
|
||||
subtext="Specify an OpenAPI schema that defines the APIs you want to make available as part of this tool."
|
||||
placeholder="Enter your OpenAPI schema here"
|
||||
isTextArea={true}
|
||||
defaultHeight="h-96"
|
||||
fontSize="text-sm"
|
||||
isCode
|
||||
hideError
|
||||
/>
|
||||
<button
|
||||
type="button"
|
||||
className="
|
||||
absolute
|
||||
bottom-4
|
||||
right-4
|
||||
border-border
|
||||
border
|
||||
bg-background
|
||||
rounded
|
||||
py-1
|
||||
px-3
|
||||
text-sm
|
||||
hover:bg-hover-light
|
||||
"
|
||||
onClick={() => {
|
||||
const definition = values.definition;
|
||||
if (definition) {
|
||||
try {
|
||||
const formatted = prettifyDefinition(
|
||||
parseJsonWithTrailingCommas(definition)
|
||||
);
|
||||
setFieldValue("definition", formatted);
|
||||
} catch (error) {
|
||||
alert("Invalid JSON format");
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
Format
|
||||
</button>
|
||||
</div>
|
||||
{definitionError && (
|
||||
<div className="text-error text-sm">{definitionError}</div>
|
||||
)}
|
||||
<ErrorMessage
|
||||
name="definition"
|
||||
component="div"
|
||||
className="text-error text-sm"
|
||||
/>
|
||||
|
||||
{methodSpecs && methodSpecs.length > 0 && (
|
||||
<div className="mt-4">
|
||||
<h3 className="text-base font-semibold mb-2">Available methods</h3>
|
||||
<div className="overflow-x-auto">
|
||||
<table className="min-w-full bg-white border border-gray-200">
|
||||
<thead>
|
||||
<tr>
|
||||
<th className="px-4 py-2 border-b">Name</th>
|
||||
<th className="px-4 py-2 border-b">Summary</th>
|
||||
<th className="px-4 py-2 border-b">Method</th>
|
||||
<th className="px-4 py-2 border-b">Path</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{methodSpecs?.map((method: MethodSpec, index: number) => (
|
||||
<tr key={index} className="text-sm">
|
||||
<td className="px-4 py-2 border-b">{method.name}</td>
|
||||
<td className="px-4 py-2 border-b">{method.summary}</td>
|
||||
<td className="px-4 py-2 border-b">
|
||||
{method.method.toUpperCase()}
|
||||
</td>
|
||||
<td className="px-4 py-2 border-b">{method.path}</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Divider />
|
||||
<div className="flex">
|
||||
<Button
|
||||
className="mx-auto"
|
||||
color="green"
|
||||
size="md"
|
||||
type="submit"
|
||||
disabled={isSubmitting || !!definitionError}
|
||||
>
|
||||
{existingTool ? "Update Tool" : "Create Tool"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
);
|
||||
}
|
||||
|
||||
interface ToolFormValues {
|
||||
definition: string;
|
||||
}
|
||||
|
||||
const ToolSchema = Yup.object().shape({
|
||||
definition: Yup.string().required("Tool definition is required"),
|
||||
});
|
||||
|
||||
export function ToolEditor({ tool }: { tool?: ToolSnapshot }) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [definitionError, setDefinitionError] = useState<string | null>(null);
|
||||
const [methodSpecs, setMethodSpecs] = useState<MethodSpec[] | null>(null);
|
||||
|
||||
const prettifiedDefinition = tool?.definition
|
||||
? prettifyDefinition(tool.definition)
|
||||
: "";
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
<Formik
|
||||
initialValues={{
|
||||
definition: prettifiedDefinition,
|
||||
}}
|
||||
validationSchema={ToolSchema}
|
||||
onSubmit={async (values: ToolFormValues) => {
|
||||
let definition: any;
|
||||
try {
|
||||
definition = parseJsonWithTrailingCommas(values.definition);
|
||||
} catch (error) {
|
||||
setDefinitionError("Invalid JSON in tool definition");
|
||||
return;
|
||||
}
|
||||
|
||||
const name = definition?.info?.title;
|
||||
const description = definition?.info?.description;
|
||||
const toolData = {
|
||||
name: name,
|
||||
description: description || "",
|
||||
definition: definition,
|
||||
};
|
||||
let response;
|
||||
if (tool) {
|
||||
response = await updateCustomTool(tool.id, toolData);
|
||||
} else {
|
||||
response = await createCustomTool(toolData);
|
||||
}
|
||||
if (response.error) {
|
||||
setPopup({
|
||||
message: "Failed to create tool - " + response.error,
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
router.push(`/admin/tools?u=${Date.now()}`);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => {
|
||||
return (
|
||||
<ToolForm
|
||||
existingTool={tool}
|
||||
values={values}
|
||||
setFieldValue={setFieldValue}
|
||||
isSubmitting={isSubmitting}
|
||||
definitionErrorState={[definitionError, setDefinitionError]}
|
||||
methodSpecsState={[methodSpecs, setMethodSpecs]}
|
||||
/>
|
||||
);
|
||||
}}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
}
|
107
web/src/app/admin/tools/ToolsTable.tsx
Normal file
107
web/src/app/admin/tools/ToolsTable.tsx
Normal file
@ -0,0 +1,107 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Text,
|
||||
Table,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableHeaderCell,
|
||||
TableBody,
|
||||
TableCell,
|
||||
} from "@tremor/react";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { FiCheckCircle, FiEdit, FiXCircle } from "react-icons/fi";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { deleteCustomTool } from "@/lib/tools/edit";
|
||||
|
||||
export function ToolsTable({ tools }: { tools: ToolSnapshot[] }) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const sortedTools = [...tools];
|
||||
sortedTools.sort((a, b) => a.id - b.id);
|
||||
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
|
||||
<Table>
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Name</TableHeaderCell>
|
||||
<TableHeaderCell>Description</TableHeaderCell>
|
||||
<TableHeaderCell>Built In?</TableHeaderCell>
|
||||
<TableHeaderCell>Delete</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{sortedTools.map((tool) => (
|
||||
<TableRow key={tool.id.toString()}>
|
||||
<TableCell>
|
||||
<div className="flex">
|
||||
{tool.in_code_tool_id === null && (
|
||||
<FiEdit
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() =>
|
||||
router.push(
|
||||
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
|
||||
)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
<p className="text font-medium whitespace-normal break-none">
|
||||
{tool.name}
|
||||
</p>
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="whitespace-normal break-all max-w-2xl">
|
||||
{tool.description}
|
||||
</TableCell>
|
||||
<TableCell className="whitespace-nowrap">
|
||||
{tool.in_code_tool_id === null ? (
|
||||
<span>
|
||||
<FiXCircle className="inline-block mr-1 my-auto" />
|
||||
No
|
||||
</span>
|
||||
) : (
|
||||
<span>
|
||||
<FiCheckCircle className="inline-block mr-1 my-auto" />
|
||||
Yes
|
||||
</span>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell className="whitespace-nowrap">
|
||||
<div className="flex">
|
||||
{tool.in_code_tool_id === null ? (
|
||||
<div className="my-auto">
|
||||
<div
|
||||
className="hover:bg-hover rounded p-1 cursor-pointer"
|
||||
onClick={async () => {
|
||||
const response = await deleteCustomTool(tool.id);
|
||||
if (response.data) {
|
||||
router.refresh();
|
||||
} else {
|
||||
setPopup({
|
||||
message: `Failed to delete tool - ${response.error}`,
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
}}
|
||||
>
|
||||
<TrashIcon />
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
"-"
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
28
web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx
Normal file
28
web/src/app/admin/tools/edit/[toolId]/DeleteToolButton.tsx
Normal file
@ -0,0 +1,28 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@tremor/react";
|
||||
import { FiTrash } from "react-icons/fi";
|
||||
import { deleteCustomTool } from "@/lib/tools/edit";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function DeleteToolButton({ toolId }: { toolId: number }) {
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
<Button
|
||||
size="xs"
|
||||
color="red"
|
||||
onClick={async () => {
|
||||
const response = await deleteCustomTool(toolId);
|
||||
if (response.data) {
|
||||
router.push(`/admin/tools?u=${Date.now()}`);
|
||||
} else {
|
||||
alert(`Failed to delete tool - ${response.error}`);
|
||||
}
|
||||
}}
|
||||
icon={FiTrash}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
);
|
||||
}
|
55
web/src/app/admin/tools/edit/[toolId]/page.tsx
Normal file
55
web/src/app/admin/tools/edit/[toolId]/page.tsx
Normal file
@ -0,0 +1,55 @@
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { Card, Text, Title } from "@tremor/react";
|
||||
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
|
||||
import { fetchToolByIdSS } from "@/lib/tools/fetchTools";
|
||||
import { DeleteToolButton } from "./DeleteToolButton";
|
||||
import { FiTool } from "react-icons/fi";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
|
||||
export default async function Page({ params }: { params: { toolId: string } }) {
|
||||
const tool = await fetchToolByIdSS(params.toolId);
|
||||
|
||||
let body;
|
||||
if (!tool) {
|
||||
body = (
|
||||
<div>
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg="Tool not found"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
} else {
|
||||
body = (
|
||||
<div className="w-full my-8">
|
||||
<div>
|
||||
<div>
|
||||
<Card>
|
||||
<ToolEditor tool={tool} />
|
||||
</Card>
|
||||
|
||||
<Title className="mt-12">Delete Tool</Title>
|
||||
<Text>Click the button below to permanently delete this tool.</Text>
|
||||
<div className="flex mt-6">
|
||||
<DeleteToolButton toolId={tool.id} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<BackButton />
|
||||
|
||||
<AdminPageTitle
|
||||
title="Edit Tool"
|
||||
icon={<FiTool size={32} className="my-auto" />}
|
||||
/>
|
||||
|
||||
{body}
|
||||
</div>
|
||||
);
|
||||
}
|
24
web/src/app/admin/tools/new/page.tsx
Normal file
24
web/src/app/admin/tools/new/page.tsx
Normal file
@ -0,0 +1,24 @@
|
||||
"use client";
|
||||
|
||||
import { ToolEditor } from "@/app/admin/tools/ToolEditor";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { Card } from "@tremor/react";
|
||||
import { FiTool } from "react-icons/fi";
|
||||
|
||||
export default function NewToolPage() {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<BackButton />
|
||||
|
||||
<AdminPageTitle
|
||||
title="Create Tool"
|
||||
icon={<FiTool size={32} className="my-auto" />}
|
||||
/>
|
||||
|
||||
<Card>
|
||||
<ToolEditor />
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
68
web/src/app/admin/tools/page.tsx
Normal file
68
web/src/app/admin/tools/page.tsx
Normal file
@ -0,0 +1,68 @@
|
||||
import { ToolsTable } from "./ToolsTable";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { FiPlusSquare, FiTool } from "react-icons/fi";
|
||||
import Link from "next/link";
|
||||
import { Divider, Text, Title } from "@tremor/react";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
|
||||
export default async function Page() {
|
||||
const toolResponse = await fetchSS("/tool");
|
||||
|
||||
if (!toolResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch tools - ${await toolResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const tools = (await toolResponse.json()) as ToolSnapshot[];
|
||||
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<AdminPageTitle
|
||||
icon={<FiTool size={32} className="my-auto" />}
|
||||
title="Tools"
|
||||
/>
|
||||
|
||||
<Text className="mb-2">
|
||||
Tools allow assistants to retrieve information or take actions.
|
||||
</Text>
|
||||
|
||||
<div>
|
||||
<Divider />
|
||||
|
||||
<Title>Create a Tool</Title>
|
||||
<Link
|
||||
href="/admin/tools/new"
|
||||
className="
|
||||
flex
|
||||
py-2
|
||||
px-4
|
||||
mt-2
|
||||
border
|
||||
border-border
|
||||
h-fit
|
||||
cursor-pointer
|
||||
hover:bg-hover
|
||||
text-sm
|
||||
w-40
|
||||
"
|
||||
>
|
||||
<div className="mx-auto flex">
|
||||
<FiPlusSquare className="my-auto mr-2" />
|
||||
New Tool
|
||||
</div>
|
||||
</Link>
|
||||
|
||||
<Divider />
|
||||
|
||||
<Title>Existing Tools</Title>
|
||||
<ToolsTable tools={tools} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -12,7 +12,8 @@ import {
|
||||
Message,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
ToolRunKickoff,
|
||||
ToolCallFinalResult,
|
||||
ToolCallMetadata,
|
||||
} from "./interfaces";
|
||||
import { ChatSidebar } from "./sessionSidebar/ChatSidebar";
|
||||
import { Persona } from "../admin/assistants/interfaces";
|
||||
@ -265,6 +266,7 @@ export function ChatPage({
|
||||
message: "",
|
||||
type: "system",
|
||||
files: [],
|
||||
toolCalls: [],
|
||||
parentMessageId: null,
|
||||
childrenMessageIds: [firstMessageId],
|
||||
latestChildMessageId: firstMessageId,
|
||||
@ -307,7 +309,6 @@ export function ChatPage({
|
||||
return newCompleteMessageMap;
|
||||
};
|
||||
const messageHistory = buildLatestMessageChain(completeMessageMap);
|
||||
const [currentTool, setCurrentTool] = useState<string | null>(null);
|
||||
const [isStreaming, setIsStreaming] = useState(false);
|
||||
|
||||
// uploaded files
|
||||
@ -535,6 +536,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: parentMessage?.messageId || null,
|
||||
},
|
||||
];
|
||||
@ -569,6 +571,7 @@ export function ChatPage({
|
||||
let aiMessageImages: FileDescriptor[] | null = null;
|
||||
let error: string | null = null;
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCalls: ToolCallMetadata[] = [];
|
||||
try {
|
||||
const lastSuccessfulMessageId =
|
||||
getLastSuccessfulMessageId(currMessageHistory);
|
||||
@ -627,7 +630,13 @@ export function ChatPage({
|
||||
}
|
||||
);
|
||||
} else if (Object.hasOwn(packet, "tool_name")) {
|
||||
setCurrentTool((packet as ToolRunKickoff).tool_name);
|
||||
toolCalls = [
|
||||
{
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
},
|
||||
];
|
||||
} else if (Object.hasOwn(packet, "error")) {
|
||||
error = (packet as StreamingError).error;
|
||||
} else if (Object.hasOwn(packet, "message_id")) {
|
||||
@ -657,6 +666,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: parentMessage?.messageId || null,
|
||||
childrenMessageIds: [newAssistantMessageId],
|
||||
latestChildMessageId: newAssistantMessageId,
|
||||
@ -670,6 +680,7 @@ export function ChatPage({
|
||||
documents: finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
parentMessageId: newUserMessageId,
|
||||
},
|
||||
]);
|
||||
@ -687,6 +698,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
{
|
||||
@ -694,6 +706,7 @@ export function ChatPage({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
files: aiMessageImages || [],
|
||||
toolCalls: [],
|
||||
parentMessageId: TEMP_USER_MESSAGE_ID,
|
||||
},
|
||||
],
|
||||
@ -1031,7 +1044,7 @@ export function ChatPage({
|
||||
citedDocuments={getCitedDocumentsFromMessage(
|
||||
message
|
||||
)}
|
||||
currentTool={currentTool}
|
||||
toolCall={message.toolCalls[0]}
|
||||
isComplete={
|
||||
i !== messageHistory.length - 1 ||
|
||||
!isStreaming
|
||||
|
@ -34,6 +34,18 @@ export interface FileDescriptor {
|
||||
isUploading?: boolean;
|
||||
}
|
||||
|
||||
export interface ToolCallMetadata {
|
||||
tool_name: string;
|
||||
tool_args: Record<string, any>;
|
||||
tool_result?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface ToolCallFinalResult {
|
||||
tool_name: string;
|
||||
tool_args: Record<string, any>;
|
||||
tool_result: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface ChatSession {
|
||||
id: number;
|
||||
name: string;
|
||||
@ -52,6 +64,7 @@ export interface Message {
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
toolCalls: ToolCallMetadata[];
|
||||
// for rebuilding the message tree
|
||||
parentMessageId: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
@ -79,6 +92,7 @@ export interface BackendMessage {
|
||||
time_sent: string;
|
||||
citations: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
tool_calls: ToolCallFinalResult[];
|
||||
}
|
||||
|
||||
export interface DocumentsResponse {
|
||||
@ -90,11 +104,6 @@ export interface ImageGenerationDisplay {
|
||||
file_ids: string[];
|
||||
}
|
||||
|
||||
export interface ToolRunKickoff {
|
||||
tool_name: string;
|
||||
tool_args: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface StreamingError {
|
||||
error: string;
|
||||
}
|
||||
|
@ -15,7 +15,7 @@ import {
|
||||
Message,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
ToolRunKickoff,
|
||||
ToolCallMetadata,
|
||||
} from "./interfaces";
|
||||
import { Persona } from "../admin/assistants/interfaces";
|
||||
import { ReadonlyURLSearchParams } from "next/navigation";
|
||||
@ -138,7 +138,7 @@ export async function* sendMessage({
|
||||
| DocumentsResponse
|
||||
| BackendMessage
|
||||
| ImageGenerationDisplay
|
||||
| ToolRunKickoff
|
||||
| ToolCallMetadata
|
||||
| StreamingError
|
||||
>(sendMessageResponse);
|
||||
}
|
||||
@ -384,6 +384,7 @@ export function processRawChatHistory(
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
: {}),
|
||||
toolCalls: messageInfo.tool_calls,
|
||||
parentMessageId: messageInfo.parent_message,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: messageInfo.latest_child_message,
|
||||
|
@ -9,6 +9,7 @@ import {
|
||||
FiEdit2,
|
||||
FiChevronRight,
|
||||
FiChevronLeft,
|
||||
FiTool,
|
||||
} from "react-icons/fi";
|
||||
import { FeedbackType } from "../types";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
@ -20,9 +21,12 @@ import { ThreeDots } from "react-loader-spinner";
|
||||
import { SkippedSearch } from "./SkippedSearch";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { CopyButton } from "@/components/CopyButton";
|
||||
import { ChatFileType, FileDescriptor } from "../interfaces";
|
||||
import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants";
|
||||
import { ToolRunningAnimation } from "../tools/ToolRunningAnimation";
|
||||
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
|
||||
import {
|
||||
IMAGE_GENERATION_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
} from "../tools/constants";
|
||||
import { ToolRunDisplay } from "../tools/ToolRunningAnimation";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { DocumentPreview } from "../files/documents/DocumentPreview";
|
||||
import { InMessageImage } from "../files/images/InMessageImage";
|
||||
@ -35,6 +39,11 @@ import Prism from "prismjs";
|
||||
import "prismjs/themes/prism-tomorrow.css";
|
||||
import "./custom-code-styles.css";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
IMAGE_GENERATION_TOOL_NAME,
|
||||
];
|
||||
|
||||
function FileDisplay({ files }: { files: FileDescriptor[] }) {
|
||||
const imageFiles = files.filter((file) => file.type === ChatFileType.IMAGE);
|
||||
const nonImgFiles = files.filter((file) => file.type !== ChatFileType.IMAGE);
|
||||
@ -77,7 +86,7 @@ export const AIMessage = ({
|
||||
query,
|
||||
personaName,
|
||||
citedDocuments,
|
||||
currentTool,
|
||||
toolCall,
|
||||
isComplete,
|
||||
hasDocs,
|
||||
handleFeedback,
|
||||
@ -93,7 +102,7 @@ export const AIMessage = ({
|
||||
query?: string;
|
||||
personaName?: string;
|
||||
citedDocuments?: [string, DanswerDocument][] | null;
|
||||
currentTool?: string | null;
|
||||
toolCall?: ToolCallMetadata;
|
||||
isComplete?: boolean;
|
||||
hasDocs?: boolean;
|
||||
handleFeedback?: (feedbackType: FeedbackType) => void;
|
||||
@ -133,28 +142,23 @@ export const AIMessage = ({
|
||||
content = trimIncompleteCodeSection(content);
|
||||
}
|
||||
|
||||
const loader =
|
||||
currentTool === IMAGE_GENERATION_TOOL_NAME ? (
|
||||
<div className="text-sm my-auto">
|
||||
<ToolRunningAnimation
|
||||
toolName="Generating images"
|
||||
toolLogo={<FiImage size={16} className="my-auto mr-1" />}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="text-sm my-auto">
|
||||
<ThreeDots
|
||||
height="30"
|
||||
width="50"
|
||||
color="#3b82f6"
|
||||
ariaLabel="grid-loading"
|
||||
radius="12.5"
|
||||
wrapperStyle={{}}
|
||||
wrapperClass=""
|
||||
visible={true}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
const shouldShowLoader =
|
||||
!toolCall ||
|
||||
(toolCall.tool_name === SEARCH_TOOL_NAME && query === undefined);
|
||||
const defaultLoader = shouldShowLoader ? (
|
||||
<div className="text-sm my-auto">
|
||||
<ThreeDots
|
||||
height="30"
|
||||
width="50"
|
||||
color="#3b82f6"
|
||||
ariaLabel="grid-loading"
|
||||
radius="12.5"
|
||||
wrapperStyle={{}}
|
||||
wrapperClass=""
|
||||
visible={true}
|
||||
/>
|
||||
</div>
|
||||
) : undefined;
|
||||
|
||||
return (
|
||||
<div className={"py-5 px-5 flex -mr-6 w-full"}>
|
||||
@ -189,28 +193,61 @@ export const AIMessage = ({
|
||||
</div>
|
||||
|
||||
<div className="w-message-xs 2xl:w-message-sm 3xl:w-message-default break-words mt-1 ml-8">
|
||||
{query !== undefined &&
|
||||
handleShowRetrieved !== undefined &&
|
||||
isCurrentlyShowingRetrieved !== undefined &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SearchSummary
|
||||
query={query}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
isCurrentlyShowingRetrieved={isCurrentlyShowingRetrieved}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
|
||||
<>
|
||||
{query !== undefined &&
|
||||
handleShowRetrieved !== undefined &&
|
||||
isCurrentlyShowingRetrieved !== undefined &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SearchSummary
|
||||
query={query}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
isCurrentlyShowingRetrieved={
|
||||
isCurrentlyShowingRetrieved
|
||||
}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SkippedSearch handleForceSearch={handleForceSearch} />
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
{toolCall &&
|
||||
!TOOLS_WITH_CUSTOM_HANDLING.includes(toolCall.tool_name) && (
|
||||
<div className="my-2">
|
||||
<ToolRunDisplay
|
||||
toolName={
|
||||
toolCall.tool_result && content
|
||||
? `Used "${toolCall.tool_name}"`
|
||||
: `Using "${toolCall.tool_name}"`
|
||||
}
|
||||
toolLogo={<FiTool size={15} className="my-auto mr-1" />}
|
||||
isRunning={!toolCall.tool_result || !content}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SkippedSearch handleForceSearch={handleForceSearch} />
|
||||
|
||||
{toolCall &&
|
||||
toolCall.tool_name === IMAGE_GENERATION_TOOL_NAME &&
|
||||
!toolCall.tool_result && (
|
||||
<div className="my-2">
|
||||
<ToolRunDisplay
|
||||
toolName={`Generating images`}
|
||||
toolLogo={<FiImage size={15} className="my-auto mr-1" />}
|
||||
isRunning={!toolCall.tool_result}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@ -260,7 +297,7 @@ export const AIMessage = ({
|
||||
)}
|
||||
</>
|
||||
) : isComplete ? null : (
|
||||
loader
|
||||
defaultLoader
|
||||
)}
|
||||
{citedDocuments && citedDocuments.length > 0 && (
|
||||
<div className="mt-2">
|
||||
|
@ -1,16 +1,18 @@
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
|
||||
export function ToolRunningAnimation({
|
||||
export function ToolRunDisplay({
|
||||
toolName,
|
||||
toolLogo,
|
||||
isRunning,
|
||||
}: {
|
||||
toolName: string;
|
||||
toolLogo: JSX.Element;
|
||||
toolLogo?: JSX.Element;
|
||||
isRunning: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div className="text-sm my-auto flex">
|
||||
<div className="text-sm text-subtle my-auto flex">
|
||||
{toolLogo}
|
||||
<LoadingAnimation text={toolName} />
|
||||
{isRunning ? <LoadingAnimation text={toolName} /> : toolName}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
@ -2,15 +2,12 @@ import { Header } from "@/components/header/Header";
|
||||
import { AdminSidebar } from "@/components/admin/connectors/AdminSidebar";
|
||||
import {
|
||||
NotebookIcon,
|
||||
KeyIcon,
|
||||
UsersIcon,
|
||||
ThumbsUpIcon,
|
||||
BookmarkIcon,
|
||||
CPUIcon,
|
||||
ZoomInIcon,
|
||||
RobotIcon,
|
||||
ConnectorIcon,
|
||||
SlackIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { User } from "@/lib/types";
|
||||
import {
|
||||
@ -19,13 +16,7 @@ import {
|
||||
getCurrentUserSS,
|
||||
} from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
import {
|
||||
FiCpu,
|
||||
FiLayers,
|
||||
FiPackage,
|
||||
FiSettings,
|
||||
FiSlack,
|
||||
} from "react-icons/fi";
|
||||
import { FiCpu, FiPackage, FiSettings, FiSlack, FiTool } from "react-icons/fi";
|
||||
|
||||
export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
const tasks = [getAuthTypeMetadataSS(), getCurrentUserSS()];
|
||||
@ -142,6 +133,15 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
),
|
||||
link: "/admin/bot",
|
||||
},
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<FiTool size={18} className="my-auto" />
|
||||
<div className="ml-1">Tools</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/tools",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
|
@ -43,6 +43,10 @@ export function TextFormField({
|
||||
disabled = false,
|
||||
autoCompleteDisabled = true,
|
||||
error,
|
||||
defaultHeight,
|
||||
isCode = false,
|
||||
fontSize,
|
||||
hideError,
|
||||
}: {
|
||||
name: string;
|
||||
label: string;
|
||||
@ -54,7 +58,16 @@ export function TextFormField({
|
||||
disabled?: boolean;
|
||||
autoCompleteDisabled?: boolean;
|
||||
error?: string;
|
||||
defaultHeight?: string;
|
||||
isCode?: boolean;
|
||||
fontSize?: "text-sm" | "text-base" | "text-lg";
|
||||
hideError?: boolean;
|
||||
}) {
|
||||
let heightString = defaultHeight || "";
|
||||
if (isTextArea && !heightString) {
|
||||
heightString = "h-28";
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<Label>{label}</Label>
|
||||
@ -64,18 +77,19 @@ export function TextFormField({
|
||||
type={type}
|
||||
name={name}
|
||||
id={name}
|
||||
className={
|
||||
`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mt-1
|
||||
${isTextArea ? " h-28" : ""}
|
||||
` + (disabled ? " bg-background-strong" : " bg-background-emphasis")
|
||||
}
|
||||
className={`
|
||||
border
|
||||
border-border
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mt-1
|
||||
${heightString}
|
||||
${fontSize}
|
||||
${disabled ? " bg-background-strong" : " bg-background-emphasis"}
|
||||
${isCode ? " font-mono" : ""}
|
||||
`}
|
||||
disabled={disabled}
|
||||
placeholder={placeholder}
|
||||
autoComplete={autoCompleteDisabled ? "off" : undefined}
|
||||
@ -84,11 +98,13 @@ export function TextFormField({
|
||||
{error ? (
|
||||
<ManualErrorMessage>{error}</ManualErrorMessage>
|
||||
) : (
|
||||
<ErrorMessage
|
||||
name={name}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
!hideError && (
|
||||
<ErrorMessage
|
||||
name={name}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
111
web/src/lib/tools/edit.ts
Normal file
111
web/src/lib/tools/edit.ts
Normal file
@ -0,0 +1,111 @@
|
||||
import { MethodSpec, ToolSnapshot } from "./interfaces";
|
||||
|
||||
interface ApiResponse<T> {
|
||||
data: T | null;
|
||||
error: string | null;
|
||||
}
|
||||
|
||||
export async function createCustomTool(toolData: {
|
||||
name: string;
|
||||
description?: string;
|
||||
definition: Record<string, any>;
|
||||
}): Promise<ApiResponse<ToolSnapshot>> {
|
||||
try {
|
||||
const response = await fetch("/api/admin/tool/custom", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(toolData),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorDetail = (await response.json()).detail;
|
||||
return { data: null, error: `Failed to create tool: ${errorDetail}` };
|
||||
}
|
||||
|
||||
const tool: ToolSnapshot = await response.json();
|
||||
return { data: tool, error: null };
|
||||
} catch (error) {
|
||||
console.error("Error creating tool:", error);
|
||||
return { data: null, error: "Error creating tool" };
|
||||
}
|
||||
}
|
||||
|
||||
export async function updateCustomTool(
|
||||
toolId: number,
|
||||
toolData: {
|
||||
name?: string;
|
||||
description?: string;
|
||||
definition?: Record<string, any>;
|
||||
}
|
||||
): Promise<ApiResponse<ToolSnapshot>> {
|
||||
try {
|
||||
const response = await fetch(`/api/admin/tool/custom/${toolId}`, {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(toolData),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorDetail = (await response.json()).detail;
|
||||
return { data: null, error: `Failed to update tool: ${errorDetail}` };
|
||||
}
|
||||
|
||||
const updatedTool: ToolSnapshot = await response.json();
|
||||
return { data: updatedTool, error: null };
|
||||
} catch (error) {
|
||||
console.error("Error updating tool:", error);
|
||||
return { data: null, error: "Error updating tool" };
|
||||
}
|
||||
}
|
||||
|
||||
export async function deleteCustomTool(
|
||||
toolId: number
|
||||
): Promise<ApiResponse<boolean>> {
|
||||
try {
|
||||
const response = await fetch(`/api/admin/tool/custom/${toolId}`, {
|
||||
method: "DELETE",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorDetail = (await response.json()).detail;
|
||||
return { data: false, error: `Failed to delete tool: ${errorDetail}` };
|
||||
}
|
||||
|
||||
return { data: true, error: null };
|
||||
} catch (error) {
|
||||
console.error("Error deleting tool:", error);
|
||||
return { data: false, error: "Error deleting tool" };
|
||||
}
|
||||
}
|
||||
|
||||
export async function validateToolDefinition(toolData: {
|
||||
definition: Record<string, any>;
|
||||
}): Promise<ApiResponse<MethodSpec[]>> {
|
||||
try {
|
||||
const response = await fetch(`/api/admin/tool/custom/validate`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(toolData),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorDetail = (await response.json()).detail;
|
||||
return { data: null, error: errorDetail };
|
||||
}
|
||||
|
||||
const responseJson = await response.json();
|
||||
return { data: responseJson.methods, error: null };
|
||||
} catch (error) {
|
||||
console.error("Error validating tool:", error);
|
||||
return { data: null, error: "Unexpected error validating tool definition" };
|
||||
}
|
||||
}
|
@ -14,3 +14,21 @@ export async function fetchToolsSS(): Promise<ToolSnapshot[] | null> {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
export async function fetchToolByIdSS(
|
||||
toolId: string
|
||||
): Promise<ToolSnapshot | null> {
|
||||
try {
|
||||
const response = await fetchSS(`/tool/${toolId}`);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch tool with ID ${toolId}: ${await response.text()}`
|
||||
);
|
||||
}
|
||||
const tool: ToolSnapshot = await response.json();
|
||||
return tool;
|
||||
} catch (error) {
|
||||
console.error(`Error fetching tool with ID ${toolId}:`, error);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
@ -2,5 +2,21 @@ export interface ToolSnapshot {
|
||||
id: number;
|
||||
name: string;
|
||||
description: string;
|
||||
|
||||
// only specified for Custom Tools. OpenAPI schema which represents
|
||||
// the tool's API.
|
||||
definition: Record<string, any> | null;
|
||||
|
||||
// only specified for Custom Tools. ID of the tool in the codebase.
|
||||
in_code_tool_id: string | null;
|
||||
}
|
||||
|
||||
export interface MethodSpec {
|
||||
/* Defines a single method that is part of a custom tool. Each method maps to a single
|
||||
action that the LLM can choose to take. */
|
||||
name: string;
|
||||
summary: string;
|
||||
path: string;
|
||||
method: string;
|
||||
spec: Record<string, any>;
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user