diff --git a/backend/alembic/versions/0568ccf46a6b_add_thread_specific_model_selection.py b/backend/alembic/versions/0568ccf46a6b_add_thread_specific_model_selection.py new file mode 100644 index 000000000..8bb56d0f5 --- /dev/null +++ b/backend/alembic/versions/0568ccf46a6b_add_thread_specific_model_selection.py @@ -0,0 +1,31 @@ +"""Add thread specific model selection + +Revision ID: 0568ccf46a6b +Revises: e209dc5a8156 +Create Date: 2024-06-19 14:25:36.376046 + +""" + +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "0568ccf46a6b" +down_revision = "e209dc5a8156" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column( + "chat_session", + sa.Column("current_alternate_model", sa.String(), nullable=True), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("chat_session", "current_alternate_model") + # ### end Alembic commands ### diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 29ca74bfc..b17ccafa4 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -661,6 +661,8 @@ class ChatSession(Base): ForeignKey("chat_folder.id"), nullable=True ) + current_alternate_model: Mapped[str | None] = mapped_column(String, default=None) + # the latest "overrides" specified by the user. These take precedence over # the attached persona. However, overrides specified directly in the # `send-message` call will take precedence over these. diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index af5501556..090c92012 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -63,6 +63,7 @@ from danswer.server.query_and_chat.models import LLMOverride from danswer.server.query_and_chat.models import PromptOverride from danswer.server.query_and_chat.models import RenameChatSessionResponse from danswer.server.query_and_chat.models import SearchFeedbackRequest +from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest from danswer.utils.logger import setup_logger logger = setup_logger() @@ -77,9 +78,13 @@ def get_user_chat_sessions( ) -> ChatSessionsResponse: user_id = user.id if user is not None else None - chat_sessions = get_chat_sessions_by_user( - user_id=user_id, deleted=False, db_session=db_session - ) + try: + chat_sessions = get_chat_sessions_by_user( + user_id=user_id, deleted=False, db_session=db_session + ) + + except ValueError: + raise ValueError("Chat session does not exist or has been deleted") return ChatSessionsResponse( sessions=[ @@ -90,12 +95,30 @@ def get_user_chat_sessions( time_created=chat.time_created.isoformat(), shared_status=chat.shared_status, folder_id=chat.folder_id, + current_alternate_model=chat.current_alternate_model, ) for chat in chat_sessions ] ) +@router.put("/update-chat-session-model") +def update_chat_session_model( + update_thread_req: UpdateChatSessionThreadRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + chat_session = get_chat_session_by_id( + chat_session_id=update_thread_req.chat_session_id, + user_id=user.id if user is not None else None, + db_session=db_session, + ) + chat_session.current_alternate_model = update_thread_req.new_alternate_model + + db_session.add(chat_session) + db_session.commit() + + @router.get("/get-chat-session/{session_id}") def get_chat_session( session_id: int, @@ -138,6 +161,7 @@ def get_chat_session( description=chat_session.description, persona_id=chat_session.persona_id, persona_name=chat_session.persona.name, + current_alternate_model=chat_session.current_alternate_model, messages=[ translate_db_message_to_chat_message_detail( msg, remove_doc_content=is_shared # if shared, don't leak doc content diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 09561bf24..2636b9019 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -32,6 +32,12 @@ class SimpleQueryRequest(BaseModel): query: str +class UpdateChatSessionThreadRequest(BaseModel): + # If not specified, use Danswer default persona + chat_session_id: int + new_alternate_model: str + + class ChatSessionCreationRequest(BaseModel): # If not specified, use Danswer default persona persona_id: int = 0 @@ -142,6 +148,7 @@ class ChatSessionDetails(BaseModel): time_created: str shared_status: ChatSessionSharedStatus folder_id: int | None + current_alternate_model: str | None = None class ChatSessionsResponse(BaseModel): @@ -193,6 +200,7 @@ class ChatSessionDetailResponse(BaseModel): messages: list[ChatMessageDetail] time_created: datetime shared_status: ChatSessionSharedStatus + current_alternate_model: str | None class QueryValidationResponse(BaseModel): diff --git a/web/src/app/chat/ChatPage.tsx b/web/src/app/chat/ChatPage.tsx index 92e4384f9..78114c5df 100644 --- a/web/src/app/chat/ChatPage.tsx +++ b/web/src/app/chat/ChatPage.tsx @@ -34,6 +34,7 @@ import { removeMessage, sendMessage, setMessageAsLatest, + updateModelOverrideForChatSession, updateParentChildren, uploadFilesForChat, } from "./lib"; @@ -59,7 +60,12 @@ import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces"; import { buildFilters } from "@/lib/search/utils"; import { SettingsContext } from "@/components/settings/SettingsProvider"; import Dropzone from "react-dropzone"; -import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils"; +import { + checkLLMSupportsImageInput, + destructureValue, + getFinalLLM, + structureValue, +} from "@/lib/llm/utils"; import { ChatInputBar } from "./input/ChatInputBar"; import { ConfigurationModal } from "./modal/configuration/ConfigurationModal"; import { useChatContext } from "@/components/context/ChatContext"; @@ -92,6 +98,7 @@ export function ChatPage({ folders, openedFolders, } = useChatContext(); + const filteredAssistants = orderAssistantsForUser(availablePersonas, user); const router = useRouter(); @@ -104,6 +111,9 @@ export function ChatPage({ const selectedChatSession = chatSessions.find( (chatSession) => chatSession.id === existingChatSessionId ); + + const llmOverrideManager = useLlmOverride(selectedChatSession); + const existingChatSessionPersonaId = selectedChatSession?.persona_id; // used to track whether or not the initial "submit on load" has been performed @@ -124,25 +134,37 @@ export function ChatPage({ // this is triggered every time the user switches which chat // session they are using useEffect(() => { + if ( + chatSessionId && + !urlChatSessionId.current && + llmOverrideManager.llmOverride + ) { + updateModelOverrideForChatSession( + chatSessionId, + structureValue( + llmOverrideManager.llmOverride.name, + llmOverrideManager.llmOverride.provider, + llmOverrideManager.llmOverride.modelName + ) as string + ); + } urlChatSessionId.current = existingChatSessionId; - textAreaRef.current?.focus(); // only clear things if we're going from one chat session to another + if (chatSessionId !== null && existingChatSessionId !== chatSessionId) { // de-select documents clearSelectedDocuments(); // reset all filters + filterManager.setSelectedDocumentSets([]); filterManager.setSelectedSources([]); filterManager.setSelectedTags([]); filterManager.setTimeRange(null); - // reset LLM overrides - llmOverrideManager.setLlmOverride({ - name: "", - provider: "", - modelName: "", - }); + + // reset LLM overrides (based on chat session!) + llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession); llmOverrideManager.setTemperature(null); // remove uploaded files setCurrentMessageFiles([]); @@ -177,7 +199,6 @@ export function ChatPage({ submitOnLoadPerformed.current = true; await onSubmit(); } - return; } @@ -186,6 +207,7 @@ export function ChatPage({ `/api/chat/get-chat-session/${existingChatSessionId}` ); const chatSession = (await response.json()) as BackendChatSession; + setSelectedPersona( filteredAssistants.find( (persona) => persona.id === chatSession.persona_id @@ -386,8 +408,6 @@ export function ChatPage({ availableDocumentSets, }); - const llmOverrideManager = useLlmOverride(); - // state for cancelling streaming const [isCancelled, setIsCancelled] = useState(false); const isCancelledRef = useRef(isCancelled); @@ -595,6 +615,7 @@ export function ChatPage({ .map((document) => document.db_doc_id as number), queryOverride, forceSearch, + modelProvider: llmOverrideManager.llmOverride.name || undefined, modelVersion: llmOverrideManager.llmOverride.modelName || @@ -893,6 +914,7 @@ export function ChatPage({ )} setConfigModalActiveTab(null)} @@ -1044,7 +1066,9 @@ export function ChatPage({ citedDocuments={getCitedDocumentsFromMessage( message )} - toolCall={message.toolCalls[0]} + toolCall={ + message.toolCalls && message.toolCalls[0] + } isComplete={ i !== messageHistory.length - 1 || !isStreaming @@ -1212,7 +1236,6 @@ export function ChatPage({ )} )} -
diff --git a/web/src/app/chat/input/ChatInputBar.tsx b/web/src/app/chat/input/ChatInputBar.tsx index 231fb03de..278acfabb 100644 --- a/web/src/app/chat/input/ChatInputBar.tsx +++ b/web/src/app/chat/input/ChatInputBar.tsx @@ -163,6 +163,7 @@ export function ChatInputBar({ icon={FaBrain} onClick={() => setConfigModalActiveTab("assistants")} /> + void; @@ -65,6 +66,7 @@ export function ConfigurationModal({ filterManager: FilterManager; llmProviders: LLMProviderDescriptor[]; llmOverrideManager: LlmOverrideManager; + chatSessionId?: number; }) { useEffect(() => { const handleKeyDown = (event: KeyboardEvent) => { @@ -149,6 +151,7 @@ export function ConfigurationModal({ {activeTab === "llms" && ( diff --git a/web/src/app/chat/modal/configuration/LlmTab.tsx b/web/src/app/chat/modal/configuration/LlmTab.tsx index a4fd5259b..79dd5d82f 100644 --- a/web/src/app/chat/modal/configuration/LlmTab.tsx +++ b/web/src/app/chat/modal/configuration/LlmTab.tsx @@ -5,14 +5,18 @@ import { debounce } from "lodash"; import { DefaultDropdown } from "@/components/Dropdown"; import { Text } from "@tremor/react"; import { Persona } from "@/app/admin/assistants/interfaces"; -import { getFinalLLM } from "@/lib/llm/utils"; +import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils"; +import { updateModelOverrideForChatSession } from "../../lib"; +import { PopupSpec } from "@/components/admin/connectors/Popup"; export function LlmTab({ llmOverrideManager, currentAssistant, + chatSessionId, }: { llmOverrideManager: LlmOverrideManager; currentAssistant: Persona; + chatSessionId?: number; }) { const { llmProviders } = useChatContext(); const { llmOverride, setLlmOverride, temperature, setTemperature } = @@ -37,21 +41,6 @@ export function LlmTab({ const [_, defaultLlmName] = getFinalLLM(llmProviders, currentAssistant, null); const llmOptions: { name: string; value: string }[] = []; - const structureValue = ( - name: string, - provider: string, - modelName: string - ) => { - return `${name}__${provider}__${modelName}`; - }; - const destructureValue = (value: string): LlmOverride => { - const [displayName, provider, modelName] = value.split("__"); - return { - name: displayName, - provider, - modelName, - }; - }; llmProviders.forEach((llmProvider) => { llmProvider.model_names.forEach((modelName) => { llmOptions.push({ @@ -76,6 +65,7 @@ export function LlmTab({ Default Model: {defaultLlmName}. +
- setLlmOverride(destructureValue(value as string)) - } + onSelect={(value) => { + setLlmOverride(destructureValue(value as string)); + if (chatSessionId) { + updateModelOverrideForChatSession(chatSessionId, value as string); + } + }} />
diff --git a/web/src/lib/hooks.ts b/web/src/lib/hooks.ts index d2cb538cc..18a859f88 100644 --- a/web/src/lib/hooks.ts +++ b/web/src/lib/hooks.ts @@ -12,6 +12,8 @@ import { useState } from "react"; import { DateRangePickerValue } from "@tremor/react"; import { SourceMetadata } from "./search/interfaces"; import { EE_ENABLED } from "./constants"; +import { destructureValue } from "./llm/utils"; +import { ChatSession } from "@/app/chat/interfaces"; const CREDENTIAL_URL = "/api/manage/admin/credential"; @@ -136,17 +138,38 @@ export interface LlmOverrideManager { setLlmOverride: React.Dispatch>; temperature: number | null; setTemperature: React.Dispatch>; + updateModelOverrideForChatSession: (chatSession?: ChatSession) => void; } -export function useLlmOverride(): LlmOverrideManager { - const [llmOverride, setLlmOverride] = useState({ - name: "", - provider: "", - modelName: "", - }); +export function useLlmOverride( + currentChatSession?: ChatSession +): LlmOverrideManager { + const [llmOverride, setLlmOverride] = useState( + currentChatSession && currentChatSession.current_alternate_model + ? destructureValue(currentChatSession.current_alternate_model) + : { + name: "", + provider: "", + modelName: "", + } + ); + + const updateModelOverrideForChatSession = (chatSession?: ChatSession) => { + setLlmOverride( + chatSession && chatSession.current_alternate_model + ? destructureValue(chatSession.current_alternate_model) + : { + name: "", + provider: "", + modelName: "", + } + ); + }; + const [temperature, setTemperature] = useState(null); return { + updateModelOverrideForChatSession, llmOverride, setLlmOverride, temperature, diff --git a/web/src/lib/llm/utils.ts b/web/src/lib/llm/utils.ts index c9c7a5b8c..7427ed36d 100644 --- a/web/src/lib/llm/utils.ts +++ b/web/src/lib/llm/utils.ts @@ -43,3 +43,20 @@ export function checkLLMSupportsImageInput(provider: string, model: string) { ([p, m]) => p === provider && m === model ); } + +export const structureValue = ( + name: string, + provider: string, + modelName: string +) => { + return `${name}__${provider}__${modelName}`; +}; + +export const destructureValue = (value: string): LlmOverride => { + const [displayName, provider, modelName] = value.split("__"); + return { + name: displayName, + provider, + modelName, + }; +};