diff --git a/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py new file mode 100644 index 000000000..791d7e42e --- /dev/null +++ b/backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py @@ -0,0 +1,40 @@ +"""Add overrides to the chat session + +Revision ID: ecab2b3f1a3b +Revises: 38eda64af7fe +Create Date: 2024-04-01 19:08:21.359102 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "ecab2b3f1a3b" +down_revision = "38eda64af7fe" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "chat_session", + sa.Column( + "llm_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + op.add_column( + "chat_session", + sa.Column( + "prompt_override", + postgresql.JSONB(astext_type=sa.Text()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("chat_session", "prompt_override") + op.drop_column("chat_session", "llm_override") diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index d9e7f9b6c..f904f4963 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -95,6 +95,10 @@ def stream_chat_message_objects( # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, + # if specified, uses the last user message and does not create a new user message based + # on the `new_msg_req.message`. Currently, requires a state where the last message is a + # user message (e.g. this can only be used for the chat-seeding flow). + use_existing_user_message: bool = False, ) -> ChatPacketStream: """Streams in order: 1. [conditional] Retrieved documents if a search needs to be run @@ -161,33 +165,43 @@ def stream_chat_message_objects( else: parent_message = root_message - # Create new message at the right place in the tree and update the parent's child pointer - # Don't commit yet until we verify the chat message chain - new_user_message = create_new_chat_message( - chat_session_id=chat_session_id, - parent_message=parent_message, - prompt_id=prompt_id, - message=message_text, - token_count=len(llm_tokenizer_encode_func(message_text)), - message_type=MessageType.USER, - db_session=db_session, - commit=False, - ) - - # Create linear history of messages - final_msg, history_msgs = create_chat_chain( - chat_session_id=chat_session_id, db_session=db_session - ) - - if final_msg.id != new_user_message.id: - db_session.rollback() - raise RuntimeError( - "The new message was not on the mainline. " - "Be sure to update the chat pointers before calling this." + if not use_existing_user_message: + # Create new message at the right place in the tree and update the parent's child pointer + # Don't commit yet until we verify the chat message chain + user_message = create_new_chat_message( + chat_session_id=chat_session_id, + parent_message=parent_message, + prompt_id=prompt_id, + message=message_text, + token_count=len(llm_tokenizer_encode_func(message_text)), + message_type=MessageType.USER, + db_session=db_session, + commit=False, ) + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session + ) + if final_msg.id != user_message.id: + db_session.rollback() + raise RuntimeError( + "The new message was not on the mainline. " + "Be sure to update the chat pointers before calling this." + ) - # Save now to save the latest chat message - db_session.commit() + # Save now to save the latest chat message + db_session.commit() + else: + # re-create linear history of messages + final_msg, history_msgs = create_chat_chain( + chat_session_id=chat_session_id, db_session=db_session + ) + if final_msg.message_type != MessageType.USER: + raise RuntimeError( + "The last message was not a user message. Cannot call " + "`stream_chat_message_objects` with `is_regenerate=True` " + "when the last message is not a user message." + ) run_search = False # Retrieval options are only None if reference_doc_ids are provided @@ -304,7 +318,7 @@ def stream_chat_message_objects( partial_response = partial( create_new_chat_message, chat_session_id=chat_session_id, - parent_message=new_user_message, + parent_message=final_msg, prompt_id=prompt_id, # message=, rephrased_query=rephrased_query, @@ -346,10 +360,14 @@ def stream_chat_message_objects( document_pruning_config=document_pruning_config, ), prompt_config=PromptConfig.from_model( - final_msg.prompt, prompt_override=new_msg_req.prompt_override + final_msg.prompt, + prompt_override=( + new_msg_req.prompt_override or chat_session.prompt_override + ), ), llm_config=LLMConfig.from_persona( - persona, llm_override=new_msg_req.llm_override + persona, + llm_override=(new_msg_req.llm_override or chat_session.llm_override), ), message_history=[ PreviousMessage.from_chat_message(msg) for msg in history_msgs @@ -399,12 +417,14 @@ def stream_chat_message_objects( def stream_chat_message( new_msg_req: CreateChatMessageRequest, user: User | None, + use_existing_user_message: bool = False, ) -> Iterator[str]: with get_session_context_manager() as db_session: objects = stream_chat_message_objects( new_msg_req=new_msg_req, user=user, db_session=db_session, + use_existing_user_message=use_existing_user_message, ) for obj in objects: yield get_json_line(obj.dict()) diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index eb2d49b4c..738d02a16 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -28,6 +28,8 @@ from danswer.db.models import SearchDoc from danswer.db.models import SearchDoc as DBSearchDoc from danswer.db.models import StarterMessage from danswer.db.models import User__UserGroup +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.enums import RecencyBiasSetting from danswer.search.models import RetrievalDocs from danswer.search.models import SavedSearchDoc @@ -53,7 +55,9 @@ def get_chat_session_by_id( # if user_id is None, assume this is an admin who should be able # to view all chat sessions if user_id is not None: - stmt = stmt.where(ChatSession.user_id == user_id) + stmt = stmt.where( + or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None)) + ) result = db_session.execute(stmt) chat_session = result.scalar_one_or_none() @@ -92,12 +96,16 @@ def create_chat_session( description: str, user_id: UUID | None, persona_id: int | None = None, + llm_override: LLMOverride | None = None, + prompt_override: PromptOverride | None = None, one_shot: bool = False, ) -> ChatSession: chat_session = ChatSession( user_id=user_id, persona_id=persona_id, description=description, + llm_override=llm_override, + prompt_override=prompt_override, one_shot=one_shot, ) diff --git a/backend/danswer/db/enums.py b/backend/danswer/db/enums.py new file mode 100644 index 000000000..2a02e078c --- /dev/null +++ b/backend/danswer/db/enums.py @@ -0,0 +1,35 @@ +from enum import Enum as PyEnum + + +class IndexingStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# these may differ in the future, which is why we're okay with this duplication +class DeletionStatus(str, PyEnum): + NOT_STARTED = "not_started" + IN_PROGRESS = "in_progress" + SUCCESS = "success" + FAILED = "failed" + + +# Consistent with Celery task statuses +class TaskStatus(str, PyEnum): + PENDING = "PENDING" + STARTED = "STARTED" + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + + +class IndexModelStatus(str, PyEnum): + PAST = "PAST" + PRESENT = "PRESENT" + FUTURE = "FUTURE" + + +class ChatSessionSharedStatus(str, PyEnum): + PUBLIC = "public" + PRIVATE = "private" diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 4a44882f2..7fb6bbaa7 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -35,45 +35,18 @@ from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType +from danswer.db.enums import ChatSessionSharedStatus +from danswer.db.enums import IndexingStatus +from danswer.db.enums import IndexModelStatus +from danswer.db.enums import TaskStatus +from danswer.db.pydantic_type import PydanticType from danswer.dynamic_configs.interface import JSON_ro +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import SearchType -class IndexingStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# these may differ in the future, which is why we're okay with this duplication -class DeletionStatus(str, PyEnum): - NOT_STARTED = "not_started" - IN_PROGRESS = "in_progress" - SUCCESS = "success" - FAILED = "failed" - - -# Consistent with Celery task statuses -class TaskStatus(str, PyEnum): - PENDING = "PENDING" - STARTED = "STARTED" - SUCCESS = "SUCCESS" - FAILURE = "FAILURE" - - -class IndexModelStatus(str, PyEnum): - PAST = "PAST" - PRESENT = "PRESENT" - FUTURE = "FUTURE" - - -class ChatSessionSharedStatus(str, PyEnum): - PUBLIC = "public" - PRIVATE = "private" - - class Base(DeclarativeBase): pass @@ -596,6 +569,20 @@ class ChatSession(Base): Enum(ChatSessionSharedStatus, native_enum=False), default=ChatSessionSharedStatus.PRIVATE, ) + + # the latest "overrides" specified by the user. These take precedence over + # the attached persona. However, overrides specified directly in the + # `send-message` call will take precedence over these. + # NOTE: currently only used by the chat seeding flow, will be used in the + # future once we allow users to override default values via the Chat UI + # itself + llm_override: Mapped[LLMOverride | None] = mapped_column( + PydanticType(LLMOverride), nullable=True + ) + prompt_override: Mapped[PromptOverride | None] = mapped_column( + PydanticType(PromptOverride), nullable=True + ) + time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), diff --git a/backend/danswer/db/pydantic_type.py b/backend/danswer/db/pydantic_type.py new file mode 100644 index 000000000..1f37152a8 --- /dev/null +++ b/backend/danswer/db/pydantic_type.py @@ -0,0 +1,32 @@ +import json +from typing import Any +from typing import Optional +from typing import Type + +from pydantic import BaseModel +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.types import TypeDecorator + + +class PydanticType(TypeDecorator): + impl = JSONB + + def __init__( + self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any + ) -> None: + super().__init__(*args, **kwargs) + self.pydantic_model = pydantic_model + + def process_bind_param( + self, value: Optional[BaseModel], dialect: Any + ) -> Optional[dict]: + if value is not None: + return json.loads(value.json()) + return None + + def process_result_value( + self, value: Optional[dict], dialect: Any + ) -> Optional[BaseModel]: + if value is not None: + return self.pydantic_model.parse_obj(value) + return None diff --git a/backend/danswer/llm/answering/models.py b/backend/danswer/llm/answering/models.py index f7f2bbad9..71ea66661 100644 --- a/backend/danswer/llm/answering/models.py +++ b/backend/danswer/llm/answering/models.py @@ -10,9 +10,9 @@ from pydantic import root_validator from danswer.chat.models import AnswerQuestionStreamReturn from danswer.configs.constants import MessageType from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.llm.utils import get_default_llm_version -from danswer.server.query_and_chat.models import LLMOverride -from danswer.server.query_and_chat.models import PromptOverride if TYPE_CHECKING: from danswer.db.models import ChatMessage diff --git a/backend/danswer/llm/override_models.py b/backend/danswer/llm/override_models.py new file mode 100644 index 000000000..1ecb3192f --- /dev/null +++ b/backend/danswer/llm/override_models.py @@ -0,0 +1,17 @@ +"""Overrides sent over the wire / stored in the DB + +NOTE: these models are used in many places, so have to be +kepy in a separate file to avoid circular imports. +""" +from pydantic import BaseModel + + +class LLMOverride(BaseModel): + model_provider: str | None = None + model_version: str | None = None + temperature: float | None = None + + +class PromptOverride(BaseModel): + system_prompt: str | None = None + task_prompt: str | None = None diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 19b7db23b..52d6414e2 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -8,12 +8,16 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_user from danswer.chat.chat_utils import create_chat_chain from danswer.chat.process_message import stream_chat_message +from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.constants import MessageType from danswer.db.chat import create_chat_session +from danswer.db.chat import create_new_chat_message from danswer.db.chat import delete_chat_session from danswer.db.chat import get_chat_message from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_session_by_id from danswer.db.chat import get_chat_sessions_by_user +from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_persona_by_id from danswer.db.chat import set_as_latest_chat_message from danswer.db.chat import translate_db_message_to_chat_message_detail @@ -27,6 +31,7 @@ from danswer.document_index.factory import get_default_document_index from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) +from danswer.llm.utils import get_default_llm_tokenizer from danswer.secondary_llm_flows.chat_session_naming import ( get_renamed_conversation_name, ) @@ -40,6 +45,8 @@ from danswer.server.query_and_chat.models import ChatSessionsResponse from danswer.server.query_and_chat.models import ChatSessionUpdateRequest from danswer.server.query_and_chat.models import CreateChatMessageRequest from danswer.server.query_and_chat.models import CreateChatSessionID +from danswer.server.query_and_chat.models import LLMOverride +from danswer.server.query_and_chat.models import PromptOverride from danswer.server.query_and_chat.models import RenameChatSessionResponse from danswer.server.query_and_chat.models import SearchFeedbackRequest from danswer.utils.logger import setup_logger @@ -93,6 +100,13 @@ def get_chat_session( except ValueError: raise ValueError("Chat session does not exist or has been deleted") + # for chat-seeding: if the session is unassigned, assign it now. This is done here + # to avoid another back and forth between FE -> BE before starting the first + # message generation + if chat_session.user_id is None and user_id is not None: + chat_session.user_id = user_id + db_session.commit() + session_messages = get_chat_messages_by_session( chat_session_id=session_id, user_id=user_id, db_session=db_session ) @@ -209,15 +223,24 @@ def handle_new_chat_message( - Sending a new message in the session - Regenerating a message in the session (just send the same one again) - Editing a message (similar to regenerating but sending a different message) + - Kicking off a seeded chat session (set `use_existing_user_message`) To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path have already been set as latest""" logger.info(f"Received new chat message: {chat_message_req.message}") - if not chat_message_req.message and chat_message_req.prompt_id is not None: + if ( + not chat_message_req.message + and chat_message_req.prompt_id is not None + and not chat_message_req.use_existing_user_message + ): raise HTTPException(status_code=400, detail="Empty chat message is invalid") - packets = stream_chat_message(new_msg_req=chat_message_req, user=user) + packets = stream_chat_message( + new_msg_req=chat_message_req, + user=user, + use_existing_user_message=chat_message_req.use_existing_user_message, + ) return StreamingResponse(packets, media_type="application/json") @@ -308,3 +331,71 @@ def get_max_document_tokens( return MaxSelectedDocumentTokens( max_tokens=compute_max_document_tokens_for_persona(persona), ) + + +"""Endpoints for chat seeding""" + + +class ChatSeedRequest(BaseModel): + # standard chat session stuff + persona_id: int + prompt_id: int | None = None + + # overrides / seeding + llm_override: LLMOverride | None = None + prompt_override: PromptOverride | None = None + description: str | None = None + message: str | None = None + + # TODO: support this + # initial_message_retrieval_options: RetrievalDetails | None = None + + +class ChatSeedResponse(BaseModel): + redirect_url: str + + +@router.post("/seed-chat-session") +def seed_chat( + chat_seed_request: ChatSeedRequest, + # NOTE: realistically, this will be an API key not an actual user + _: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> ChatSeedResponse: + try: + new_chat_session = create_chat_session( + db_session=db_session, + description=chat_seed_request.description or "", + user_id=None, # this chat session is "unassigned" until a user visits the web UI + persona_id=chat_seed_request.persona_id, + llm_override=chat_seed_request.llm_override, + prompt_override=chat_seed_request.prompt_override, + ) + except Exception as e: + logger.exception(e) + raise HTTPException(status_code=400, detail="Invalid Persona provided.") + + if chat_seed_request.message is not None: + root_message = get_or_create_root_message( + chat_session_id=new_chat_session.id, db_session=db_session + ) + create_new_chat_message( + chat_session_id=new_chat_session.id, + parent_message=root_message, + prompt_id=chat_seed_request.prompt_id + or ( + new_chat_session.persona.prompts[0].id + if new_chat_session.persona.prompts + else None + ), + message=chat_seed_request.message, + token_count=len( + get_default_llm_tokenizer().encode(chat_seed_request.message) + ), + message_type=MessageType.USER, + db_session=db_session, + ) + + return ChatSeedResponse( + redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}" + ) diff --git a/backend/danswer/server/query_and_chat/models.py b/backend/danswer/server/query_and_chat/models.py index 8f8681477..90be759ad 100644 --- a/backend/danswer/server/query_and_chat/models.py +++ b/backend/danswer/server/query_and_chat/models.py @@ -8,7 +8,9 @@ from danswer.chat.models import RetrievalDocs from danswer.configs.constants import DocumentSource from danswer.configs.constants import MessageType from danswer.configs.constants import SearchFeedbackType -from danswer.db.models import ChatSessionSharedStatus +from danswer.db.enums import ChatSessionSharedStatus +from danswer.llm.override_models import LLMOverride +from danswer.llm.override_models import PromptOverride from danswer.search.models import BaseFilters from danswer.search.models import RetrievalDetails from danswer.search.models import SearchDoc @@ -70,17 +72,6 @@ class DocumentSearchRequest(BaseModel): skip_llm_chunk_filter: bool | None = None -class LLMOverride(BaseModel): - model_provider: str | None = None - model_version: str | None = None - temperature: float | None = None - - -class PromptOverride(BaseModel): - system_prompt: str | None = None - task_prompt: str | None = None - - """ Currently the different branches are generated by changing the search query @@ -116,6 +107,9 @@ class CreateChatMessageRequest(BaseModel): llm_override: LLMOverride | None = None prompt_override: PromptOverride | None = None + # used for seeded chats to kick off the generation of an AI answer + use_existing_user_message: bool = False + @root_validator def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict: search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get( diff --git a/web/src/app/chat/Chat.tsx b/web/src/app/chat/Chat.tsx index cd2f1c4da..7c174f18d 100644 --- a/web/src/app/chat/Chat.tsx +++ b/web/src/app/chat/Chat.tsx @@ -133,7 +133,7 @@ export const Chat = ({ !submitOnLoadPerformed.current ) { submitOnLoadPerformed.current = true; - onSubmit(); + await onSubmit(); } return; @@ -162,6 +162,21 @@ export const Chat = ({ setChatSessionSharedStatus(chatSession.shared_status); setIsFetchingChatMessages(false); + + // if this is a seeded chat, then kick off the AI message generation + if (newMessageHistory.length === 1 && !submitOnLoadPerformed.current) { + submitOnLoadPerformed.current = true; + const seededMessage = newMessageHistory[0].message; + await onSubmit({ + isSeededChat: true, + messageOverride: seededMessage, + }); + // force re-name if the chat session doesn't have one + if (!chatSession.description) { + await nameChatSession(existingChatSessionId, seededMessage); + router.refresh(); // need to refresh to update name on sidebar + } + } } initialSessionFetch(); @@ -326,11 +341,13 @@ export const Chat = ({ messageOverride, queryOverride, forceSearch, + isSeededChat, }: { messageIdToResend?: number; messageOverride?: string; queryOverride?: string; forceSearch?: boolean; + isSeededChat?: boolean; } = {}) => { let currChatSessionId: number; let isNewSession = chatSessionId === null; @@ -419,6 +436,7 @@ export const Chat = ({ undefined, systemPromptOverride: searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined, + useExistingUserMessage: isSeededChat, })) { for (const packet of packetBunch) { if (Object.hasOwn(packet, "answer_piece")) { diff --git a/web/src/app/chat/lib.tsx b/web/src/app/chat/lib.tsx index c5195abf8..29a90526c 100644 --- a/web/src/app/chat/lib.tsx +++ b/web/src/app/chat/lib.tsx @@ -57,6 +57,7 @@ export async function* sendMessage({ modelVersion, temperature, systemPromptOverride, + useExistingUserMessage, }: { message: string; parentMessageId: number | null; @@ -71,6 +72,9 @@ export async function* sendMessage({ temperature?: number; // prompt overrides systemPromptOverride?: string; + // if specified, will use the existing latest user message + // and will ignore the specified `message` + useExistingUserMessage?: boolean; }) { const documentsAreSelected = selectedDocumentIds && selectedDocumentIds.length > 0; @@ -99,13 +103,19 @@ export async function* sendMessage({ } : null, query_override: queryOverride, - prompt_override: { - system_prompt: systemPromptOverride, - }, - llm_override: { - temperature, - model_version: modelVersion, - }, + prompt_override: systemPromptOverride + ? { + system_prompt: systemPromptOverride, + } + : null, + llm_override: + temperature || modelVersion + ? { + temperature, + model_version: modelVersion, + } + : null, + use_existing_user_message: useExistingUserMessage, }), }); if (!sendMessageResponse.ok) {