mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-10 03:40:35 +02:00
No Context Chat Backend (#397)
This commit is contained in:
parent
630386c8c4
commit
5977a28f58
85
backend/alembic/versions/5809c0787398_add_chat_sessions.py
Normal file
85
backend/alembic/versions/5809c0787398_add_chat_sessions.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""Add Chat Sessions
|
||||
|
||||
Revision ID: 5809c0787398
|
||||
Revises: d929f0c1c6af
|
||||
Create Date: 2023-09-04 15:29:44.002164
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "5809c0787398"
|
||||
down_revision = "d929f0c1c6af"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chat_session",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("description", sa.Text(), nullable=False),
|
||||
sa.Column("deleted", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"time_updated",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"chat_message",
|
||||
sa.Column("chat_session_id", sa.Integer(), nullable=False),
|
||||
sa.Column("message_number", sa.Integer(), nullable=False),
|
||||
sa.Column("edit_number", sa.Integer(), nullable=False),
|
||||
sa.Column("parent_edit_number", sa.Integer(), nullable=True),
|
||||
sa.Column("latest", sa.Boolean(), nullable=False),
|
||||
sa.Column("message", sa.Text(), nullable=False),
|
||||
sa.Column(
|
||||
"message_type",
|
||||
sa.Enum(
|
||||
"SYSTEM",
|
||||
"USER",
|
||||
"ASSISTANT",
|
||||
"DANSWER",
|
||||
name="messagetype",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"time_sent",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["chat_session_id"],
|
||||
["chat_session.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("chat_session_id", "message_number", "edit_number"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("chat_message")
|
||||
op.drop_table("chat_session")
|
@ -24,13 +24,13 @@ def upgrade() -> None:
|
||||
sa.Column("query", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"selected_search_flow",
|
||||
sa.Enum("KEYWORD", "SEMANTIC", name="searchtype"),
|
||||
sa.Enum("KEYWORD", "SEMANTIC", name="searchtype", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("llm_answer", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"feedback",
|
||||
sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype"),
|
||||
sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype", native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
@ -65,6 +65,7 @@ def upgrade() -> None:
|
||||
"HIDE",
|
||||
"UNHIDE",
|
||||
name="searchfeedbacktype",
|
||||
native_enum=False,
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
|
0
backend/danswer/chat/__init__.py
Normal file
0
backend/danswer/chat/__init__.py
Normal file
27
backend/danswer/chat/chat_llm.py
Normal file
27
backend/danswer/chat/chat_llm.py
Normal file
@ -0,0 +1,27 @@
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.build import get_default_llm
|
||||
|
||||
|
||||
def llm_chat_answer(previous_messages: list[ChatMessage]) -> Iterator[str]:
|
||||
prompt: list[BaseMessage] = []
|
||||
for msg in previous_messages:
|
||||
content = msg.message
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
prompt.append(SystemMessage(content=content))
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
prompt.append(AIMessage(content=content))
|
||||
if (
|
||||
msg.message_type == MessageType.USER
|
||||
or msg.message_type == MessageType.DANSWER # consider using FunctionMessage
|
||||
):
|
||||
prompt.append(HumanMessage(content=content))
|
||||
|
||||
return get_default_llm().stream(prompt)
|
@ -149,6 +149,7 @@ QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "10") # 10 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||
|
||||
|
||||
#####
|
||||
|
@ -83,3 +83,11 @@ class SearchFeedbackType(str, Enum):
|
||||
REJECT = "reject" # down-boost this document for all future queries
|
||||
HIDE = "hide" # mark this document as untrusted, hide from LLM
|
||||
UNHIDE = "unhide"
|
||||
|
||||
|
||||
class MessageType(str, Enum):
|
||||
# Using OpenAI standards, Langchain equivalent shown in comment
|
||||
SYSTEM = "system" # SystemMessage
|
||||
USER = "user" # HumanMessage
|
||||
ASSISTANT = "assistant" # AIMessage
|
||||
DANSWER = "danswer" # FunctionMessage
|
||||
|
247
backend/danswer/db/chat.py
Normal file
247
backend/danswer/db/chat.py
Normal file
@ -0,0 +1,247 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
|
||||
|
||||
def fetch_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
return list(chat_sessions)
|
||||
|
||||
|
||||
def fetch_chat_messages_by_session(
|
||||
chat_session_id: int, db_session: Session
|
||||
) -> list[ChatMessage]:
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(ChatMessage.chat_session_id == chat_session_id)
|
||||
.order_by(ChatMessage.message_number.asc(), ChatMessage.edit_number.asc())
|
||||
)
|
||||
result = db_session.execute(stmt).scalars().all()
|
||||
return list(result)
|
||||
|
||||
|
||||
def fetch_chat_message(
|
||||
chat_session_id: int, message_number: int, edit_number: int, db_session: Session
|
||||
) -> ChatMessage:
|
||||
stmt = (
|
||||
select(ChatMessage)
|
||||
.where(
|
||||
(ChatMessage.chat_session_id == chat_session_id)
|
||||
& (ChatMessage.message_number == message_number)
|
||||
& (ChatMessage.edit_number == edit_number)
|
||||
)
|
||||
.options(selectinload(ChatMessage.chat_session))
|
||||
)
|
||||
|
||||
chat_message = db_session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not chat_message:
|
||||
raise ValueError("Invalid Chat Message specified")
|
||||
|
||||
return chat_message
|
||||
|
||||
|
||||
def fetch_chat_session_by_id(chat_session_id: int, db_session: Session) -> ChatSession:
|
||||
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
result = db_session.execute(stmt)
|
||||
chat_session = result.scalar_one_or_none()
|
||||
|
||||
if not chat_session:
|
||||
raise ValueError("Invalid Chat Session ID provided")
|
||||
|
||||
return chat_session
|
||||
|
||||
|
||||
def verify_parent_exists(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_edit_number: int | None,
|
||||
db_session: Session,
|
||||
) -> ChatMessage:
|
||||
stmt = select(ChatMessage).where(
|
||||
(ChatMessage.chat_session_id == chat_session_id)
|
||||
& (ChatMessage.message_number == message_number - 1)
|
||||
& (ChatMessage.edit_number == parent_edit_number)
|
||||
)
|
||||
|
||||
result = db_session.execute(stmt)
|
||||
|
||||
try:
|
||||
return result.scalar_one()
|
||||
except NoResultFound:
|
||||
raise ValueError("Invalid message, parent message not found")
|
||||
|
||||
|
||||
def create_chat_session(
|
||||
description: str, user_id: UUID | None, db_session: Session
|
||||
) -> ChatSession:
|
||||
chat_session = ChatSession(
|
||||
user_id=user_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
db_session.add(chat_session)
|
||||
db_session.commit()
|
||||
|
||||
return chat_session
|
||||
|
||||
|
||||
def update_chat_session(
|
||||
user_id: UUID | None, chat_session_id: int, description: str, db_session: Session
|
||||
) -> ChatSession:
|
||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Trying to rename a deleted chat session")
|
||||
|
||||
if user_id != chat_session.user_id:
|
||||
raise ValueError("User trying to update chat of another user.")
|
||||
|
||||
chat_session.description = description
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return chat_session
|
||||
|
||||
|
||||
def delete_chat_session(
|
||||
user_id: UUID | None,
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
hard_delete: bool = HARD_DELETE_CHATS,
|
||||
) -> None:
|
||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||
|
||||
if user_id != chat_session.user_id:
|
||||
raise ValueError("User trying to delete chat of another user.")
|
||||
|
||||
if hard_delete:
|
||||
stmt_messages = delete(ChatMessage).where(
|
||||
ChatMessage.chat_session_id == chat_session_id
|
||||
)
|
||||
db_session.execute(stmt_messages)
|
||||
|
||||
stmt = delete(ChatSession).where(ChatSession.id == chat_session_id)
|
||||
db_session.execute(stmt)
|
||||
|
||||
else:
|
||||
chat_session.deleted = True
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _set_latest_chat_message_no_commit(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_edit_number: int | None,
|
||||
edit_number: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
if message_number != 0 and parent_edit_number is None:
|
||||
raise ValueError(
|
||||
"Only initial message in a chat is allowed to not have a parent"
|
||||
)
|
||||
|
||||
db_session.query(ChatMessage).filter(
|
||||
and_(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
ChatMessage.message_number == message_number,
|
||||
ChatMessage.parent_edit_number == parent_edit_number,
|
||||
)
|
||||
).update({ChatMessage.latest: False})
|
||||
|
||||
db_session.query(ChatMessage).filter(
|
||||
and_(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
ChatMessage.message_number == message_number,
|
||||
ChatMessage.edit_number == edit_number,
|
||||
)
|
||||
).update({ChatMessage.latest: True})
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
message: str,
|
||||
parent_edit_number: int | None,
|
||||
message_type: MessageType,
|
||||
db_session: Session,
|
||||
) -> ChatMessage:
|
||||
"""Creates a new chat message and sets it to the latest message of its parent message"""
|
||||
# Get the count of existing edits at the provided message number
|
||||
latest_edit_number = (
|
||||
db_session.query(func.max(ChatMessage.edit_number))
|
||||
.filter_by(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
)
|
||||
.scalar()
|
||||
)
|
||||
|
||||
# The new message is a new edit at the provided message number
|
||||
new_edit_number = latest_edit_number + 1 if latest_edit_number is not None else 0
|
||||
|
||||
# Create a new message and set it to be the latest for its parent message
|
||||
new_chat_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=new_edit_number,
|
||||
message=message,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
# Set the previous latest message of the same parent, as no longer the latest
|
||||
_set_latest_chat_message_no_commit(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=new_edit_number,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return new_chat_message
|
||||
|
||||
|
||||
def set_latest_chat_message(
|
||||
chat_session_id: int,
|
||||
message_number: int,
|
||||
parent_edit_number: int | None,
|
||||
edit_number: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
_set_latest_chat_message_no_commit(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
edit_number=edit_number,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
db_session.commit()
|
@ -22,7 +22,7 @@ def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent:
|
||||
query_event = result.scalar_one_or_none()
|
||||
|
||||
if not query_event:
|
||||
raise ValueError("Invalid Query Event provided for updating")
|
||||
raise ValueError("Invalid Query Event ID Provided")
|
||||
|
||||
return query_event
|
||||
|
||||
@ -33,7 +33,7 @@ def fetch_docs_by_id(doc_id: str, db_session: Session) -> DbDocument:
|
||||
doc = result.scalar_one_or_none()
|
||||
|
||||
if not doc:
|
||||
raise ValueError("Invalid Document provided for updating")
|
||||
raise ValueError("Invalid Document ID Provided")
|
||||
|
||||
return doc
|
||||
|
||||
|
@ -13,6 +13,7 @@ from sqlalchemy import DateTime
|
||||
from sqlalchemy import Enum
|
||||
from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
@ -25,6 +26,7 @@ from sqlalchemy.orm import relationship
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import InputType
|
||||
@ -52,7 +54,7 @@ class Base(DeclarativeBase):
|
||||
|
||||
class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base):
|
||||
# even an almost empty token from keycloak will not fit the default 1024 bytes
|
||||
access_token: Mapped[str] = mapped_column(Text(), nullable=False) # type: ignore
|
||||
access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore
|
||||
|
||||
|
||||
class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
@ -68,6 +70,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
query_events: Mapped[List["QueryEvent"]] = relationship(
|
||||
"QueryEvent", back_populates="user"
|
||||
)
|
||||
chat_sessions: Mapped[List["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="user"
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
@ -193,7 +198,7 @@ class IndexAttempt(Base):
|
||||
status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
|
||||
num_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0)
|
||||
error_msg: Mapped[str | None] = mapped_column(
|
||||
String(), default=None
|
||||
Text, default=None
|
||||
) # only filled if status = "failed"
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
@ -217,6 +222,15 @@ class IndexAttempt(Base):
|
||||
"Credential", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_index_attempt_latest_for_connector_credential_pair",
|
||||
"connector_id",
|
||||
"credential_id",
|
||||
"time_created",
|
||||
),
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IndexAttempt(id={self.id!r}, "
|
||||
@ -245,7 +259,7 @@ class DeletionAttempt(Base):
|
||||
status: Mapped[DeletionStatus] = mapped_column(Enum(DeletionStatus))
|
||||
num_docs_deleted: Mapped[int] = mapped_column(Integer, default=0)
|
||||
error_msg: Mapped[str | None] = mapped_column(
|
||||
String(), default=None
|
||||
Text, default=None
|
||||
) # only filled if status = "failed"
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
@ -291,12 +305,12 @@ class QueryEvent(Base):
|
||||
__tablename__ = "query_event"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
query: Mapped[str] = mapped_column(String())
|
||||
query: Mapped[str] = mapped_column(Text)
|
||||
# search_flow refers to user selection, None if user used auto
|
||||
selected_search_flow: Mapped[SearchType | None] = mapped_column(
|
||||
Enum(SearchType), nullable=True
|
||||
)
|
||||
llm_answer: Mapped[str | None] = mapped_column(String(), default=None)
|
||||
llm_answer: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
feedback: Mapped[QAFeedbackType | None] = mapped_column(
|
||||
Enum(QAFeedbackType), nullable=True
|
||||
)
|
||||
@ -340,8 +354,8 @@ class DocumentRetrievalFeedback(Base):
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
|
||||
# this should correspond to the ID of the document (as is passed around
|
||||
# in Danswer)
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Danswer)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
@ -354,3 +368,46 @@ class Document(Base):
|
||||
retrieval_feedbacks: Mapped[List[DocumentRetrievalFeedback]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
|
||||
|
||||
class ChatSession(Base):
|
||||
__tablename__ = "chat_session"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
onupdate=func.now(),
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
|
||||
messages: Mapped[List["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session", cascade="delete"
|
||||
)
|
||||
|
||||
|
||||
class ChatMessage(Base):
|
||||
__tablename__ = "chat_message"
|
||||
|
||||
chat_session_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_session.id"), primary_key=True
|
||||
)
|
||||
message_number: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
edit_number: Mapped[int] = mapped_column(Integer, default=0, primary_key=True)
|
||||
parent_edit_number: Mapped[int | None] = mapped_column(
|
||||
Integer, nullable=True
|
||||
) # null if first message
|
||||
latest: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
message_type: Mapped[MessageType] = mapped_column(Enum(MessageType))
|
||||
time_sent: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
chat_session: Mapped[ChatSession] = relationship("ChatSession")
|
||||
|
@ -30,6 +30,7 @@ from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.server.chat_backend import router as chat_router
|
||||
from danswer.server.credential import router as credential_router
|
||||
from danswer.server.event_loading import router as event_processing_router
|
||||
from danswer.server.health import router as health_router
|
||||
@ -66,6 +67,7 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
|
||||
application.include_router(backend_router)
|
||||
application.include_router(chat_router)
|
||||
application.include_router(event_processing_router)
|
||||
application.include_router(admin_router)
|
||||
application.include_router(user_router)
|
||||
|
19
backend/danswer/secondary_llm_flows/chat_helpers.py
Normal file
19
backend/danswer/secondary_llm_flows/chat_helpers.py
Normal file
@ -0,0 +1,19 @@
|
||||
from danswer.llm.build import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
|
||||
|
||||
def get_chat_name_messages(user_query: str) -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Give a short name for this chat session based on the user's first message.",
|
||||
},
|
||||
{"role": "user", "content": user_query},
|
||||
]
|
||||
return messages
|
||||
|
||||
|
||||
def get_new_chat_name(user_query: str) -> str:
|
||||
messages = get_chat_name_messages(user_query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
return get_default_llm().invoke(filled_llm_prompt)
|
373
backend/danswer/server/chat_backend.py
Normal file
373
backend/danswer/server/chat_backend.py
Normal file
@ -0,0 +1,373 @@
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_llm import llm_chat_answer
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import delete_chat_session
|
||||
from danswer.db.chat import fetch_chat_message
|
||||
from danswer.db.chat import fetch_chat_messages_by_session
|
||||
from danswer.db.chat import fetch_chat_session_by_id
|
||||
from danswer.db.chat import fetch_chat_sessions_by_user
|
||||
from danswer.db.chat import set_latest_chat_message
|
||||
from danswer.db.chat import update_chat_session
|
||||
from danswer.db.chat import verify_parent_exists
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name
|
||||
from danswer.server.models import ChatMessageDetail
|
||||
from danswer.server.models import ChatMessageIdentifier
|
||||
from danswer.server.models import ChatRenameRequest
|
||||
from danswer.server.models import ChatSessionDetailResponse
|
||||
from danswer.server.models import ChatSessionIdsResponse
|
||||
from danswer.server.models import CreateChatID
|
||||
from danswer.server.models import CreateChatRequest
|
||||
from danswer.server.models import RenameChatSessionResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
@router.get("/get-user-chat-sessions")
|
||||
def get_user_chat_sessions(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionIdsResponse:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
# Don't included deleted chats, even if soft delete only
|
||||
chat_sessions = fetch_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session
|
||||
)
|
||||
|
||||
return ChatSessionIdsResponse(sessions=[chat.id for chat in chat_sessions])
|
||||
|
||||
|
||||
@router.get("/get-chat-session/{session_id}")
|
||||
def get_chat_session_messages(
|
||||
session_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionDetailResponse:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
try:
|
||||
session = fetch_chat_session_by_id(session_id, db_session)
|
||||
except ValueError:
|
||||
raise ValueError("Chat Session has been deleted")
|
||||
|
||||
if session.deleted:
|
||||
raise ValueError("Chat Session has been deleted")
|
||||
|
||||
if user_id != session.user_id:
|
||||
if user is None:
|
||||
raise PermissionError(
|
||||
"The No-Auth User is trying to read a different user's chat"
|
||||
)
|
||||
raise PermissionError(
|
||||
f"User {user.email} is trying to read a different user's chat"
|
||||
)
|
||||
|
||||
session_messages = fetch_chat_messages_by_session(
|
||||
chat_session_id=session_id, db_session=db_session
|
||||
)
|
||||
|
||||
return ChatSessionDetailResponse(
|
||||
chat_session_id=session_id,
|
||||
description=session.description,
|
||||
messages=[
|
||||
ChatMessageDetail(
|
||||
message_number=msg.message_number,
|
||||
edit_number=msg.edit_number,
|
||||
parent_edit_number=msg.parent_edit_number,
|
||||
latest=msg.latest,
|
||||
message=msg.message,
|
||||
message_type=msg.message_type,
|
||||
time_sent=msg.time_sent,
|
||||
)
|
||||
for msg in session_messages
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.post("/create-chat-session")
|
||||
def create_new_chat_session(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatID:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
new_chat_session = create_chat_session(
|
||||
"", user_id, db_session # Leave the naming till later to prevent delay
|
||||
)
|
||||
|
||||
return CreateChatID(chat_session_id=new_chat_session.id)
|
||||
|
||||
|
||||
@router.put("/rename-chat-session")
|
||||
def rename_chat_session(
|
||||
rename: ChatRenameRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> RenameChatSessionResponse:
|
||||
name = rename.name
|
||||
message = rename.first_message
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
if not name and not message:
|
||||
raise ValueError("Can't assign a name for the chat without context")
|
||||
|
||||
new_name = name or get_new_chat_name(str(message))
|
||||
|
||||
update_chat_session(user_id, rename.chat_session_id, new_name, db_session)
|
||||
|
||||
return RenameChatSessionResponse(new_name=new_name)
|
||||
|
||||
|
||||
@router.delete("/delete-chat-session/{session_id}")
|
||||
def delete_chat_session_by_id(
|
||||
session_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user is not None else None
|
||||
delete_chat_session(user_id, session_id, db_session)
|
||||
|
||||
|
||||
def _create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
stop_after: int | None = None,
|
||||
) -> list[ChatMessage]:
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
all_chat_messages = fetch_chat_messages_by_session(chat_session_id, db_session)
|
||||
target_message_num = 0
|
||||
target_parent_edit_num = None
|
||||
|
||||
# Chat messages must be ordered by message_number
|
||||
# (fetch_chat_messages_by_session ensures this so no resorting here necessary)
|
||||
for msg in all_chat_messages:
|
||||
if (
|
||||
msg.message_number != target_message_num
|
||||
or msg.parent_edit_number != target_parent_edit_num
|
||||
or not msg.latest
|
||||
):
|
||||
continue
|
||||
|
||||
target_parent_edit_num = msg.edit_number
|
||||
target_message_num += 1
|
||||
|
||||
mainline_messages.append(msg)
|
||||
|
||||
if stop_after is not None and target_message_num > stop_after:
|
||||
break
|
||||
|
||||
if not mainline_messages:
|
||||
raise RuntimeError("Could not trace chat message history")
|
||||
|
||||
return mainline_messages
|
||||
|
||||
|
||||
@router.post("/send-message")
|
||||
def handle_new_chat_message(
|
||||
chat_message: CreateChatRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
"""This endpoint is both used for sending new messages and for sending edited messages.
|
||||
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
|
||||
have already been set as latest"""
|
||||
chat_session_id = chat_message.chat_session_id
|
||||
message_number = chat_message.message_number
|
||||
message_content = chat_message.message
|
||||
parent_edit_number = chat_message.parent_edit_number
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Cannot send messages to a deleted chat session")
|
||||
|
||||
if chat_session.user_id != user_id:
|
||||
if user is None:
|
||||
raise PermissionError(
|
||||
"The No-Auth User trying to interact with a different user's chat"
|
||||
)
|
||||
raise PermissionError(
|
||||
f"User {user.email} trying to interact with a different user's chat"
|
||||
)
|
||||
|
||||
if message_number != 0:
|
||||
if parent_edit_number is None:
|
||||
raise ValueError("Message must have a valid parent message")
|
||||
|
||||
verify_parent_exists(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
db_session=db_session,
|
||||
)
|
||||
else:
|
||||
if parent_edit_number is not None:
|
||||
raise ValueError("Initial message in session cannot have parent")
|
||||
|
||||
# Create new message at the right place in the tree and label it latest for its parent
|
||||
new_message = create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
parent_edit_number=parent_edit_number,
|
||||
message=message_content,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
mainline_messages = _create_chat_chain(
|
||||
chat_session_id,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if mainline_messages[-1].message != message_content:
|
||||
raise RuntimeError(
|
||||
"The new message was not on the mainline. "
|
||||
"Be sure to update latests before calling this."
|
||||
)
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_chat_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(mainline_messages)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
llm_output += token
|
||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
|
||||
|
||||
create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number + 1,
|
||||
parent_edit_number=new_message.edit_number,
|
||||
message=llm_output,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return StreamingResponse(stream_chat_tokens(), media_type="application/json")
|
||||
|
||||
|
||||
@router.post("/regenerate-from-parent")
|
||||
def regenerate_message_given_parent(
|
||||
parent_message: ChatMessageIdentifier,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
"""Regenerate an LLM response given a particular parent message
|
||||
The parent message is set as latest and a new LLM response is set as
|
||||
the latest following message"""
|
||||
chat_session_id = parent_message.chat_session_id
|
||||
message_number = parent_message.message_number
|
||||
edit_number = parent_message.edit_number
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_message = fetch_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number,
|
||||
edit_number=edit_number,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Chat session has been deleted")
|
||||
|
||||
if chat_session.user_id != user_id:
|
||||
if user is None:
|
||||
raise PermissionError(
|
||||
"The No-Auth User trying to regenerate chat messages of another user"
|
||||
)
|
||||
raise PermissionError(
|
||||
f"User {user.email} trying to regenerate chat messages of another user"
|
||||
)
|
||||
|
||||
set_latest_chat_message(
|
||||
chat_session_id,
|
||||
message_number,
|
||||
chat_message.parent_edit_number,
|
||||
edit_number,
|
||||
db_session,
|
||||
)
|
||||
|
||||
# The parent message, now set as latest, may have follow on messages
|
||||
# Don't want to include those in the context to LLM
|
||||
mainline_messages = _create_chat_chain(
|
||||
chat_session_id, db_session, stop_after=message_number
|
||||
)
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_regenerate_tokens() -> Iterator[str]:
|
||||
tokens = llm_chat_answer(mainline_messages)
|
||||
llm_output = ""
|
||||
for token in tokens:
|
||||
llm_output += token
|
||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=token)))
|
||||
|
||||
create_new_chat_message(
|
||||
chat_session_id=chat_session_id,
|
||||
message_number=message_number + 1,
|
||||
parent_edit_number=edit_number,
|
||||
message=llm_output,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return StreamingResponse(stream_regenerate_tokens(), media_type="application/json")
|
||||
|
||||
|
||||
@router.put("/set-message-as-latest")
|
||||
def set_message_as_latest(
|
||||
message_identifier: ChatMessageIdentifier,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_message = fetch_chat_message(
|
||||
chat_session_id=message_identifier.chat_session_id,
|
||||
message_number=message_identifier.message_number,
|
||||
edit_number=message_identifier.edit_number,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
chat_session = chat_message.chat_session
|
||||
|
||||
if chat_session.deleted:
|
||||
raise ValueError("Chat session has been deleted")
|
||||
|
||||
if chat_session.user_id != user_id:
|
||||
if user is None:
|
||||
raise PermissionError(
|
||||
"The No-Auth User trying to update chat messages of another user"
|
||||
)
|
||||
raise PermissionError(
|
||||
f"User {user.email} trying to update chat messages of another user"
|
||||
)
|
||||
|
||||
set_latest_chat_message(
|
||||
chat_session_id=chat_message.chat_session_id,
|
||||
message_number=chat_message.message_number,
|
||||
parent_edit_number=chat_message.parent_edit_number,
|
||||
edit_number=chat_message.edit_number,
|
||||
db_session=db_session,
|
||||
)
|
@ -11,6 +11,7 @@ from pydantic.generics import GenericModel
|
||||
|
||||
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import InputType
|
||||
@ -130,6 +131,10 @@ class SearchDoc(BaseModel):
|
||||
score: float | None
|
||||
|
||||
|
||||
class CreateChatID(BaseModel):
|
||||
chat_session_id: int
|
||||
|
||||
|
||||
class QuestionRequest(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
@ -151,6 +156,49 @@ class SearchFeedbackRequest(BaseModel):
|
||||
search_feedback: SearchFeedbackType
|
||||
|
||||
|
||||
class CreateChatRequest(BaseModel):
|
||||
chat_session_id: int
|
||||
message_number: int
|
||||
parent_edit_number: int | None
|
||||
message: str
|
||||
|
||||
|
||||
class ChatMessageIdentifier(BaseModel):
|
||||
chat_session_id: int
|
||||
message_number: int
|
||||
edit_number: int
|
||||
|
||||
|
||||
class ChatRenameRequest(BaseModel):
|
||||
chat_session_id: int
|
||||
name: str | None
|
||||
first_message: str | None
|
||||
|
||||
|
||||
class RenameChatSessionResponse(BaseModel):
|
||||
new_name: str # This is only really useful if the name is generated
|
||||
|
||||
|
||||
class ChatSessionIdsResponse(BaseModel):
|
||||
sessions: list[int]
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
message_number: int
|
||||
edit_number: int
|
||||
parent_edit_number: int | None
|
||||
latest: bool
|
||||
message: str
|
||||
message_type: MessageType
|
||||
time_sent: datetime
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: int
|
||||
description: str
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
|
||||
class QueryValidationResponse(BaseModel):
|
||||
reasoning: str
|
||||
answerable: bool
|
||||
|
@ -1,6 +1,7 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
@ -10,7 +11,7 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
FG = TypeVar("FG", bound=Callable[..., Generator])
|
||||
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
||||
|
||||
|
||||
def log_function_time(
|
||||
|
Loading…
x
Reference in New Issue
Block a user