From ed550986a68f35bbd3d993e78946bd869ae47475 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Fri, 28 Jun 2024 17:18:39 -0700 Subject: [PATCH] Feature/assistants (#1581) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * include alternate assisstant - migrate models - migrate db * functional alternate assistant selection * refactor chat components for persona API * functional assistants api * add full functionality- assistants * add functional assistants dropdown handler * refactor assistants for full compatability - hooks - track the live assistant for edge cases - UI updates * add assistant UI features - Autotab - Arrow selection - Icons - Proper @ detection - Info Popup prune unnecessary comments * functional search toggling for assistants * add functional cross-page assistants rebase with main * add proper interactivity for edge cases - click outside of input / text box - "force search" assistant consistency * refactor alt assistant consistency * update alembic versions * rebased * undo formatting changes * additional formatting * current processing * merge fixes * formatting * colors * 2 -> 1 * 1 -> 2 --------- Co-authored-by: “Pablo <“pablo@danswer.ai”> --- .vscode/env_template.txt | 2 +- ...add_alternate_assistant_to_chat_message.py | 38 ++++ backend/danswer/chat/process_message.py | 26 ++- backend/danswer/db/chat.py | 15 +- backend/danswer/db/models.py | 7 + .../danswer/server/query_and_chat/models.py | 4 + web/src/app/chat/ChatIntro.tsx | 3 - web/src/app/chat/ChatPage.tsx | 163 ++++++++++---- web/src/app/chat/StarterMessage.tsx | 4 +- web/src/app/chat/input/ChatInputBar.tsx | 202 +++++++++++++++++- web/src/app/chat/interfaces.ts | 2 + web/src/app/chat/lib.tsx | 31 +++ web/src/app/chat/message/Messages.tsx | 26 +-- web/src/app/chat/page.tsx | 2 - .../shared/[chatId]/SharedChatDisplay.tsx | 7 + web/src/app/search/page.tsx | 1 + .../components/assistants/AssistantIcon.tsx | 23 +- web/src/components/health/healthcheck.tsx | 4 +- web/src/components/tooltip/Tooltip.tsx | 3 +- web/tailwind-themes/tailwind.config.js | 3 + 20 files changed, 469 insertions(+), 97 deletions(-) create mode 100644 backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py diff --git a/.vscode/env_template.txt b/.vscode/env_template.txt index 5a79b6535..bff9b79e5 100644 --- a/.vscode/env_template.txt +++ b/.vscode/env_template.txt @@ -49,4 +49,4 @@ PYTHONUNBUFFERED=1 # Enable the full set of Danswer Enterprise Edition features # NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development) -ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False +ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False \ No newline at end of file diff --git a/backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py b/backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py new file mode 100644 index 000000000..0514de56c --- /dev/null +++ b/backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py @@ -0,0 +1,38 @@ +"""add alternate assistant to chat message + +Revision ID: 3a7802814195 +Revises: 23957775e5f5 +Create Date: 2024-06-05 11:18:49.966333 + +""" + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "3a7802814195" +down_revision = "23957775e5f5" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True) + ) + op.create_foreign_key( + "fk_chat_message_persona", + "chat_message", + "persona", + ["alternate_assistant_id"], + ["id"], + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey") + op.drop_column("chat_message", "alternate_assistant_id") diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 5d0a33695..58a575745 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -34,6 +34,7 @@ from danswer.db.llm import fetch_existing_llm_providers from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import ToolCall from danswer.db.models import User +from danswer.db.persona import get_persona_by_id from danswer.document_index.factory import get_default_document_index from danswer.file_store.models import ChatFileType from danswer.file_store.models import FileDescriptor @@ -223,7 +224,15 @@ def stream_chat_message_objects( parent_id = new_msg_req.parent_message_id reference_doc_ids = new_msg_req.search_doc_ids retrieval_options = new_msg_req.retrieval_options - persona = chat_session.persona + alternate_assistant_id = new_msg_req.alternate_assistant_id + + # use alternate persona if alternative assistant id is passed in + if alternate_assistant_id is not None: + persona = get_persona_by_id( + alternate_assistant_id, user=user, db_session=db_session + ) + else: + persona = chat_session.persona prompt_id = new_msg_req.prompt_id if prompt_id is None and persona.prompts: @@ -380,6 +389,7 @@ def stream_chat_message_objects( # rephrased_query=, # token_count=, message_type=MessageType.ASSISTANT, + alternate_assistant_id=new_msg_req.alternate_assistant_id, # error=, # reference_docs=, db_session=db_session, @@ -389,11 +399,15 @@ def stream_chat_message_objects( if not final_msg.prompt: raise RuntimeError("No Prompt found") - prompt_config = PromptConfig.from_model( - final_msg.prompt, - prompt_override=( - new_msg_req.prompt_override or chat_session.prompt_override - ), + prompt_config = ( + PromptConfig.from_model( + final_msg.prompt, + prompt_override=( + new_msg_req.prompt_override or chat_session.prompt_override + ), + ) + if not persona + else PromptConfig.from_model(persona.prompts[0]) ) # find out what tools to use diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 61e42bde6..11a4cfec0 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -316,6 +316,7 @@ def create_new_chat_message( rephrased_query: str | None = None, error: str | None = None, reference_docs: list[DBSearchDoc] | None = None, + alternate_assistant_id: int | None = None, # Maps the citation number [n] to the DB SearchDoc citations: dict[int, int] | None = None, tool_calls: list[ToolCall] | None = None, @@ -334,6 +335,7 @@ def create_new_chat_message( files=files, tool_calls=tool_calls if tool_calls else [], error=error, + alternate_assistant_id=alternate_assistant_id, ) # SQL Alchemy will propagate this to update the reference_docs' foreign keys @@ -497,14 +499,14 @@ def translate_db_search_doc_to_server_search_doc( hidden=db_search_doc.hidden, metadata=db_search_doc.doc_metadata if not remove_doc_content else {}, score=db_search_doc.score, - match_highlights=db_search_doc.match_highlights - if not remove_doc_content - else [], + match_highlights=( + db_search_doc.match_highlights if not remove_doc_content else [] + ), updated_at=db_search_doc.updated_at if not remove_doc_content else None, primary_owners=db_search_doc.primary_owners if not remove_doc_content else [], - secondary_owners=db_search_doc.secondary_owners - if not remove_doc_content - else [], + secondary_owners=( + db_search_doc.secondary_owners if not remove_doc_content else [] + ), ) @@ -545,6 +547,7 @@ def translate_db_message_to_chat_message_detail( ) for tool_call in chat_message.tool_calls ], + alternate_assistant_id=chat_message.alternate_assistant_id, ) return chat_msg_detail diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index bfed0a03e..909236a97 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -708,6 +708,11 @@ class ChatMessage(Base): id: Mapped[int] = mapped_column(primary_key=True) chat_session_id: Mapped[int] = mapped_column(ForeignKey("chat_session.id")) + + alternate_assistant_id = mapped_column( + Integer, ForeignKey("persona.id"), nullable=True + ) + parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True) latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True) message: Mapped[str] = mapped_column(Text) @@ -736,10 +741,12 @@ class ChatMessage(Base): chat_session: Mapped[ChatSession] = relationship("ChatSession") prompt: Mapped[Optional["Prompt"]] = relationship("Prompt") + chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship( "ChatMessageFeedback", back_populates="chat_message", ) + document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="chat_message", diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 2636b9019..ea1ce1ff6 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -107,6 +107,9 @@ class CreateChatMessageRequest(ChunkContext): llm_override: LLMOverride | None = None prompt_override: PromptOverride | None = None + # allow user to specify an alternate assistnat + alternate_assistant_id: int | None = None + # used for seeded chats to kick off the generation of an AI answer use_existing_user_message: bool = False @@ -181,6 +184,7 @@ class ChatMessageDetail(BaseModel): context_docs: RetrievalDocs | None message_type: MessageType time_sent: datetime + alternate_assistant_id: str | None # Dict mapping citation number to db_doc_id citations: dict[int, int] | None files: list[FileDescriptor] diff --git a/web/src/app/chat/ChatIntro.tsx b/web/src/app/chat/ChatIntro.tsx index d44be4da5..926656958 100644 --- a/web/src/app/chat/ChatIntro.tsx +++ b/web/src/app/chat/ChatIntro.tsx @@ -7,7 +7,6 @@ import { FiBookmark, FiCpu, FiInfo, FiX, FiZoomIn } from "react-icons/fi"; import { HoverPopup } from "@/components/HoverPopup"; import { Modal } from "@/components/Modal"; import { useState } from "react"; -import { FaCaretDown, FaCaretRight } from "react-icons/fa"; import { Logo } from "@/components/Logo"; const MAX_PERSONAS_TO_DISPLAY = 4; @@ -29,11 +28,9 @@ function HelperItemDisplay({ export function ChatIntro({ availableSources, - availablePersonas, selectedPersona, }: { availableSources: ValidSources[]; - availablePersonas: Persona[]; selectedPersona: Persona; }) { const availableSourceMetadata = getSourceMetadataForSources(availableSources); diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 8e9427dd0..765f7e26a 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -22,6 +22,7 @@ import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh"; import { buildChatUrl, buildLatestMessageChain, + checkAnyAssistantHasSearch, createChatSession, getCitedDocumentsFromMessage, getHumanAndAIMessageFromMessageNumber, @@ -62,7 +63,6 @@ import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; import { checkLLMSupportsImageInput, - destructureValue, getFinalLLM, structureValue, } from "@/lib/llm/utils"; @@ -78,9 +78,7 @@ import { TbLayoutSidebarRightExpand } from "react-icons/tb"; import { SIDEBAR_WIDTH_CONST } from "@/lib/constants"; import ResizableSection from "@/components/resizable/ResizableSection"; -import { Button } from "@tremor/react"; -const MAX_INPUT_HEIGHT = 200; const TEMP_USER_MESSAGE_ID = -1; const TEMP_ASSISTANT_MESSAGE_ID = -2; const SYSTEM_MESSAGE_ID = -3; @@ -108,6 +106,12 @@ export function ChatPage({ const filteredAssistants = orderAssistantsForUser(availablePersonas, user); + const [selectedAssistant, setSelectedAssistant] = useState( + null + ); + const [alternativeGeneratingAssistant, setAlternativeGeneratingAssistant] = + useState(null); + const router = useRouter(); const searchParams = useSearchParams(); const existingChatIdRaw = searchParams.get("chatId"); @@ -213,6 +217,7 @@ export function ChatPage({ const response = await fetch( `/api/chat/get-chat-session/${existingChatSessionId}` ); + const chatSession = (await response.json()) as BackendChatSession; setSelectedPersona( @@ -350,6 +355,7 @@ export function ChatPage({ setCompleteMessageMap(newCompleteMessageMap); return newCompleteMessageMap; }; + const messageHistory = buildLatestMessageChain(completeMessageMap); const [isStreaming, setIsStreaming] = useState(false); @@ -405,6 +411,7 @@ export function ChatPage({ // just choose a conservative default, this will be updated in the // background on initial load / on persona change const [maxTokens, setMaxTokens] = useState(4096); + // fetch # of allowed document tokens for the selected Persona useEffect(() => { async function fetchMaxTokens() { @@ -619,13 +626,17 @@ export function ChatPage({ queryOverride, forceSearch, isSeededChat, + alternativeAssistant = null, }: { messageIdToResend?: number; messageOverride?: string; queryOverride?: string; forceSearch?: boolean; isSeededChat?: boolean; + alternativeAssistant?: Persona | null; } = {}) => { + setAlternativeGeneratingAssistant(alternativeAssistant); + clientScrollToBottom(); let currChatSessionId: number; let isNewSession = chatSessionId === null; @@ -645,6 +656,7 @@ export function ChatPage({ const messageToResend = messageHistory.find( (message) => message.messageId === messageIdToResend ); + const messageToResendParent = messageToResend?.parentMessageId !== null && messageToResend?.parentMessageId !== undefined @@ -703,12 +715,19 @@ export function ChatPage({ const frozenCompleteMessageMap = upsertToCompleteMessageMap({ messages: messageUpdates, }); + // on initial message send, we insert a dummy system message // set this as the parent here if no parent is set if (!parentMessage && frozenCompleteMessageMap.size === 2) { parentMessage = frozenCompleteMessageMap.get(SYSTEM_MESSAGE_ID) || null; } + + const currentAssistantId = alternativeAssistant + ? alternativeAssistant.id + : selectedAssistant?.id; + resetInputBar(); + setIsStreaming(true); let answer = ""; let query: string | null = null; @@ -721,6 +740,7 @@ export function ChatPage({ let error: string | null = null; let finalMessage: BackendMessage | null = null; let toolCalls: ToolCallMetadata[] = []; + try { const lastSuccessfulMessageId = getLastSuccessfulMessageId(currMessageHistory); @@ -728,6 +748,7 @@ export function ChatPage({ const stack = new CurrentMessageFIFO(); updateCurrentMessageFIFO(stack, { message: currMessage, + alternateAssistantId: currentAssistantId, fileDescriptors: currentMessageFiles, parentMessageId: lastSuccessfulMessageId, chatSessionId: currChatSessionId, @@ -846,6 +867,7 @@ export function ChatPage({ files: finalMessage?.files || aiMessageImages || [], toolCalls: finalMessage?.tool_calls || toolCalls, parentMessageId: newUserMessageId, + alternateAssistantID: selectedAssistant?.id, }, ]); } @@ -905,6 +927,7 @@ export function ChatPage({ ) { setSelectedMessageForDocDisplay(finalMessage.message_id); } + setAlternativeGeneratingAssistant(null); }; const onFeedback = async ( @@ -1021,10 +1044,37 @@ export function ChatPage({ setShowDocSidebar((showDocSidebar) => !showDocSidebar); // Toggle the state which will in turn toggle the class }; - const retrievalDisabled = !personaIncludesRetrieval(livePersona); + useEffect(() => { + const includes = checkAnyAssistantHasSearch( + messageHistory, + availablePersonas, + livePersona + ); + setRetrievalEnabled(includes); + }, [messageHistory, availablePersonas, livePersona]); + + const [retrievalEnabled, setRetrievalEnabled] = useState(() => { + return checkAnyAssistantHasSearch( + messageHistory, + availablePersonas, + livePersona + ); + }); + const [editingRetrievalEnabled, setEditingRetrievalEnabled] = useState(false); const sidebarElementRef = useRef(null); const innerSidebarElementRef = useRef(null); + const currentPersona = selectedAssistant || livePersona; + + const updateSelectedAssistant = (newAssistant: Persona | null) => { + setSelectedAssistant(newAssistant); + if (newAssistant) { + setEditingRetrievalEnabled(personaIncludesRetrieval(newAssistant)); + } else { + setEditingRetrievalEnabled(false); + } + }; + return ( <> @@ -1094,7 +1144,7 @@ export function ChatPage({ <>
{/* */} +
- {!retrievalDisabled && !showDocSidebar && ( + {retrievalEnabled && !showDocSidebar && ( -
- )} +
+
{ + updateSelectedAssistant(alternativeAssistant); + }} + alternativeAssistant={selectedAssistant} + personas={filteredAssistants} message={message} setMessage={setMessage} onSubmit={onSubmit} isStreaming={isStreaming} setIsCancelled={setIsCancelled} - retrievalDisabled={retrievalDisabled} + retrievalDisabled={ + !personaIncludesRetrieval(currentPersona) + } filterManager={filterManager} llmOverrideManager={llmOverrideManager} selectedAssistant={livePersona} @@ -1468,7 +1539,7 @@ export function ChatPage({
- {!retrievalDisabled ? ( + {retrievalEnabled || editingRetrievalEnabled ? (
-

{starterMessage.name}

-

{starterMessage.description}

+

{starterMessage.name}

+

{starterMessage.description}

); } diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index a2da13ff6..5664fb765 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -1,18 +1,36 @@ -import React, { EventHandler, useEffect, useRef } from "react"; -import { FiSend, FiFilter, FiPlusCircle, FiCpu } from "react-icons/fi"; +import React, { + Dispatch, + SetStateAction, + useEffect, + useRef, + useState, +} from "react"; +import { + FiSend, + FiFilter, + FiPlusCircle, + FiCpu, + FiX, + FiPlus, + FiInfo, +} from "react-icons/fi"; import ChatInputOption from "./ChatInputOption"; import { FaBrain } from "react-icons/fa"; import { Persona } from "@/app/admin/assistants/interfaces"; -import { FilterManager, LlmOverride, LlmOverrideManager } from "@/lib/hooks"; +import { FilterManager, LlmOverrideManager } from "@/lib/hooks"; import { SelectedFilterDisplay } from "./SelectedFilterDisplay"; import { useChatContext } from "@/components/context/ChatContext"; import { getFinalLLM } from "@/lib/llm/utils"; import { FileDescriptor } from "../interfaces"; import { InputBarPreview } from "../files/InputBarPreview"; - +import { RobotIcon } from "@/components/icons/icons"; +import { Hoverable } from "@/components/Hoverable"; +import { AssistantIcon } from "@/components/assistants/AssistantIcon"; +import { Tooltip } from "@/components/tooltip/Tooltip"; const MAX_INPUT_HEIGHT = 200; export function ChatInputBar({ + personas, message, setMessage, onSubmit, @@ -21,13 +39,17 @@ export function ChatInputBar({ retrievalDisabled, filterManager, llmOverrideManager, + onSetSelectedAssistant, selectedAssistant, files, setFiles, handleFileUpload, setConfigModalActiveTab, textAreaRef, + alternativeAssistant, }: { + onSetSelectedAssistant: (alternativeAssistant: Persona | null) => void; + personas: Persona[]; message: string; setMessage: (message: string) => void; onSubmit: () => void; @@ -37,6 +59,7 @@ export function ChatInputBar({ filterManager: FilterManager; llmOverrideManager: LlmOverrideManager; selectedAssistant: Persona; + alternativeAssistant: Persona | null; files: FileDescriptor[]; setFiles: (files: FileDescriptor[]) => void; handleFileUpload: (files: File[]) => void; @@ -75,6 +98,100 @@ export function ChatInputBar({ const { llmProviders } = useChatContext(); const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null); + const suggestionsRef = useRef(null); + const [showSuggestions, setShowSuggestions] = useState(false); + + const interactionsRef = useRef(null); + + const hideSuggestions = () => { + setShowSuggestions(false); + setAssistantIconIndex(0); + }; + + // Update selected persona + const updateCurrentPersona = (persona: Persona) => { + onSetSelectedAssistant(persona.id == selectedAssistant.id ? null : persona); + hideSuggestions(); + setMessage(""); + }; + + // Click out of assistant suggestions + useEffect(() => { + const handleClickOutside = (event: MouseEvent) => { + if ( + suggestionsRef.current && + !suggestionsRef.current.contains(event.target as Node) && + (!interactionsRef.current || + !interactionsRef.current.contains(event.target as Node)) + ) { + hideSuggestions(); + } + }; + document.addEventListener("mousedown", handleClickOutside); + return () => { + document.removeEventListener("mousedown", handleClickOutside); + }; + }, []); + + // Complete user input handling + const handleInputChange = (event: React.ChangeEvent) => { + const text = event.target.value; + setMessage(text); + + if (!text.startsWith("@")) { + hideSuggestions(); + return; + } + + // If looking for an assistant...fup + const match = text.match(/(?:\s|^)@(\w*)$/); + if (match) { + setShowSuggestions(true); + } else { + hideSuggestions(); + } + }; + + const filteredPersonas = personas.filter((persona) => + persona.name.toLowerCase().startsWith( + message + .slice(message.lastIndexOf("@") + 1) + .split(/\s/)[0] + .toLowerCase() + ) + ); + + const [assistantIconIndex, setAssistantIconIndex] = useState(0); + + const handleKeyDown = (e: React.KeyboardEvent) => { + if ( + showSuggestions && + filteredPersonas.length > 0 && + (e.key === "Tab" || e.key == "Enter") + ) { + e.preventDefault(); + if (assistantIconIndex == filteredPersonas.length) { + window.open("/assistants/new", "_blank"); + hideSuggestions(); + setMessage(""); + } else { + const option = + filteredPersonas[assistantIconIndex >= 0 ? assistantIconIndex : 0]; + updateCurrentPersona(option); + } + } else if (e.key === "ArrowDown") { + e.preventDefault(); + setAssistantIconIndex((assistantIconIndex) => + Math.min(assistantIconIndex + 1, filteredPersonas.length) + ); + } else if (e.key === "ArrowUp") { + e.preventDefault(); + setAssistantIconIndex((assistantIconIndex) => + Math.max(assistantIconIndex - 1, 0) + ); + } + }; + return (
@@ -90,9 +207,45 @@ export function ChatInputBar({ mx-auto " > + {showSuggestions && filteredPersonas.length > 0 && ( +
+
+ {filteredPersonas.map((currentPersona, index) => ( + + ))} + + +

Create a new assistant

+
+
+
+ )} +
+
+ {alternativeAssistant && ( +
+
+ +

+ {alternativeAssistant.name} +

+
+ + {alternativeAssistant.description} +

+ } + > + +
+ + onSetSelectedAssistant(null)} + /> +
+
+
+ )} + {files.length > 0 && (
{files.map((file) => ( @@ -128,8 +313,11 @@ export function ChatInputBar({ ))}
)} +