diff --git a/backend/alembic/versions/3879338f8ba1_add_tool_table.py b/backend/alembic/versions/3879338f8ba1_add_tool_table.py new file mode 100644 index 000000000000..242eb6645b5b --- /dev/null +++ b/backend/alembic/versions/3879338f8ba1_add_tool_table.py @@ -0,0 +1,45 @@ +"""Add tool table + +Revision ID: 3879338f8ba1 +Revises: f1c6478c3fd8 +Create Date: 2024-05-11 16:11:23.718084 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "3879338f8ba1" +down_revision = "f1c6478c3fd8" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "tool", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("in_code_tool_id", sa.String(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "persona__tool", + sa.Column("persona_id", sa.Integer(), nullable=False), + sa.Column("tool_id", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["persona_id"], + ["persona.id"], + ), + sa.ForeignKeyConstraint( + ["tool_id"], + ["tool.id"], + ), + sa.PrimaryKeyConstraint("persona_id", "tool_id"), + ) + + +def downgrade() -> None: + op.drop_table("persona__tool") + op.drop_table("tool") diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 9ce76e0d3617..03e770bd5f5e 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -23,8 +23,8 @@ from fastapi_users.authentication import CookieTransport from fastapi_users.authentication import Strategy from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy -from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.openapi import OpenAPIResponseType +from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from sqlalchemy.orm import Session from danswer.auth.schemas import UserCreate diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index 66b0f37de5c2..f4b0b2e02c5e 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -1,5 +1,6 @@ import re from collections.abc import Sequence +from typing import cast from sqlalchemy.orm import Session @@ -7,6 +8,7 @@ from danswer.chat.models import CitationInfo from danswer.chat.models import LlmDoc from danswer.db.chat import get_chat_messages_by_session from danswer.db.models import ChatMessage +from danswer.llm.answering.models import PreviousMessage from danswer.search.models import InferenceChunk from danswer.search.models import InferenceSection from danswer.utils.logger import setup_logger @@ -88,7 +90,7 @@ def create_chat_chain( def combine_message_chain( - messages: list[ChatMessage], + messages: list[ChatMessage] | list[PreviousMessage], token_limit: int, msg_limit: int | None = None, ) -> str: @@ -99,7 +101,7 @@ def combine_message_chain( if msg_limit is not None: messages = messages[-msg_limit:] - for message in reversed(messages): + for message in cast(list[ChatMessage] | list[PreviousMessage], reversed(messages)): message_token_count = message.token_count if total_token_count + message_token_count > token_limit: diff --git a/backend/danswer/chat/models.py b/backend/danswer/chat/models.py index d2dd9f31fafb..8fa5eecaeebe 100644 --- a/backend/danswer/chat/models.py +++ b/backend/danswer/chat/models.py @@ -102,8 +102,17 @@ class QAResponse(SearchResponse, DanswerAnswer): error_msg: str | None = None +class ImageGenerationDisplay(BaseModel): + file_ids: list[str] + + AnswerQuestionPossibleReturn = ( - DanswerAnswerPiece | DanswerQuotes | CitationInfo | DanswerContexts | StreamingError + DanswerAnswerPiece + | DanswerQuotes + | CitationInfo + | DanswerContexts + | ImageGenerationDisplay + | StreamingError ) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 922badf68d88..7ee92d79c375 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -6,9 +6,9 @@ from typing import cast from sqlalchemy.orm import Session from danswer.chat.chat_utils import create_chat_chain -from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece +from danswer.chat.models import ImageGenerationDisplay from danswer.chat.models import LlmDoc from danswer.chat.models import LLMRelevanceFilterResponse from danswer.chat.models import QADocsResponse @@ -27,11 +27,14 @@ from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session_context_manager +from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.file_store.models import ChatFileType +from danswer.file_store.models import FileDescriptor from danswer.file_store.utils import load_all_chat_files +from danswer.file_store.utils import save_files_from_urls from danswer.llm.answering.answer import Answer from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import CitationConfig @@ -41,16 +44,25 @@ from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_llm_for_persona from danswer.llm.utils import get_default_llm_tokenizer -from danswer.search.models import OptionalSearchSetting -from danswer.search.models import SearchRequest -from danswer.search.pipeline import SearchPipeline +from danswer.search.enums import OptionalSearchSetting from danswer.search.retrieval.search_runner import inference_documents_from_ids from danswer.search.utils import chunks_or_sections_to_search_docs -from danswer.secondary_llm_flows.choose_search import check_if_need_search -from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.utils import get_json_line +from danswer.tools.factory import get_tool_cls +from danswer.tools.force import ForceUseTool +from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID +from danswer.tools.images.image_generation_tool import ImageGenerationResponse +from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID +from danswer.tools.search.search_tool import SearchResponseSummary +from danswer.tools.search.search_tool import SearchTool +from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.tools.utils import compute_all_tool_tokens +from danswer.tools.utils import explicit_tool_calling_supported from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -77,14 +89,78 @@ def translate_citations( return citation_to_saved_doc_id_map -ChatPacketStream = Iterator[ +def _handle_search_tool_response_summary( + packet: ToolResponse, + db_session: Session, + selected_search_docs: list[DbSearchDoc] | None, +) -> tuple[QADocsResponse, list[DbSearchDoc]]: + response_sumary = cast(SearchResponseSummary, packet.response) + + if not selected_search_docs: + top_docs = chunks_or_sections_to_search_docs(response_sumary.top_sections) + reference_db_search_docs = [ + create_db_search_doc(server_search_doc=top_doc, db_session=db_session) + for top_doc in top_docs + ] + else: + reference_db_search_docs = selected_search_docs + + response_docs = [ + translate_db_search_doc_to_server_search_doc(db_search_doc) + for db_search_doc in reference_db_search_docs + ] + return ( + QADocsResponse( + rephrased_query=response_sumary.rephrased_query, + top_documents=response_docs, + predicted_flow=response_sumary.predicted_flow, + predicted_search=response_sumary.predicted_search, + applied_source_filters=response_sumary.final_filters.source_type, + applied_time_cutoff=response_sumary.final_filters.time_cutoff, + recency_bias_multiplier=response_sumary.recency_bias_multiplier, + ), + reference_db_search_docs, + ) + + +def _check_should_force_search( + new_msg_req: CreateChatMessageRequest, +) -> ForceUseTool | None: + if ( + new_msg_req.query_override + or ( + new_msg_req.retrieval_options + and new_msg_req.retrieval_options.run_search == OptionalSearchSetting.ALWAYS + ) + or new_msg_req.search_doc_ids + ): + args = ( + {"query": new_msg_req.query_override} + if new_msg_req.query_override + else None + ) + # if we are using selected docs, just put something here so the Tool doesn't need + # to build its own args via an LLM call + if new_msg_req.search_doc_ids: + args = {"query": new_msg_req.message} + + return ForceUseTool( + tool_name=SearchTool.name(), + args=args, + ) + return None + + +ChatPacket = ( StreamingError | QADocsResponse | LLMRelevanceFilterResponse | ChatMessageDetail | DanswerAnswerPiece | CitationInfo -] + | ImageGenerationDisplay +) +ChatPacketStream = Iterator[ChatPacket] def stream_chat_message_objects( @@ -123,11 +199,9 @@ def stream_chat_message_objects( reference_doc_ids = new_msg_req.search_doc_ids retrieval_options = new_msg_req.retrieval_options persona = chat_session.persona - query_override = new_msg_req.query_override - # After this section, no_ai_answer is represented by prompt being None prompt_id = new_msg_req.prompt_id - if prompt_id is None and persona.prompts and not new_msg_req.no_ai_answer: + if prompt_id is None and persona.prompts: prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id if reference_doc_ids is None and retrieval_options is None: @@ -140,7 +214,7 @@ def stream_chat_message_objects( persona, new_msg_req.llm_override or chat_session.llm_override ) except GenAIDisabledException: - llm = None + raise RuntimeError("LLM is disabled. Can't use chat flow without LLM.") llm_tokenizer = get_default_llm_tokenizer() llm_tokenizer_encode_func = cast( @@ -214,25 +288,8 @@ def stream_chat_message_objects( file for file in files if file.file_id in new_msg_req.file_ids ] - run_search = False - # Retrieval options are only None if reference_doc_ids are provided - # Also don't perform search if the user uploaded at least one file - just use the files - if ( - retrieval_options is not None - and persona.num_chunks != 0 - and not new_msg_req.file_ids - ): - if retrieval_options.run_search == OptionalSearchSetting.ALWAYS: - run_search = True - elif retrieval_options.run_search == OptionalSearchSetting.NEVER: - run_search = False - else: - run_search = check_if_need_search( - query_message=final_msg, history=history_msgs, llm=llm - ) - - rephrased_query = None - llm_relevance_list = None + selected_db_search_docs = None + selected_llm_docs: list[LlmDoc] | None = None if reference_doc_ids: identifier_tuples = get_doc_query_identifiers_from_model( search_doc_ids=reference_doc_ids, @@ -243,7 +300,7 @@ def stream_chat_message_objects( # Generates full documents currently # May extend to include chunk ranges - llm_docs: list[LlmDoc] = inference_documents_from_ids( + selected_llm_docs = inference_documents_from_ids( doc_identifiers=identifier_tuples, document_index=document_index, ) @@ -258,66 +315,11 @@ def stream_chat_message_objects( for doc_id in reference_doc_ids ] - reference_db_search_docs = [ + selected_db_search_docs = [ db_sd for db_sd in db_search_docs_or_none if db_sd ] - elif run_search: - rephrased_query = ( - history_based_query_rephrase( - query_message=final_msg, history=history_msgs, llm=llm - ) - if query_override is None - else query_override - ) - - search_pipeline = SearchPipeline( - search_request=SearchRequest( - query=rephrased_query, - human_selected_filters=retrieval_options.filters - if retrieval_options - else None, - persona=persona, - offset=retrieval_options.offset if retrieval_options else None, - limit=retrieval_options.limit if retrieval_options else None, - chunks_above=new_msg_req.chunks_above, - chunks_below=new_msg_req.chunks_below, - full_doc=new_msg_req.full_doc, - ), - user=user, - db_session=db_session, - ) - - top_sections = search_pipeline.reranked_sections - top_docs = chunks_or_sections_to_search_docs(top_sections) - - reference_db_search_docs = [ - create_db_search_doc(server_search_doc=top_doc, db_session=db_session) - for top_doc in top_docs - ] - - response_docs = [ - translate_db_search_doc_to_server_search_doc(db_search_doc) - for db_search_doc in reference_db_search_docs - ] - - initial_response = QADocsResponse( - rephrased_query=rephrased_query, - top_documents=response_docs, - predicted_flow=search_pipeline.predicted_flow, - predicted_search=search_pipeline.predicted_search_type, - applied_source_filters=search_pipeline.search_query.filters.source_type, - applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, - recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, - ) - yield initial_response - - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI - llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=search_pipeline.relevant_chunk_indices - ) - yield llm_relevance_filtering_response - + else: document_pruning_config = DocumentPruningConfig( max_chunks=int( persona.num_chunks @@ -325,19 +327,10 @@ def stream_chat_message_objects( else default_num_chunks ), max_window_percentage=max_document_percentage, - use_sections=search_pipeline.ran_merge_chunk, + use_sections=new_msg_req.chunks_above > 0 + or new_msg_req.chunks_below > 0, ) - llm_docs = [ - llm_doc_from_inference_section(section) for section in top_sections - ] - llm_relevance_list = search_pipeline.section_relevance_list - - else: - llm_docs = [] - reference_db_search_docs = None - document_pruning_config = DocumentPruningConfig() - # Cannot determine these without the LLM step or breaking out early partial_response = partial( create_new_chat_message, @@ -345,64 +338,139 @@ def stream_chat_message_objects( parent_message=final_msg, prompt_id=prompt_id, # message=, - rephrased_query=rephrased_query, + # rephrased_query=, # token_count=, message_type=MessageType.ASSISTANT, # error=, - reference_docs=reference_db_search_docs, + # reference_docs=, db_session=db_session, commit=True, ) - # If no prompt is provided, this is interpreted as not wanting an AI Answer - # Simply provide/save the retrieval results - if final_msg.prompt is None: - gen_ai_response_message = partial_response( - message="", - token_count=0, - citations=None, - error=None, - ) - msg_detail_response = translate_db_message_to_chat_message_detail( - gen_ai_response_message - ) + if not final_msg.prompt: + raise RuntimeError("No Prompt found") - yield msg_detail_response + prompt_config = PromptConfig.from_model( + final_msg.prompt, + prompt_override=( + new_msg_req.prompt_override or chat_session.prompt_override + ), + ) - # Stop here after saving message details, the above still needs to be sent for the - # message id to send the next follow-up message - return + persona_tool_classes = [ + get_tool_cls(tool, db_session) for tool in persona.tools + ] + + # factor in tool definition size when pruning + document_pruning_config.tool_num_tokens = compute_all_tool_tokens( + persona_tool_classes + ) + document_pruning_config.using_tool_message = explicit_tool_calling_supported( + llm.config.model_provider, llm.config.model_name + ) + + # NOTE: for now, only support SearchTool and ImageGenerationTool + # in the future, will support arbitrary user-defined tools + search_tool: SearchTool | None = None + tools: list[Tool] = [] + for tool_cls in persona_tool_classes: + if tool_cls.__name__ == SearchTool.__name__: + search_tool = SearchTool( + db_session=db_session, + user=user, + persona=persona, + retrieval_options=retrieval_options, + prompt_config=prompt_config, + llm_config=llm.config, + pruning_config=document_pruning_config, + selected_docs=selected_llm_docs, + chunks_above=new_msg_req.chunks_above, + chunks_below=new_msg_req.chunks_below, + full_doc=new_msg_req.full_doc, + ) + tools.append(search_tool) + elif tool_cls.__name__ == ImageGenerationTool.__name__: + dalle_key = None + if llm and llm.config.api_key and llm.config.model_provider == "openai": + dalle_key = llm.config.api_key + else: + llm_providers = fetch_existing_llm_providers(db_session) + openai_provider = next( + iter( + [ + llm_provider + for llm_provider in llm_providers + if llm_provider.provider == "openai" + ] + ), + None, + ) + if not openai_provider or not openai_provider.api_key: + raise ValueError( + "Image generation tool requires an OpenAI API key" + ) + dalle_key = openai_provider.api_key + tools.append(ImageGenerationTool(api_key=dalle_key)) # LLM prompt building, response capturing, etc. answer = Answer( question=final_msg.message, - docs=llm_docs, latest_query_files=latest_query_files, answer_style_config=AnswerStyleConfig( citation_config=CitationConfig( - all_docs_useful=reference_db_search_docs is not None + all_docs_useful=selected_db_search_docs is not None ), document_pruning_config=document_pruning_config, ), - prompt_config=PromptConfig.from_model( - final_msg.prompt, - prompt_override=( - new_msg_req.prompt_override or chat_session.prompt_override - ), - ), + prompt_config=prompt_config, llm=( llm or get_llm_for_persona( persona, new_msg_req.llm_override or chat_session.llm_override ) ), - doc_relevance_list=llm_relevance_list, message_history=[ PreviousMessage.from_chat_message(msg, files) for msg in history_msgs ], + tools=tools, + force_use_tool=_check_should_force_search(new_msg_req), ) - # generator will not include quotes, so we can cast - yield from cast(ChatPacketStream, answer.processed_streamed_output) + + reference_db_search_docs = None + qa_docs_response = None + ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images + for packet in answer.processed_streamed_output: + if isinstance(packet, ToolResponse): + if packet.id == SEARCH_RESPONSE_SUMMARY_ID: + ( + qa_docs_response, + reference_db_search_docs, + ) = _handle_search_tool_response_summary( + packet, db_session, selected_db_search_docs + ) + yield qa_docs_response + elif packet.id == SECTION_RELEVANCE_LIST_ID: + yield LLMRelevanceFilterResponse( + relevant_chunk_indices=packet.response + ) + elif packet.id == IMAGE_GENERATION_RESPONSE_ID: + img_generation_response = cast( + list[ImageGenerationResponse], packet.response + ) + + file_ids = save_files_from_urls( + [img.url for img in img_generation_response] + ) + ai_message_files = [ + FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE) + for file_id in file_ids + ] + yield ImageGenerationDisplay( + file_ids=[str(file_id) for file_id in file_ids] + ) + + else: + yield cast(ChatPacket, packet) except Exception as e: logger.exception(e) @@ -424,6 +492,11 @@ def stream_chat_message_objects( # Saving Gen AI answer and responding with message info gen_ai_response_message = partial_response( message=answer.llm_answer, + rephrased_query=( + qa_docs_response.rephrased_query if qa_docs_response else None + ), + reference_docs=reference_db_search_docs, + files=ai_message_files, token_count=len(llm_tokenizer_encode_func(answer.llm_answer)), citations=db_citations, error=None, diff --git a/backend/danswer/db/auth.py b/backend/danswer/db/auth.py index 1883e8abdda9..6d726c2f92a9 100644 --- a/backend/danswer/db/auth.py +++ b/backend/danswer/db/auth.py @@ -3,8 +3,8 @@ from typing import Any from typing import Dict from fastapi import Depends -from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.models import UP +from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase from sqlalchemy import func from sqlalchemy.ext.asyncio import AsyncSession diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 42fa26ffd342..cef17916273a 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -28,6 +28,7 @@ from danswer.db.models import Prompt from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage +from danswer.db.models import Tool from danswer.db.models import User from danswer.db.models import User__UserGroup from danswer.file_store.models import FileDescriptor @@ -508,6 +509,7 @@ def upsert_persona( starter_messages: list[StarterMessage] | None, is_public: bool, db_session: Session, + tool_ids: list[int] | None = None, persona_id: int | None = None, default_persona: bool = False, commit: bool = True, @@ -519,6 +521,13 @@ def upsert_persona( persona_name=name, user=user, db_session=db_session ) + # Fetch and attach tools by IDs + tools = None + if tool_ids is not None: + tools = db_session.query(Tool).filter(Tool.id.in_(tool_ids)).all() + if not tools and tool_ids: + raise ValueError("Tools not found") + if persona: if not default_persona and persona.default_persona: raise ValueError("Cannot update default persona with non-default.") @@ -546,6 +555,9 @@ def upsert_persona( persona.prompts.clear() persona.prompts = prompts + if tools is not None: + persona.tools = tools + else: persona = Persona( id=persona_id, @@ -563,6 +575,7 @@ def upsert_persona( llm_model_provider_override=llm_model_provider_override, llm_model_version_override=llm_model_version_override, starter_messages=starter_messages, + tools=tools or [], ) db_session.add(persona) diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 8eb04d57ecca..1991aee5a95b 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -2,7 +2,6 @@ import datetime import json from enum import Enum as PyEnum from typing import Any -from typing import List from typing import Literal from typing import NotRequired from typing import Optional @@ -102,24 +101,24 @@ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): class User(SQLAlchemyBaseUserTableUUID, Base): - oauth_accounts: Mapped[List[OAuthAccount]] = relationship( + oauth_accounts: Mapped[list[OAuthAccount]] = relationship( "OAuthAccount", lazy="joined" ) role: Mapped[UserRole] = mapped_column( Enum(UserRole, native_enum=False, default=UserRole.BASIC) ) - credentials: Mapped[List["Credential"]] = relationship( + credentials: Mapped[list["Credential"]] = relationship( "Credential", back_populates="user", lazy="joined" ) - chat_sessions: Mapped[List["ChatSession"]] = relationship( + chat_sessions: Mapped[list["ChatSession"]] = relationship( "ChatSession", back_populates="user" ) - chat_folders: Mapped[List["ChatFolder"]] = relationship( + chat_folders: Mapped[list["ChatFolder"]] = relationship( "ChatFolder", back_populates="user" ) - prompts: Mapped[List["Prompt"]] = relationship("Prompt", back_populates="user") + prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user") # Personas owned by this user - personas: Mapped[List["Persona"]] = relationship("Persona", back_populates="user") + personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): @@ -224,6 +223,13 @@ class Document__Tag(Base): tag_id: Mapped[int] = mapped_column(ForeignKey("tag.id"), primary_key=True) +class Persona__Tool(Base): + __tablename__ = "persona__tool" + + persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) + tool_id: Mapped[int] = mapped_column(ForeignKey("tool.id"), primary_key=True) + + """ Documents/Indexing Tables """ @@ -274,7 +280,7 @@ class ConnectorCredentialPair(Base): credential: Mapped["Credential"] = relationship( "Credential", back_populates="connectors" ) - document_sets: Mapped[List["DocumentSet"]] = relationship( + document_sets: Mapped[list["DocumentSet"]] = relationship( "DocumentSet", secondary=DocumentSet__ConnectorCredentialPair.__table__, back_populates="connector_credential_pairs", @@ -315,7 +321,7 @@ class Document(Base): ) # TODO if more sensitive data is added here for display, make sure to add user/group permission - retrieval_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship( + retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="document" ) tags = relationship( @@ -369,15 +375,15 @@ class Connector(Base): ) disabled: Mapped[bool] = mapped_column(Boolean, default=False) - credentials: Mapped[List["ConnectorCredentialPair"]] = relationship( + credentials: Mapped[list["ConnectorCredentialPair"]] = relationship( "ConnectorCredentialPair", back_populates="connector", cascade="all, delete-orphan", ) documents_by_connector: Mapped[ - List["DocumentByConnectorCredentialPair"] + list["DocumentByConnectorCredentialPair"] ] = relationship("DocumentByConnectorCredentialPair", back_populates="connector") - index_attempts: Mapped[List["IndexAttempt"]] = relationship( + index_attempts: Mapped[list["IndexAttempt"]] = relationship( "IndexAttempt", back_populates="connector" ) @@ -397,15 +403,15 @@ class Credential(Base): DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) - connectors: Mapped[List["ConnectorCredentialPair"]] = relationship( + connectors: Mapped[list["ConnectorCredentialPair"]] = relationship( "ConnectorCredentialPair", back_populates="credential", cascade="all, delete-orphan", ) documents_by_credential: Mapped[ - List["DocumentByConnectorCredentialPair"] + list["DocumentByConnectorCredentialPair"] ] = relationship("DocumentByConnectorCredentialPair", back_populates="credential") - index_attempts: Mapped[List["IndexAttempt"]] = relationship( + index_attempts: Mapped[list["IndexAttempt"]] = relationship( "IndexAttempt", back_populates="credential" ) user: Mapped[User | None] = relationship("User", back_populates="credentials") @@ -425,7 +431,7 @@ class EmbeddingModel(Base): ) index_name: Mapped[str] = mapped_column(String) - index_attempts: Mapped[List["IndexAttempt"]] = relationship( + index_attempts: Mapped[list["IndexAttempt"]] = relationship( "IndexAttempt", back_populates="embedding_model" ) @@ -644,7 +650,7 @@ class ChatSession(Base): folder: Mapped["ChatFolder"] = relationship( "ChatFolder", back_populates="chat_sessions" ) - messages: Mapped[List["ChatMessage"]] = relationship( + messages: Mapped[list["ChatMessage"]] = relationship( "ChatMessage", back_populates="chat_session", cascade="delete" ) persona: Mapped["Persona"] = relationship("Persona") @@ -691,10 +697,10 @@ class ChatMessage(Base): chat_session: Mapped[ChatSession] = relationship("ChatSession") prompt: Mapped[Optional["Prompt"]] = relationship("Prompt") - chat_message_feedbacks: Mapped[List["ChatMessageFeedback"]] = relationship( + chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship( "ChatMessageFeedback", back_populates="chat_message" ) - document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship( + document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="chat_message" ) search_docs: Mapped[list["SearchDoc"]] = relationship( @@ -716,7 +722,7 @@ class ChatFolder(Base): display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0) user: Mapped[User] = relationship("User", back_populates="chat_folders") - chat_sessions: Mapped[List["ChatSession"]] = relationship( + chat_sessions: Mapped[list["ChatSession"]] = relationship( "ChatSession", back_populates="folder" ) @@ -865,6 +871,24 @@ class Prompt(Base): ) +class Tool(Base): + __tablename__ = "tool" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + name: Mapped[str] = mapped_column(String, nullable=False) + description: Mapped[str] = mapped_column(Text, nullable=True) + # ID of the tool in the codebase, only applies for in-code tools. + # tools defiend via the UI will have this as None + in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True) + + # Relationship to Persona through the association table + personas: Mapped[list["Persona"]] = relationship( + "Persona", + secondary=Persona__Tool.__table__, + back_populates="tools", + ) + + class StarterMessage(TypedDict): """NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column in Postgres""" @@ -933,6 +957,11 @@ class Persona(Base): secondary=Persona__DocumentSet.__table__, back_populates="personas", ) + tools: Mapped[list[Tool]] = relationship( + "Tool", + secondary=Persona__Tool.__table__, + back_populates="personas", + ) # Owner user: Mapped[User | None] = relationship("User", back_populates="personas") # Other users with access diff --git a/backend/danswer/db/persona.py b/backend/danswer/db/persona.py index a2dece7fb4da..6fa65135dcaa 100644 --- a/backend/danswer/db/persona.py +++ b/backend/danswer/db/persona.py @@ -70,6 +70,7 @@ def create_update_persona( llm_filter_extraction=create_persona_request.llm_filter_extraction, recency_bias=create_persona_request.recency_bias, prompts=prompts, + tool_ids=create_persona_request.tool_ids, document_sets=document_sets, llm_model_provider_override=create_persona_request.llm_model_provider_override, llm_model_version_override=create_persona_request.llm_model_version_override, diff --git a/backend/danswer/file_store/utils.py b/backend/danswer/file_store/utils.py index 28d14db8af69..e487e8f7cb19 100644 --- a/backend/danswer/file_store/utils.py +++ b/backend/danswer/file_store/utils.py @@ -1,8 +1,12 @@ +from io import BytesIO from typing import cast from uuid import UUID +from uuid import uuid4 +import requests from sqlalchemy.orm import Session +from danswer.db.engine import get_session_context_manager from danswer.db.models import ChatMessage from danswer.file_store.file_store import get_default_file_store from danswer.file_store.models import InMemoryChatFile @@ -38,3 +42,25 @@ def load_all_chat_files( ), ) return files + + +def save_file_from_url(url: str) -> UUID: + """NOTE: using multiple sessions here, since this is often called + using multithreading. In practice, sharing a session has resulted in + weird errors.""" + with get_session_context_manager() as db_session: + response = requests.get(url) + response.raise_for_status() + + file_id = uuid4() + file_name = build_chat_file_name(file_id) + + file_io = BytesIO(response.content) + file_store = get_default_file_store(db_session) + file_store.save_file(file_name=file_name, content=file_io) + return file_id + + +def save_files_from_urls(urls: list[str]) -> list[UUID]: + funcs = [(save_file_from_url, (url,)) for url in urls] + return run_functions_tuples_in_parallel(funcs) diff --git a/backend/danswer/llm/answering/answer.py b/backend/danswer/llm/answering/answer.py index d5917bca3ebd..b3a88421e44d 100644 --- a/backend/danswer/llm/answering/answer.py +++ b/backend/danswer/llm/answering/answer.py @@ -1,24 +1,29 @@ from collections.abc import Iterator from typing import cast +from uuid import uuid4 from langchain.schema.messages import BaseMessage +from langchain_core.messages import AIMessageChunk +from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.models import AnswerQuestionPossibleReturn -from danswer.chat.models import AnswerQuestionStreamReturn from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE from danswer.file_store.utils import InMemoryChatFile -from danswer.llm.answering.doc_pruning import prune_documents from danswer.llm.answering.models import AnswerStyleConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import StreamProcessor -from danswer.llm.answering.prompts.citations_prompt import build_citations_prompt -from danswer.llm.answering.prompts.quotes_prompt import ( - build_quotes_prompt, +from danswer.llm.answering.prompts.build import AnswerPromptBuilder +from danswer.llm.answering.prompts.build import default_build_system_message +from danswer.llm.answering.prompts.build import default_build_user_message +from danswer.llm.answering.prompts.citations_prompt import ( + build_citations_system_message, ) +from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message +from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message from danswer.llm.answering.stream_processing.citation_processing import ( build_citation_processor, ) @@ -27,9 +32,30 @@ from danswer.llm.answering.stream_processing.quotes_processing import ( ) from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import message_generator_to_string_generator +from danswer.tools.force import filter_tools_for_force_tool_use +from danswer.tools.force import ForceUseTool +from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID +from danswer.tools.images.image_generation_tool import ImageGenerationResponse +from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.images.prompt import build_image_generation_user_prompt +from danswer.tools.message import build_tool_message +from danswer.tools.message import ToolCallSummary +from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS +from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID +from danswer.tools.search.search_tool import SearchResponseSummary +from danswer.tools.search.search_tool import SearchTool +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.tools.tool_runner import ( + check_which_tools_should_run_for_non_tool_calling_llm, +) +from danswer.tools.tool_runner import ToolRunKickoff +from danswer.tools.tool_runner import ToolRunner +from danswer.tools.utils import explicit_tool_calling_supported -def _get_stream_processor( +def _get_answer_stream_processor( context_docs: list[LlmDoc], search_order_docs: list[LlmDoc], answer_style_configs: AnswerStyleConfig, @@ -46,21 +72,29 @@ def _get_stream_processor( raise RuntimeError("Not implemented yet") +AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolRunKickoff | ToolResponse] + + class Answer: def __init__( self, question: str, - docs: list[LlmDoc], answer_style_config: AnswerStyleConfig, llm: LLM, prompt_config: PromptConfig, # must be the same length as `docs`. If None, all docs are considered "relevant" - doc_relevance_list: list[bool] | None = None, message_history: list[PreviousMessage] | None = None, single_message_history: str | None = None, # newly passed in files to include as part of this question latest_query_files: list[InMemoryChatFile] | None = None, files: list[InMemoryChatFile] | None = None, + tools: list[Tool] | None = None, + # if specified, tells the LLM to always this tool + # NOTE: for native tool-calling, this is only supported by OpenAI atm, + # but we only support them anyways + force_use_tool: ForceUseTool | None = None, + # if set to True, then never use the LLMs provided tool-calling functonality + skip_explicit_tool_calling: bool = False, ) -> None: if single_message_history and message_history: raise ValueError( @@ -68,12 +102,14 @@ class Answer: ) self.question = question - self.docs = docs self.latest_query_files = latest_query_files or [] self.file_id_to_file = {file.file_id: file for file in (files or [])} - self.doc_relevance_list = doc_relevance_list + self.tools = tools or [] + self.force_use_tool = force_use_tool + self.skip_explicit_tool_calling = skip_explicit_tool_calling + self.message_history = message_history or [] # used for QA flow where we only want to send a single message self.single_message_history = single_message_history @@ -86,83 +122,304 @@ class Answer: self._final_prompt: list[BaseMessage] | None = None - self._pruned_docs: list[LlmDoc] | None = None - self._streamed_output: list[str] | None = None - self._processed_stream: list[AnswerQuestionPossibleReturn] | None = None - - @property - def pruned_docs(self) -> list[LlmDoc]: - if self._pruned_docs is not None: - return self._pruned_docs - - self._pruned_docs = prune_documents( - docs=self.docs, - doc_relevance_list=self.doc_relevance_list, - prompt_config=self.prompt_config, - llm_config=self.llm.config, - question=self.question, - document_pruning_config=self.answer_style_config.document_pruning_config, - ) - return self._pruned_docs - - @property - def final_prompt(self) -> list[BaseMessage]: - if self._final_prompt is not None: - return self._final_prompt + self._processed_stream: list[ + AnswerQuestionPossibleReturn | ToolResponse | ToolRunKickoff + ] | None = None + def _update_prompt_builder_for_search_tool( + self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc] + ) -> None: if self.answer_style_config.citation_config: - self._final_prompt = build_citations_prompt( - question=self.question, - message_history=self.message_history, - llm_config=self.llm.config, - prompt_config=self.prompt_config, - context_docs=self.pruned_docs, - latest_query_files=self.latest_query_files, - all_doc_useful=self.answer_style_config.citation_config.all_docs_useful, - llm_tokenizer_encode_func=self.llm_tokenizer.encode, - history_message=self.single_message_history or "", + prompt_builder.update_system_prompt( + build_citations_system_message(self.prompt_config) + ) + prompt_builder.update_user_prompt( + build_citations_user_message( + question=self.question, + prompt_config=self.prompt_config, + context_docs=final_context_documents, + files=self.latest_query_files, + all_doc_useful=( + self.answer_style_config.citation_config.all_docs_useful + if self.answer_style_config.citation_config + else False + ), + ) ) elif self.answer_style_config.quotes_config: - # NOTE: quotes prompt doesn't currently support files - # this is okay for now, since the search UI (which uses this) - # doesn't support image upload - self._final_prompt = build_quotes_prompt( - question=self.question, - context_docs=self.pruned_docs, - history_str=self.single_message_history or "", - prompt=self.prompt_config, + prompt_builder.update_user_prompt( + build_quotes_user_message( + question=self.question, + context_docs=final_context_documents, + history_str=self.single_message_history or "", + prompt=self.prompt_config, + ) ) - return cast(list[BaseMessage], self._final_prompt) + def _raw_output_for_explicit_tool_calling_llms( + self, + ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) + + tool_call_chunk: AIMessageChunk | None = None + if self.force_use_tool and self.force_use_tool.args is not None: + # if we are forcing a tool WITH args specified, we don't need to check which tools to run + # / need to generate the args + tool_call_chunk = AIMessageChunk( + content="", + ) + tool_call_chunk.tool_calls = [ + { + "name": self.force_use_tool.tool_name, + "args": self.force_use_tool.args, + "id": str(uuid4()), + } + ] + else: + # if tool calling is supported, first try the raw message + # to see if we don't need to use any tools + prompt_builder.update_system_prompt( + default_build_system_message(self.prompt_config) + ) + prompt_builder.update_user_prompt( + default_build_user_message( + self.question, self.prompt_config, self.latest_query_files + ) + ) + prompt = prompt_builder.build() + final_tool_definitions = [ + tool.tool_definition() + for tool in filter_tools_for_force_tool_use( + self.tools, self.force_use_tool + ) + ] + for message in self.llm.stream( + prompt=prompt, + tools=final_tool_definitions if final_tool_definitions else None, + tool_choice="required" if self.force_use_tool else None, + ): + if isinstance(message, AIMessageChunk) and ( + message.tool_call_chunks or message.tool_calls + ): + if tool_call_chunk is None: + tool_call_chunk = message + else: + tool_call_chunk += message # type: ignore + else: + if message.content: + yield cast(str, message.content) + + if not tool_call_chunk: + return # no tool call needed + + # if we have a tool call, we need to call the tool + tool_call_requests = tool_call_chunk.tool_calls + for tool_call_request in tool_call_requests: + tool = [ + tool for tool in self.tools if tool.name() == tool_call_request["name"] + ][0] + tool_args = ( + self.force_use_tool.args + if self.force_use_tool and self.force_use_tool.args + else tool_call_request["args"] + ) + + tool_runner = ToolRunner(tool, tool_args) + yield tool_runner.kickoff() + yield from tool_runner.tool_responses() + + tool_call_summary = ToolCallSummary( + tool_call_request=tool_call_chunk, + tool_call_result=build_tool_message( + tool_call_request, tool_runner.tool_message_content() + ), + ) + + if tool.name() == SearchTool.name(): + self._update_prompt_builder_for_search_tool(prompt_builder, []) + elif tool.name() == ImageGenerationTool.name(): + prompt_builder.update_user_prompt( + build_image_generation_user_prompt( + query=self.question, + ) + ) + prompt = prompt_builder.build(tool_call_summary=tool_call_summary) + + yield from message_generator_to_string_generator( + self.llm.stream( + prompt=prompt, + tools=[tool.tool_definition() for tool in self.tools], + ) + ) - @property - def raw_streamed_output(self) -> Iterator[str]: - if self._streamed_output is not None: - yield from self._streamed_output return - streamed_output = [] - for message in self.llm.stream(self.final_prompt): - streamed_output.append(message) - yield message + def _raw_output_for_non_explicit_tool_calling_llms( + self, + ) -> Iterator[str | ToolRunKickoff | ToolResponse]: + prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config) + chosen_tool_and_args: tuple[Tool, dict] | None = None - self._streamed_output = streamed_output + if self.force_use_tool: + # if we are forcing a tool, we don't need to check which tools to run + tool = next( + iter( + [ + tool + for tool in self.tools + if tool.name() == self.force_use_tool.tool_name + ] + ), + None, + ) + if not tool: + raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found") + + tool_args = ( + self.force_use_tool.args + if self.force_use_tool.args + else tool.get_args_for_non_tool_calling_llm( + query=self.question, + history=self.message_history, + llm=self.llm, + force_run=True, + ) + ) + + if tool_args is None: + raise RuntimeError(f"Tool '{tool.name()}' did not return args") + + chosen_tool_and_args = (tool, tool_args) + else: + all_tool_args = check_which_tools_should_run_for_non_tool_calling_llm( + tools=self.tools, + query=self.question, + history=self.message_history, + llm=self.llm, + ) + for ind, args in enumerate(all_tool_args): + if args is not None: + chosen_tool_and_args = (self.tools[ind], args) + # for now, just pick the first tool selected + break + + if not chosen_tool_and_args: + prompt_builder.update_system_prompt( + default_build_system_message(self.prompt_config) + ) + prompt_builder.update_user_prompt( + default_build_user_message( + self.question, self.prompt_config, self.latest_query_files + ) + ) + prompt = prompt_builder.build() + yield from message_generator_to_string_generator( + self.llm.stream(prompt=prompt) + ) + return + + tool, tool_args = chosen_tool_and_args + tool_runner = ToolRunner(tool, tool_args) + yield tool_runner.kickoff() + + if tool.name() == SearchTool.name(): + final_context_documents = None + for response in tool_runner.tool_responses(): + if response.id == FINAL_CONTEXT_DOCUMENTS: + final_context_documents = cast(list[LlmDoc], response.response) + yield response + + if final_context_documents is None: + raise RuntimeError("SearchTool did not return final context documents") + + self._update_prompt_builder_for_search_tool( + prompt_builder, final_context_documents + ) + elif tool.name() == ImageGenerationTool.name(): + img_urls = [] + for response in tool_runner.tool_responses(): + if response.id == IMAGE_GENERATION_RESPONSE_ID: + img_generation_response = cast( + list[ImageGenerationResponse], response.response + ) + img_urls = [img.url for img in img_generation_response] + break + yield response + + prompt_builder.update_user_prompt( + build_image_generation_user_prompt( + query=self.question, + img_urls=img_urls, + ) + ) + + prompt = prompt_builder.build() + yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt)) @property - def processed_streamed_output(self) -> AnswerQuestionStreamReturn: + def processed_streamed_output(self) -> AnswerStream: if self._processed_stream is not None: yield from self._processed_stream return - process_stream_fn = _get_stream_processor( - context_docs=self.pruned_docs, - search_order_docs=self.docs, - answer_style_configs=self.answer_style_config, + output_generator = ( + self._raw_output_for_explicit_tool_calling_llms() + if explicit_tool_calling_supported( + self.llm.config.model_provider, self.llm.config.model_name + ) + and not self.skip_explicit_tool_calling + else self._raw_output_for_non_explicit_tool_calling_llms() ) + def _process_stream( + stream: Iterator[ToolRunKickoff | ToolResponse | str], + ) -> AnswerStream: + message = None + + # special things we need to keep track of for the SearchTool + search_results: list[ + LlmDoc + ] | None = None # raw results that will be displayed to the user + final_context_docs: list[ + LlmDoc + ] | None = None # processed docs to feed into the LLM + + for message in stream: + if isinstance(message, ToolRunKickoff): + yield message + elif isinstance(message, ToolResponse): + if message.id == SEARCH_RESPONSE_SUMMARY_ID: + search_results = [ + llm_doc_from_inference_section(section) + for section in cast( + SearchResponseSummary, message.response + ).top_sections + ] + elif message.id == FINAL_CONTEXT_DOCUMENTS: + final_context_docs = cast(list[LlmDoc], message.response) + yield message + else: + # assumes all tool responses will come first, then the final answer + break + + process_answer_stream_fn = _get_answer_stream_processor( + context_docs=final_context_docs or [], + # if doc selection is enabled, then search_results will be None, + # so we need to use the final_context_docs + search_order_docs=search_results or final_context_docs or [], + answer_style_configs=self.answer_style_config, + ) + + def _stream() -> Iterator[str]: + if message: + yield cast(str, message) + yield from cast(Iterator[str], stream) + + yield from process_answer_stream_fn(_stream()) + processed_stream = [] - for processed_packet in process_stream_fn(self.raw_streamed_output): + for processed_packet in _process_stream(output_generator): processed_stream.append(processed_packet) yield processed_packet diff --git a/backend/danswer/llm/answering/doc_pruning.py b/backend/danswer/llm/answering/doc_pruning.py index fa243895a0c4..5a43ab3c6fd2 100644 --- a/backend/danswer/llm/answering/doc_pruning.py +++ b/backend/danswer/llm/answering/doc_pruning.py @@ -1,3 +1,4 @@ +import json from copy import deepcopy from typing import TypeVar @@ -14,6 +15,7 @@ from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import tokenizer_trim_content from danswer.prompts.prompt_utils import build_doc_context_str from danswer.search.models import InferenceChunk +from danswer.tools.search.search_utils import llm_doc_to_dict from danswer.utils.logger import setup_logger @@ -35,9 +37,13 @@ def _compute_limit( max_chunks: int | None, max_window_percentage: float | None, max_tokens: int | None, + tool_token_count: int, ) -> int: llm_max_document_tokens = compute_max_document_tokens( - prompt_config=prompt_config, llm_config=llm_config, actual_user_input=question + prompt_config=prompt_config, + llm_config=llm_config, + tool_token_count=tool_token_count, + actual_user_input=question, ) window_percentage_based_limit = ( @@ -88,6 +94,7 @@ def _apply_pruning( token_limit: int, is_manually_selected_docs: bool, use_sections: bool, + using_tool_message: bool, ) -> list[LlmDoc]: llm_tokenizer = get_default_llm_tokenizer() docs = deepcopy(docs) # don't modify in place @@ -101,18 +108,20 @@ def _apply_pruning( final_doc_ind = None total_tokens = 0 for ind, llm_doc in enumerate(docs): - doc_tokens = len( - llm_tokenizer.encode( - build_doc_context_str( - semantic_identifier=llm_doc.semantic_identifier, - source_type=llm_doc.source_type, - content=llm_doc.content, - metadata_dict=llm_doc.metadata, - updated_at=llm_doc.updated_at, - ind=ind, - ) + doc_str = ( + json.dumps(llm_doc_to_dict(llm_doc, ind)) + if using_tool_message + else build_doc_context_str( + semantic_identifier=llm_doc.semantic_identifier, + source_type=llm_doc.source_type, + content=llm_doc.content, + metadata_dict=llm_doc.metadata, + updated_at=llm_doc.updated_at, + ind=ind, ) ) + + doc_tokens = len(llm_tokenizer.encode(doc_str)) # if chunks, truncate chunks that are way too long # this can happen if the embedding model tokenizer is different # than the LLM tokenizer @@ -152,12 +161,12 @@ def _apply_pruning( "LLM context window exceeded. Please de-select some documents or shorten your query." ) - final_doc_desired_length = tokens_per_doc[final_doc_ind] - ( - total_tokens - token_limit - ) - final_doc_content_length = ( - final_doc_desired_length - _METADATA_TOKEN_ESTIMATE - ) + amount_to_truncate = total_tokens - token_limit + # NOTE: need to recalculate the length here, since the previous calculation included + # overhead from JSON-fying the doc / the metadata + final_doc_content_length = len( + llm_tokenizer.encode(docs[final_doc_ind].content) + ) - (amount_to_truncate) # this could occur if we only have space for the title / metadata # not ideal, but it's the most reasonable thing to do # NOTE: the frontend prevents documents from being selected if @@ -209,6 +218,7 @@ def prune_documents( max_chunks=document_pruning_config.max_chunks, max_window_percentage=document_pruning_config.max_window_percentage, max_tokens=document_pruning_config.max_tokens, + tool_token_count=document_pruning_config.tool_num_tokens, ) return _apply_pruning( docs=docs, @@ -216,4 +226,5 @@ def prune_documents( token_limit=doc_token_limit, is_manually_selected_docs=document_pruning_config.is_manually_selected_docs, use_sections=document_pruning_config.use_sections, + using_tool_message=document_pruning_config.using_tool_message, ) diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index cdeeb1eb69e7..a5248fac27a6 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -3,6 +3,10 @@ from collections.abc import Iterator from typing import Any from typing import TYPE_CHECKING +from langchain.schema.messages import AIMessage +from langchain.schema.messages import BaseMessage +from langchain.schema.messages import HumanMessage +from langchain.schema.messages import SystemMessage from pydantic import BaseModel from pydantic import Field from pydantic import root_validator @@ -11,6 +15,7 @@ from danswer.chat.models import AnswerQuestionStreamReturn from danswer.configs.constants import MessageType from danswer.file_store.models import InMemoryChatFile from danswer.llm.override_models import PromptOverride +from danswer.llm.utils import build_content_with_imgs if TYPE_CHECKING: from danswer.db.models import ChatMessage @@ -46,6 +51,15 @@ class PreviousMessage(BaseModel): ], ) + def to_langchain_msg(self) -> BaseMessage: + content = build_content_with_imgs(self.message, self.files) + if self.message_type == MessageType.USER: + return HumanMessage(content=content) + elif self.message_type == MessageType.ASSISTANT: + return AIMessage(content=content) + else: + return SystemMessage(content=content) + class DocumentPruningConfig(BaseModel): max_chunks: int | None = None @@ -59,6 +73,11 @@ class DocumentPruningConfig(BaseModel): # If user specifies to include additional context chunks for each match, then different pruning # is used. As many Sections as possible are included, and the last Section is truncated use_sections: bool = False + # If using tools, then we need to consider the tool length + tool_num_tokens: int = 0 + # If using a tool message to represent the docs, then we have to JSON serialize + # the document content, which adds to the token count. + using_tool_message: bool = False class CitationConfig(BaseModel): diff --git a/backend/danswer/llm/answering/prompts/build.py b/backend/danswer/llm/answering/prompts/build.py new file mode 100644 index 000000000000..2c459731b42e --- /dev/null +++ b/backend/danswer/llm/answering/prompts/build.py @@ -0,0 +1,125 @@ +from collections.abc import Callable +from typing import cast + +from langchain_core.messages import BaseMessage +from langchain_core.messages import HumanMessage +from langchain_core.messages import SystemMessage + +from danswer.file_store.models import InMemoryChatFile +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig +from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens +from danswer.llm.interfaces import LLMConfig +from danswer.llm.utils import build_content_with_imgs +from danswer.llm.utils import check_message_tokens +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.llm.utils import translate_history_to_basemessages +from danswer.prompts.chat_prompts import ADDITIONAL_INFO +from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT +from danswer.prompts.prompt_utils import drop_messages_history_overflow +from danswer.prompts.prompt_utils import get_current_llm_day_time +from danswer.tools.message import ToolCallSummary + + +def default_build_system_message( + prompt_config: PromptConfig, +) -> SystemMessage | None: + system_prompt = prompt_config.system_prompt.strip() + if prompt_config.datetime_aware: + if system_prompt: + system_prompt += ADDITIONAL_INFO.format( + datetime_info=get_current_llm_day_time() + ) + else: + system_prompt = get_current_llm_day_time() + + if not system_prompt: + return None + + system_msg = SystemMessage(content=system_prompt) + + return system_msg + + +def default_build_user_message( + user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = [] +) -> HumanMessage: + user_prompt = ( + CHAT_USER_CONTEXT_FREE_PROMPT.format( + task_prompt=prompt_config.task_prompt, user_query=user_query + ) + if prompt_config.task_prompt + else user_query + ) + user_prompt = user_prompt.strip() + user_msg = HumanMessage( + content=build_content_with_imgs(user_prompt, files) if files else user_prompt + ) + return user_msg + + +class AnswerPromptBuilder: + def __init__( + self, message_history: list[PreviousMessage], llm_config: LLMConfig + ) -> None: + self.max_tokens = compute_max_llm_input_tokens(llm_config) + + ( + self.message_history, + self.history_token_cnts, + ) = translate_history_to_basemessages(message_history) + + self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = None + self.user_message_and_token_cnt: tuple[HumanMessage, int] | None = None + + llm_tokenizer = get_default_llm_tokenizer() + self.llm_tokenizer_encode_func = cast( + Callable[[str], list[int]], llm_tokenizer.encode + ) + + def update_system_prompt(self, system_message: SystemMessage | None) -> None: + if not system_message: + self.system_message_and_token_cnt = None + return + + self.system_message_and_token_cnt = ( + system_message, + check_message_tokens(system_message, self.llm_tokenizer_encode_func), + ) + + def update_user_prompt(self, user_message: HumanMessage) -> None: + if not user_message: + self.user_message_and_token_cnt = None + return + + self.user_message_and_token_cnt = ( + user_message, + check_message_tokens(user_message, self.llm_tokenizer_encode_func), + ) + + def build( + self, tool_call_summary: ToolCallSummary | None = None + ) -> list[BaseMessage]: + if not self.user_message_and_token_cnt: + raise ValueError("User message must be set before building prompt") + + final_messages_with_tokens: list[tuple[BaseMessage, int]] = [] + if self.system_message_and_token_cnt: + final_messages_with_tokens.append(self.system_message_and_token_cnt) + + final_messages_with_tokens.extend( + [ + (self.message_history[i], self.history_token_cnts[i]) + for i in range(len(self.message_history)) + ] + ) + + final_messages_with_tokens.append(self.user_message_and_token_cnt) + + if tool_call_summary: + final_messages_with_tokens.append((tool_call_summary.tool_call_request, 0)) + final_messages_with_tokens.append((tool_call_summary.tool_call_result, 0)) + + return drop_messages_history_overflow( + final_messages_with_tokens, self.max_tokens + ) diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 716ace320645..10bf81aebab2 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -1,8 +1,5 @@ -from collections.abc import Callable from functools import lru_cache -from typing import cast -from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage @@ -12,23 +9,17 @@ from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MA from danswer.db.chat import get_default_prompt from danswer.db.models import Persona from danswer.file_store.utils import InMemoryChatFile -from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig from danswer.llm.factory import get_llm_for_persona from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_number_of_tokens -from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import get_max_input_tokens -from danswer.llm.utils import translate_history_to_basemessages from danswer.prompts.chat_prompts import ADDITIONAL_INFO -from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT -from danswer.prompts.chat_prompts import NO_CITATION_STATEMENT from danswer.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT from danswer.prompts.constants import DEFAULT_IGNORE_STATEMENT -from danswer.prompts.direct_qa_prompts import ( - CITATIONS_PROMPT, -) +from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT +from danswer.prompts.direct_qa_prompts import CITATIONS_PROMPT_FOR_TOOL_CALLING from danswer.prompts.prompt_utils import build_complete_context_str from danswer.prompts.prompt_utils import build_task_prompt_reminders from danswer.prompts.prompt_utils import get_current_llm_day_time @@ -42,59 +33,6 @@ from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.search.models import InferenceChunk -_PER_MESSAGE_TOKEN_BUFFER = 7 - - -def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: - """From the back, find the index of the last element to include - before the list exceeds the maximum""" - running_sum = 0 - - last_ind = 0 - for i in range(len(lst) - 1, -1, -1): - running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER - if running_sum > max_prompt_tokens: - last_ind = i + 1 - break - if last_ind >= len(lst): - raise ValueError("Last message alone is too large!") - return last_ind - - -def _drop_messages_history_overflow( - system_msg: BaseMessage | None, - system_token_count: int, - history_msgs: list[BaseMessage], - history_token_counts: list[int], - final_msg: BaseMessage, - final_msg_token_count: int, - max_allowed_tokens: int, -) -> list[BaseMessage]: - """As message history grows, messages need to be dropped starting from the furthest in the past. - The System message should be kept if at all possible and the latest user input which is inserted in the - prompt template must be included""" - if len(history_msgs) != len(history_token_counts): - # This should never happen - raise ValueError("Need exactly 1 token count per message for tracking overflow") - - prompt: list[BaseMessage] = [] - - # Start dropping from the history if necessary - all_tokens = history_token_counts + [system_token_count, final_msg_token_count] - ind_prev_msg_start = find_last_index( - all_tokens, max_prompt_tokens=max_allowed_tokens - ) - - if system_msg and ind_prev_msg_start <= len(history_msgs): - prompt.append(system_msg) - - prompt.extend(history_msgs[ind_prev_msg_start:]) - - prompt.append(final_msg) - - return prompt - - def get_prompt_tokens(prompt_config: PromptConfig) -> int: # Note: currently custom prompts do not allow datetime aware, only default prompts return ( @@ -117,6 +55,7 @@ def compute_max_document_tokens( prompt_config: PromptConfig, llm_config: LLMConfig, actual_user_input: str | None = None, + tool_token_count: int = 0, max_llm_token_override: int | None = None, ) -> int: """Estimates the number of tokens available for context documents. Formula is roughly: @@ -146,7 +85,13 @@ def compute_max_document_tokens( else GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS ) - return max_input_tokens - prompt_tokens - user_input_tokens - _MISC_BUFFER + return ( + max_input_tokens + - prompt_tokens + - user_input_tokens + - tool_token_count + - _MISC_BUFFER + ) def compute_max_document_tokens_for_persona( @@ -173,19 +118,11 @@ def compute_max_llm_input_tokens(llm_config: LLMConfig) -> int: @lru_cache() -def _build_system_message( +def build_citations_system_message( prompt_config: PromptConfig, - context_exists: bool, - llm_tokenizer_encode_func: Callable, - citation_line: str = REQUIRE_CITATION_STATEMENT, - no_citation_line: str = NO_CITATION_STATEMENT, -) -> tuple[SystemMessage | None, int]: +) -> SystemMessage: system_prompt = prompt_config.system_prompt.strip() - if prompt_config.include_citations: - if context_exists: - system_prompt += citation_line - else: - system_prompt += no_citation_line + system_prompt += REQUIRE_CITATION_STATEMENT if prompt_config.datetime_aware: if system_prompt: system_prompt += ADDITIONAL_INFO.format( @@ -194,108 +131,40 @@ def _build_system_message( else: system_prompt = get_current_llm_day_time() - if not system_prompt: - return None, 0 - - token_count = len(llm_tokenizer_encode_func(system_prompt)) - system_msg = SystemMessage(content=system_prompt) - - return system_msg, token_count + return SystemMessage(content=system_prompt) -def _build_user_message( +def build_citations_user_message( question: str, prompt_config: PromptConfig, context_docs: list[LlmDoc] | list[InferenceChunk], files: list[InMemoryChatFile], all_doc_useful: bool, - history_message: str, -) -> tuple[HumanMessage, int]: - llm_tokenizer = get_default_llm_tokenizer() - llm_tokenizer_encode_func = cast(Callable[[str], list[int]], llm_tokenizer.encode) - - if not context_docs: - # Simpler prompt for cases where there is no context - user_prompt = ( - CHAT_USER_CONTEXT_FREE_PROMPT.format( - task_prompt=prompt_config.task_prompt, user_query=question - ) - if prompt_config.task_prompt - else question - ) - user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) - user_msg = HumanMessage( - content=build_content_with_imgs(user_prompt, files) - if files - else user_prompt - ) - return user_msg, token_count - - context_docs_str = build_complete_context_str(context_docs) - optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT - + history_message: str = "", +) -> HumanMessage: task_prompt_with_reminder = build_task_prompt_reminders(prompt_config) - user_prompt = CITATIONS_PROMPT.format( - optional_ignore_statement=optional_ignore, - context_docs_str=context_docs_str, - task_prompt=task_prompt_with_reminder, - user_query=question, - history_block=history_message, - ) + if context_docs: + context_docs_str = build_complete_context_str(context_docs) + optional_ignore = "" if all_doc_useful else DEFAULT_IGNORE_STATEMENT + + user_prompt = CITATIONS_PROMPT.format( + optional_ignore_statement=optional_ignore, + context_docs_str=context_docs_str, + task_prompt=task_prompt_with_reminder, + user_query=question, + history_block=history_message, + ) + else: + # if no context docs provided, assume we're in the tool calling flow + user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format( + task_prompt=task_prompt_with_reminder, + user_query=question, + ) user_prompt = user_prompt.strip() - token_count = len(llm_tokenizer_encode_func(user_prompt)) user_msg = HumanMessage( content=build_content_with_imgs(user_prompt, files) if files else user_prompt ) - return user_msg, token_count - - -def build_citations_prompt( - question: str, - message_history: list[PreviousMessage], - prompt_config: PromptConfig, - llm_config: LLMConfig, - context_docs: list[LlmDoc] | list[InferenceChunk], - latest_query_files: list[InMemoryChatFile], - all_doc_useful: bool, - history_message: str, - llm_tokenizer_encode_func: Callable, -) -> list[BaseMessage]: - context_exists = len(context_docs) > 0 - - system_message_or_none, system_tokens = _build_system_message( - prompt_config=prompt_config, - context_exists=context_exists, - llm_tokenizer_encode_func=llm_tokenizer_encode_func, - ) - - history_basemessages, history_token_counts = translate_history_to_basemessages( - message_history - ) - - # Be sure the context_docs passed to build_chat_user_message - # Is the same as passed in later for extracting citations - user_message, user_tokens = _build_user_message( - question=question, - prompt_config=prompt_config, - context_docs=context_docs, - files=latest_query_files, - all_doc_useful=all_doc_useful, - history_message=history_message, - ) - - final_prompt_msgs = _drop_messages_history_overflow( - system_msg=system_message_or_none, - system_token_count=system_tokens, - history_msgs=history_basemessages, - history_token_counts=history_token_counts, - final_msg=user_message, - final_msg_token_count=user_tokens, - max_allowed_tokens=compute_max_llm_input_tokens(llm_config), - ) - - return final_prompt_msgs + return user_msg diff --git a/backend/danswer/llm/answering/prompts/quotes_prompt.py b/backend/danswer/llm/answering/prompts/quotes_prompt.py index 841f3b5d5687..c0a36b10ec98 100644 --- a/backend/danswer/llm/answering/prompts/quotes_prompt.py +++ b/backend/danswer/llm/answering/prompts/quotes_prompt.py @@ -1,4 +1,3 @@ -from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from danswer.chat.models import LlmDoc @@ -20,7 +19,7 @@ def _build_weak_llm_quotes_prompt( history_str: str, prompt: PromptConfig, use_language_hint: bool, -) -> list[BaseMessage]: +) -> HumanMessage: """Since Danswer supports a variety of LLMs, this less demanding prompt is provided as an option to use with weaker LLMs such as small version, low float precision, quantized, or distilled models. It only uses one context document and has very weak requirements of @@ -36,7 +35,7 @@ def _build_weak_llm_quotes_prompt( task_prompt=prompt.task_prompt, user_query=question, ) - return [HumanMessage(content=prompt_str)] + return HumanMessage(content=prompt_str) def _build_strong_llm_quotes_prompt( @@ -45,7 +44,7 @@ def _build_strong_llm_quotes_prompt( history_str: str, prompt: PromptConfig, use_language_hint: bool, -) -> list[BaseMessage]: +) -> HumanMessage: context_block = "" if context_docs: context_docs_str = build_complete_context_str(context_docs) @@ -63,16 +62,38 @@ def _build_strong_llm_quotes_prompt( user_query=question, language_hint_or_none=LANGUAGE_HINT.strip() if use_language_hint else "", ).strip() - return [HumanMessage(content=full_prompt)] + return HumanMessage(content=full_prompt) -def build_quotes_prompt( +def build_quotes_user_message( question: str, context_docs: list[LlmDoc] | list[InferenceChunk], history_str: str, prompt: PromptConfig, use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), -) -> list[BaseMessage]: +) -> HumanMessage: + prompt_builder = ( + _build_weak_llm_quotes_prompt + if QA_PROMPT_OVERRIDE == "weak" + else _build_strong_llm_quotes_prompt + ) + + return prompt_builder( + question=question, + context_docs=context_docs, + history_str=history_str, + prompt=prompt, + use_language_hint=use_language_hint, + ) + + +def build_quotes_prompt( + question: str, + context_docs: list[LlmDoc] | list[InferenceChunk], + history_str: str, + prompt: PromptConfig, + use_language_hint: bool = bool(MULTILINGUAL_QUERY_EXPANSION), +) -> HumanMessage: prompt_builder = ( _build_weak_llm_quotes_prompt if QA_PROMPT_OVERRIDE == "weak" diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index 42d4214f761d..326650bcc442 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -1,12 +1,25 @@ -import abc +import json import os from collections.abc import Iterator from typing import Any +from typing import cast -import litellm # type:ignore -from langchain.chat_models.base import BaseChatModel +import litellm # type: ignore from langchain.schema.language_model import LanguageModelInput -from langchain_community.chat_models import ChatLiteLLM +from langchain_core.messages import AIMessage +from langchain_core.messages import AIMessageChunk +from langchain_core.messages import BaseMessage +from langchain_core.messages import BaseMessageChunk +from langchain_core.messages import ChatMessage +from langchain_core.messages import ChatMessageChunk +from langchain_core.messages import FunctionMessage +from langchain_core.messages import FunctionMessageChunk +from langchain_core.messages import HumanMessage +from langchain_core.messages import HumanMessageChunk +from langchain_core.messages import SystemMessage +from langchain_core.messages import SystemMessageChunk +from langchain_core.messages.tool import ToolCallChunk +from langchain_core.messages.tool import ToolMessage from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING @@ -17,8 +30,7 @@ from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLMConfig -from danswer.llm.utils import message_generator_to_string_generator -from danswer.llm.utils import should_be_verbose +from danswer.llm.interfaces import ToolChoiceOptions from danswer.utils.logger import setup_logger @@ -30,63 +42,144 @@ litellm.drop_params = True litellm.telemetry = False -class LangChainChatLLM(LLM, abc.ABC): - @property - @abc.abstractmethod - def llm(self) -> BaseChatModel: - raise NotImplementedError +def _base_msg_to_role(msg: BaseMessage) -> str: + if isinstance(msg, HumanMessage) or isinstance(msg, HumanMessageChunk): + return "user" + if isinstance(msg, AIMessage) or isinstance(msg, AIMessageChunk): + return "assistant" + if isinstance(msg, SystemMessage) or isinstance(msg, SystemMessageChunk): + return "system" + if isinstance(msg, FunctionMessage) or isinstance(msg, FunctionMessageChunk): + return "function" + return "unknown" - @staticmethod - def _log_prompt(prompt: LanguageModelInput) -> None: - if isinstance(prompt, list): - for ind, msg in enumerate(prompt): - logger.debug(f"Message {ind}:\n{msg.content}") - if isinstance(prompt, str): - logger.debug(f"Prompt:\n{prompt}") - def log_model_configs(self) -> None: - llm_dict = {k: v for k, v in self.llm.__dict__.items() if v} - llm_dict.pop("client") - logger.info( - f"LLM Model Class: {self.llm.__class__.__name__}, Model Config: {llm_dict}" +def _convert_litellm_message_to_langchain_message( + litellm_message: litellm.Message, +) -> BaseMessage: + # Extracting the basic attributes from the litellm message + content = litellm_message.content + role = litellm_message.role + + # Handling function calls and tool calls if present + tool_calls = ( + cast( + list[litellm.utils.ChatCompletionMessageToolCall], + litellm_message.tool_calls, ) + if hasattr(litellm_message, "tool_calls") + else [] + ) - def invoke(self, prompt: LanguageModelInput) -> str: - if LOG_ALL_MODEL_INTERACTIONS: - self._log_prompt(prompt) + # Create the appropriate langchain message based on the role + if role == "user": + return HumanMessage(content=content) + elif role == "assistant": + return AIMessage( + content=content, + tool_calls=[ + { + "name": tool_call.function.name or "", + "args": json.loads(tool_call.function.arguments), + "id": tool_call.id, + } + for tool_call in tool_calls + ], + ) + elif role == "system": + return SystemMessage(content=content) + else: + raise ValueError(f"Unknown role type received: {role}") - model_raw = self.llm.invoke(prompt).content - if LOG_ALL_MODEL_INTERACTIONS: - logger.debug(f"Raw Model Output:\n{model_raw}") - if not isinstance(model_raw, str): - raise RuntimeError( - "Model output inconsistent with expected type, " - "is this related to a library upgrade?" +def _convert_message_to_dict(message: BaseMessage) -> dict: + """Adapted from langchain_community.chat_models.litellm._convert_message_to_dict""" + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + if message.tool_calls: + message_dict["tool_calls"] = [ + { + "id": tool_call.get("id"), + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["args"]), + }, + "type": "function", + "index": 0, # only support a single tool call atm + } + for tool_call in message.tool_calls + ] + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } + elif isinstance(message, ToolMessage): + message_dict = { + "tool_call_id": message.tool_call_id, + "role": "tool", + "name": message.name or "", + "content": message.content, + } + else: + raise ValueError(f"Got unknown type {message}") + if "name" in message.additional_kwargs: + message_dict["name"] = message.additional_kwargs["name"] + return message_dict + + +def _convert_delta_to_message_chunk( + _dict: dict[str, Any], curr_msg: BaseMessage | None +) -> BaseMessageChunk: + """Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk""" + role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None) + content = _dict.get("content") or "" + additional_kwargs = {} + if _dict.get("function_call"): + additional_kwargs.update({"function_call": dict(_dict["function_call"])}) + tool_calls = cast( + list[litellm.utils.ChatCompletionDeltaToolCall] | None, _dict.get("tool_calls") + ) + + if role == "user": + return HumanMessageChunk(content=content) + elif role == "assistant": + if tool_calls: + tool_call = tool_calls[0] + tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or "" + + tool_call_chunk = ToolCallChunk( + name=tool_name, + id=tool_call.id, + args=tool_call.function.arguments, + index=0, # only support a single tool call atm ) + return AIMessageChunk( + content=content, + additional_kwargs=additional_kwargs, + tool_call_chunks=[tool_call_chunk], + ) + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system": + return SystemMessageChunk(content=content) + elif role == "function": + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role: + return ChatMessageChunk(content=content, role=role) - return model_raw - - def stream(self, prompt: LanguageModelInput) -> Iterator[str]: - if LOG_ALL_MODEL_INTERACTIONS: - self.log_model_configs() - self._log_prompt(prompt) - - if DISABLE_LITELLM_STREAMING: - yield self.invoke(prompt) - return - - output_tokens = [] - for token in message_generator_to_string_generator(self.llm.stream(prompt)): - output_tokens.append(token) - yield token - - full_output = "".join(output_tokens) - if LOG_ALL_MODEL_INTERACTIONS: - logger.debug(f"Raw Model Output:\n{full_output}") + raise ValueError(f"Unknown role: {role}") -class DefaultMultiLLM(LangChainChatLLM): +class DefaultMultiLLM(LLM): """Uses Litellm library to allow easy configuration to use a multitude of LLMs See https://python.langchain.com/docs/integrations/chat/litellm""" @@ -109,14 +202,16 @@ class DefaultMultiLLM(LangChainChatLLM): custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, ): + self._timeout = timeout self._model_provider = model_provider self._model_version = model_name self._temperature = temperature - - # Litellm Langchain integration currently doesn't take in the api key param - # Can place this in the call below once integration is in - litellm.api_key = api_key or "dummy-key" - litellm.api_version = api_version + self._api_key = api_key + self._api_base = api_base + self._api_version = api_version + self._custom_llm_provider = custom_llm_provider + self._max_output_tokens = max_output_tokens + self._custom_config = custom_config # NOTE: have to set these as environment variables for Litellm since # not all are able to passed in but they always support them set as env @@ -128,25 +223,60 @@ class DefaultMultiLLM(LangChainChatLLM): model_kwargs = ( DefaultMultiLLM.DEFAULT_MODEL_PARAMS if model_provider == "openai" else {} ) - if extra_headers: model_kwargs.update({"extra_headers": extra_headers}) - self._llm = ChatLiteLLM( # type: ignore - model=( - model_name if custom_llm_provider else f"{model_provider}/{model_name}" - ), - api_base=api_base, - custom_llm_provider=custom_llm_provider, - max_tokens=max_output_tokens, - temperature=temperature, - request_timeout=timeout, - # LiteLLM and some model providers don't handle these params well - # only turning it on for OpenAI - model_kwargs=model_kwargs, - verbose=should_be_verbose(), - max_retries=0, # retries are handled outside of langchain - ) + self._model_kwargs = model_kwargs + + @staticmethod + def _log_prompt(prompt: LanguageModelInput) -> None: + if isinstance(prompt, list): + for ind, msg in enumerate(prompt): + logger.debug(f"Message {ind}:\n{msg.content}") + if isinstance(prompt, str): + logger.debug(f"Prompt:\n{prompt}") + + def log_model_configs(self) -> None: + logger.info(f"Config: {self.config}") + + def _completion( + self, + prompt: LanguageModelInput, + tools: list[dict] | None, + tool_choice: ToolChoiceOptions | None, + stream: bool, + ) -> litellm.ModelResponse | litellm.CustomStreamWrapper: + if isinstance(prompt, list): + prompt = [ + _convert_message_to_dict(msg) if isinstance(msg, BaseMessage) else msg + for msg in prompt + ] + elif isinstance(prompt, str): + prompt = [_convert_message_to_dict(HumanMessage(content=prompt))] + + try: + return litellm.completion( + # model choice + model=f"{self.config.model_provider}/{self.config.model_name}", + api_key=self._api_key, + base_url=self._api_base, + api_version=self._api_version, + custom_llm_provider=self._custom_llm_provider, + # actual input + messages=prompt, + tools=tools, + tool_choice=tool_choice, + # streaming choice + stream=stream, + # model params + temperature=self._temperature, + max_tokens=self._max_output_tokens, + timeout=self._timeout, + **self._model_kwargs, + ) + except Exception as e: + # for break pointing + raise e @property def config(self) -> LLMConfig: @@ -154,8 +284,54 @@ class DefaultMultiLLM(LangChainChatLLM): model_provider=self._model_provider, model_name=self._model_version, temperature=self._temperature, + api_key=self._api_key, ) - @property - def llm(self) -> ChatLiteLLM: - return self._llm + def invoke( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> BaseMessage: + if LOG_ALL_MODEL_INTERACTIONS: + self.log_model_configs() + self._log_prompt(prompt) + + response = cast( + litellm.ModelResponse, self._completion(prompt, tools, tool_choice, False) + ) + return _convert_litellm_message_to_langchain_message( + response.choices[0].message + ) + + def stream( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> Iterator[BaseMessage]: + if LOG_ALL_MODEL_INTERACTIONS: + self.log_model_configs() + self._log_prompt(prompt) + + if DISABLE_LITELLM_STREAMING: + yield self.invoke(prompt) + return + + output = None + response = self._completion(prompt, tools, tool_choice, True) + for part in response: + if len(part["choices"]) == 0: + continue + delta = part["choices"][0]["delta"] + message_chunk = _convert_delta_to_message_chunk(delta, output) + if output is None: + output = message_chunk + else: + output += message_chunk + + yield message_chunk + + full_output = output.content if output else "" + if LOG_ALL_MODEL_INTERACTIONS: + logger.debug(f"Raw Model Output:\n{full_output}") diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 4c11a29a4284..2c4c029aa2dd 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -3,11 +3,14 @@ from collections.abc import Iterator import requests from langchain.schema.language_model import LanguageModelInput +from langchain_core.messages import AIMessage +from langchain_core.messages import BaseMessage from requests import Timeout from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.llm.interfaces import LLM +from danswer.llm.interfaces import ToolChoiceOptions from danswer.llm.utils import convert_lm_input_to_basic_string from danswer.utils.logger import setup_logger @@ -47,7 +50,7 @@ class CustomModelServer(LLM): self._max_output_tokens = max_output_tokens self._timeout = timeout - def _execute(self, input: LanguageModelInput) -> str: + def _execute(self, input: LanguageModelInput) -> AIMessage: headers = { "Content-Type": "application/json", } @@ -67,13 +70,24 @@ class CustomModelServer(LLM): raise Timeout(f"Model inference to {self._endpoint} timed out") from error response.raise_for_status() - return json.loads(response.content).get("generated_text", "") + response_content = json.loads(response.content).get("generated_text", "") + return AIMessage(content=response_content) def log_model_configs(self) -> None: logger.debug(f"Custom model at: {self._endpoint}") - def invoke(self, prompt: LanguageModelInput) -> str: + def invoke( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> BaseMessage: return self._execute(prompt) - def stream(self, prompt: LanguageModelInput) -> Iterator[str]: + def stream( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> Iterator[BaseMessage]: yield self._execute(prompt) diff --git a/backend/danswer/llm/gpt_4_all.py b/backend/danswer/llm/gpt_4_all.py deleted file mode 100644 index c7cf6a61557c..000000000000 --- a/backend/danswer/llm/gpt_4_all.py +++ /dev/null @@ -1,77 +0,0 @@ -from collections.abc import Iterator -from typing import Any - -from langchain.schema.language_model import LanguageModelInput - -from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import GEN_AI_TEMPERATURE -from danswer.llm.interfaces import LLM -from danswer.llm.utils import convert_lm_input_to_basic_string -from danswer.utils.logger import setup_logger - - -logger = setup_logger() - - -class DummyGPT4All: - """In the case of import failure due to architectural incompatibilities, - this module does not raise exceptions during server startup, - as long as the module isn't actually used""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - raise RuntimeError("GPT4All library not installed.") - - -try: - from gpt4all import GPT4All # type:ignore -except ImportError: - # Setting a low log level because users get scared when they see this - logger.debug( - "GPT4All library not installed. " - "If you wish to run GPT4ALL (in memory) to power Danswer's " - "Generative AI features, please install gpt4all==2.0.2." - ) - GPT4All = DummyGPT4All - - -class DanswerGPT4All(LLM): - """Option to run an LLM locally, however this is significantly slower and - answers tend to be much worse - - NOTE: currently unused, but kept for future reference / if we want to add this back. - """ - - @property - def requires_warm_up(self) -> bool: - """GPT4All models are lazy loaded, load them on server start so that the - first inference isn't extremely delayed""" - return True - - @property - def requires_api_key(self) -> bool: - return False - - def __init__( - self, - timeout: int, - model_version: str, - max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, - temperature: float = GEN_AI_TEMPERATURE, - ): - self.timeout = timeout - self.max_output_tokens = max_output_tokens - self.temperature = temperature - self.gpt4all_model = GPT4All(model_version) - - def log_model_configs(self) -> None: - logger.debug( - f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}" - ) - - def invoke(self, prompt: LanguageModelInput) -> str: - prompt_basic = convert_lm_input_to_basic_string(prompt) - return self.gpt4all_model.generate(prompt_basic) - - def stream(self, prompt: LanguageModelInput) -> Iterator[str]: - prompt_basic = convert_lm_input_to_basic_string(prompt) - return self.gpt4all_model.generate(prompt_basic, streaming=True) diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index c1cbe6253756..1f99383fae80 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -1,7 +1,9 @@ import abc from collections.abc import Iterator +from typing import Literal from langchain.schema.language_model import LanguageModelInput +from langchain_core.messages import BaseMessage from pydantic import BaseModel from danswer.utils.logger import setup_logger @@ -9,11 +11,14 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +ToolChoiceOptions = Literal["required"] | Literal["auto"] | Literal["none"] + class LLMConfig(BaseModel): model_provider: str model_name: str temperature: float + api_key: str | None class LLM(abc.ABC): @@ -39,9 +44,19 @@ class LLM(abc.ABC): raise NotImplementedError @abc.abstractmethod - def invoke(self, prompt: LanguageModelInput) -> str: + def invoke( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> BaseMessage: raise NotImplementedError @abc.abstractmethod - def stream(self, prompt: LanguageModelInput) -> Iterator[str]: + def stream( + self, + prompt: LanguageModelInput, + tools: list[dict] | None = None, + tool_choice: ToolChoiceOptions | None = None, + ) -> Iterator[BaseMessage]: raise NotImplementedError diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py index 4007e4b8ccd0..22c096146c4e 100644 --- a/backend/danswer/llm/utils.py +++ b/backend/danswer/llm/utils.py @@ -14,7 +14,6 @@ from langchain.schema import PromptValue from langchain.schema.language_model import LanguageModelInput from langchain.schema.messages import AIMessage from langchain.schema.messages import BaseMessage -from langchain.schema.messages import BaseMessageChunk from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from tiktoken.core import Encoding @@ -115,13 +114,18 @@ def translate_history_to_basemessages( def build_content_with_imgs( - message: str, files: list[InMemoryChatFile] -) -> str | list[str | dict]: # matching Langchain's BaseMessage content type - if not files: + message: str, + files: list[InMemoryChatFile] | None = None, + img_urls: list[str] | None = None, +) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type + if not files and not img_urls: return message + files = files or [] + img_urls = img_urls or [] + return cast( - list[str | dict], + list[str | dict[str, Any]], [ { "type": "text", @@ -137,6 +141,15 @@ def build_content_with_imgs( } for file in files if file.file_type == "image" + ] + + [ + { + "type": "image_url", + "image_url": { + "url": url, + }, + } + for url in img_urls ], ) @@ -188,20 +201,50 @@ def convert_lm_input_to_basic_string(lm_input: LanguageModelInput) -> str: return prompt_value.to_string() +def message_to_string(message: BaseMessage) -> str: + if not isinstance(message.content, str): + raise RuntimeError("LLM message not in expected format.") + + return message.content + + def message_generator_to_string_generator( - messages: Iterator[BaseMessageChunk], + messages: Iterator[BaseMessage], ) -> Iterator[str]: for message in messages: - if not isinstance(message.content, str): - raise RuntimeError("LLM message not in expected format.") - - yield message.content + yield message_to_string(message) def should_be_verbose() -> bool: return LOG_LEVEL == "debug" +# estimate of the number of tokens in an image url +# is correct when downsampling is used. Is very wrong when OpenAI does not downsample +# TODO: improve this +_IMG_TOKENS = 85 + + +def check_message_tokens( + message: BaseMessage, encode_fn: Callable[[str], list] | None = None +) -> int: + if isinstance(message.content, str): + return check_number_of_tokens(message.content, encode_fn) + + total_tokens = 0 + for part in message.content: + if isinstance(part, str): + total_tokens += check_number_of_tokens(part, encode_fn) + continue + + if part["type"] == "text": + total_tokens += check_number_of_tokens(part["text"], encode_fn) + elif part["type"] == "image_url": + total_tokens += _IMG_TOKENS + + return total_tokens + + def check_number_of_tokens( text: str, encode_fn: Callable[[str], list] | None = None ) -> int: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 81dcd9bb64c1..99e7e1c123ab 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -60,6 +60,7 @@ from danswer.server.features.folder.api import router as folder_router from danswer.server.features.persona.api import admin_router as admin_persona_router from danswer.server.features.persona.api import basic_router as persona_router from danswer.server.features.prompt.api import basic_router as prompt_router +from danswer.server.features.tool.api import router as tool_router from danswer.server.gpts.api import router as gpts_router from danswer.server.manage.administrative import router as admin_router from danswer.server.manage.get_state import router as state_router @@ -75,6 +76,9 @@ from danswer.server.query_and_chat.query_backend import ( from danswer.server.query_and_chat.query_backend import basic_router as query_router from danswer.server.settings.api import admin_router as settings_admin_router from danswer.server.settings.api import basic_router as settings_router +from danswer.tools.built_in_tools import auto_add_search_tool_to_personas +from danswer.tools.built_in_tools import load_builtin_tools +from danswer.tools.built_in_tools import refresh_built_in_tools_cache from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import RecordType @@ -199,6 +203,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: delete_old_default_personas(db_session) load_chat_yamls() + logger.info("Loading built-in tools") + load_builtin_tools(db_session) + refresh_built_in_tools_cache(db_session) + auto_add_search_tool_to_personas(db_session) + logger.info("Verifying Document Index(s) is/are available.") document_index = get_default_document_index( primary_index_name=db_embedding_model.index_name, @@ -257,6 +266,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, persona_router) include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, prompt_router) + include_router_with_global_prefix_prepended(application, tool_router) include_router_with_global_prefix_prepended(application, state_router) include_router_with_global_prefix_prepended(application, danswer_api_router) include_router_with_global_prefix_prepended(application, gpts_router) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 4fd492e441e5..8d40f9abaa19 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -1,9 +1,9 @@ from collections.abc import Callable from collections.abc import Iterator +from typing import cast from sqlalchemy.orm import Session -from danswer.chat.chat_utils import llm_doc_from_inference_section from danswer.chat.chat_utils import reorganize_citations from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece @@ -38,13 +38,18 @@ from danswer.one_shot_answer.models import QueryRephrase from danswer.one_shot_answer.qa_utils import combine_message_thread from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchRequest -from danswer.search.pipeline import SearchPipeline from danswer.search.utils import chunks_or_sections_to_search_docs from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase from danswer.server.query_and_chat.models import ChatMessageDetail from danswer.server.utils import get_json_line +from danswer.tools.force import ForceUseTool +from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID +from danswer.tools.search.search_tool import SearchResponseSummary +from danswer.tools.search.search_tool import SearchTool +from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID +from danswer.tools.tool import ToolResponse +from danswer.tools.tool_runner import ToolRunKickoff from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -60,6 +65,7 @@ AnswerObjectIterator = Iterator[ | StreamingError | ChatMessageDetail | CitationInfo + | ToolRunKickoff ] @@ -121,57 +127,6 @@ def stream_answer_objects( # In chat flow it's given back along with the documents yield QueryRephrase(rephrased_query=rephrased_query) - search_pipeline = SearchPipeline( - search_request=SearchRequest( - query=rephrased_query, - human_selected_filters=query_req.retrieval_options.filters, - persona=chat_session.persona, - offset=query_req.retrieval_options.offset, - limit=query_req.retrieval_options.limit, - skip_rerank=query_req.skip_rerank, - skip_llm_chunk_filter=query_req.skip_llm_chunk_filter, - chunks_above=query_req.chunks_above, - chunks_below=query_req.chunks_below, - full_doc=query_req.full_doc, - ), - user=user, - db_session=db_session, - bypass_acl=bypass_acl, - retrieval_metrics_callback=retrieval_metrics_callback, - rerank_metrics_callback=rerank_metrics_callback, - ) - - # First fetch and return the top chunks so the user can immediately see some results - top_sections = search_pipeline.reranked_sections - top_docs = chunks_or_sections_to_search_docs(top_sections) - - reference_db_search_docs = [ - create_db_search_doc(server_search_doc=top_doc, db_session=db_session) - for top_doc in top_docs - ] - - response_docs = [ - translate_db_search_doc_to_server_search_doc(db_search_doc) - for db_search_doc in reference_db_search_docs - ] - - initial_response = QADocsResponse( - rephrased_query=rephrased_query, - top_documents=response_docs, - predicted_flow=search_pipeline.predicted_flow, - predicted_search=search_pipeline.predicted_search_type, - applied_source_filters=search_pipeline.search_query.filters.source_type, - applied_time_cutoff=search_pipeline.search_query.filters.time_cutoff, - recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, - ) - yield initial_response - - # Yield the list of LLM selected chunks for showing the LLM selected icons in the UI - llm_relevance_filtering_response = LLMRelevanceFilterResponse( - relevant_chunk_indices=search_pipeline.relevant_chunk_indices - ) - yield llm_relevance_filtering_response - prompt = None if query_req.prompt_id is not None: prompt = get_prompt_by_id( @@ -196,34 +151,84 @@ def stream_answer_objects( commit=True, ) + llm = get_llm_for_persona(persona=chat_session.persona) + prompt_config = PromptConfig.from_model(prompt) + document_pruning_config = DocumentPruningConfig( + max_chunks=int( + chat_session.persona.num_chunks + if chat_session.persona.num_chunks is not None + else default_num_chunks + ), + max_tokens=max_document_tokens, + use_sections=query_req.chunks_above > 0 or query_req.chunks_below > 0, + ) + search_tool = SearchTool( + db_session=db_session, + user=user, + persona=chat_session.persona, + retrieval_options=query_req.retrieval_options, + prompt_config=prompt_config, + llm_config=llm.config, + pruning_config=document_pruning_config, + ) + answer_config = AnswerStyleConfig( citation_config=CitationConfig() if use_citations else None, quotes_config=QuotesConfig() if not use_citations else None, - document_pruning_config=DocumentPruningConfig( - max_chunks=int( - chat_session.persona.num_chunks - if chat_session.persona.num_chunks is not None - else default_num_chunks - ), - max_tokens=max_document_tokens, - use_sections=search_pipeline.ran_merge_chunk, - ), + document_pruning_config=document_pruning_config, ) answer = Answer( question=query_msg.message, - docs=[llm_doc_from_inference_section(section) for section in top_sections], answer_style_config=answer_config, prompt_config=PromptConfig.from_model(prompt), llm=get_llm_for_persona(persona=chat_session.persona), - doc_relevance_list=search_pipeline.section_relevance_list, single_message_history=history_str, + tools=[search_tool], + force_use_tool=ForceUseTool( + tool_name=search_tool.name(), + args={"query": rephrased_query}, + ), + # for now, don't use tool calling for this flow, as we haven't + # tested quotes with tool calling too much yet + skip_explicit_tool_calling=True, ) - yield from answer.processed_streamed_output + # won't be any ImageGenerationDisplay responses since that tool is never passed in + for packet in cast(AnswerObjectIterator, answer.processed_streamed_output): + # for one-shot flow, don't currently do anything with these + if isinstance(packet, ToolResponse): + if packet.id == SEARCH_RESPONSE_SUMMARY_ID: + search_response_summary = cast(SearchResponseSummary, packet.response) - reference_db_search_docs = [ - create_db_search_doc(server_search_doc=top_doc, db_session=db_session) - for top_doc in top_docs - ] + top_docs = chunks_or_sections_to_search_docs( + search_response_summary.top_sections + ) + + reference_db_search_docs = [ + create_db_search_doc( + server_search_doc=top_doc, db_session=db_session + ) + for top_doc in top_docs + ] + + response_docs = [ + translate_db_search_doc_to_server_search_doc(db_search_doc) + for db_search_doc in reference_db_search_docs + ] + + initial_response = QADocsResponse( + rephrased_query=rephrased_query, + top_documents=response_docs, + predicted_flow=search_response_summary.predicted_flow, + predicted_search=search_response_summary.predicted_search, + applied_source_filters=search_response_summary.final_filters.source_type, + applied_time_cutoff=search_response_summary.final_filters.time_cutoff, + recency_bias_multiplier=search_response_summary.recency_bias_multiplier, + ) + yield initial_response + elif packet.id == SECTION_RELEVANCE_LIST_ID: + yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response) + else: + yield packet # Saving Gen AI answer and responding with message info gen_ai_response_message = create_new_chat_message( diff --git a/backend/danswer/prompts/direct_qa_prompts.py b/backend/danswer/prompts/direct_qa_prompts.py index 6028ed89645d..ee1b492be8a5 100644 --- a/backend/danswer/prompts/direct_qa_prompts.py +++ b/backend/danswer/prompts/direct_qa_prompts.py @@ -112,6 +112,19 @@ CONTEXT: {{user_query}} """ +# with tool calling, the documents are in a separate "tool" message +# NOTE: need to add the extra line about "getting right to the point" since the +# tool calling models from OpenAI tend to be more verbose +CITATIONS_PROMPT_FOR_TOOL_CALLING = f""" +Refer to the provided context documents when responding to me.{DEFAULT_IGNORE_STATEMENT} \ +You should always get right to the point, and never use extraneous language. + +{{task_prompt}} + +{QUESTION_PAT.upper()} +{{user_query}} +""" + # For weak LLM which only takes one chunk and cannot output json # Also not requiring quotes as it tends to not work diff --git a/backend/danswer/prompts/prompt_utils.py b/backend/danswer/prompts/prompt_utils.py index 2f53a96a738c..af250c440c3c 100644 --- a/backend/danswer/prompts/prompt_utils.py +++ b/backend/danswer/prompts/prompt_utils.py @@ -1,5 +1,8 @@ from collections.abc import Sequence from datetime import datetime +from typing import cast + +from langchain_core.messages import BaseMessage from danswer.chat.models import LlmDoc from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION @@ -93,3 +96,65 @@ def build_complete_context_str( ) return context_str.strip() + + +_PER_MESSAGE_TOKEN_BUFFER = 7 + + +def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: + """From the back, find the index of the last element to include + before the list exceeds the maximum""" + running_sum = 0 + + last_ind = 0 + for i in range(len(lst) - 1, -1, -1): + running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER + if running_sum > max_prompt_tokens: + last_ind = i + 1 + break + if last_ind >= len(lst): + raise ValueError("Last message alone is too large!") + return last_ind + + +def drop_messages_history_overflow( + messages_with_token_cnts: list[tuple[BaseMessage, int]], + max_allowed_tokens: int, +) -> list[BaseMessage]: + """As message history grows, messages need to be dropped starting from the furthest in the past. + The System message should be kept if at all possible and the latest user input which is inserted in the + prompt template must be included""" + + final_messages: list[BaseMessage] = [] + messages, token_counts = cast( + tuple[list[BaseMessage], list[int]], zip(*messages_with_token_cnts) + ) + system_msg = ( + final_messages[0] + if final_messages and final_messages[0].type == "system" + else None + ) + + history_msgs = messages[:-1] + final_msg = messages[-1] + if final_msg.type != "human": + if final_msg.type != "tool": + raise ValueError("Last message must be user input OR a tool result") + else: + final_msgs = messages[-3:] + history_msgs = messages[:-3] + else: + final_msgs = [final_msg] + + # Start dropping from the history if necessary + ind_prev_msg_start = find_last_index( + token_counts, max_prompt_tokens=max_allowed_tokens + ) + + if system_msg and ind_prev_msg_start <= len(history_msgs): + final_messages.append(system_msg) + + final_messages.extend(history_msgs[ind_prev_msg_start:]) + final_messages.extend(final_msgs) + + return final_messages diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py index 88a153da4174..2ef3787c11b1 100644 --- a/backend/danswer/secondary_llm_flows/answer_validation.py +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -1,6 +1,7 @@ from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_to_string from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -52,7 +53,7 @@ def get_answer_validity( messages = _get_answer_validation_messages(query, answer) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = llm.invoke(filled_llm_prompt) + model_output = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(model_output) validity = _extract_validity(model_output) diff --git a/backend/danswer/secondary_llm_flows/chat_session_naming.py b/backend/danswer/secondary_llm_flows/chat_session_naming.py index aa604131bf30..5f4182e42946 100644 --- a/backend/danswer/secondary_llm_flows/chat_session_naming.py +++ b/backend/danswer/secondary_llm_flows/chat_session_naming.py @@ -5,6 +5,7 @@ from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import CHAT_NAMING from danswer.utils.logger import setup_logger @@ -39,7 +40,7 @@ def get_renamed_conversation_name( prompt_msgs = get_chat_rename_messages(history_str) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) - new_name_raw = llm.invoke(filled_llm_prompt) + new_name_raw = message_to_string(llm.invoke(filled_llm_prompt)) new_name = new_name_raw.strip().strip(' "') diff --git a/backend/danswer/secondary_llm_flows/choose_search.py b/backend/danswer/secondary_llm_flows/choose_search.py index 9e07bf647106..df3597d5641e 100644 --- a/backend/danswer/secondary_llm_flows/choose_search.py +++ b/backend/danswer/secondary_llm_flows/choose_search.py @@ -3,13 +3,12 @@ from langchain.schema import HumanMessage from langchain.schema import SystemMessage from danswer.chat.chat_utils import combine_message_chain -from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.answering.models import PreviousMessage from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_to_string from danswer.llm.utils import translate_danswer_msg_to_langchain from danswer.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE from danswer.prompts.chat_prompts import NO_SEARCH @@ -38,7 +37,7 @@ def check_if_need_search_multi_message( prompt_msgs.append(HumanMessage(content=f"{last_query}\n\n{REQUIRE_SEARCH_HINT}")) - model_out = llm.invoke(prompt_msgs) + model_out = message_to_string(llm.invoke(prompt_msgs)) if (NO_SEARCH.split()[0] + " ").lower() in model_out.lower(): return False @@ -47,10 +46,9 @@ def check_if_need_search_multi_message( def check_if_need_search( - query_message: ChatMessage, - history: list[ChatMessage], - llm: LLM | None = None, - disable_llm_check: bool = DISABLE_LLM_CHOOSE_SEARCH, + query: str, + history: list[PreviousMessage], + llm: LLM, ) -> bool: def _get_search_messages( question: str, @@ -67,27 +65,14 @@ def check_if_need_search( return messages - if disable_llm_check: - return True - - if llm is None: - try: - llm = get_default_llm() - except GenAIDisabledException: - # If Generative AI is turned off the always run Search as Danswer is being used - # as just a search engine - return True - history_str = combine_message_chain( messages=history, token_limit=GEN_AI_HISTORY_CUTOFF ) - prompt_msgs = _get_search_messages( - question=query_message.message, history_str=history_str - ) + prompt_msgs = _get_search_messages(question=query, history_str=history_str) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) - require_search_output = llm.invoke(filled_llm_prompt) + require_search_output = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(f"Run search prediction: {require_search_output}") diff --git a/backend/danswer/secondary_llm_flows/chunk_usefulness.py b/backend/danswer/secondary_llm_flows/chunk_usefulness.py index 2db06bdbafe2..d37feb0c0b1c 100644 --- a/backend/danswer/secondary_llm_flows/chunk_usefulness.py +++ b/backend/danswer/secondary_llm_flows/chunk_usefulness.py @@ -3,6 +3,7 @@ from collections.abc import Callable from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_to_string from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT from danswer.prompts.llm_chunk_filter import NONUSEFUL_PAT from danswer.utils.logger import setup_logger @@ -44,7 +45,7 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool: # When running in a batch, it takes as long as the longest thread # And when running a large batch, one may fail and take the whole timeout # instead cap it to 5 seconds - model_output = llm.invoke(filled_llm_prompt) + model_output = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(model_output) return _extract_usefulness(model_output) diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 1dd88e3fa57c..2f221bfa9030 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -1,14 +1,15 @@ from collections.abc import Callable -from typing import cast from danswer.chat.chat_utils import combine_message_chain from danswer.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage +from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_to_string from danswer.prompts.chat_prompts import HISTORY_QUERY_REPHRASE from danswer.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT from danswer.utils.logger import setup_logger @@ -41,7 +42,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str: messages = _get_rephrase_messages() filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = llm.invoke(filled_llm_prompt) + model_output = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(model_output) return model_output @@ -87,54 +88,42 @@ def get_contextual_rephrase_messages( def history_based_query_rephrase( - query_message: ChatMessage, - history: list[ChatMessage], - llm: LLM | None = None, + query: str, + history: list[ChatMessage] | list[PreviousMessage], + llm: LLM, size_heuristic: int = 200, punctuation_heuristic: int = 10, skip_first_rephrase: bool = False, ) -> str: - user_query = cast(str, query_message.message) - # Globally disabled, just use the exact user query if DISABLE_LLM_QUERY_REPHRASE: - return user_query - - if not user_query: - raise ValueError("Can't rephrase/search an empty query") - - if llm is None: - try: - llm = get_default_llm() - except GenAIDisabledException: - # If Generative AI is turned off, just return the original query - return user_query + return query # For some use cases, the first query should be untouched. Later queries must be rephrased # due to needing context but the first query has no context. if skip_first_rephrase and not history: - return user_query + return query # If it's a very large query, assume it's a copy paste which we may want to find exactly # or at least very closely, so don't rephrase it - if len(user_query) >= size_heuristic: - return user_query + if len(query) >= size_heuristic: + return query # If there is an unusually high number of punctuations, it's probably not natural language # so don't rephrase it - if count_punctuation(user_query) >= punctuation_heuristic: - return user_query + if count_punctuation(query) >= punctuation_heuristic: + return query history_str = combine_message_chain( messages=history, token_limit=GEN_AI_HISTORY_CUTOFF ) prompt_msgs = get_contextual_rephrase_messages( - question=user_query, history_str=history_str + question=query, history_str=history_str ) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) - rephrased_query = llm.invoke(filled_llm_prompt) + rephrased_query = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(f"Rephrased combined query: {rephrased_query}") @@ -169,7 +158,7 @@ def thread_based_query_rephrase( ) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(prompt_msgs) - rephrased_query = llm.invoke(filled_llm_prompt) + rephrased_query = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(f"Rephrased combined query: {rephrased_query}") diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 22ba49e68cd4..4130b7ee3560 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -7,6 +7,8 @@ from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_generator_to_string_generator +from danswer.llm.utils import message_to_string from danswer.prompts.constants import ANSWERABLE_PAT from danswer.prompts.constants import THOUGHT_PAT from danswer.prompts.query_validation import ANSWERABLE_PROMPT @@ -56,7 +58,7 @@ def get_query_answerability( messages = get_query_validation_messages(user_query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = llm.invoke(filled_llm_prompt) + model_output = message_to_string(llm.invoke(filled_llm_prompt)) reasoning = extract_answerability_reasoning(model_output) answerable = extract_answerability_bool(model_output) @@ -86,11 +88,10 @@ def stream_query_answerability( ).dict() ) return - messages = get_query_validation_messages(user_query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) try: - tokens = llm.stream(filled_llm_prompt) + tokens = message_generator_to_string_generator(llm.stream(filled_llm_prompt)) reasoning_pat_found = False model_output = "" hold_answerable = "" diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index 969bd92829ed..6a27963ff9f9 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -9,6 +9,7 @@ from danswer.db.engine import get_sqlalchemy_engine from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_to_string from danswer.prompts.constants import SOURCES_KEY from danswer.prompts.filter_extration import FILE_SOURCE_WARNING from danswer.prompts.filter_extration import SOURCE_FILTER_PROMPT @@ -157,7 +158,7 @@ def extract_source_filter( messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = llm.invoke(filled_llm_prompt) + model_output = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(model_output) return _extract_source_filters_from_llm_out(model_output) diff --git a/backend/danswer/secondary_llm_flows/time_filter.py b/backend/danswer/secondary_llm_flows/time_filter.py index be2799f8f4e3..9080dc1f90f0 100644 --- a/backend/danswer/secondary_llm_flows/time_filter.py +++ b/backend/danswer/secondary_llm_flows/time_filter.py @@ -8,6 +8,7 @@ from dateutil.parser import parse from danswer.llm.exceptions import GenAIDisabledException from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.llm.utils import message_to_string from danswer.prompts.filter_extration import TIME_FILTER_PROMPT from danswer.prompts.prompt_utils import get_current_llm_day_time from danswer.utils.logger import setup_logger @@ -153,7 +154,7 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]: messages = _get_time_filter_messages(query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = llm.invoke(filled_llm_prompt) + model_output = message_to_string(llm.invoke(filled_llm_prompt)) logger.debug(model_output) return _extract_time_filter_from_llm_out(model_output) diff --git a/backend/danswer/server/features/persona/models.py b/backend/danswer/server/features/persona/models.py index d313e54d804a..ee6a129310d8 100644 --- a/backend/danswer/server/features/persona/models.py +++ b/backend/danswer/server/features/persona/models.py @@ -7,6 +7,7 @@ from danswer.db.models import StarterMessage from danswer.search.enums import RecencyBiasSetting from danswer.server.features.document_set.models import DocumentSet from danswer.server.features.prompt.models import PromptSnapshot +from danswer.server.features.tool.api import ToolSnapshot from danswer.server.models import MinimalUserSnapshot @@ -20,6 +21,8 @@ class CreatePersonaRequest(BaseModel): recency_bias: RecencyBiasSetting prompt_ids: list[int] document_set_ids: list[int] + # e.g. ID of SearchTool or ImageGenerationTool or + tool_ids: list[int] llm_model_provider_override: str | None = None llm_model_version_override: str | None = None starter_messages: list[StarterMessage] | None = None @@ -44,6 +47,7 @@ class PersonaSnapshot(BaseModel): starter_messages: list[StarterMessage] | None default_persona: bool prompts: list[PromptSnapshot] + tools: list[ToolSnapshot] document_sets: list[DocumentSet] users: list[UUID] groups: list[int] @@ -73,6 +77,7 @@ class PersonaSnapshot(BaseModel): starter_messages=persona.starter_messages, default_persona=persona.default_persona, prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts], + tools=[ToolSnapshot.from_model(tool) for tool in persona.tools], document_sets=[ DocumentSet.from_model(document_set_model) for document_set_model in persona.document_sets diff --git a/backend/danswer/server/features/tool/api.py b/backend/danswer/server/features/tool/api.py new file mode 100644 index 000000000000..0a9666646a4f --- /dev/null +++ b/backend/danswer/server/features/tool/api.py @@ -0,0 +1,38 @@ +from fastapi import APIRouter +from fastapi import Depends +from pydantic import BaseModel +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.db.engine import get_session +from danswer.db.models import Tool +from danswer.db.models import User + + +router = APIRouter(prefix="/tool") + + +class ToolSnapshot(BaseModel): + id: int + name: str + description: str + in_code_tool_id: str | None + + @classmethod + def from_model(cls, tool: Tool) -> "ToolSnapshot": + return cls( + id=tool.id, + name=tool.name, + description=tool.description, + in_code_tool_id=tool.in_code_tool_id, + ) + + +@router.get("") +def list_tools( + db_session: Session = Depends(get_session), + _: User | None = Depends(current_user), +) -> list[ToolSnapshot]: + tools = db_session.execute(select(Tool)).scalars().all() + return [ToolSnapshot.from_model(tool) for tool in tools] diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index b0840e570147..1f70af298846 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -96,7 +96,6 @@ class CreateChatMessageRequest(ChunkContext): # allows the caller to specify the exact search query they want to use # will disable Query Rewording if specified query_override: str | None = None - no_ai_answer: bool = False # allows the caller to override the Persona / Prompt llm_override: LLMOverride | None = None diff --git a/backend/danswer/tools/built_in_tools.py b/backend/danswer/tools/built_in_tools.py new file mode 100644 index 000000000000..94ffb085748c --- /dev/null +++ b/backend/danswer/tools/built_in_tools.py @@ -0,0 +1,168 @@ +from typing import Type +from typing import TypedDict + +from sqlalchemy import not_ +from sqlalchemy import or_ +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.db.models import Persona +from danswer.db.models import Tool as ToolDBModel +from danswer.tools.images.image_generation_tool import ImageGenerationTool +from danswer.tools.search.search_tool import SearchTool +from danswer.tools.tool import Tool +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class InCodeToolInfo(TypedDict): + cls: Type[Tool] + description: str + in_code_tool_id: str + + +BUILT_IN_TOOLS: list[InCodeToolInfo] = [ + { + "cls": SearchTool, + "description": "The Search Tool allows the Assistant to search through connected knowledge to help build an answer.", + "in_code_tool_id": SearchTool.__name__, + }, + { + "cls": ImageGenerationTool, + "description": ( + "The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. " + "The tool will be used when the user asks the assistant to generate an image." + ), + "in_code_tool_id": ImageGenerationTool.__name__, + }, +] + + +def load_builtin_tools(db_session: Session) -> None: + existing_in_code_tools = db_session.scalars( + select(ToolDBModel).where(not_(ToolDBModel.in_code_tool_id.is_(None))) + ).all() + in_code_tool_id_to_tool = { + tool.in_code_tool_id: tool for tool in existing_in_code_tools + } + + # Add or update existing tools + for tool_info in BUILT_IN_TOOLS: + tool_name = tool_info["cls"].__name__ + tool = in_code_tool_id_to_tool.get(tool_info["in_code_tool_id"]) + if tool: + # Update existing tool + tool.name = tool_name + tool.description = tool_info["description"] + logger.info(f"Updated tool: {tool_name}") + else: + # Add new tool + new_tool = ToolDBModel( + name=tool_name, + description=tool_info["description"], + in_code_tool_id=tool_info["in_code_tool_id"], + ) + db_session.add(new_tool) + logger.info(f"Added new tool: {tool_name}") + + # Remove tools that are no longer in BUILT_IN_TOOLS + built_in_ids = {tool_info["in_code_tool_id"] for tool_info in BUILT_IN_TOOLS} + for tool_id, tool in list(in_code_tool_id_to_tool.items()): + if tool_id not in built_in_ids: + db_session.delete(tool) + logger.info(f"Removed tool no longer in built-in list: {tool.name}") + + db_session.commit() + logger.info("All built-in tools are loaded/verified.") + + +def auto_add_search_tool_to_personas(db_session: Session) -> None: + """ + Automatically adds the SearchTool to all Persona objects in the database that have + `num_chunks` either unset or set to a value that isn't 0. This is done to migrate + Persona objects that were created before the concept of Tools were added. + """ + # Fetch the SearchTool from the database based on in_code_tool_id from BUILT_IN_TOOLS + search_tool_id = next( + ( + tool["in_code_tool_id"] + for tool in BUILT_IN_TOOLS + if tool["cls"].__name__ == SearchTool.__name__ + ), + None, + ) + if not search_tool_id: + raise RuntimeError("SearchTool not found in the BUILT_IN_TOOLS list.") + + search_tool = db_session.execute( + select(ToolDBModel).where(ToolDBModel.in_code_tool_id == search_tool_id) + ).scalar_one_or_none() + + if not search_tool: + raise RuntimeError("SearchTool not found in the database.") + + # Fetch all Personas that need the SearchTool added + personas_to_update = ( + db_session.execute( + select(Persona).where( + or_(Persona.num_chunks.is_(None), Persona.num_chunks != 0) + ) + ) + .scalars() + .all() + ) + + # Add the SearchTool to each relevant Persona + for persona in personas_to_update: + if search_tool not in persona.tools: + persona.tools.append(search_tool) + logger.info(f"Added SearchTool to Persona ID: {persona.id}") + + # Commit changes to the database + db_session.commit() + logger.info("Completed adding SearchTool to relevant Personas.") + + +_built_in_tools_cache: dict[int, Type[Tool]] | None = None + + +def refresh_built_in_tools_cache(db_session: Session) -> None: + global _built_in_tools_cache + _built_in_tools_cache = {} + all_tool_built_in_tools = ( + db_session.execute( + select(ToolDBModel).where(not_(ToolDBModel.in_code_tool_id.is_(None))) + ) + .scalars() + .all() + ) + for tool in all_tool_built_in_tools: + tool_info = next( + ( + item + for item in BUILT_IN_TOOLS + if item["in_code_tool_id"] == tool.in_code_tool_id + ), + None, + ) + if tool_info: + _built_in_tools_cache[tool.id] = tool_info["cls"] + + +def get_built_in_tool_by_id( + tool_id: int, db_session: Session, force_refresh: bool = False +) -> Type[Tool]: + global _built_in_tools_cache + if _built_in_tools_cache is None or force_refresh: + refresh_built_in_tools_cache(db_session) + + if _built_in_tools_cache is None: + raise RuntimeError( + "Built-in tools cache is None despite being refreshed. Should never happen." + ) + + if tool_id in _built_in_tools_cache: + return _built_in_tools_cache[tool_id] + else: + raise ValueError(f"No built-in tool found in the cache with ID {tool_id}") diff --git a/backend/danswer/tools/factory.py b/backend/danswer/tools/factory.py new file mode 100644 index 000000000000..197bdd6619ab --- /dev/null +++ b/backend/danswer/tools/factory.py @@ -0,0 +1,12 @@ +from typing import Type + +from sqlalchemy.orm import Session + +from danswer.db.models import Tool as ToolDBModel +from danswer.tools.built_in_tools import get_built_in_tool_by_id +from danswer.tools.tool import Tool + + +def get_tool_cls(tool: ToolDBModel, db_session: Session) -> Type[Tool]: + # Currently only support built-in tools + return get_built_in_tool_by_id(tool.id, db_session) diff --git a/backend/danswer/tools/force.py b/backend/danswer/tools/force.py new file mode 100644 index 000000000000..1c3f0a220edd --- /dev/null +++ b/backend/danswer/tools/force.py @@ -0,0 +1,40 @@ +from typing import Any + +from langchain_core.messages import AIMessage +from langchain_core.messages import BaseMessage +from pydantic import BaseModel + +from danswer.tools.tool import Tool + + +class ForceUseTool(BaseModel): + tool_name: str + args: dict[str, Any] | None = None + + def build_openai_tool_choice_dict(self) -> dict[str, Any]: + """Build dict in the format that OpenAI expects which tells them to use this tool.""" + return {"type": "function", "function": {"name": self.tool_name}} + + +def modify_message_chain_for_force_use_tool( + messages: list[BaseMessage], force_use_tool: ForceUseTool | None = None +) -> list[BaseMessage]: + """NOTE: modifies `messages` in place.""" + if not force_use_tool: + return messages + + for message in messages: + if isinstance(message, AIMessage) and message.tool_calls: + for tool_call in message.tool_calls: + tool_call["args"] = force_use_tool.args or {} + + return messages + + +def filter_tools_for_force_tool_use( + tools: list[Tool], force_use_tool: ForceUseTool | None = None +) -> list[Tool]: + if not force_use_tool: + return tools + + return [tool for tool in tools if tool.name() == force_use_tool.tool_name] diff --git a/backend/danswer/tools/images/image_generation_tool.py b/backend/danswer/tools/images/image_generation_tool.py new file mode 100644 index 000000000000..da66271322fb --- /dev/null +++ b/backend/danswer/tools/images/image_generation_tool.py @@ -0,0 +1,164 @@ +import json +from collections.abc import Generator +from typing import Any +from typing import cast + +from litellm import image_generation # type: ignore +from pydantic import BaseModel + +from danswer.chat.chat_utils import combine_message_chain +from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM +from danswer.llm.utils import build_content_with_imgs +from danswer.llm.utils import message_to_string +from danswer.prompts.constants import GENERAL_SEP_PAT +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel + +logger = setup_logger() + + +IMAGE_GENERATION_RESPONSE_ID = "image_generation_response" + +YES_IMAGE_GENERATION = "Yes Image Generation" +SKIP_IMAGE_GENERATION = "Skip Image Generation" + +IMAGE_GENERATION_TEMPLATE = f""" +Given the conversation history and a follow up query, determine if the system should call \ +an external image generation tool to better answer the latest user input. +Your default response is {SKIP_IMAGE_GENERATION}. + +Respond "{YES_IMAGE_GENERATION}" if: +- The user is asking for an image to be generated. + +Conversation History: +{GENERAL_SEP_PAT} +{{chat_history}} +{GENERAL_SEP_PAT} + +If you are at all unsure, respond with {SKIP_IMAGE_GENERATION}. +Respond with EXACTLY and ONLY "{YES_IMAGE_GENERATION}" or "{SKIP_IMAGE_GENERATION}" + +Follow Up Input: +{{final_query}} +""".strip() + + +class ImageGenerationResponse(BaseModel): + revised_prompt: str + url: str + + +class ImageGenerationTool(Tool): + def __init__( + self, api_key: str, model: str = "dall-e-3", num_imgs: int = 2 + ) -> None: + self.api_key = api_key + self.model = model + self.num_imgs = num_imgs + + @classmethod + def name(self) -> str: + return "run_image_generation" + + @classmethod + def tool_definition(cls) -> dict: + return { + "type": "function", + "function": { + "name": cls.name(), + "description": "Generate an image from a prompt", + "parameters": { + "type": "object", + "properties": { + "prompt": { + "type": "string", + "description": "Prompt used to generate the image", + }, + }, + "required": ["prompt"], + }, + }, + } + + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + args = {"prompt": query} + if force_run: + return args + + history_str = combine_message_chain( + messages=history, token_limit=GEN_AI_HISTORY_CUTOFF + ) + prompt = IMAGE_GENERATION_TEMPLATE.format( + chat_history=history_str, + final_query=query, + ) + use_image_generation_tool_output = message_to_string(llm.invoke(prompt)) + + logger.debug( + f"Evaluated if should use ImageGenerationTool: {use_image_generation_tool_output}" + ) + if ( + YES_IMAGE_GENERATION.split()[0] + ).lower() in use_image_generation_tool_output.lower(): + return args + + return None + + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + generation_response = args[0] + image_generations = cast( + list[ImageGenerationResponse], generation_response.response + ) + + return build_content_with_imgs( + json.dumps( + [ + { + "revised_prompt": image_generation.revised_prompt, + "url": image_generation.url, + } + for image_generation in image_generations + ] + ), + img_urls=[image_generation.url for image_generation in image_generations], + ) + + def _generate_image(self, prompt: str) -> ImageGenerationResponse: + response = image_generation( + prompt=prompt, + model=self.model, + api_key=self.api_key, + n=1, + ) + return ImageGenerationResponse( + revised_prompt=response.data[0]["revised_prompt"], + url=response.data[0]["url"], + ) + + def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + prompt = cast(str, kwargs["prompt"]) + + # dalle3 only supports 1 image at a time, which is why we have to + # parallelize this via threading + results = cast( + list[ImageGenerationResponse], + run_functions_tuples_in_parallel( + [(self._generate_image, (prompt,)) for _ in range(self.num_imgs)] + ), + ) + yield ToolResponse( + id=IMAGE_GENERATION_RESPONSE_ID, + response=results, + ) diff --git a/backend/danswer/tools/images/prompt.py b/backend/danswer/tools/images/prompt.py new file mode 100644 index 000000000000..dee28b49c846 --- /dev/null +++ b/backend/danswer/tools/images/prompt.py @@ -0,0 +1,33 @@ +from langchain_core.messages import HumanMessage + +from danswer.llm.utils import build_content_with_imgs + + +NON_TOOL_CALLING_PROMPT = """ +You have just created the attached images in response to the following query: "{{query}}". + +Can you please summarize them in a sentence or two? +""" + +TOOL_CALLING_PROMPT = """ +Can you please summarize the two images you generate in a sentence or two? +""" + + +def build_image_generation_user_prompt( + query: str, img_urls: list[str] | None = None +) -> HumanMessage: + if img_urls: + return HumanMessage( + content=build_content_with_imgs( + message=NON_TOOL_CALLING_PROMPT.format(query=query).strip(), + img_urls=img_urls, + ) + ) + + return HumanMessage( + content=build_content_with_imgs( + message=TOOL_CALLING_PROMPT.strip(), + img_urls=img_urls, + ) + ) diff --git a/backend/danswer/tools/message.py b/backend/danswer/tools/message.py new file mode 100644 index 000000000000..cdf86a23b05d --- /dev/null +++ b/backend/danswer/tools/message.py @@ -0,0 +1,39 @@ +import json +from typing import Any + +from langchain_core.messages.ai import AIMessage +from langchain_core.messages.tool import ToolCall +from langchain_core.messages.tool import ToolMessage +from pydantic import BaseModel + +from danswer.llm.utils import get_default_llm_tokenizer + + +def build_tool_message( + tool_call: ToolCall, tool_content: str | list[str | dict[str, Any]] +) -> ToolMessage: + return ToolMessage( + tool_call_id=tool_call["id"] or "", + name=tool_call["name"], + content=tool_content, + ) + + +class ToolCallSummary(BaseModel): + tool_call_request: AIMessage + tool_call_result: ToolMessage + + +def tool_call_tokens(tool_call_summary: ToolCallSummary) -> int: + llm_tokenizer = get_default_llm_tokenizer() + + request_tokens = len( + llm_tokenizer.encode( + json.dumps(tool_call_summary.tool_call_request.tool_calls[0]["args"]) + ) + ) + result_tokens = len( + llm_tokenizer.encode(json.dumps(tool_call_summary.tool_call_result.content)) + ) + + return request_tokens + result_tokens diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py new file mode 100644 index 000000000000..968c17f5a28f --- /dev/null +++ b/backend/danswer/tools/search/search_tool.py @@ -0,0 +1,240 @@ +import json +from collections.abc import Generator +from typing import Any +from typing import cast + +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from danswer.chat.chat_utils import llm_doc_from_inference_section +from danswer.chat.models import LlmDoc +from danswer.db.models import Persona +from danswer.db.models import User +from danswer.llm.answering.doc_pruning import prune_documents +from danswer.llm.answering.models import DocumentPruningConfig +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.answering.models import PromptConfig +from danswer.llm.interfaces import LLM +from danswer.llm.interfaces import LLMConfig +from danswer.search.enums import QueryFlow +from danswer.search.enums import SearchType +from danswer.search.models import IndexFilters +from danswer.search.models import InferenceSection +from danswer.search.models import RetrievalDetails +from danswer.search.models import SearchRequest +from danswer.search.pipeline import SearchPipeline +from danswer.secondary_llm_flows.choose_search import check_if_need_search +from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase +from danswer.tools.search.search_utils import llm_doc_to_dict +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse + +SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary" +SECTION_RELEVANCE_LIST_ID = "section_relevance_list" +FINAL_CONTEXT_DOCUMENTS = "final_context_documents" + + +class SearchResponseSummary(BaseModel): + top_sections: list[InferenceSection] + rephrased_query: str | None = None + predicted_flow: QueryFlow | None + predicted_search: SearchType | None + final_filters: IndexFilters + recency_bias_multiplier: float + + +search_tool_description = """ +Runs a semantic search over the user's knowledge base. The default behavior is to use this tool. \ +The only scenario where you should not use this tool is if: + +- There is sufficient information in chat history to FULLY and ACCURATELY answer the query AND \ +additional information or details would provide little or no value. +- The query is some form of request that does not require additional information to handle. + +HINT: if you are unfamiliar with the user input OR think the user input is a typo, use this tool. +""" + + +class SearchTool(Tool): + def __init__( + self, + db_session: Session, + user: User | None, + persona: Persona, + retrieval_options: RetrievalDetails | None, + prompt_config: PromptConfig, + llm_config: LLMConfig, + pruning_config: DocumentPruningConfig, + # if specified, will not actually run a search and will instead return these + # sections. Used when the user selects specific docs to talk to + selected_docs: list[LlmDoc] | None = None, + chunks_above: int = 0, + chunks_below: int = 0, + full_doc: bool = False, + ) -> None: + self.user = user + self.persona = persona + self.retrieval_options = retrieval_options + self.prompt_config = prompt_config + self.llm_config = llm_config + self.pruning_config = pruning_config + + self.selected_docs = selected_docs + + self.chunks_above = chunks_above + self.chunks_below = chunks_below + self.full_doc = full_doc + self.db_session = db_session + + @classmethod + def name(cls) -> str: + return "run_search" + + """For explicit tool calling""" + + @classmethod + def tool_definition(cls) -> dict: + return { + "type": "function", + "function": { + "name": cls.name(), + "description": search_tool_description, + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "What to search for", + }, + }, + "required": ["query"], + }, + }, + } + + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + final_context_docs_response = args[2] + final_context_docs = cast(list[LlmDoc], final_context_docs_response.response) + + return json.dumps( + { + "search_results": [ + llm_doc_to_dict(doc, ind) + for ind, doc in enumerate(final_context_docs) + ] + } + ) + + """For LLMs that don't support tool calling""" + + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + if not force_run and not check_if_need_search( + query=query, history=history, llm=llm + ): + return None + + rephrased_query = history_based_query_rephrase( + query=query, history=history, llm=llm + ) + return {"query": rephrased_query} + + """Actual tool execution""" + + def _build_response_for_specified_sections( + self, query: str + ) -> Generator[ToolResponse, None, None]: + if self.selected_docs is None: + raise ValueError("sections must be specified") + + yield ToolResponse( + id=SEARCH_RESPONSE_SUMMARY_ID, + response=SearchResponseSummary( + rephrased_query=None, + top_sections=[], + predicted_flow=None, + predicted_search=None, + final_filters=IndexFilters(access_control_list=None), # dummy filters + recency_bias_multiplier=1.0, + ), + ) + yield ToolResponse( + id=SECTION_RELEVANCE_LIST_ID, + response=[i for i in range(len(self.selected_docs))], + ) + yield ToolResponse( + id=FINAL_CONTEXT_DOCUMENTS, + response=prune_documents( + docs=self.selected_docs, + doc_relevance_list=None, + prompt_config=self.prompt_config, + llm_config=self.llm_config, + question=query, + document_pruning_config=self.pruning_config, + ), + ) + + def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]: + query = cast(str, kwargs["query"]) + + if self.selected_docs: + yield from self._build_response_for_specified_sections(query) + return + + search_pipeline = SearchPipeline( + search_request=SearchRequest( + query=query, + human_selected_filters=self.retrieval_options.filters + if self.retrieval_options + else None, + persona=self.persona, + offset=self.retrieval_options.offset + if self.retrieval_options + else None, + limit=self.retrieval_options.limit if self.retrieval_options else None, + chunks_above=self.chunks_above, + chunks_below=self.chunks_below, + full_doc=self.full_doc, + ), + user=self.user, + db_session=self.db_session, + ) + yield ToolResponse( + id=SEARCH_RESPONSE_SUMMARY_ID, + response=SearchResponseSummary( + rephrased_query=query, + top_sections=search_pipeline.reranked_sections, + predicted_flow=search_pipeline.predicted_flow, + predicted_search=search_pipeline.predicted_search_type, + final_filters=search_pipeline.search_query.filters, + recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier, + ), + ) + yield ToolResponse( + id=SECTION_RELEVANCE_LIST_ID, + response=search_pipeline.relevant_chunk_indices, + ) + + llm_docs = [ + llm_doc_from_inference_section(section) + for section in search_pipeline.reranked_sections + ] + final_context_documents = prune_documents( + docs=llm_docs, + doc_relevance_list=[ + True if ind in search_pipeline.relevant_chunk_indices else False + for ind in range(len(llm_docs)) + ], + prompt_config=self.prompt_config, + llm_config=self.llm_config, + question=query, + document_pruning_config=self.pruning_config, + ) + yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=final_context_documents) diff --git a/backend/danswer/tools/search/search_utils.py b/backend/danswer/tools/search/search_utils.py new file mode 100644 index 000000000000..7e5151bb582d --- /dev/null +++ b/backend/danswer/tools/search/search_utils.py @@ -0,0 +1,15 @@ +from danswer.chat.models import LlmDoc +from danswer.prompts.prompt_utils import clean_up_source + + +def llm_doc_to_dict(llm_doc: LlmDoc, doc_num: int) -> dict: + doc_dict = { + "document_number": doc_num + 1, + "title": llm_doc.semantic_identifier, + "content": llm_doc.content, + "source": clean_up_source(llm_doc.source_type), + "metadata": llm_doc.metadata, + } + if llm_doc.updated_at: + doc_dict["updated_at"] = llm_doc.updated_at.strftime("%B %d, %Y %H:%M") + return doc_dict diff --git a/backend/danswer/tools/tool.py b/backend/danswer/tools/tool.py new file mode 100644 index 000000000000..dd443757e676 --- /dev/null +++ b/backend/danswer/tools/tool.py @@ -0,0 +1,51 @@ +import abc +from collections.abc import Generator +from typing import Any + +from pydantic import BaseModel + +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM + + +class ToolResponse(BaseModel): + id: str | None = None + response: Any + + +class Tool(abc.ABC): + @classmethod + @abc.abstractmethod + def name(self) -> str: + raise NotImplementedError + + """For LLMs which support explicit tool calling""" + + @classmethod + @abc.abstractmethod + def tool_definition(self) -> dict: + raise NotImplementedError + + @abc.abstractmethod + def build_tool_message_content( + self, *args: ToolResponse + ) -> str | list[str | dict[str, Any]]: + raise NotImplementedError + + """For LLMs which do NOT support explicit tool calling""" + + @abc.abstractmethod + def get_args_for_non_tool_calling_llm( + self, + query: str, + history: list[PreviousMessage], + llm: LLM, + force_run: bool = False, + ) -> dict[str, Any] | None: + raise NotImplementedError + + """Actual execution of the tool""" + + @abc.abstractmethod + def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]: + raise NotImplementedError diff --git a/backend/danswer/tools/tool_runner.py b/backend/danswer/tools/tool_runner.py new file mode 100644 index 000000000000..46f247b06dca --- /dev/null +++ b/backend/danswer/tools/tool_runner.py @@ -0,0 +1,73 @@ +from collections.abc import Generator +from typing import Any + +from pydantic import BaseModel +from pydantic import root_validator + +from danswer.llm.answering.models import PreviousMessage +from danswer.llm.interfaces import LLM +from danswer.tools.tool import Tool +from danswer.tools.tool import ToolResponse +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel + + +class ToolRunKickoff(BaseModel): + tool_name: str + tool_args: dict[str, Any] + + +class ToolRunnerResponse(BaseModel): + tool_run_kickoff: ToolRunKickoff | None = None + tool_response: ToolResponse | None = None + tool_message_content: str | list[str | dict[str, Any]] | None = None + + @root_validator + def validate_tool_runner_response( + cls, values: dict[str, ToolResponse | str] + ) -> dict[str, ToolResponse | str]: + fields = ["tool_response", "tool_message_content", "tool_run_kickoff"] + provided = sum(1 for field in fields if values.get(field) is not None) + + if provided != 1: + raise ValueError( + "Exactly one of 'tool_response', 'tool_message_content', " + "or 'tool_run_kickoff' must be provided" + ) + + return values + + +class ToolRunner: + def __init__(self, tool: Tool, args: dict[str, Any]): + self.tool = tool + self.args = args + + self._tool_responses: list[ToolResponse] | None = None + + def kickoff(self) -> ToolRunKickoff: + return ToolRunKickoff(tool_name=self.tool.name(), tool_args=self.args) + + def tool_responses(self) -> Generator[ToolResponse, None, None]: + if self._tool_responses is not None: + yield from self._tool_responses + + tool_responses: list[ToolResponse] = [] + for tool_response in self.tool.run(**self.args): + yield tool_response + tool_responses.append(tool_response) + + self._tool_responses = tool_responses + + def tool_message_content(self) -> str | list[str | dict[str, Any]]: + tool_responses = list(self.tool_responses()) + return self.tool.build_tool_message_content(*tool_responses) + + +def check_which_tools_should_run_for_non_tool_calling_llm( + tools: list[Tool], query: str, history: list[PreviousMessage], llm: LLM +) -> list[dict[str, Any] | None]: + tool_args_list = [ + (tool.get_args_for_non_tool_calling_llm, (query, history, llm)) + for tool in tools + ] + return run_functions_tuples_in_parallel(tool_args_list) diff --git a/backend/danswer/tools/utils.py b/backend/danswer/tools/utils.py new file mode 100644 index 000000000000..831021cdab3a --- /dev/null +++ b/backend/danswer/tools/utils.py @@ -0,0 +1,31 @@ +import json +from typing import Type + +from tiktoken import Encoding + +from danswer.llm.utils import get_default_llm_tokenizer +from danswer.tools.tool import Tool + + +OPEN_AI_TOOL_CALLING_MODELS = {"gpt-3.5-turbo", "gpt-4-turbo", "gpt-4"} + + +def explicit_tool_calling_supported(model_provider: str, model_name: str) -> bool: + if model_provider == "openai" and model_name in OPEN_AI_TOOL_CALLING_MODELS: + return True + + return False + + +def compute_tool_tokens( + tool: Tool | Type[Tool], llm_tokenizer: Encoding | None = None +) -> int: + if not llm_tokenizer: + llm_tokenizer = get_default_llm_tokenizer() + return len(llm_tokenizer.encode(json.dumps(tool.tool_definition()))) + + +def compute_all_tool_tokens( + tools: list[Tool] | list[Type[Tool]], llm_tokenizer: Encoding | None = None +) -> int: + return sum(compute_tool_tokens(tool, llm_tokenizer) for tool in tools) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index b51d1bb5340f..00f8c03f931e 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -24,8 +24,11 @@ httpx[http2]==0.23.3 httpx-oauth==0.11.2 huggingface-hub==0.20.1 jira==3.5.1 -langchain==0.1.9 -litellm==1.34.21 +langchain==0.1.17 +langchain-community==0.0.36 +langchain-core==0.1.50 +langchain-text-splitters==0.0.1 +litellm==1.35.31 llama-index==0.9.45 Mako==1.2.4 msal==1.26.0 diff --git a/web/src/app/admin/assistants/AssistantEditor.tsx b/web/src/app/admin/assistants/AssistantEditor.tsx index 90f4a3281235..a1f9e5a04a29 100644 --- a/web/src/app/admin/assistants/AssistantEditor.tsx +++ b/web/src/app/admin/assistants/AssistantEditor.tsx @@ -33,6 +33,21 @@ import { SuccessfulPersonaUpdateRedirectType } from "./enums"; import { DocumentSetSelectable } from "@/components/documentSet/DocumentSetSelectable"; import { FullLLMProvider } from "../models/llm/interfaces"; import { Option } from "@/components/Dropdown"; +import { ToolSnapshot } from "@/lib/tools/interfaces"; + +function findSearchTool(tools: ToolSnapshot[]) { + return tools.find((tool) => tool.in_code_tool_id === "SearchTool"); +} + +function findImageGenerationTool(tools: ToolSnapshot[]) { + return tools.find((tool) => tool.in_code_tool_id === "ImageGenerationTool"); +} + +function checkLLMSupportsImageGeneration(provider: string, model: string) { + console.log(provider); + console.log(model); + return provider === "openai" && model === "gpt-4-turbo"; +} function Label({ children }: { children: string | JSX.Element }) { return ( @@ -52,6 +67,7 @@ export function AssistantEditor({ defaultPublic, redirectType, llmProviders, + tools, }: { existingPersona?: Persona | null; ccPairs: CCPairBasicInfo[]; @@ -60,6 +76,7 @@ export function AssistantEditor({ defaultPublic: boolean; redirectType: SuccessfulPersonaUpdateRedirectType; llmProviders: FullLLMProvider[]; + tools: ToolSnapshot[]; }) { const router = useRouter(); const { popup, setPopup } = usePopup(); @@ -98,9 +115,18 @@ export function AssistantEditor({ } }, []); - const defaultLLM = llmProviders.find( + const defaultProvider = llmProviders.find( (llmProvider) => llmProvider.is_default_provider - )?.default_model_name; + ); + const defaultProviderName = defaultProvider?.provider; + const defaultModelName = defaultProvider?.default_model_name; + const providerDisplayNameToProviderName = new Map(); + llmProviders.forEach((llmProvider) => { + providerDisplayNameToProviderName.set( + llmProvider.name, + llmProvider.provider + ); + }); const modelOptionsByProvider = new Map[]>(); llmProviders.forEach((llmProvider) => { @@ -112,6 +138,16 @@ export function AssistantEditor({ }); modelOptionsByProvider.set(llmProvider.name, providerOptions); }); + const providerSupportingImageGenerationExists = llmProviders.some( + (provider) => provider.provider === "openai" + ); + + const personaCurrentToolIds = + existingPersona?.tools.map((tool) => tool.id) || []; + const searchTool = findSearchTool(tools); + const imageGenerationTool = providerSupportingImageGenerationExists + ? findImageGenerationTool(tools) + : undefined; return (
@@ -123,7 +159,6 @@ export function AssistantEditor({ description: existingPersona?.description ?? "", system_prompt: existingPrompt?.system_prompt ?? "", task_prompt: existingPrompt?.task_prompt ?? "", - disable_retrieval: (existingPersona?.num_chunks ?? 10) === 0, is_public: existingPersona?.is_public ?? defaultPublic, document_set_ids: existingPersona?.document_sets?.map( @@ -140,6 +175,10 @@ export function AssistantEditor({ starter_messages: existingPersona?.starter_messages ?? [], // EE Only groups: existingPersona?.groups ?? [], + search_tool_enabled: personaCurrentToolIds.includes(searchTool!.id), + image_generation_tool_enabled: imageGenerationTool + ? personaCurrentToolIds.includes(imageGenerationTool.id) + : false, }} validationSchema={Yup.object() .shape({ @@ -149,7 +188,6 @@ export function AssistantEditor({ ), system_prompt: Yup.string(), task_prompt: Yup.string(), - disable_retrieval: Yup.boolean().required(), is_public: Yup.boolean().required(), document_set_ids: Yup.array().of(Yup.number()), num_chunks: Yup.number().max(20).nullable(), @@ -166,6 +204,8 @@ export function AssistantEditor({ ), // EE Only groups: Yup.array().of(Yup.number()), + search_tool_enabled: Yup.boolean().required(), + image_generation_tool_enabled: Yup.boolean().required(), }) .test( "system-prompt-or-task-prompt", @@ -210,11 +250,30 @@ export function AssistantEditor({ formikHelpers.setSubmitting(true); + const tools = []; + if (values.search_tool_enabled) { + tools.push(searchTool!.id); + } + if ( + values.image_generation_tool_enabled && + imageGenerationTool && + checkLLMSupportsImageGeneration( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || defaultModelName || "" + ) + ) { + tools.push(imageGenerationTool.id); + } + // if disable_retrieval is set, set num_chunks to 0 // to tell the backend to not fetch any documents - const numChunks = values.disable_retrieval - ? 0 - : values.num_chunks || 10; + const numChunks = values.search_tool_enabled + ? values.num_chunks || 10 + : 0; // don't set groups if marked as public const groups = values.is_public ? [] : values.groups; @@ -229,6 +288,7 @@ export function AssistantEditor({ num_chunks: numChunks, users: user ? [user.id] : undefined, groups, + tool_ids: tools, }); } else { [promptResponse, personaResponse] = await createPersona({ @@ -236,6 +296,7 @@ export function AssistantEditor({ num_chunks: numChunks, users: user ? [user.id] : undefined, groups, + tool_ids: tools, }); } @@ -296,7 +357,7 @@ export function AssistantEditor({ triggerFinalPromptUpdate( e.target.value, values.task_prompt, - values.disable_retrieval + values.search_tool_enabled ); }} error={finalPromptError} @@ -314,7 +375,7 @@ export function AssistantEditor({ triggerFinalPromptUpdate( values.system_prompt, e.target.value, - values.disable_retrieval + values.search_tool_enabled ); }} error={finalPromptError} @@ -334,32 +395,24 @@ export function AssistantEditor({ - {ccPairs.length > 0 && ( + <> - - <> - { - setFieldValue("disable_retrieval", e.target.checked); - triggerFinalPromptUpdate( - values.system_prompt, - values.task_prompt, - e.target.checked - ); - }} - /> + { + setFieldValue("num_chunks", null); + setFieldValue("search_tool_enabled", e.target.checked); + }} + /> - {!values.disable_retrieval && ( + {values.search_tool_enabled && ( +
+ {ccPairs.length > 0 && ( <> + +
<> @@ -426,37 +479,80 @@ export function AssistantEditor({ )} )} + + <> + + How many chunks should we feed into the LLM + when generating the final response? Each chunk + is ~400 words long. +
+ } + onChange={(e) => { + const value = e.target.value; + // Allow only integer values + if (value === "" || /^[0-9]+$/.test(value)) { + setFieldValue("num_chunks", value); + } + }} + /> + + + + + + + )} - - +
+ )} - - - )} - - {!values.disable_retrieval && ( - <> - - <> + {imageGenerationTool && + checkLLMSupportsImageGeneration( + providerDisplayNameToProviderName.get( + values.llm_model_provider_override || "" + ) || + defaultProviderName || + "", + values.llm_model_version_override || + defaultModelName || + "" + ) && ( { + setFieldValue( + "image_generation_tool_enabled", + e.target.checked + ); + }} /> - - - - + )} - )} +
+ + {llmProviders.length > 0 && ( <> @@ -467,7 +563,8 @@ export function AssistantEditor({ <> Pick which LLM to use for this Assistant. If left as - Default, will use {defaultLLM} + Default, will use{" "} + {defaultModelName} .

@@ -531,49 +628,6 @@ export function AssistantEditor({ )} - {!values.disable_retrieval && ( - <> - - <> - - How many chunks should we feed into the LLM when - generating the final response? Each chunk is ~400 - words long. -
-
- If unspecified, will use 10 chunks. -
- } - onChange={(e) => { - const value = e.target.value; - // Allow only integer values - if (value === "" || /^[0-9]+$/.test(value)) { - setFieldValue("num_chunks", value); - } - }} - /> - - - - - - - - )} - ([]); + const [currentTool, setCurrentTool] = useState(null); const [isStreaming, setIsStreaming] = useState(false); // uploaded files @@ -456,6 +459,7 @@ export function ChatPage({ ? RetrievalType.SelectedDocs : RetrievalType.None; let documents: DanswerDocument[] = selectedDocuments; + let aiMessageImages: FileDescriptor[] | null = null; let error: string | null = null; let finalMessage: BackendMessage | null = null; try { @@ -502,6 +506,17 @@ export function ChatPage({ // we have to use -1) setSelectedMessageForDocDisplay(-1); } + } else if (Object.hasOwn(packet, "file_ids")) { + aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map( + (fileId) => { + return { + id: fileId, + type: "image", + }; + } + ); + } else if (Object.hasOwn(packet, "tool_name")) { + setCurrentTool((packet as ToolRunKickoff).tool_name); } else if (Object.hasOwn(packet, "error")) { error = (packet as StreamingError).error; } else if (Object.hasOwn(packet, "message_id")) { @@ -524,7 +539,7 @@ export function ChatPage({ query: finalMessage?.rephrased_query || query, documents: finalMessage?.context_docs?.top_documents || documents, citations: finalMessage?.citations || {}, - files: finalMessage?.files || [], + files: finalMessage?.files || aiMessageImages || [], }, ]); if (isCancelledRef.current) { @@ -546,7 +561,7 @@ export function ChatPage({ messageId: null, message: errorMsg, type: "error", - files: [], + files: aiMessageImages || [], }, ]); } @@ -796,11 +811,13 @@ export function ChatPage({ ; +} + export interface StreamingError { error: string; } diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index 37e4dba605d7..a63901c1ec76 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -10,9 +10,11 @@ import { BackendMessage, ChatSession, DocumentsResponse, + ImageGenerationDisplay, Message, RetrievalType, StreamingError, + ToolRunKickoff, } from "./interfaces"; import { Persona } from "../admin/assistants/interfaces"; import { ReadonlyURLSearchParams } from "next/navigation"; @@ -128,7 +130,12 @@ export async function* sendMessage({ } yield* handleStream< - AnswerPiecePacket | DocumentsResponse | BackendMessage | StreamingError + | AnswerPiecePacket + | DocumentsResponse + | BackendMessage + | ImageGenerationDisplay + | ToolRunKickoff + | StreamingError >(sendMessageResponse); } diff --git a/web/src/app/chat/message/Messages.tsx b/web/src/app/chat/message/Messages.tsx index bfed7a26fedf..6e8bca90a414 100644 --- a/web/src/app/chat/message/Messages.tsx +++ b/web/src/app/chat/message/Messages.tsx @@ -2,8 +2,10 @@ import { FiCheck, FiCopy, FiCpu, + FiImage, FiThumbsDown, FiThumbsUp, + FiTool, FiUser, } from "react-icons/fi"; import { FeedbackType } from "../types"; @@ -18,6 +20,11 @@ import remarkGfm from "remark-gfm"; import { CopyButton } from "@/components/CopyButton"; import { FileDescriptor } from "../interfaces"; import { InMessageImage } from "../images/InMessageImage"; +import { + IMAGE_GENERATION_TOOL_NAME, + SEARCH_TOOL_NAME, +} from "../tools/constants"; +import { ToolRunningAnimation } from "../tools/ToolRunningAnimation"; export const Hoverable: React.FC<{ children: JSX.Element; @@ -36,9 +43,11 @@ export const Hoverable: React.FC<{ export const AIMessage = ({ messageId, content, + files, query, personaName, citedDocuments, + currentTool, isComplete, hasDocs, handleFeedback, @@ -50,9 +59,11 @@ export const AIMessage = ({ }: { messageId: number | null; content: string | JSX.Element; + files?: FileDescriptor[]; query?: string; personaName?: string; citedDocuments?: [string, DanswerDocument][] | null; + currentTool?: string | null; isComplete?: boolean; hasDocs?: boolean; handleFeedback?: (feedbackType: FeedbackType) => void; @@ -62,7 +73,29 @@ export const AIMessage = ({ handleForceSearch?: () => void; retrievalDisabled?: boolean; }) => { - const [copyClicked, setCopyClicked] = useState(false); + const loader = + currentTool === IMAGE_GENERATION_TOOL_NAME ? ( +
+ } + /> +
+ ) : ( +
+ +
+ ); + return (
@@ -123,6 +156,17 @@ export const AIMessage = ({ {content ? ( <> + {files && files.length > 0 && ( +
+
+ {files.map((file) => { + return ( + + ); + })} +
+
+ )} {typeof content === "string" ? ( ) : isComplete ? null : ( -
- -
+ loader )} {citedDocuments && citedDocuments.length > 0 && (
diff --git a/web/src/app/chat/tools/ToolRunningAnimation.tsx b/web/src/app/chat/tools/ToolRunningAnimation.tsx new file mode 100644 index 000000000000..bd0414295fed --- /dev/null +++ b/web/src/app/chat/tools/ToolRunningAnimation.tsx @@ -0,0 +1,16 @@ +import { LoadingAnimation } from "@/components/Loading"; + +export function ToolRunningAnimation({ + toolName, + toolLogo, +}: { + toolName: string; + toolLogo: JSX.Element; +}) { + return ( +
+ {toolLogo} + +
+ ); +} diff --git a/web/src/app/chat/tools/constants.ts b/web/src/app/chat/tools/constants.ts new file mode 100644 index 000000000000..576829201e41 --- /dev/null +++ b/web/src/app/chat/tools/constants.ts @@ -0,0 +1,2 @@ +export const SEARCH_TOOL_NAME = "run_search"; +export const IMAGE_GENERATION_TOOL_NAME = "run_image_generation"; diff --git a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts index 614bcb81ff05..b98564bee3df 100644 --- a/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts +++ b/web/src/lib/assistants/fetchPersonaEditorInfoSS.ts @@ -3,6 +3,8 @@ import { CCPairBasicInfo, DocumentSet, User } from "../types"; import { getCurrentUserSS } from "../userSS"; import { fetchSS } from "../utilsSS"; import { FullLLMProvider } from "@/app/admin/models/llm/interfaces"; +import { ToolSnapshot } from "../tools/interfaces"; +import { fetchToolsSS } from "../tools/fetchTools"; export async function fetchAssistantEditorInfoSS( personaId?: number | string @@ -14,6 +16,7 @@ export async function fetchAssistantEditorInfoSS( llmProviders: FullLLMProvider[]; user: User | null; existingPersona: Persona | null; + tools: ToolSnapshot[]; }, null, ] @@ -26,6 +29,7 @@ export async function fetchAssistantEditorInfoSS( // duplicate fetch, but shouldn't be too big of a deal // this page is not a high traffic page getCurrentUserSS(), + fetchToolsSS(), ]; if (personaId) { tasks.push(fetchSS(`/persona/${personaId}`)); @@ -38,12 +42,14 @@ export async function fetchAssistantEditorInfoSS( documentSetsResponse, llmProvidersResponse, user, + toolsResponse, personaResponse, ] = (await Promise.all(tasks)) as [ Response, Response, Response, User | null, + ToolSnapshot[] | null, Response | null, ]; @@ -63,6 +69,10 @@ export async function fetchAssistantEditorInfoSS( } const documentSets = (await documentSetsResponse.json()) as DocumentSet[]; + if (!toolsResponse) { + return [null, `Failed to fetch tools`]; + } + if (!llmProvidersResponse.ok) { return [ null, @@ -85,6 +95,7 @@ export async function fetchAssistantEditorInfoSS( llmProviders, user, existingPersona, + tools: toolsResponse, }, null, ]; diff --git a/web/src/lib/tools/fetchTools.ts b/web/src/lib/tools/fetchTools.ts new file mode 100644 index 000000000000..51969c6db7fb --- /dev/null +++ b/web/src/lib/tools/fetchTools.ts @@ -0,0 +1,16 @@ +import { ToolSnapshot } from "./interfaces"; +import { fetchSS } from "../utilsSS"; + +export async function fetchToolsSS(): Promise { + try { + const response = await fetchSS("/tool"); + if (!response.ok) { + throw new Error(`Failed to fetch tools: ${await response.text()}`); + } + const tools: ToolSnapshot[] = await response.json(); + return tools; + } catch (error) { + console.error("Error fetching tools:", error); + return null; + } +} diff --git a/web/src/lib/tools/interfaces.ts b/web/src/lib/tools/interfaces.ts new file mode 100644 index 000000000000..f8882e6bfdb1 --- /dev/null +++ b/web/src/lib/tools/interfaces.ts @@ -0,0 +1,6 @@ +export interface ToolSnapshot { + id: number; + name: string; + description: string; + in_code_tool_id: string | null; +}