Allow seeding of chat sessions via POST

This commit is contained in:
Weves 2024-04-02 00:14:05 -07:00 committed by Chris Weaver
parent 33da86c802
commit 7ba7224929
12 changed files with 339 additions and 87 deletions

View File

@ -0,0 +1,40 @@
"""Add overrides to the chat session
Revision ID: ecab2b3f1a3b
Revises: 38eda64af7fe
Create Date: 2024-04-01 19:08:21.359102
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "ecab2b3f1a3b"
down_revision = "38eda64af7fe"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session",
sa.Column(
"llm_override",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
op.add_column(
"chat_session",
sa.Column(
"prompt_override",
postgresql.JSONB(astext_type=sa.Text()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("chat_session", "prompt_override")
op.drop_column("chat_session", "llm_override")

View File

@ -95,6 +95,10 @@ def stream_chat_message_objects(
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@ -161,33 +165,43 @@ def stream_chat_message_objects(
else:
parent_message = root_message
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
new_user_message = create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=parent_message,
prompt_id=prompt_id,
message=message_text,
token_count=len(llm_tokenizer_encode_func(message_text)),
message_type=MessageType.USER,
db_session=db_session,
commit=False,
)
# Create linear history of messages
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.id != new_user_message.id:
db_session.rollback()
raise RuntimeError(
"The new message was not on the mainline. "
"Be sure to update the chat pointers before calling this."
if not use_existing_user_message:
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
user_message = create_new_chat_message(
chat_session_id=chat_session_id,
parent_message=parent_message,
prompt_id=prompt_id,
message=message_text,
token_count=len(llm_tokenizer_encode_func(message_text)),
message_type=MessageType.USER,
db_session=db_session,
commit=False,
)
# re-create linear history of messages
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.id != user_message.id:
db_session.rollback()
raise RuntimeError(
"The new message was not on the mainline. "
"Be sure to update the chat pointers before calling this."
)
# Save now to save the latest chat message
db_session.commit()
# Save now to save the latest chat message
db_session.commit()
else:
# re-create linear history of messages
final_msg, history_msgs = create_chat_chain(
chat_session_id=chat_session_id, db_session=db_session
)
if final_msg.message_type != MessageType.USER:
raise RuntimeError(
"The last message was not a user message. Cannot call "
"`stream_chat_message_objects` with `is_regenerate=True` "
"when the last message is not a user message."
)
run_search = False
# Retrieval options are only None if reference_doc_ids are provided
@ -304,7 +318,7 @@ def stream_chat_message_objects(
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=new_user_message,
parent_message=final_msg,
prompt_id=prompt_id,
# message=,
rephrased_query=rephrased_query,
@ -346,10 +360,14 @@ def stream_chat_message_objects(
document_pruning_config=document_pruning_config,
),
prompt_config=PromptConfig.from_model(
final_msg.prompt, prompt_override=new_msg_req.prompt_override
final_msg.prompt,
prompt_override=(
new_msg_req.prompt_override or chat_session.prompt_override
),
),
llm_config=LLMConfig.from_persona(
persona, llm_override=new_msg_req.llm_override
persona,
llm_override=(new_msg_req.llm_override or chat_session.llm_override),
),
message_history=[
PreviousMessage.from_chat_message(msg) for msg in history_msgs
@ -399,12 +417,14 @@ def stream_chat_message_objects(
def stream_chat_message(
new_msg_req: CreateChatMessageRequest,
user: User | None,
use_existing_user_message: bool = False,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
new_msg_req=new_msg_req,
user=user,
db_session=db_session,
use_existing_user_message=use_existing_user_message,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@ -28,6 +28,8 @@ from danswer.db.models import SearchDoc
from danswer.db.models import SearchDoc as DBSearchDoc
from danswer.db.models import StarterMessage
from danswer.db.models import User__UserGroup
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
@ -53,7 +55,9 @@ def get_chat_session_by_id(
# if user_id is None, assume this is an admin who should be able
# to view all chat sessions
if user_id is not None:
stmt = stmt.where(ChatSession.user_id == user_id)
stmt = stmt.where(
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
)
result = db_session.execute(stmt)
chat_session = result.scalar_one_or_none()
@ -92,12 +96,16 @@ def create_chat_session(
description: str,
user_id: UUID | None,
persona_id: int | None = None,
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
one_shot: bool = False,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
persona_id=persona_id,
description=description,
llm_override=llm_override,
prompt_override=prompt_override,
one_shot=one_shot,
)

View File

@ -0,0 +1,35 @@
from enum import Enum as PyEnum
class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
# Consistent with Celery task statuses
class TaskStatus(str, PyEnum):
PENDING = "PENDING"
STARTED = "STARTED"
SUCCESS = "SUCCESS"
FAILURE = "FAILURE"
class IndexModelStatus(str, PyEnum):
PAST = "PAST"
PRESENT = "PRESENT"
FUTURE = "FUTURE"
class ChatSessionSharedStatus(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"

View File

@ -35,45 +35,18 @@ from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import InputType
from danswer.db.enums import ChatSessionSharedStatus
from danswer.db.enums import IndexingStatus
from danswer.db.enums import IndexModelStatus
from danswer.db.enums import TaskStatus
from danswer.db.pydantic_type import PydanticType
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
class IndexingStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
# these may differ in the future, which is why we're okay with this duplication
class DeletionStatus(str, PyEnum):
NOT_STARTED = "not_started"
IN_PROGRESS = "in_progress"
SUCCESS = "success"
FAILED = "failed"
# Consistent with Celery task statuses
class TaskStatus(str, PyEnum):
PENDING = "PENDING"
STARTED = "STARTED"
SUCCESS = "SUCCESS"
FAILURE = "FAILURE"
class IndexModelStatus(str, PyEnum):
PAST = "PAST"
PRESENT = "PRESENT"
FUTURE = "FUTURE"
class ChatSessionSharedStatus(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
class Base(DeclarativeBase):
pass
@ -596,6 +569,20 @@ class ChatSession(Base):
Enum(ChatSessionSharedStatus, native_enum=False),
default=ChatSessionSharedStatus.PRIVATE,
)
# the latest "overrides" specified by the user. These take precedence over
# the attached persona. However, overrides specified directly in the
# `send-message` call will take precedence over these.
# NOTE: currently only used by the chat seeding flow, will be used in the
# future once we allow users to override default values via the Chat UI
# itself
llm_override: Mapped[LLMOverride | None] = mapped_column(
PydanticType(LLMOverride), nullable=True
)
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),

View File

@ -0,0 +1,32 @@
import json
from typing import Any
from typing import Optional
from typing import Type
from pydantic import BaseModel
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import TypeDecorator
class PydanticType(TypeDecorator):
impl = JSONB
def __init__(
self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self.pydantic_model = pydantic_model
def process_bind_param(
self, value: Optional[BaseModel], dialect: Any
) -> Optional[dict]:
if value is not None:
return json.loads(value.json())
return None
def process_result_value(
self, value: Optional[dict], dialect: Any
) -> Optional[BaseModel]:
if value is not None:
return self.pydantic_model.parse_obj(value)
return None

View File

@ -10,9 +10,9 @@ from pydantic import root_validator
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.llm.utils import get_default_llm_version
from danswer.server.query_and_chat.models import LLMOverride
from danswer.server.query_and_chat.models import PromptOverride
if TYPE_CHECKING:
from danswer.db.models import ChatMessage

View File

@ -0,0 +1,17 @@
"""Overrides sent over the wire / stored in the DB
NOTE: these models are used in many places, so have to be
kepy in a separate file to avoid circular imports.
"""
from pydantic import BaseModel
class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
class PromptOverride(BaseModel):
system_prompt: str | None = None
task_prompt: str | None = None

View File

@ -8,12 +8,16 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.process_message import stream_chat_message
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.constants import MessageType
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
from danswer.db.chat import get_chat_message
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_chat_sessions_by_user
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import get_persona_by_id
from danswer.db.chat import set_as_latest_chat_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
@ -27,6 +31,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
get_renamed_conversation_name,
)
@ -40,6 +45,8 @@ from danswer.server.query_and_chat.models import ChatSessionsResponse
from danswer.server.query_and_chat.models import ChatSessionUpdateRequest
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.query_and_chat.models import CreateChatSessionID
from danswer.server.query_and_chat.models import LLMOverride
from danswer.server.query_and_chat.models import PromptOverride
from danswer.server.query_and_chat.models import RenameChatSessionResponse
from danswer.server.query_and_chat.models import SearchFeedbackRequest
from danswer.utils.logger import setup_logger
@ -93,6 +100,13 @@ def get_chat_session(
except ValueError:
raise ValueError("Chat session does not exist or has been deleted")
# for chat-seeding: if the session is unassigned, assign it now. This is done here
# to avoid another back and forth between FE -> BE before starting the first
# message generation
if chat_session.user_id is None and user_id is not None:
chat_session.user_id = user_id
db_session.commit()
session_messages = get_chat_messages_by_session(
chat_session_id=session_id, user_id=user_id, db_session=db_session
)
@ -209,15 +223,24 @@ def handle_new_chat_message(
- Sending a new message in the session
- Regenerating a message in the session (just send the same one again)
- Editing a message (similar to regenerating but sending a different message)
- Kicking off a seeded chat session (set `use_existing_user_message`)
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
have already been set as latest"""
logger.info(f"Received new chat message: {chat_message_req.message}")
if not chat_message_req.message and chat_message_req.prompt_id is not None:
if (
not chat_message_req.message
and chat_message_req.prompt_id is not None
and not chat_message_req.use_existing_user_message
):
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
packets = stream_chat_message(new_msg_req=chat_message_req, user=user)
packets = stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
)
return StreamingResponse(packets, media_type="application/json")
@ -308,3 +331,71 @@ def get_max_document_tokens(
return MaxSelectedDocumentTokens(
max_tokens=compute_max_document_tokens_for_persona(persona),
)
"""Endpoints for chat seeding"""
class ChatSeedRequest(BaseModel):
# standard chat session stuff
persona_id: int
prompt_id: int | None = None
# overrides / seeding
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
description: str | None = None
message: str | None = None
# TODO: support this
# initial_message_retrieval_options: RetrievalDetails | None = None
class ChatSeedResponse(BaseModel):
redirect_url: str
@router.post("/seed-chat-session")
def seed_chat(
chat_seed_request: ChatSeedRequest,
# NOTE: realistically, this will be an API key not an actual user
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSeedResponse:
try:
new_chat_session = create_chat_session(
db_session=db_session,
description=chat_seed_request.description or "",
user_id=None, # this chat session is "unassigned" until a user visits the web UI
persona_id=chat_seed_request.persona_id,
llm_override=chat_seed_request.llm_override,
prompt_override=chat_seed_request.prompt_override,
)
except Exception as e:
logger.exception(e)
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
if chat_seed_request.message is not None:
root_message = get_or_create_root_message(
chat_session_id=new_chat_session.id, db_session=db_session
)
create_new_chat_message(
chat_session_id=new_chat_session.id,
parent_message=root_message,
prompt_id=chat_seed_request.prompt_id
or (
new_chat_session.persona.prompts[0].id
if new_chat_session.persona.prompts
else None
),
message=chat_seed_request.message,
token_count=len(
get_default_llm_tokenizer().encode(chat_seed_request.message)
),
message_type=MessageType.USER,
db_session=db_session,
)
return ChatSeedResponse(
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
)

View File

@ -8,7 +8,9 @@ from danswer.chat.models import RetrievalDocs
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import SearchFeedbackType
from danswer.db.models import ChatSessionSharedStatus
from danswer.db.enums import ChatSessionSharedStatus
from danswer.llm.override_models import LLMOverride
from danswer.llm.override_models import PromptOverride
from danswer.search.models import BaseFilters
from danswer.search.models import RetrievalDetails
from danswer.search.models import SearchDoc
@ -70,17 +72,6 @@ class DocumentSearchRequest(BaseModel):
skip_llm_chunk_filter: bool | None = None
class LLMOverride(BaseModel):
model_provider: str | None = None
model_version: str | None = None
temperature: float | None = None
class PromptOverride(BaseModel):
system_prompt: str | None = None
task_prompt: str | None = None
"""
Currently the different branches are generated by changing the search query
@ -116,6 +107,9 @@ class CreateChatMessageRequest(BaseModel):
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
# used for seeded chats to kick off the generation of an AI answer
use_existing_user_message: bool = False
@root_validator
def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict:
search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get(

View File

@ -133,7 +133,7 @@ export const Chat = ({
!submitOnLoadPerformed.current
) {
submitOnLoadPerformed.current = true;
onSubmit();
await onSubmit();
}
return;
@ -162,6 +162,21 @@ export const Chat = ({
setChatSessionSharedStatus(chatSession.shared_status);
setIsFetchingChatMessages(false);
// if this is a seeded chat, then kick off the AI message generation
if (newMessageHistory.length === 1 && !submitOnLoadPerformed.current) {
submitOnLoadPerformed.current = true;
const seededMessage = newMessageHistory[0].message;
await onSubmit({
isSeededChat: true,
messageOverride: seededMessage,
});
// force re-name if the chat session doesn't have one
if (!chatSession.description) {
await nameChatSession(existingChatSessionId, seededMessage);
router.refresh(); // need to refresh to update name on sidebar
}
}
}
initialSessionFetch();
@ -326,11 +341,13 @@ export const Chat = ({
messageOverride,
queryOverride,
forceSearch,
isSeededChat,
}: {
messageIdToResend?: number;
messageOverride?: string;
queryOverride?: string;
forceSearch?: boolean;
isSeededChat?: boolean;
} = {}) => {
let currChatSessionId: number;
let isNewSession = chatSessionId === null;
@ -419,6 +436,7 @@ export const Chat = ({
undefined,
systemPromptOverride:
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
useExistingUserMessage: isSeededChat,
})) {
for (const packet of packetBunch) {
if (Object.hasOwn(packet, "answer_piece")) {

View File

@ -57,6 +57,7 @@ export async function* sendMessage({
modelVersion,
temperature,
systemPromptOverride,
useExistingUserMessage,
}: {
message: string;
parentMessageId: number | null;
@ -71,6 +72,9 @@ export async function* sendMessage({
temperature?: number;
// prompt overrides
systemPromptOverride?: string;
// if specified, will use the existing latest user message
// and will ignore the specified `message`
useExistingUserMessage?: boolean;
}) {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
@ -99,13 +103,19 @@ export async function* sendMessage({
}
: null,
query_override: queryOverride,
prompt_override: {
system_prompt: systemPromptOverride,
},
llm_override: {
temperature,
model_version: modelVersion,
},
prompt_override: systemPromptOverride
? {
system_prompt: systemPromptOverride,
}
: null,
llm_override:
temperature || modelVersion
? {
temperature,
model_version: modelVersion,
}
: null,
use_existing_user_message: useExistingUserMessage,
}),
});
if (!sendMessageResponse.ok) {