No Context Chat Backend (#397)

This commit is contained in:
Yuhong Sun 2023-09-05 22:32:00 -07:00 committed by GitHub
parent 630386c8c4
commit 5977a28f58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 881 additions and 12 deletions

View File

@ -0,0 +1,85 @@
"""Add Chat Sessions
Revision ID: 5809c0787398
Revises: d929f0c1c6af
Create Date: 2023-09-04 15:29:44.002164
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "5809c0787398"
down_revision = "d929f0c1c6af"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chat_session",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("description", sa.Text(), nullable=False),
sa.Column("deleted", sa.Boolean(), nullable=False),
sa.Column(
"time_updated",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"chat_message",
sa.Column("chat_session_id", sa.Integer(), nullable=False),
sa.Column("message_number", sa.Integer(), nullable=False),
sa.Column("edit_number", sa.Integer(), nullable=False),
sa.Column("parent_edit_number", sa.Integer(), nullable=True),
sa.Column("latest", sa.Boolean(), nullable=False),
sa.Column("message", sa.Text(), nullable=False),
sa.Column(
"message_type",
sa.Enum(
"SYSTEM",
"USER",
"ASSISTANT",
"DANSWER",
name="messagetype",
native_enum=False,
),
nullable=False,
),
sa.Column(
"time_sent",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["chat_session_id"],
["chat_session.id"],
),
sa.PrimaryKeyConstraint("chat_session_id", "message_number", "edit_number"),
)
def downgrade() -> None:
op.drop_table("chat_message")
op.drop_table("chat_session")

View File

@ -24,13 +24,13 @@ def upgrade() -> None:
sa.Column("query", sa.String(), nullable=False),
sa.Column(
"selected_search_flow",
sa.Enum("KEYWORD", "SEMANTIC", name="searchtype"),
sa.Enum("KEYWORD", "SEMANTIC", name="searchtype", native_enum=False),
nullable=True,
),
sa.Column("llm_answer", sa.String(), nullable=True),
sa.Column(
"feedback",
sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype"),
sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype", native_enum=False),
nullable=True,
),
sa.Column(
@ -65,6 +65,7 @@ def upgrade() -> None:
"HIDE",
"UNHIDE",
name="searchfeedbacktype",
native_enum=False,
),
nullable=True,
),

View File

View File

@ -0,0 +1,27 @@
from collections.abc import Iterator
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.llm.build import get_default_llm
def llm_chat_answer(previous_messages: list[ChatMessage]) -> Iterator[str]:
prompt: list[BaseMessage] = []
for msg in previous_messages:
content = msg.message
if msg.message_type == MessageType.SYSTEM:
prompt.append(SystemMessage(content=content))
if msg.message_type == MessageType.ASSISTANT:
prompt.append(AIMessage(content=content))
if (
msg.message_type == MessageType.USER
or msg.message_type == MessageType.DANSWER # consider using FunctionMessage
):
prompt.append(HumanMessage(content=content))
return get_default_llm().stream(prompt)

View File

@ -149,6 +149,7 @@ QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
#####

View File

@ -83,3 +83,11 @@ class SearchFeedbackType(str, Enum):
REJECT = "reject" # down-boost this document for all future queries
HIDE = "hide" # mark this document as untrusted, hide from LLM
UNHIDE = "unhide"
class MessageType(str, Enum):
# Using OpenAI standards, Langchain equivalent shown in comment
SYSTEM = "system" # SystemMessage
USER = "user" # HumanMessage
ASSISTANT = "assistant" # AIMessage
DANSWER = "danswer" # FunctionMessage

247
backend/danswer/db/chat.py Normal file
View File

@ -0,0 +1,247 @@
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from danswer.configs.app_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
def fetch_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()
return list(chat_sessions)
def fetch_chat_messages_by_session(
chat_session_id: int, db_session: Session
) -> list[ChatMessage]:
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id == chat_session_id)
.order_by(ChatMessage.message_number.asc(), ChatMessage.edit_number.asc())
)
result = db_session.execute(stmt).scalars().all()
return list(result)
def fetch_chat_message(
chat_session_id: int, message_number: int, edit_number: int, db_session: Session
) -> ChatMessage:
stmt = (
select(ChatMessage)
.where(
(ChatMessage.chat_session_id == chat_session_id)
& (ChatMessage.message_number == message_number)
& (ChatMessage.edit_number == edit_number)
)
.options(selectinload(ChatMessage.chat_session))
)
chat_message = db_session.execute(stmt).scalar_one_or_none()
if not chat_message:
raise ValueError("Invalid Chat Message specified")
return chat_message
def fetch_chat_session_by_id(chat_session_id: int, db_session: Session) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
result = db_session.execute(stmt)
chat_session = result.scalar_one_or_none()
if not chat_session:
raise ValueError("Invalid Chat Session ID provided")
return chat_session
def verify_parent_exists(
chat_session_id: int,
message_number: int,
parent_edit_number: int | None,
db_session: Session,
) -> ChatMessage:
stmt = select(ChatMessage).where(
(ChatMessage.chat_session_id == chat_session_id)
& (ChatMessage.message_number == message_number - 1)
& (ChatMessage.edit_number == parent_edit_number)
)
result = db_session.execute(stmt)
try:
return result.scalar_one()
except NoResultFound:
raise ValueError("Invalid message, parent message not found")
def create_chat_session(
description: str, user_id: UUID | None, db_session: Session
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
description=description,
)
db_session.add(chat_session)
db_session.commit()
return chat_session
def update_chat_session(
user_id: UUID | None, chat_session_id: int, description: str, db_session: Session
) -> ChatSession:
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
if chat_session.deleted:
raise ValueError("Trying to rename a deleted chat session")
if user_id != chat_session.user_id:
raise ValueError("User trying to update chat of another user.")
chat_session.description = description
db_session.commit()
return chat_session
def delete_chat_session(
user_id: UUID | None,
chat_session_id: int,
db_session: Session,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
if user_id != chat_session.user_id:
raise ValueError("User trying to delete chat of another user.")
if hard_delete:
stmt_messages = delete(ChatMessage).where(
ChatMessage.chat_session_id == chat_session_id
)
db_session.execute(stmt_messages)
stmt = delete(ChatSession).where(ChatSession.id == chat_session_id)
db_session.execute(stmt)
else:
chat_session.deleted = True
db_session.commit()
def _set_latest_chat_message_no_commit(
chat_session_id: int,
message_number: int,
parent_edit_number: int | None,
edit_number: int,
db_session: Session,
) -> None:
if message_number != 0 and parent_edit_number is None:
raise ValueError(
"Only initial message in a chat is allowed to not have a parent"
)
db_session.query(ChatMessage).filter(
and_(
ChatMessage.chat_session_id == chat_session_id,
ChatMessage.message_number == message_number,
ChatMessage.parent_edit_number == parent_edit_number,
)
).update({ChatMessage.latest: False})
db_session.query(ChatMessage).filter(
and_(
ChatMessage.chat_session_id == chat_session_id,
ChatMessage.message_number == message_number,
ChatMessage.edit_number == edit_number,
)
).update({ChatMessage.latest: True})
def create_new_chat_message(
chat_session_id: int,
message_number: int,
message: str,
parent_edit_number: int | None,
message_type: MessageType,
db_session: Session,
) -> ChatMessage:
"""Creates a new chat message and sets it to the latest message of its parent message"""
# Get the count of existing edits at the provided message number
latest_edit_number = (
db_session.query(func.max(ChatMessage.edit_number))
.filter_by(
chat_session_id=chat_session_id,
message_number=message_number,
)
.scalar()
)
# The new message is a new edit at the provided message number
new_edit_number = latest_edit_number + 1 if latest_edit_number is not None else 0
# Create a new message and set it to be the latest for its parent message
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
edit_number=new_edit_number,
message=message,
message_type=message_type,
)
db_session.add(new_chat_message)
# Set the previous latest message of the same parent, as no longer the latest
_set_latest_chat_message_no_commit(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
edit_number=new_edit_number,
db_session=db_session,
)
db_session.commit()
return new_chat_message
def set_latest_chat_message(
chat_session_id: int,
message_number: int,
parent_edit_number: int | None,
edit_number: int,
db_session: Session,
) -> None:
_set_latest_chat_message_no_commit(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
edit_number=edit_number,
db_session=db_session,
)
db_session.commit()

View File

@ -22,7 +22,7 @@ def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent:
query_event = result.scalar_one_or_none()
if not query_event:
raise ValueError("Invalid Query Event provided for updating")
raise ValueError("Invalid Query Event ID Provided")
return query_event
@ -33,7 +33,7 @@ def fetch_docs_by_id(doc_id: str, db_session: Session) -> DbDocument:
doc = result.scalar_one_or_none()
if not doc:
raise ValueError("Invalid Document provided for updating")
raise ValueError("Invalid Document ID Provided")
return doc

View File

@ -13,6 +13,7 @@ from sqlalchemy import DateTime
from sqlalchemy import Enum
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import String
from sqlalchemy import Text
@ -25,6 +26,7 @@ from sqlalchemy.orm import relationship
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import InputType
@ -52,7 +54,7 @@ class Base(DeclarativeBase):
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
# even an almost empty token from keycloak will not fit the default 1024 bytes
access_token: Mapped[str] = mapped_column(Text(), nullable=False) # type: ignore
access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore
class User(SQLAlchemyBaseUserTableUUID, Base):
@ -68,6 +70,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
query_events: Mapped[List["QueryEvent"]] = relationship(
"QueryEvent", back_populates="user"
)
chat_sessions: Mapped[List["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
@ -193,7 +198,7 @@ class IndexAttempt(Base):
status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
num_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0)
error_msg: Mapped[str | None] = mapped_column(
String(), default=None
Text, default=None
) # only filled if status = "failed"
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
@ -217,6 +222,15 @@ class IndexAttempt(Base):
"Credential", back_populates="index_attempts"
)
__table_args__ = (
Index(
"ix_index_attempt_latest_for_connector_credential_pair",
"connector_id",
"credential_id",
"time_created",
),
)
def __repr__(self) -> str:
return (
f"<IndexAttempt(id={self.id!r}, "
@ -245,7 +259,7 @@ class DeletionAttempt(Base):
status: Mapped[DeletionStatus] = mapped_column(Enum(DeletionStatus))
num_docs_deleted: Mapped[int] = mapped_column(Integer, default=0)
error_msg: Mapped[str | None] = mapped_column(
String(), default=None
Text, default=None
) # only filled if status = "failed"
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
@ -291,12 +305,12 @@ class QueryEvent(Base):
__tablename__ = "query_event"
id: Mapped[int] = mapped_column(primary_key=True)
query: Mapped[str] = mapped_column(String())
query: Mapped[str] = mapped_column(Text)
# search_flow refers to user selection, None if user used auto
selected_search_flow: Mapped[SearchType | None] = mapped_column(
Enum(SearchType), nullable=True
)
llm_answer: Mapped[str | None] = mapped_column(String(), default=None)
llm_answer: Mapped[str | None] = mapped_column(Text, default=None)
feedback: Mapped[QAFeedbackType | None] = mapped_column(
Enum(QAFeedbackType), nullable=True
)
@ -340,8 +354,8 @@ class DocumentRetrievalFeedback(Base):
class Document(Base):
__tablename__ = "document"
# this should correspond to the ID of the document (as is passed around
# in Danswer)
# this should correspond to the ID of the document
# (as is passed around in Danswer)
id: Mapped[str] = mapped_column(String, primary_key=True)
# 0 for neutral, positive for mostly endorse, negative for mostly reject
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
@ -354,3 +368,46 @@ class Document(Base):
retrieval_feedbacks: Mapped[List[DocumentRetrievalFeedback]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
class ChatSession(Base):
__tablename__ = "chat_session"
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
description: Mapped[str] = mapped_column(Text)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
time_updated: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
messages: Mapped[List["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="delete"
)
class ChatMessage(Base):
__tablename__ = "chat_message"
chat_session_id: Mapped[int] = mapped_column(
ForeignKey("chat_session.id"), primary_key=True
)
message_number: Mapped[int] = mapped_column(Integer, primary_key=True)
edit_number: Mapped[int] = mapped_column(Integer, default=0, primary_key=True)
parent_edit_number: Mapped[int | None] = mapped_column(
Integer, nullable=True
) # null if first message
latest: Mapped[bool] = mapped_column(Boolean, default=True)
message: Mapped[str] = mapped_column(Text)
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
time_sent: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
chat_session: Mapped[ChatSession] = relationship("ChatSession")

View File

@ -30,6 +30,7 @@ from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.datastores.document_index import get_default_document_index
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.server.chat_backend import router as chat_router
from danswer.server.credential import router as credential_router
from danswer.server.event_loading import router as event_processing_router
from danswer.server.health import router as health_router
@ -66,6 +67,7 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
def get_application() -> FastAPI:
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
application.include_router(backend_router)
application.include_router(chat_router)
application.include_router(event_processing_router)
application.include_router(admin_router)
application.include_router(user_router)

View File

@ -0,0 +1,19 @@
from danswer.llm.build import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
def get_chat_name_messages(user_query: str) -> list[dict[str, str]]:
messages = [
{
"role": "system",
"content": "Give a short name for this chat session based on the user's first message.",
},
{"role": "user", "content": user_query},
]
return messages
def get_new_chat_name(user_query: str) -> str:
messages = get_chat_name_messages(user_query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
return get_default_llm().invoke(filled_llm_prompt)

View File

@ -0,0 +1,373 @@
from collections.abc import Iterator
from dataclasses import asdict
from fastapi import APIRouter
from fastapi import Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chat.chat_llm import llm_chat_answer
from danswer.configs.constants import MessageType
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import delete_chat_session
from danswer.db.chat import fetch_chat_message
from danswer.db.chat import fetch_chat_messages_by_session
from danswer.db.chat import fetch_chat_session_by_id
from danswer.db.chat import fetch_chat_sessions_by_user
from danswer.db.chat import set_latest_chat_message
from danswer.db.chat import update_chat_session
from danswer.db.chat import verify_parent_exists
from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from danswer.db.models import User
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name
from danswer.server.models import ChatMessageDetail
from danswer.server.models import ChatMessageIdentifier
from danswer.server.models import ChatRenameRequest
from danswer.server.models import ChatSessionDetailResponse
from danswer.server.models import ChatSessionIdsResponse
from danswer.server.models import CreateChatID
from danswer.server.models import CreateChatRequest
from danswer.server.models import RenameChatSessionResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
router = APIRouter(prefix="/chat")
@router.get("/get-user-chat-sessions")
def get_user_chat_sessions(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSessionIdsResponse:
user_id = user.id if user is not None else None
# Don't included deleted chats, even if soft delete only
chat_sessions = fetch_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session
)
return ChatSessionIdsResponse(sessions=[chat.id for chat in chat_sessions])
@router.get("/get-chat-session/{session_id}")
def get_chat_session_messages(
session_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSessionDetailResponse:
user_id = user.id if user is not None else None
try:
session = fetch_chat_session_by_id(session_id, db_session)
except ValueError:
raise ValueError("Chat Session has been deleted")
if session.deleted:
raise ValueError("Chat Session has been deleted")
if user_id != session.user_id:
if user is None:
raise PermissionError(
"The No-Auth User is trying to read a different user's chat"
)
raise PermissionError(
f"User {user.email} is trying to read a different user's chat"
)
session_messages = fetch_chat_messages_by_session(
chat_session_id=session_id, db_session=db_session
)
return ChatSessionDetailResponse(
chat_session_id=session_id,
description=session.description,
messages=[
ChatMessageDetail(
message_number=msg.message_number,
edit_number=msg.edit_number,
parent_edit_number=msg.parent_edit_number,
latest=msg.latest,
message=msg.message,
message_type=msg.message_type,
time_sent=msg.time_sent,
)
for msg in session_messages
],
)
@router.post("/create-chat-session")
def create_new_chat_session(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CreateChatID:
user_id = user.id if user is not None else None
new_chat_session = create_chat_session(
"", user_id, db_session # Leave the naming till later to prevent delay
)
return CreateChatID(chat_session_id=new_chat_session.id)
@router.put("/rename-chat-session")
def rename_chat_session(
rename: ChatRenameRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> RenameChatSessionResponse:
name = rename.name
message = rename.first_message
user_id = user.id if user is not None else None
if not name and not message:
raise ValueError("Can't assign a name for the chat without context")
new_name = name or get_new_chat_name(str(message))
update_chat_session(user_id, rename.chat_session_id, new_name, db_session)
return RenameChatSessionResponse(new_name=new_name)
@router.delete("/delete-chat-session/{session_id}")
def delete_chat_session_by_id(
session_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user is not None else None
delete_chat_session(user_id, session_id, db_session)
def _create_chat_chain(
chat_session_id: int,
db_session: Session,
stop_after: int | None = None,
) -> list[ChatMessage]:
mainline_messages: list[ChatMessage] = []
all_chat_messages = fetch_chat_messages_by_session(chat_session_id, db_session)
target_message_num = 0
target_parent_edit_num = None
# Chat messages must be ordered by message_number
# (fetch_chat_messages_by_session ensures this so no resorting here necessary)
for msg in all_chat_messages:
if (
msg.message_number != target_message_num
or msg.parent_edit_number != target_parent_edit_num
or not msg.latest
):
continue
target_parent_edit_num = msg.edit_number
target_message_num += 1
mainline_messages.append(msg)
if stop_after is not None and target_message_num > stop_after:
break
if not mainline_messages:
raise RuntimeError("Could not trace chat message history")
return mainline_messages
@router.post("/send-message")
def handle_new_chat_message(
chat_message: CreateChatRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
"""This endpoint is both used for sending new messages and for sending edited messages.
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
have already been set as latest"""
chat_session_id = chat_message.chat_session_id
message_number = chat_message.message_number
message_content = chat_message.message
parent_edit_number = chat_message.parent_edit_number
user_id = user.id if user is not None else None
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
if chat_session.deleted:
raise ValueError("Cannot send messages to a deleted chat session")
if chat_session.user_id != user_id:
if user is None:
raise PermissionError(
"The No-Auth User trying to interact with a different user's chat"
)
raise PermissionError(
f"User {user.email} trying to interact with a different user's chat"
)
if message_number != 0:
if parent_edit_number is None:
raise ValueError("Message must have a valid parent message")
verify_parent_exists(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
db_session=db_session,
)
else:
if parent_edit_number is not None:
raise ValueError("Initial message in session cannot have parent")
# Create new message at the right place in the tree and label it latest for its parent
new_message = create_new_chat_message(
chat_session_id=chat_session_id,
message_number=message_number,
parent_edit_number=parent_edit_number,
message=message_content,
message_type=MessageType.USER,
db_session=db_session,
)
mainline_messages = _create_chat_chain(
chat_session_id,
db_session,
)
if mainline_messages[-1].message != message_content:
raise RuntimeError(
"The new message was not on the mainline. "
"Be sure to update latests before calling this."
)
@log_generator_function_time()
def stream_chat_tokens() -> Iterator[str]:
tokens = llm_chat_answer(mainline_messages)
llm_output = ""
for token in tokens:
llm_output += token
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
create_new_chat_message(
chat_session_id=chat_session_id,
message_number=message_number + 1,
parent_edit_number=new_message.edit_number,
message=llm_output,
message_type=MessageType.ASSISTANT,
db_session=db_session,
)
return StreamingResponse(stream_chat_tokens(), media_type="application/json")
@router.post("/regenerate-from-parent")
def regenerate_message_given_parent(
parent_message: ChatMessageIdentifier,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
"""Regenerate an LLM response given a particular parent message
The parent message is set as latest and a new LLM response is set as
the latest following message"""
chat_session_id = parent_message.chat_session_id
message_number = parent_message.message_number
edit_number = parent_message.edit_number
user_id = user.id if user is not None else None
chat_message = fetch_chat_message(
chat_session_id=chat_session_id,
message_number=message_number,
edit_number=edit_number,
db_session=db_session,
)
chat_session = chat_message.chat_session
if chat_session.deleted:
raise ValueError("Chat session has been deleted")
if chat_session.user_id != user_id:
if user is None:
raise PermissionError(
"The No-Auth User trying to regenerate chat messages of another user"
)
raise PermissionError(
f"User {user.email} trying to regenerate chat messages of another user"
)
set_latest_chat_message(
chat_session_id,
message_number,
chat_message.parent_edit_number,
edit_number,
db_session,
)
# The parent message, now set as latest, may have follow on messages
# Don't want to include those in the context to LLM
mainline_messages = _create_chat_chain(
chat_session_id, db_session, stop_after=message_number
)
@log_generator_function_time()
def stream_regenerate_tokens() -> Iterator[str]:
tokens = llm_chat_answer(mainline_messages)
llm_output = ""
for token in tokens:
llm_output += token
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
create_new_chat_message(
chat_session_id=chat_session_id,
message_number=message_number + 1,
parent_edit_number=edit_number,
message=llm_output,
message_type=MessageType.ASSISTANT,
db_session=db_session,
)
return StreamingResponse(stream_regenerate_tokens(), media_type="application/json")
@router.put("/set-message-as-latest")
def set_message_as_latest(
message_identifier: ChatMessageIdentifier,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user is not None else None
chat_message = fetch_chat_message(
chat_session_id=message_identifier.chat_session_id,
message_number=message_identifier.message_number,
edit_number=message_identifier.edit_number,
db_session=db_session,
)
chat_session = chat_message.chat_session
if chat_session.deleted:
raise ValueError("Chat session has been deleted")
if chat_session.user_id != user_id:
if user is None:
raise PermissionError(
"The No-Auth User trying to update chat messages of another user"
)
raise PermissionError(
f"User {user.email} trying to update chat messages of another user"
)
set_latest_chat_message(
chat_session_id=chat_message.chat_session_id,
message_number=chat_message.message_number,
parent_edit_number=chat_message.parent_edit_number,
edit_number=chat_message.edit_number,
db_session=db_session,
)

View File

@ -11,6 +11,7 @@ from pydantic.generics import GenericModel
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import InputType
@ -130,6 +131,10 @@ class SearchDoc(BaseModel):
score: float | None
class CreateChatID(BaseModel):
chat_session_id: int
class QuestionRequest(BaseModel):
query: str
collection: str
@ -151,6 +156,49 @@ class SearchFeedbackRequest(BaseModel):
search_feedback: SearchFeedbackType
class CreateChatRequest(BaseModel):
chat_session_id: int
message_number: int
parent_edit_number: int | None
message: str
class ChatMessageIdentifier(BaseModel):
chat_session_id: int
message_number: int
edit_number: int
class ChatRenameRequest(BaseModel):
chat_session_id: int
name: str | None
first_message: str | None
class RenameChatSessionResponse(BaseModel):
new_name: str # This is only really useful if the name is generated
class ChatSessionIdsResponse(BaseModel):
sessions: list[int]
class ChatMessageDetail(BaseModel):
message_number: int
edit_number: int
parent_edit_number: int | None
latest: bool
message: str
message_type: MessageType
time_sent: datetime
class ChatSessionDetailResponse(BaseModel):
chat_session_id: int
description: str
messages: list[ChatMessageDetail]
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool

View File

@ -1,6 +1,7 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import TypeVar
@ -10,7 +11,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable)
FG = TypeVar("FG", bound=Callable[..., Generator])
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
def log_function_time(