mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-19 04:20:57 +02:00
Add functional thread modification endpoints (#1668)
Makes it so if you change which LLM you are using in a given ChatSession, that is persisted and sticks around if you reload the page / come back to the ChatSession later
This commit is contained in:
parent
5cafc96cae
commit
8178d536b4
@ -0,0 +1,31 @@
|
|||||||
|
"""Add thread specific model selection
|
||||||
|
|
||||||
|
Revision ID: 0568ccf46a6b
|
||||||
|
Revises: e209dc5a8156
|
||||||
|
Create Date: 2024-06-19 14:25:36.376046
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "0568ccf46a6b"
|
||||||
|
down_revision = "e209dc5a8156"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.add_column(
|
||||||
|
"chat_session",
|
||||||
|
sa.Column("current_alternate_model", sa.String(), nullable=True),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_column("chat_session", "current_alternate_model")
|
||||||
|
# ### end Alembic commands ###
|
@ -661,6 +661,8 @@ class ChatSession(Base):
|
|||||||
ForeignKey("chat_folder.id"), nullable=True
|
ForeignKey("chat_folder.id"), nullable=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
current_alternate_model: Mapped[str | None] = mapped_column(String, default=None)
|
||||||
|
|
||||||
# the latest "overrides" specified by the user. These take precedence over
|
# the latest "overrides" specified by the user. These take precedence over
|
||||||
# the attached persona. However, overrides specified directly in the
|
# the attached persona. However, overrides specified directly in the
|
||||||
# `send-message` call will take precedence over these.
|
# `send-message` call will take precedence over these.
|
||||||
|
@ -63,6 +63,7 @@ from danswer.server.query_and_chat.models import LLMOverride
|
|||||||
from danswer.server.query_and_chat.models import PromptOverride
|
from danswer.server.query_and_chat.models import PromptOverride
|
||||||
from danswer.server.query_and_chat.models import RenameChatSessionResponse
|
from danswer.server.query_and_chat.models import RenameChatSessionResponse
|
||||||
from danswer.server.query_and_chat.models import SearchFeedbackRequest
|
from danswer.server.query_and_chat.models import SearchFeedbackRequest
|
||||||
|
from danswer.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -77,9 +78,13 @@ def get_user_chat_sessions(
|
|||||||
) -> ChatSessionsResponse:
|
) -> ChatSessionsResponse:
|
||||||
user_id = user.id if user is not None else None
|
user_id = user.id if user is not None else None
|
||||||
|
|
||||||
chat_sessions = get_chat_sessions_by_user(
|
try:
|
||||||
user_id=user_id, deleted=False, db_session=db_session
|
chat_sessions = get_chat_sessions_by_user(
|
||||||
)
|
user_id=user_id, deleted=False, db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("Chat session does not exist or has been deleted")
|
||||||
|
|
||||||
return ChatSessionsResponse(
|
return ChatSessionsResponse(
|
||||||
sessions=[
|
sessions=[
|
||||||
@ -90,12 +95,30 @@ def get_user_chat_sessions(
|
|||||||
time_created=chat.time_created.isoformat(),
|
time_created=chat.time_created.isoformat(),
|
||||||
shared_status=chat.shared_status,
|
shared_status=chat.shared_status,
|
||||||
folder_id=chat.folder_id,
|
folder_id=chat.folder_id,
|
||||||
|
current_alternate_model=chat.current_alternate_model,
|
||||||
)
|
)
|
||||||
for chat in chat_sessions
|
for chat in chat_sessions
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/update-chat-session-model")
|
||||||
|
def update_chat_session_model(
|
||||||
|
update_thread_req: UpdateChatSessionThreadRequest,
|
||||||
|
user: User | None = Depends(current_user),
|
||||||
|
db_session: Session = Depends(get_session),
|
||||||
|
) -> None:
|
||||||
|
chat_session = get_chat_session_by_id(
|
||||||
|
chat_session_id=update_thread_req.chat_session_id,
|
||||||
|
user_id=user.id if user is not None else None,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
chat_session.current_alternate_model = update_thread_req.new_alternate_model
|
||||||
|
|
||||||
|
db_session.add(chat_session)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
|
||||||
@router.get("/get-chat-session/{session_id}")
|
@router.get("/get-chat-session/{session_id}")
|
||||||
def get_chat_session(
|
def get_chat_session(
|
||||||
session_id: int,
|
session_id: int,
|
||||||
@ -138,6 +161,7 @@ def get_chat_session(
|
|||||||
description=chat_session.description,
|
description=chat_session.description,
|
||||||
persona_id=chat_session.persona_id,
|
persona_id=chat_session.persona_id,
|
||||||
persona_name=chat_session.persona.name,
|
persona_name=chat_session.persona.name,
|
||||||
|
current_alternate_model=chat_session.current_alternate_model,
|
||||||
messages=[
|
messages=[
|
||||||
translate_db_message_to_chat_message_detail(
|
translate_db_message_to_chat_message_detail(
|
||||||
msg, remove_doc_content=is_shared # if shared, don't leak doc content
|
msg, remove_doc_content=is_shared # if shared, don't leak doc content
|
||||||
|
@ -32,6 +32,12 @@ class SimpleQueryRequest(BaseModel):
|
|||||||
query: str
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateChatSessionThreadRequest(BaseModel):
|
||||||
|
# If not specified, use Danswer default persona
|
||||||
|
chat_session_id: int
|
||||||
|
new_alternate_model: str
|
||||||
|
|
||||||
|
|
||||||
class ChatSessionCreationRequest(BaseModel):
|
class ChatSessionCreationRequest(BaseModel):
|
||||||
# If not specified, use Danswer default persona
|
# If not specified, use Danswer default persona
|
||||||
persona_id: int = 0
|
persona_id: int = 0
|
||||||
@ -142,6 +148,7 @@ class ChatSessionDetails(BaseModel):
|
|||||||
time_created: str
|
time_created: str
|
||||||
shared_status: ChatSessionSharedStatus
|
shared_status: ChatSessionSharedStatus
|
||||||
folder_id: int | None
|
folder_id: int | None
|
||||||
|
current_alternate_model: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatSessionsResponse(BaseModel):
|
class ChatSessionsResponse(BaseModel):
|
||||||
@ -193,6 +200,7 @@ class ChatSessionDetailResponse(BaseModel):
|
|||||||
messages: list[ChatMessageDetail]
|
messages: list[ChatMessageDetail]
|
||||||
time_created: datetime
|
time_created: datetime
|
||||||
shared_status: ChatSessionSharedStatus
|
shared_status: ChatSessionSharedStatus
|
||||||
|
current_alternate_model: str | None
|
||||||
|
|
||||||
|
|
||||||
class QueryValidationResponse(BaseModel):
|
class QueryValidationResponse(BaseModel):
|
||||||
|
@ -34,6 +34,7 @@ import {
|
|||||||
removeMessage,
|
removeMessage,
|
||||||
sendMessage,
|
sendMessage,
|
||||||
setMessageAsLatest,
|
setMessageAsLatest,
|
||||||
|
updateModelOverrideForChatSession,
|
||||||
updateParentChildren,
|
updateParentChildren,
|
||||||
uploadFilesForChat,
|
uploadFilesForChat,
|
||||||
} from "./lib";
|
} from "./lib";
|
||||||
@ -59,7 +60,12 @@ import { AnswerPiecePacket, DanswerDocument } from "@/lib/search/interfaces";
|
|||||||
import { buildFilters } from "@/lib/search/utils";
|
import { buildFilters } from "@/lib/search/utils";
|
||||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||||
import Dropzone from "react-dropzone";
|
import Dropzone from "react-dropzone";
|
||||||
import { checkLLMSupportsImageInput, getFinalLLM } from "@/lib/llm/utils";
|
import {
|
||||||
|
checkLLMSupportsImageInput,
|
||||||
|
destructureValue,
|
||||||
|
getFinalLLM,
|
||||||
|
structureValue,
|
||||||
|
} from "@/lib/llm/utils";
|
||||||
import { ChatInputBar } from "./input/ChatInputBar";
|
import { ChatInputBar } from "./input/ChatInputBar";
|
||||||
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
|
import { ConfigurationModal } from "./modal/configuration/ConfigurationModal";
|
||||||
import { useChatContext } from "@/components/context/ChatContext";
|
import { useChatContext } from "@/components/context/ChatContext";
|
||||||
@ -92,6 +98,7 @@ export function ChatPage({
|
|||||||
folders,
|
folders,
|
||||||
openedFolders,
|
openedFolders,
|
||||||
} = useChatContext();
|
} = useChatContext();
|
||||||
|
|
||||||
const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
|
const filteredAssistants = orderAssistantsForUser(availablePersonas, user);
|
||||||
|
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@ -104,6 +111,9 @@ export function ChatPage({
|
|||||||
const selectedChatSession = chatSessions.find(
|
const selectedChatSession = chatSessions.find(
|
||||||
(chatSession) => chatSession.id === existingChatSessionId
|
(chatSession) => chatSession.id === existingChatSessionId
|
||||||
);
|
);
|
||||||
|
|
||||||
|
const llmOverrideManager = useLlmOverride(selectedChatSession);
|
||||||
|
|
||||||
const existingChatSessionPersonaId = selectedChatSession?.persona_id;
|
const existingChatSessionPersonaId = selectedChatSession?.persona_id;
|
||||||
|
|
||||||
// used to track whether or not the initial "submit on load" has been performed
|
// used to track whether or not the initial "submit on load" has been performed
|
||||||
@ -124,25 +134,37 @@ export function ChatPage({
|
|||||||
// this is triggered every time the user switches which chat
|
// this is triggered every time the user switches which chat
|
||||||
// session they are using
|
// session they are using
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
|
if (
|
||||||
|
chatSessionId &&
|
||||||
|
!urlChatSessionId.current &&
|
||||||
|
llmOverrideManager.llmOverride
|
||||||
|
) {
|
||||||
|
updateModelOverrideForChatSession(
|
||||||
|
chatSessionId,
|
||||||
|
structureValue(
|
||||||
|
llmOverrideManager.llmOverride.name,
|
||||||
|
llmOverrideManager.llmOverride.provider,
|
||||||
|
llmOverrideManager.llmOverride.modelName
|
||||||
|
) as string
|
||||||
|
);
|
||||||
|
}
|
||||||
urlChatSessionId.current = existingChatSessionId;
|
urlChatSessionId.current = existingChatSessionId;
|
||||||
|
|
||||||
textAreaRef.current?.focus();
|
textAreaRef.current?.focus();
|
||||||
|
|
||||||
// only clear things if we're going from one chat session to another
|
// only clear things if we're going from one chat session to another
|
||||||
|
|
||||||
if (chatSessionId !== null && existingChatSessionId !== chatSessionId) {
|
if (chatSessionId !== null && existingChatSessionId !== chatSessionId) {
|
||||||
// de-select documents
|
// de-select documents
|
||||||
clearSelectedDocuments();
|
clearSelectedDocuments();
|
||||||
// reset all filters
|
// reset all filters
|
||||||
|
|
||||||
filterManager.setSelectedDocumentSets([]);
|
filterManager.setSelectedDocumentSets([]);
|
||||||
filterManager.setSelectedSources([]);
|
filterManager.setSelectedSources([]);
|
||||||
filterManager.setSelectedTags([]);
|
filterManager.setSelectedTags([]);
|
||||||
filterManager.setTimeRange(null);
|
filterManager.setTimeRange(null);
|
||||||
// reset LLM overrides
|
|
||||||
llmOverrideManager.setLlmOverride({
|
// reset LLM overrides (based on chat session!)
|
||||||
name: "",
|
llmOverrideManager.updateModelOverrideForChatSession(selectedChatSession);
|
||||||
provider: "",
|
|
||||||
modelName: "",
|
|
||||||
});
|
|
||||||
llmOverrideManager.setTemperature(null);
|
llmOverrideManager.setTemperature(null);
|
||||||
// remove uploaded files
|
// remove uploaded files
|
||||||
setCurrentMessageFiles([]);
|
setCurrentMessageFiles([]);
|
||||||
@ -177,7 +199,6 @@ export function ChatPage({
|
|||||||
submitOnLoadPerformed.current = true;
|
submitOnLoadPerformed.current = true;
|
||||||
await onSubmit();
|
await onSubmit();
|
||||||
}
|
}
|
||||||
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -186,6 +207,7 @@ export function ChatPage({
|
|||||||
`/api/chat/get-chat-session/${existingChatSessionId}`
|
`/api/chat/get-chat-session/${existingChatSessionId}`
|
||||||
);
|
);
|
||||||
const chatSession = (await response.json()) as BackendChatSession;
|
const chatSession = (await response.json()) as BackendChatSession;
|
||||||
|
|
||||||
setSelectedPersona(
|
setSelectedPersona(
|
||||||
filteredAssistants.find(
|
filteredAssistants.find(
|
||||||
(persona) => persona.id === chatSession.persona_id
|
(persona) => persona.id === chatSession.persona_id
|
||||||
@ -386,8 +408,6 @@ export function ChatPage({
|
|||||||
availableDocumentSets,
|
availableDocumentSets,
|
||||||
});
|
});
|
||||||
|
|
||||||
const llmOverrideManager = useLlmOverride();
|
|
||||||
|
|
||||||
// state for cancelling streaming
|
// state for cancelling streaming
|
||||||
const [isCancelled, setIsCancelled] = useState(false);
|
const [isCancelled, setIsCancelled] = useState(false);
|
||||||
const isCancelledRef = useRef(isCancelled);
|
const isCancelledRef = useRef(isCancelled);
|
||||||
@ -595,6 +615,7 @@ export function ChatPage({
|
|||||||
.map((document) => document.db_doc_id as number),
|
.map((document) => document.db_doc_id as number),
|
||||||
queryOverride,
|
queryOverride,
|
||||||
forceSearch,
|
forceSearch,
|
||||||
|
|
||||||
modelProvider: llmOverrideManager.llmOverride.name || undefined,
|
modelProvider: llmOverrideManager.llmOverride.name || undefined,
|
||||||
modelVersion:
|
modelVersion:
|
||||||
llmOverrideManager.llmOverride.modelName ||
|
llmOverrideManager.llmOverride.modelName ||
|
||||||
@ -893,6 +914,7 @@ export function ChatPage({
|
|||||||
)}
|
)}
|
||||||
|
|
||||||
<ConfigurationModal
|
<ConfigurationModal
|
||||||
|
chatSessionId={chatSessionId!}
|
||||||
activeTab={configModalActiveTab}
|
activeTab={configModalActiveTab}
|
||||||
setActiveTab={setConfigModalActiveTab}
|
setActiveTab={setConfigModalActiveTab}
|
||||||
onClose={() => setConfigModalActiveTab(null)}
|
onClose={() => setConfigModalActiveTab(null)}
|
||||||
@ -1044,7 +1066,9 @@ export function ChatPage({
|
|||||||
citedDocuments={getCitedDocumentsFromMessage(
|
citedDocuments={getCitedDocumentsFromMessage(
|
||||||
message
|
message
|
||||||
)}
|
)}
|
||||||
toolCall={message.toolCalls[0]}
|
toolCall={
|
||||||
|
message.toolCalls && message.toolCalls[0]
|
||||||
|
}
|
||||||
isComplete={
|
isComplete={
|
||||||
i !== messageHistory.length - 1 ||
|
i !== messageHistory.length - 1 ||
|
||||||
!isStreaming
|
!isStreaming
|
||||||
@ -1212,7 +1236,6 @@ export function ChatPage({
|
|||||||
)}
|
)}
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
|
|
||||||
<div ref={endDivRef} />
|
<div ref={endDivRef} />
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -163,6 +163,7 @@ export function ChatInputBar({
|
|||||||
icon={FaBrain}
|
icon={FaBrain}
|
||||||
onClick={() => setConfigModalActiveTab("assistants")}
|
onClick={() => setConfigModalActiveTab("assistants")}
|
||||||
/>
|
/>
|
||||||
|
|
||||||
<ChatInputOption
|
<ChatInputOption
|
||||||
name={
|
name={
|
||||||
llmOverrideManager.llmOverride.modelName ||
|
llmOverrideManager.llmOverride.modelName ||
|
||||||
|
@ -53,6 +53,7 @@ export interface ChatSession {
|
|||||||
time_created: string;
|
time_created: string;
|
||||||
shared_status: ChatSessionSharedStatus;
|
shared_status: ChatSessionSharedStatus;
|
||||||
folder_id: number | null;
|
folder_id: number | null;
|
||||||
|
current_alternate_model: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface Message {
|
export interface Message {
|
||||||
@ -79,6 +80,7 @@ export interface BackendChatSession {
|
|||||||
messages: BackendMessage[];
|
messages: BackendMessage[];
|
||||||
time_created: string;
|
time_created: string;
|
||||||
shared_status: ChatSessionSharedStatus;
|
shared_status: ChatSessionSharedStatus;
|
||||||
|
current_alternate_model?: string;
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface BackendMessage {
|
export interface BackendMessage {
|
||||||
|
@ -21,6 +21,23 @@ import { Persona } from "../admin/assistants/interfaces";
|
|||||||
import { ReadonlyURLSearchParams } from "next/navigation";
|
import { ReadonlyURLSearchParams } from "next/navigation";
|
||||||
import { SEARCH_PARAM_NAMES } from "./searchParams";
|
import { SEARCH_PARAM_NAMES } from "./searchParams";
|
||||||
|
|
||||||
|
export async function updateModelOverrideForChatSession(
|
||||||
|
chatSessionId: number,
|
||||||
|
newAlternateModel: string
|
||||||
|
) {
|
||||||
|
const response = await fetch("/api/chat/update-chat-session-model", {
|
||||||
|
method: "PUT",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
chat_session_id: chatSessionId,
|
||||||
|
new_alternate_model: newAlternateModel,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
export async function createChatSession(
|
export async function createChatSession(
|
||||||
personaId: number,
|
personaId: number,
|
||||||
description: string | null
|
description: string | null
|
||||||
|
@ -55,6 +55,7 @@ export function ConfigurationModal({
|
|||||||
filterManager,
|
filterManager,
|
||||||
llmProviders,
|
llmProviders,
|
||||||
llmOverrideManager,
|
llmOverrideManager,
|
||||||
|
chatSessionId,
|
||||||
}: {
|
}: {
|
||||||
activeTab: string | null;
|
activeTab: string | null;
|
||||||
setActiveTab: (tab: string | null) => void;
|
setActiveTab: (tab: string | null) => void;
|
||||||
@ -65,6 +66,7 @@ export function ConfigurationModal({
|
|||||||
filterManager: FilterManager;
|
filterManager: FilterManager;
|
||||||
llmProviders: LLMProviderDescriptor[];
|
llmProviders: LLMProviderDescriptor[];
|
||||||
llmOverrideManager: LlmOverrideManager;
|
llmOverrideManager: LlmOverrideManager;
|
||||||
|
chatSessionId?: number;
|
||||||
}) {
|
}) {
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const handleKeyDown = (event: KeyboardEvent) => {
|
const handleKeyDown = (event: KeyboardEvent) => {
|
||||||
@ -149,6 +151,7 @@ export function ConfigurationModal({
|
|||||||
|
|
||||||
{activeTab === "llms" && (
|
{activeTab === "llms" && (
|
||||||
<LlmTab
|
<LlmTab
|
||||||
|
chatSessionId={chatSessionId}
|
||||||
llmOverrideManager={llmOverrideManager}
|
llmOverrideManager={llmOverrideManager}
|
||||||
currentAssistant={selectedAssistant}
|
currentAssistant={selectedAssistant}
|
||||||
/>
|
/>
|
||||||
|
@ -5,14 +5,18 @@ import { debounce } from "lodash";
|
|||||||
import { DefaultDropdown } from "@/components/Dropdown";
|
import { DefaultDropdown } from "@/components/Dropdown";
|
||||||
import { Text } from "@tremor/react";
|
import { Text } from "@tremor/react";
|
||||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||||
import { getFinalLLM } from "@/lib/llm/utils";
|
import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
|
||||||
|
import { updateModelOverrideForChatSession } from "../../lib";
|
||||||
|
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||||
|
|
||||||
export function LlmTab({
|
export function LlmTab({
|
||||||
llmOverrideManager,
|
llmOverrideManager,
|
||||||
currentAssistant,
|
currentAssistant,
|
||||||
|
chatSessionId,
|
||||||
}: {
|
}: {
|
||||||
llmOverrideManager: LlmOverrideManager;
|
llmOverrideManager: LlmOverrideManager;
|
||||||
currentAssistant: Persona;
|
currentAssistant: Persona;
|
||||||
|
chatSessionId?: number;
|
||||||
}) {
|
}) {
|
||||||
const { llmProviders } = useChatContext();
|
const { llmProviders } = useChatContext();
|
||||||
const { llmOverride, setLlmOverride, temperature, setTemperature } =
|
const { llmOverride, setLlmOverride, temperature, setTemperature } =
|
||||||
@ -37,21 +41,6 @@ export function LlmTab({
|
|||||||
const [_, defaultLlmName] = getFinalLLM(llmProviders, currentAssistant, null);
|
const [_, defaultLlmName] = getFinalLLM(llmProviders, currentAssistant, null);
|
||||||
|
|
||||||
const llmOptions: { name: string; value: string }[] = [];
|
const llmOptions: { name: string; value: string }[] = [];
|
||||||
const structureValue = (
|
|
||||||
name: string,
|
|
||||||
provider: string,
|
|
||||||
modelName: string
|
|
||||||
) => {
|
|
||||||
return `${name}__${provider}__${modelName}`;
|
|
||||||
};
|
|
||||||
const destructureValue = (value: string): LlmOverride => {
|
|
||||||
const [displayName, provider, modelName] = value.split("__");
|
|
||||||
return {
|
|
||||||
name: displayName,
|
|
||||||
provider,
|
|
||||||
modelName,
|
|
||||||
};
|
|
||||||
};
|
|
||||||
llmProviders.forEach((llmProvider) => {
|
llmProviders.forEach((llmProvider) => {
|
||||||
llmProvider.model_names.forEach((modelName) => {
|
llmProvider.model_names.forEach((modelName) => {
|
||||||
llmOptions.push({
|
llmOptions.push({
|
||||||
@ -76,6 +65,7 @@ export function LlmTab({
|
|||||||
<Text className="mb-3">
|
<Text className="mb-3">
|
||||||
Default Model: <i className="font-medium">{defaultLlmName}</i>.
|
Default Model: <i className="font-medium">{defaultLlmName}</i>.
|
||||||
</Text>
|
</Text>
|
||||||
|
|
||||||
<div className="w-96">
|
<div className="w-96">
|
||||||
<DefaultDropdown
|
<DefaultDropdown
|
||||||
options={llmOptions}
|
options={llmOptions}
|
||||||
@ -84,9 +74,12 @@ export function LlmTab({
|
|||||||
llmOverride.provider,
|
llmOverride.provider,
|
||||||
llmOverride.modelName
|
llmOverride.modelName
|
||||||
)}
|
)}
|
||||||
onSelect={(value) =>
|
onSelect={(value) => {
|
||||||
setLlmOverride(destructureValue(value as string))
|
setLlmOverride(destructureValue(value as string));
|
||||||
}
|
if (chatSessionId) {
|
||||||
|
updateModelOverrideForChatSession(chatSessionId, value as string);
|
||||||
|
}
|
||||||
|
}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
|
@ -12,6 +12,8 @@ import { useState } from "react";
|
|||||||
import { DateRangePickerValue } from "@tremor/react";
|
import { DateRangePickerValue } from "@tremor/react";
|
||||||
import { SourceMetadata } from "./search/interfaces";
|
import { SourceMetadata } from "./search/interfaces";
|
||||||
import { EE_ENABLED } from "./constants";
|
import { EE_ENABLED } from "./constants";
|
||||||
|
import { destructureValue } from "./llm/utils";
|
||||||
|
import { ChatSession } from "@/app/chat/interfaces";
|
||||||
|
|
||||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||||
|
|
||||||
@ -136,17 +138,38 @@ export interface LlmOverrideManager {
|
|||||||
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
setLlmOverride: React.Dispatch<React.SetStateAction<LlmOverride>>;
|
||||||
temperature: number | null;
|
temperature: number | null;
|
||||||
setTemperature: React.Dispatch<React.SetStateAction<number | null>>;
|
setTemperature: React.Dispatch<React.SetStateAction<number | null>>;
|
||||||
|
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useLlmOverride(): LlmOverrideManager {
|
export function useLlmOverride(
|
||||||
const [llmOverride, setLlmOverride] = useState<LlmOverride>({
|
currentChatSession?: ChatSession
|
||||||
name: "",
|
): LlmOverrideManager {
|
||||||
provider: "",
|
const [llmOverride, setLlmOverride] = useState<LlmOverride>(
|
||||||
modelName: "",
|
currentChatSession && currentChatSession.current_alternate_model
|
||||||
});
|
? destructureValue(currentChatSession.current_alternate_model)
|
||||||
|
: {
|
||||||
|
name: "",
|
||||||
|
provider: "",
|
||||||
|
modelName: "",
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
const updateModelOverrideForChatSession = (chatSession?: ChatSession) => {
|
||||||
|
setLlmOverride(
|
||||||
|
chatSession && chatSession.current_alternate_model
|
||||||
|
? destructureValue(chatSession.current_alternate_model)
|
||||||
|
: {
|
||||||
|
name: "",
|
||||||
|
provider: "",
|
||||||
|
modelName: "",
|
||||||
|
}
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
const [temperature, setTemperature] = useState<number | null>(null);
|
const [temperature, setTemperature] = useState<number | null>(null);
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
updateModelOverrideForChatSession,
|
||||||
llmOverride,
|
llmOverride,
|
||||||
setLlmOverride,
|
setLlmOverride,
|
||||||
temperature,
|
temperature,
|
||||||
|
@ -43,3 +43,20 @@ export function checkLLMSupportsImageInput(provider: string, model: string) {
|
|||||||
([p, m]) => p === provider && m === model
|
([p, m]) => p === provider && m === model
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export const structureValue = (
|
||||||
|
name: string,
|
||||||
|
provider: string,
|
||||||
|
modelName: string
|
||||||
|
) => {
|
||||||
|
return `${name}__${provider}__${modelName}`;
|
||||||
|
};
|
||||||
|
|
||||||
|
export const destructureValue = (value: string): LlmOverride => {
|
||||||
|
const [displayName, provider, modelName] = value.split("__");
|
||||||
|
return {
|
||||||
|
name: displayName,
|
||||||
|
provider,
|
||||||
|
modelName,
|
||||||
|
};
|
||||||
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user