mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-01 00:18:18 +02:00
Allow seeding of chat sessions via POST
This commit is contained in:
parent
33da86c802
commit
7ba7224929
@ -0,0 +1,40 @@
|
||||
"""Add overrides to the chat session
|
||||
|
||||
Revision ID: ecab2b3f1a3b
|
||||
Revises: 38eda64af7fe
|
||||
Create Date: 2024-04-01 19:08:21.359102
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ecab2b3f1a3b"
|
||||
down_revision = "38eda64af7fe"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"llm_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column(
|
||||
"prompt_override",
|
||||
postgresql.JSONB(astext_type=sa.Text()),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "prompt_override")
|
||||
op.drop_column("chat_session", "llm_override")
|
@ -95,6 +95,10 @@ def stream_chat_message_objects(
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
# user message (e.g. this can only be used for the chat-seeding flow).
|
||||
use_existing_user_message: bool = False,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@ -161,33 +165,43 @@ def stream_chat_message_objects(
|
||||
else:
|
||||
parent_message = root_message
|
||||
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
prompt_id=prompt_id,
|
||||
message=message_text,
|
||||
token_count=len(llm_tokenizer_encode_func(message_text)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
# Create linear history of messages
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
|
||||
if final_msg.id != new_user_message.id:
|
||||
db_session.rollback()
|
||||
raise RuntimeError(
|
||||
"The new message was not on the mainline. "
|
||||
"Be sure to update the chat pointers before calling this."
|
||||
if not use_existing_user_message:
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
prompt_id=prompt_id,
|
||||
message=message_text,
|
||||
token_count=len(llm_tokenizer_encode_func(message_text)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
# re-create linear history of messages
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if final_msg.id != user_message.id:
|
||||
db_session.rollback()
|
||||
raise RuntimeError(
|
||||
"The new message was not on the mainline. "
|
||||
"Be sure to update the chat pointers before calling this."
|
||||
)
|
||||
|
||||
# Save now to save the latest chat message
|
||||
db_session.commit()
|
||||
# Save now to save the latest chat message
|
||||
db_session.commit()
|
||||
else:
|
||||
# re-create linear history of messages
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
chat_session_id=chat_session_id, db_session=db_session
|
||||
)
|
||||
if final_msg.message_type != MessageType.USER:
|
||||
raise RuntimeError(
|
||||
"The last message was not a user message. Cannot call "
|
||||
"`stream_chat_message_objects` with `is_regenerate=True` "
|
||||
"when the last message is not a user message."
|
||||
)
|
||||
|
||||
run_search = False
|
||||
# Retrieval options are only None if reference_doc_ids are provided
|
||||
@ -304,7 +318,7 @@ def stream_chat_message_objects(
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=new_user_message,
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
# message=,
|
||||
rephrased_query=rephrased_query,
|
||||
@ -346,10 +360,14 @@ def stream_chat_message_objects(
|
||||
document_pruning_config=document_pruning_config,
|
||||
),
|
||||
prompt_config=PromptConfig.from_model(
|
||||
final_msg.prompt, prompt_override=new_msg_req.prompt_override
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
),
|
||||
llm_config=LLMConfig.from_persona(
|
||||
persona, llm_override=new_msg_req.llm_override
|
||||
persona,
|
||||
llm_override=(new_msg_req.llm_override or chat_session.llm_override),
|
||||
),
|
||||
message_history=[
|
||||
PreviousMessage.from_chat_message(msg) for msg in history_msgs
|
||||
@ -399,12 +417,14 @@ def stream_chat_message_objects(
|
||||
def stream_chat_message(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
|
@ -28,6 +28,8 @@ from danswer.db.models import SearchDoc
|
||||
from danswer.db.models import SearchDoc as DBSearchDoc
|
||||
from danswer.db.models import StarterMessage
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
@ -53,7 +55,9 @@ def get_chat_session_by_id(
|
||||
# if user_id is None, assume this is an admin who should be able
|
||||
# to view all chat sessions
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
stmt = stmt.where(
|
||||
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_session = result.scalar_one_or_none()
|
||||
@ -92,12 +96,16 @@ def create_chat_session(
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None = None,
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
) -> ChatSession:
|
||||
chat_session = ChatSession(
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
description=description,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
)
|
||||
|
||||
|
35
backend/danswer/db/enums.py
Normal file
35
backend/danswer/db/enums.py
Normal file
@ -0,0 +1,35 @@
|
||||
from enum import Enum as PyEnum
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# these may differ in the future, which is why we're okay with this duplication
|
||||
class DeletionStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# Consistent with Celery task statuses
|
||||
class TaskStatus(str, PyEnum):
|
||||
PENDING = "PENDING"
|
||||
STARTED = "STARTED"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILURE = "FAILURE"
|
||||
|
||||
|
||||
class IndexModelStatus(str, PyEnum):
|
||||
PAST = "PAST"
|
||||
PRESENT = "PRESENT"
|
||||
FUTURE = "FUTURE"
|
||||
|
||||
|
||||
class ChatSessionSharedStatus(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
@ -35,45 +35,18 @@ from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import ChatSessionSharedStatus
|
||||
from danswer.db.enums import IndexingStatus
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.enums import TaskStatus
|
||||
from danswer.db.pydantic_type import PydanticType
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# these may differ in the future, which is why we're okay with this duplication
|
||||
class DeletionStatus(str, PyEnum):
|
||||
NOT_STARTED = "not_started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
SUCCESS = "success"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
# Consistent with Celery task statuses
|
||||
class TaskStatus(str, PyEnum):
|
||||
PENDING = "PENDING"
|
||||
STARTED = "STARTED"
|
||||
SUCCESS = "SUCCESS"
|
||||
FAILURE = "FAILURE"
|
||||
|
||||
|
||||
class IndexModelStatus(str, PyEnum):
|
||||
PAST = "PAST"
|
||||
PRESENT = "PRESENT"
|
||||
FUTURE = "FUTURE"
|
||||
|
||||
|
||||
class ChatSessionSharedStatus(str, PyEnum):
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
@ -596,6 +569,20 @@ class ChatSession(Base):
|
||||
Enum(ChatSessionSharedStatus, native_enum=False),
|
||||
default=ChatSessionSharedStatus.PRIVATE,
|
||||
)
|
||||
|
||||
# the latest "overrides" specified by the user. These take precedence over
|
||||
# the attached persona. However, overrides specified directly in the
|
||||
# `send-message` call will take precedence over these.
|
||||
# NOTE: currently only used by the chat seeding flow, will be used in the
|
||||
# future once we allow users to override default values via the Chat UI
|
||||
# itself
|
||||
llm_override: Mapped[LLMOverride | None] = mapped_column(
|
||||
PydanticType(LLMOverride), nullable=True
|
||||
)
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
|
32
backend/danswer/db/pydantic_type.py
Normal file
32
backend/danswer/db/pydantic_type.py
Normal file
@ -0,0 +1,32 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
|
||||
class PydanticType(TypeDecorator):
|
||||
impl = JSONB
|
||||
|
||||
def __init__(
|
||||
self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.pydantic_model = pydantic_model
|
||||
|
||||
def process_bind_param(
|
||||
self, value: Optional[BaseModel], dialect: Any
|
||||
) -> Optional[dict]:
|
||||
if value is not None:
|
||||
return json.loads(value.json())
|
||||
return None
|
||||
|
||||
def process_result_value(
|
||||
self, value: Optional[dict], dialect: Any
|
||||
) -> Optional[BaseModel]:
|
||||
if value is not None:
|
||||
return self.pydantic_model.parse_obj(value)
|
||||
return None
|
@ -10,9 +10,9 @@ from pydantic import root_validator
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.server.query_and_chat.models import LLMOverride
|
||||
from danswer.server.query_and_chat.models import PromptOverride
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
|
17
backend/danswer/llm/override_models.py
Normal file
17
backend/danswer/llm/override_models.py
Normal file
@ -0,0 +1,17 @@
|
||||
"""Overrides sent over the wire / stored in the DB
|
||||
|
||||
NOTE: these models are used in many places, so have to be
|
||||
kepy in a separate file to avoid circular imports.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
|
||||
class PromptOverride(BaseModel):
|
||||
system_prompt: str | None = None
|
||||
task_prompt: str | None = None
|
@ -8,12 +8,16 @@ from sqlalchemy.orm import Session
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.process_message import stream_chat_message
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import get_chat_message
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_chat_sessions_by_user
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import get_persona_by_id
|
||||
from danswer.db.chat import set_as_latest_chat_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
@ -27,6 +31,7 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
get_renamed_conversation_name,
|
||||
)
|
||||
@ -40,6 +45,8 @@ from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
from danswer.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.query_and_chat.models import CreateChatSessionID
|
||||
from danswer.server.query_and_chat.models import LLMOverride
|
||||
from danswer.server.query_and_chat.models import PromptOverride
|
||||
from danswer.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from danswer.server.query_and_chat.models import SearchFeedbackRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -93,6 +100,13 @@ def get_chat_session(
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
|
||||
# for chat-seeding: if the session is unassigned, assign it now. This is done here
|
||||
# to avoid another back and forth between FE -> BE before starting the first
|
||||
# message generation
|
||||
if chat_session.user_id is None and user_id is not None:
|
||||
chat_session.user_id = user_id
|
||||
db_session.commit()
|
||||
|
||||
session_messages = get_chat_messages_by_session(
|
||||
chat_session_id=session_id, user_id=user_id, db_session=db_session
|
||||
)
|
||||
@ -209,15 +223,24 @@ def handle_new_chat_message(
|
||||
- Sending a new message in the session
|
||||
- Regenerating a message in the session (just send the same one again)
|
||||
- Editing a message (similar to regenerating but sending a different message)
|
||||
- Kicking off a seeded chat session (set `use_existing_user_message`)
|
||||
|
||||
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
|
||||
have already been set as latest"""
|
||||
logger.info(f"Received new chat message: {chat_message_req.message}")
|
||||
|
||||
if not chat_message_req.message and chat_message_req.prompt_id is not None:
|
||||
if (
|
||||
not chat_message_req.message
|
||||
and chat_message_req.prompt_id is not None
|
||||
and not chat_message_req.use_existing_user_message
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
packets = stream_chat_message(new_msg_req=chat_message_req, user=user)
|
||||
packets = stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
)
|
||||
|
||||
return StreamingResponse(packets, media_type="application/json")
|
||||
|
||||
@ -308,3 +331,71 @@ def get_max_document_tokens(
|
||||
return MaxSelectedDocumentTokens(
|
||||
max_tokens=compute_max_document_tokens_for_persona(persona),
|
||||
)
|
||||
|
||||
|
||||
"""Endpoints for chat seeding"""
|
||||
|
||||
|
||||
class ChatSeedRequest(BaseModel):
|
||||
# standard chat session stuff
|
||||
persona_id: int
|
||||
prompt_id: int | None = None
|
||||
|
||||
# overrides / seeding
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
description: str | None = None
|
||||
message: str | None = None
|
||||
|
||||
# TODO: support this
|
||||
# initial_message_retrieval_options: RetrievalDetails | None = None
|
||||
|
||||
|
||||
class ChatSeedResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.post("/seed-chat-session")
|
||||
def seed_chat(
|
||||
chat_seed_request: ChatSeedRequest,
|
||||
# NOTE: realistically, this will be an API key not an actual user
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSeedResponse:
|
||||
try:
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=chat_seed_request.description or "",
|
||||
user_id=None, # this chat session is "unassigned" until a user visits the web UI
|
||||
persona_id=chat_seed_request.persona_id,
|
||||
llm_override=chat_seed_request.llm_override,
|
||||
prompt_override=chat_seed_request.prompt_override,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=400, detail="Invalid Persona provided.")
|
||||
|
||||
if chat_seed_request.message is not None:
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=new_chat_session.id, db_session=db_session
|
||||
)
|
||||
create_new_chat_message(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=chat_seed_request.prompt_id
|
||||
or (
|
||||
new_chat_session.persona.prompts[0].id
|
||||
if new_chat_session.persona.prompts
|
||||
else None
|
||||
),
|
||||
message=chat_seed_request.message,
|
||||
token_count=len(
|
||||
get_default_llm_tokenizer().encode(chat_seed_request.message)
|
||||
),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return ChatSeedResponse(
|
||||
redirect_url=f"{WEB_DOMAIN}/chat?chatId={new_chat_session.id}"
|
||||
)
|
||||
|
@ -8,7 +8,9 @@ from danswer.chat.models import RetrievalDocs
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.db.models import ChatSessionSharedStatus
|
||||
from danswer.db.enums import ChatSessionSharedStatus
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import SearchDoc
|
||||
@ -70,17 +72,6 @@ class DocumentSearchRequest(BaseModel):
|
||||
skip_llm_chunk_filter: bool | None = None
|
||||
|
||||
|
||||
class LLMOverride(BaseModel):
|
||||
model_provider: str | None = None
|
||||
model_version: str | None = None
|
||||
temperature: float | None = None
|
||||
|
||||
|
||||
class PromptOverride(BaseModel):
|
||||
system_prompt: str | None = None
|
||||
task_prompt: str | None = None
|
||||
|
||||
|
||||
"""
|
||||
Currently the different branches are generated by changing the search query
|
||||
|
||||
@ -116,6 +107,9 @@ class CreateChatMessageRequest(BaseModel):
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
@root_validator
|
||||
def check_search_doc_ids_or_retrieval_options(cls: BaseModel, values: dict) -> dict:
|
||||
search_doc_ids, retrieval_options = values.get("search_doc_ids"), values.get(
|
||||
|
@ -133,7 +133,7 @@ export const Chat = ({
|
||||
!submitOnLoadPerformed.current
|
||||
) {
|
||||
submitOnLoadPerformed.current = true;
|
||||
onSubmit();
|
||||
await onSubmit();
|
||||
}
|
||||
|
||||
return;
|
||||
@ -162,6 +162,21 @@ export const Chat = ({
|
||||
setChatSessionSharedStatus(chatSession.shared_status);
|
||||
|
||||
setIsFetchingChatMessages(false);
|
||||
|
||||
// if this is a seeded chat, then kick off the AI message generation
|
||||
if (newMessageHistory.length === 1 && !submitOnLoadPerformed.current) {
|
||||
submitOnLoadPerformed.current = true;
|
||||
const seededMessage = newMessageHistory[0].message;
|
||||
await onSubmit({
|
||||
isSeededChat: true,
|
||||
messageOverride: seededMessage,
|
||||
});
|
||||
// force re-name if the chat session doesn't have one
|
||||
if (!chatSession.description) {
|
||||
await nameChatSession(existingChatSessionId, seededMessage);
|
||||
router.refresh(); // need to refresh to update name on sidebar
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
initialSessionFetch();
|
||||
@ -326,11 +341,13 @@ export const Chat = ({
|
||||
messageOverride,
|
||||
queryOverride,
|
||||
forceSearch,
|
||||
isSeededChat,
|
||||
}: {
|
||||
messageIdToResend?: number;
|
||||
messageOverride?: string;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
isSeededChat?: boolean;
|
||||
} = {}) => {
|
||||
let currChatSessionId: number;
|
||||
let isNewSession = chatSessionId === null;
|
||||
@ -419,6 +436,7 @@ export const Chat = ({
|
||||
undefined,
|
||||
systemPromptOverride:
|
||||
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
useExistingUserMessage: isSeededChat,
|
||||
})) {
|
||||
for (const packet of packetBunch) {
|
||||
if (Object.hasOwn(packet, "answer_piece")) {
|
||||
|
@ -57,6 +57,7 @@ export async function* sendMessage({
|
||||
modelVersion,
|
||||
temperature,
|
||||
systemPromptOverride,
|
||||
useExistingUserMessage,
|
||||
}: {
|
||||
message: string;
|
||||
parentMessageId: number | null;
|
||||
@ -71,6 +72,9 @@ export async function* sendMessage({
|
||||
temperature?: number;
|
||||
// prompt overrides
|
||||
systemPromptOverride?: string;
|
||||
// if specified, will use the existing latest user message
|
||||
// and will ignore the specified `message`
|
||||
useExistingUserMessage?: boolean;
|
||||
}) {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
@ -99,13 +103,19 @@ export async function* sendMessage({
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: {
|
||||
system_prompt: systemPromptOverride,
|
||||
},
|
||||
llm_override: {
|
||||
temperature,
|
||||
model_version: modelVersion,
|
||||
},
|
||||
prompt_override: systemPromptOverride
|
||||
? {
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
? {
|
||||
temperature,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
}),
|
||||
});
|
||||
if (!sendMessageResponse.ok) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user