Add functional thread modification endpoints (#1668)

Makes it so if you change which LLM you are using in a given ChatSession, that is persisted and sticks around if you reload the page / come back to the ChatSession later
This commit is contained in:
pablodanswer 2024-06-21 18:10:30 -07:00 committed by GitHub
parent 5cafc96cae
commit 8178d536b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 185 additions and 41 deletions

View File

@ -0,0 +1,31 @@
"""Add thread specific model selection
Revision ID: 0568ccf46a6b
Revises: e209dc5a8156
Create Date: 2024-06-19 14:25:36.376046
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0568ccf46a6b"
down_revision = "e209dc5a8156"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column(
"chat_session",
sa.Column("current_alternate_model", sa.String(), nullable=True),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("chat_session", "current_alternate_model")
# ### end Alembic commands ###

View File

@ -661,6 +661,8 @@ class ChatSession(Base):
ForeignKey("chat_folder.id"), nullable=True ForeignKey("chat_folder.id"), nullable=True
) )
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
# the latest "overrides" specified by the user. These take precedence over # the latest "overrides" specified by the user. These take precedence over
# the attached persona. However, overrides specified directly in the # the attached persona. However, overrides specified directly in the
# `send-message` call will take precedence over these. # `send-message` call will take precedence over these.

View File

@ -63,6 +63,7 @@ from danswer.server.query_and_chat.models import LLMOverride
from danswer.server.query_and_chat.models import PromptOverride from danswer.server.query_and_chat.models import PromptOverride
from danswer.server.query_and_chat.models import RenameChatSessionResponse from danswer.server.query_and_chat.models import RenameChatSessionResponse
from danswer.server.query_and_chat.models import SearchFeedbackRequest from danswer.server.query_and_chat.models import SearchFeedbackRequest
from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -77,10 +78,14 @@ def get_user_chat_sessions(
) -> ChatSessionsResponse: ) -> ChatSessionsResponse:
user_id = user.id if user is not None else None user_id = user.id if user is not None else None
try:
chat_sessions = get_chat_sessions_by_user( chat_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session user_id=user_id, deleted=False, db_session=db_session
) )
except ValueError:
raise ValueError("Chat session does not exist or has been deleted")
return ChatSessionsResponse( return ChatSessionsResponse(
sessions=[ sessions=[
ChatSessionDetails( ChatSessionDetails(
@ -90,12 +95,30 @@ def get_user_chat_sessions(
time_created=chat.time_created.isoformat(), time_created=chat.time_created.isoformat(),
shared_status=chat.shared_status, shared_status=chat.shared_status,
folder_id=chat.folder_id, folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
) )
for chat in chat_sessions for chat in chat_sessions
] ]
) )
@router.put("/update-chat-session-model")
def update_chat_session_model(
update_thread_req: UpdateChatSessionThreadRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=update_thread_req.chat_session_id,
user_id=user.id if user is not None else None,
db_session=db_session,
)
chat_session.current_alternate_model = update_thread_req.new_alternate_model
db_session.add(chat_session)
db_session.commit()
@router.get("/get-chat-session/{session_id}") @router.get("/get-chat-session/{session_id}")
def get_chat_session( def get_chat_session(
session_id: int, session_id: int,
@ -138,6 +161,7 @@ def get_chat_session(
description=chat_session.description, description=chat_session.description,
persona_id=chat_session.persona_id, persona_id=chat_session.persona_id,
persona_name=chat_session.persona.name, persona_name=chat_session.persona.name,
current_alternate_model=chat_session.current_alternate_model,
messages=[ messages=[
translate_db_message_to_chat_message_detail( translate_db_message_to_chat_message_detail(
msg, remove_doc_content=is_shared # if shared, don't leak doc content msg, remove_doc_content=is_shared # if shared, don't leak doc content

View File

@ -32,6 +32,12 @@ class SimpleQueryRequest(BaseModel):
query: str query: str
class UpdateChatSessionThreadRequest(BaseModel):
# If not specified, use Danswer default persona
chat_session_id: int
new_alternate_model: str
class ChatSessionCreationRequest(BaseModel): class ChatSessionCreationRequest(BaseModel):
# If not specified, use Danswer default persona # If not specified, use Danswer default persona
persona_id: int = 0 persona_id: int = 0
@ -142,6 +148,7 @@ class ChatSessionDetails(BaseModel):
time_created: str time_created: str
shared_status: ChatSessionSharedStatus shared_status: ChatSessionSharedStatus
folder_id: int | None folder_id: int | None
current_alternate_model: str | None = None
class ChatSessionsResponse(BaseModel): class ChatSessionsResponse(BaseModel):
@ -193,6 +200,7 @@ class ChatSessionDetailResponse(BaseModel):
messages: list[ChatMessageDetail] messages: list[ChatMessageDetail]
time_created: datetime time_created: datetime
shared_status: ChatSessionSharedStatus shared_status: ChatSessionSharedStatus
current_alternate_model: str | None
class QueryValidationResponse(BaseModel): class QueryValidationResponse(BaseModel):

View File

@ -34,6 +34,7 @@ import {
removeMessage, removeMessage,
sendMessage, sendMessage,
setMessageAsLatest, setMessageAsLatest,
updateModelOverrideForChatSession,
updateParentChildren, updateParentChildren,
uploadFilesForChat, uploadFilesForChat,
} from "./lib"; } from "./lib";
@ -59,7 +60,12 @@ import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces";
import { buildFilters } from "@/lib/search/utils"; import { buildFilters } from "@/lib/search/utils";
import { SettingsContext } from "@/components/settings/SettingsProvider"; import { SettingsContext } from "@/components/settings/SettingsProvider";
import Dropzone from "react-dropzone"; import Dropzone from "react-dropzone";
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils"; import {
checkLLMSupportsImageInput,
destructureValue,
getFinalLLM,
structureValue,
} from "@/lib/llm/utils";
import { ChatInputBar } from "./input/ChatInputBar"; import { ChatInputBar } from "./input/ChatInputBar";
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal"; import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
import { useChatContext } from "@/components/context/ChatContext"; import { useChatContext } from "@/components/context/ChatContext";
@ -92,6 +98,7 @@ export function ChatPage({
folders, folders,
openedFolders, openedFolders,
} = useChatContext(); } = useChatContext();
const filteredAssistants = orderAssistantsForUser(availablePersonas, user); const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
const router = useRouter(); const router = useRouter();
@ -104,6 +111,9 @@ export function ChatPage({
const selectedChatSession = chatSessions.find( const selectedChatSession = chatSessions.find(
(chatSession) => chatSession.id === existingChatSessionId (chatSession) => chatSession.id === existingChatSessionId
); );
const llmOverrideManager = useLlmOverride(selectedChatSession);
const existingChatSessionPersonaId = selectedChatSession?.persona_id; const existingChatSessionPersonaId = selectedChatSession?.persona_id;
// used to track whether or not the initial "submit on load" has been performed // used to track whether or not the initial "submit on load" has been performed
@ -124,25 +134,37 @@ export function ChatPage({
// this is triggered every time the user switches which chat // this is triggered every time the user switches which chat
// session they are using // session they are using
useEffect(() => { useEffect(() => {
if (
chatSessionId &&
!urlChatSessionId.current &&
llmOverrideManager.llmOverride
) {
updateModelOverrideForChatSession(
chatSessionId,
structureValue(
llmOverrideManager.llmOverride.name,
llmOverrideManager.llmOverride.provider,
llmOverrideManager.llmOverride.modelName
) as string
);
}
urlChatSessionId.current = existingChatSessionId; urlChatSessionId.current = existingChatSessionId;
textAreaRef.current?.focus(); textAreaRef.current?.focus();
// only clear things if we're going from one chat session to another // only clear things if we're going from one chat session to another
if (chatSessionId !== null && existingChatSessionId !== chatSessionId) { if (chatSessionId !== null && existingChatSessionId !== chatSessionId) {
// de-select documents // de-select documents
clearSelectedDocuments(); clearSelectedDocuments();
// reset all filters // reset all filters
filterManager.setSelectedDocumentSets([]); filterManager.setSelectedDocumentSets([]);
filterManager.setSelectedSources([]); filterManager.setSelectedSources([]);
filterManager.setSelectedTags([]); filterManager.setSelectedTags([]);
filterManager.setTimeRange(null); filterManager.setTimeRange(null);
// reset LLM overrides
llmOverrideManager.setLlmOverride({ // reset LLM overrides (based on chat session!)
name: "", llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
provider: "",
modelName: "",
});
llmOverrideManager.setTemperature(null); llmOverrideManager.setTemperature(null);
// remove uploaded files // remove uploaded files
setCurrentMessageFiles([]); setCurrentMessageFiles([]);
@ -177,7 +199,6 @@ export function ChatPage({
submitOnLoadPerformed.current = true; submitOnLoadPerformed.current = true;
await onSubmit(); await onSubmit();
} }
return; return;
} }
@ -186,6 +207,7 @@ export function ChatPage({
`/api/chat/get-chat-session/${existingChatSessionId}` `/api/chat/get-chat-session/${existingChatSessionId}`
); );
const chatSession = (await response.json()) as BackendChatSession; const chatSession = (await response.json()) as BackendChatSession;
setSelectedPersona( setSelectedPersona(
filteredAssistants.find( filteredAssistants.find(
(persona) => persona.id === chatSession.persona_id (persona) => persona.id === chatSession.persona_id
@ -386,8 +408,6 @@ export function ChatPage({
availableDocumentSets, availableDocumentSets,
}); });
const llmOverrideManager = useLlmOverride();
// state for cancelling streaming // state for cancelling streaming
const [isCancelled, setIsCancelled] = useState(false); const [isCancelled, setIsCancelled] = useState(false);
const isCancelledRef = useRef(isCancelled); const isCancelledRef = useRef(isCancelled);
@ -595,6 +615,7 @@ export function ChatPage({
.map((document) => document.db_doc_id as number), .map((document) => document.db_doc_id as number),
queryOverride, queryOverride,
forceSearch, forceSearch,
modelProvider: llmOverrideManager.llmOverride.name || undefined, modelProvider: llmOverrideManager.llmOverride.name || undefined,
modelVersion: modelVersion:
llmOverrideManager.llmOverride.modelName || llmOverrideManager.llmOverride.modelName ||
@ -893,6 +914,7 @@ export function ChatPage({
)} )}
<ConfigurationModal <ConfigurationModal
chatSessionId={chatSessionId!}
activeTab={configModalActiveTab} activeTab={configModalActiveTab}
setActiveTab={setConfigModalActiveTab} setActiveTab={setConfigModalActiveTab}
onClose={() => setConfigModalActiveTab(null)} onClose={() => setConfigModalActiveTab(null)}
@ -1044,7 +1066,9 @@ export function ChatPage({
citedDocuments={getCitedDocumentsFromMessage( citedDocuments={getCitedDocumentsFromMessage(
message message
)} )}
toolCall={message.toolCalls[0]} toolCall={
message.toolCalls && message.toolCalls[0]
}
isComplete={ isComplete={
i !== messageHistory.length - 1 || i !== messageHistory.length - 1 ||
!isStreaming !isStreaming
@ -1212,7 +1236,6 @@ export function ChatPage({
)} )}
</div> </div>
)} )}
<div ref={endDivRef} /> <div ref={endDivRef} />
</div> </div>
</div> </div>

View File

@ -163,6 +163,7 @@ export function ChatInputBar({
icon={FaBrain} icon={FaBrain}
onClick={() => setConfigModalActiveTab("assistants")} onClick={() => setConfigModalActiveTab("assistants")}
/> />
<ChatInputOption <ChatInputOption
name={ name={
llmOverrideManager.llmOverride.modelName || llmOverrideManager.llmOverride.modelName ||

View File

@ -53,6 +53,7 @@ export interface ChatSession {
time_created: string; time_created: string;
shared_status: ChatSessionSharedStatus; shared_status: ChatSessionSharedStatus;
folder_id: number | null; folder_id: number | null;
current_alternate_model: string;
} }
export interface Message { export interface Message {
@ -79,6 +80,7 @@ export interface BackendChatSession {
messages: BackendMessage[]; messages: BackendMessage[];
time_created: string; time_created: string;
shared_status: ChatSessionSharedStatus; shared_status: ChatSessionSharedStatus;
current_alternate_model?: string;
} }
export interface BackendMessage { export interface BackendMessage {

View File

@ -21,6 +21,23 @@ import { Persona } from "../admin/assistants/interfaces";
import { ReadonlyURLSearchParams } from "next/navigation"; import { ReadonlyURLSearchParams } from "next/navigation";
import { SEARCH_PARAM_NAMES } from "./searchParams"; import { SEARCH_PARAM_NAMES } from "./searchParams";
export async function updateModelOverrideForChatSession(
chatSessionId: number,
newAlternateModel: string
) {
const response = await fetch("/api/chat/update-chat-session-model", {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
chat_session_id: chatSessionId,
new_alternate_model: newAlternateModel,
}),
});
return response;
}
export async function createChatSession( export async function createChatSession(
personaId: number, personaId: number,
description: string | null description: string | null

View File

@ -55,6 +55,7 @@ export function ConfigurationModal({
filterManager, filterManager,
llmProviders, llmProviders,
llmOverrideManager, llmOverrideManager,
chatSessionId,
}: { }: {
activeTab: string | null; activeTab: string | null;
setActiveTab: (tab: string | null) => void; setActiveTab: (tab: string | null) => void;
@ -65,6 +66,7 @@ export function ConfigurationModal({
filterManager: FilterManager; filterManager: FilterManager;
llmProviders: LLMProviderDescriptor[]; llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager; llmOverrideManager: LlmOverrideManager;
chatSessionId?: number;
}) { }) {
useEffect(() => { useEffect(() => {
const handleKeyDown = (event: KeyboardEvent) => { const handleKeyDown = (event: KeyboardEvent) => {
@ -149,6 +151,7 @@ export function ConfigurationModal({
{activeTab === "llms" && ( {activeTab === "llms" && (
<LlmTab <LlmTab
chatSessionId={chatSessionId}
llmOverrideManager={llmOverrideManager} llmOverrideManager={llmOverrideManager}
currentAssistant={selectedAssistant} currentAssistant={selectedAssistant}
/> />

View File

@ -5,14 +5,18 @@ import { debounce } from "lodash";
import { DefaultDropdown } from "@/components/Dropdown"; import { DefaultDropdown } from "@/components/Dropdown";
import { Text } from "@tremor/react"; import { Text } from "@tremor/react";
import { Persona } from "@/app/admin/assistants/interfaces"; import { Persona } from "@/app/admin/assistants/interfaces";
import { getFinalLLM } from "@/lib/llm/utils"; import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
import { updateModelOverrideForChatSession } from "../../lib";
import { PopupSpec } from "@/components/admin/connectors/Popup";
export function LlmTab({ export function LlmTab({
llmOverrideManager, llmOverrideManager,
currentAssistant, currentAssistant,
chatSessionId,
}: { }: {
llmOverrideManager: LlmOverrideManager; llmOverrideManager: LlmOverrideManager;
currentAssistant: Persona; currentAssistant: Persona;
chatSessionId?: number;
}) { }) {
const { llmProviders } = useChatContext(); const { llmProviders } = useChatContext();
const { llmOverride, setLlmOverride, temperature, setTemperature } = const { llmOverride, setLlmOverride, temperature, setTemperature } =
@ -37,21 +41,6 @@ export function LlmTab({
const [_, defaultLlmName] = getFinalLLM(llmProviders, currentAssistant, null); const [_, defaultLlmName] = getFinalLLM(llmProviders, currentAssistant, null);
const llmOptions: { name: string; value: string }[] = []; const llmOptions: { name: string; value: string }[] = [];
const structureValue = (
name: string,
provider: string,
modelName: string
) => {
return `${name}__${provider}__${modelName}`;
};
const destructureValue = (value: string): LlmOverride => {
const [displayName, provider, modelName] = value.split("__");
return {
name: displayName,
provider,
modelName,
};
};
llmProviders.forEach((llmProvider) => { llmProviders.forEach((llmProvider) => {
llmProvider.model_names.forEach((modelName) => { llmProvider.model_names.forEach((modelName) => {
llmOptions.push({ llmOptions.push({
@ -76,6 +65,7 @@ export function LlmTab({
<Text className="mb-3"> <Text className="mb-3">
Default Model: <i className="font-medium">{defaultLlmName}</i>. Default Model: <i className="font-medium">{defaultLlmName}</i>.
</Text> </Text>
<div className="w-96"> <div className="w-96">
<DefaultDropdown <DefaultDropdown
options={llmOptions} options={llmOptions}
@ -84,9 +74,12 @@ export function LlmTab({
llmOverride.provider, llmOverride.provider,
llmOverride.modelName llmOverride.modelName
)} )}
onSelect={(value) => onSelect={(value) => {
setLlmOverride(destructureValue(value as string)) setLlmOverride(destructureValue(value as string));
if (chatSessionId) {
updateModelOverrideForChatSession(chatSessionId, value as string);
} }
}}
/> />
</div> </div>

View File

@ -12,6 +12,8 @@ import { useState } from "react";
import { DateRangePickerValue } from "@tremor/react"; import { DateRangePickerValue } from "@tremor/react";
import { SourceMetadata } from "./search/interfaces"; import { SourceMetadata } from "./search/interfaces";
import { EE_ENABLED } from "./constants"; import { EE_ENABLED } from "./constants";
import { destructureValue } from "./llm/utils";
import { ChatSession } from "@/app/chat/interfaces";
const CREDENTIAL_URL = "/api/manage/admin/credential"; const CREDENTIAL_URL = "/api/manage/admin/credential";
@ -136,17 +138,38 @@ export interface LlmOverrideManager {
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>; setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
temperature: number | null; temperature: number | null;
setTemperature: React.Dispatch<React.SetStateAction<number | null>>; setTemperature: React.Dispatch<React.SetStateAction<number | null>>;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
} }
export function useLlmOverride(): LlmOverrideManager { export function useLlmOverride(
const [llmOverride, setLlmOverride] = useState<LlmOverride>({ currentChatSession?: ChatSession
): LlmOverrideManager {
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
currentChatSession && currentChatSession.current_alternate_model
? destructureValue(currentChatSession.current_alternate_model)
: {
name: "", name: "",
provider: "", provider: "",
modelName: "", modelName: "",
}); }
);
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
setLlmOverride(
chatSession && chatSession.current_alternate_model
? destructureValue(chatSession.current_alternate_model)
: {
name: "",
provider: "",
modelName: "",
}
);
};
const [temperature, setTemperature] = useState<number | null>(null); const [temperature, setTemperature] = useState<number | null>(null);
return { return {
updateModelOverrideForChatSession,
llmOverride, llmOverride,
setLlmOverride, setLlmOverride,
temperature, temperature,

View File

@ -43,3 +43,20 @@ export function checkLLMSupportsImageInput(provider: string, model: string) {
([p, m]) => p === provider && m === model ([p, m]) => p === provider && m === model
); );
} }
export const structureValue = (
name: string,
provider: string,
modelName: string
) => {
return `${name}__${provider}__${modelName}`;
};
export const destructureValue = (value: string): LlmOverride => {
const [displayName, provider, modelName] = value.split("__");
return {
name: displayName,
provider,
modelName,
};
};