diff --git a/backend/alembic/versions/e0a68a81d434_add_chat_feedback.py b/backend/alembic/versions/e0a68a81d434_add_chat_feedback.py new file mode 100644 index 000000000..528711bce --- /dev/null +++ b/backend/alembic/versions/e0a68a81d434_add_chat_feedback.py @@ -0,0 +1,44 @@ +"""Add Chat Feedback + +Revision ID: e0a68a81d434 +Revises: ae62505e3acc +Create Date: 2023-10-04 20:22:33.380286 + +""" +from alembic import op +import sqlalchemy as sa + +# revision identifiers, used by Alembic. +revision = "e0a68a81d434" +down_revision = "ae62505e3acc" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "chat_feedback", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("chat_message_chat_session_id", sa.Integer(), nullable=False), + sa.Column("chat_message_message_number", sa.Integer(), nullable=False), + sa.Column("chat_message_edit_number", sa.Integer(), nullable=False), + sa.Column("is_positive", sa.Boolean(), nullable=True), + sa.Column("feedback_text", sa.Text(), nullable=True), + sa.ForeignKeyConstraint( + [ + "chat_message_chat_session_id", + "chat_message_message_number", + "chat_message_edit_number", + ], + [ + "chat_message.chat_session_id", + "chat_message.message_number", + "chat_message.edit_number", + ], + ), + sa.PrimaryKeyConstraint("id"), + ) + + +def downgrade() -> None: + op.drop_table("chat_feedback") diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 06bac1553..2348a31b4 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -4,12 +4,16 @@ from sqlalchemy import asc from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import select +from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import Session +from danswer.configs.constants import MessageType from danswer.configs.constants import QAFeedbackType from danswer.configs.constants import SearchFeedbackType from danswer.datastores.document_index import get_default_document_index from danswer.datastores.interfaces import UpdateRequest +from danswer.db.models import ChatMessage as DbChatMessage +from danswer.db.models import ChatMessageFeedback from danswer.db.models import Document as DbDocument from danswer.db.models import DocumentRetrievalFeedback from danswer.db.models import QueryEvent @@ -164,3 +168,46 @@ def delete_document_feedback_for_documents( DocumentRetrievalFeedback.document_id.in_(document_ids) ) db_session.execute(stmt) + + +def create_chat_message_feedback( + chat_session_id: int, + message_number: int, + edit_number: int, + user_id: UUID | None, + db_session: Session, + is_positive: bool | None = None, + feedback_text: str | None = None, +) -> None: + if is_positive is None and feedback_text is None: + raise ValueError("No feedback provided") + + try: + chat_message = ( + db_session.query(DbChatMessage) + .filter_by( + chat_session_id=chat_session_id, + message_number=message_number, + edit_number=edit_number, + ) + .one() + ) + except NoResultFound: + raise ValueError("ChatMessage not found") + + if chat_message.message_type != MessageType.ASSISTANT: + raise ValueError("Can only provide feedback on LLM Outputs") + + if user_id is not None and chat_message.chat_session.user_id != user_id: + raise ValueError("User trying to give feedback on a message by another user.") + + message_feedback = ChatMessageFeedback( + chat_message_chat_session_id=chat_session_id, + chat_message_message_number=message_number, + chat_message_edit_number=edit_number, + is_positive=is_positive, + feedback_text=feedback_text, + ) + + db_session.add(message_feedback) + db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 54ddb8cf2..09bdad108 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -14,6 +14,7 @@ from sqlalchemy import Boolean from sqlalchemy import DateTime from sqlalchemy import Enum from sqlalchemy import ForeignKey +from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import Integer @@ -478,6 +479,42 @@ class ChatMessage(Base): persona: Mapped[Persona | None] = relationship("Persona") +class ChatMessageFeedback(Base): + __tablename__ = "chat_feedback" + + id: Mapped[int] = mapped_column(Integer, primary_key=True) + chat_message_chat_session_id: Mapped[int] = mapped_column(Integer) + chat_message_message_number: Mapped[int] = mapped_column(Integer) + chat_message_edit_number: Mapped[int] = mapped_column(Integer) + is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True) + feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True) + + __table_args__ = ( + ForeignKeyConstraint( + [ + "chat_message_chat_session_id", + "chat_message_message_number", + "chat_message_edit_number", + ], + [ + "chat_message.chat_session_id", + "chat_message.message_number", + "chat_message.edit_number", + ], + ), + ) + + chat_message: Mapped[ChatMessage] = relationship( + "ChatMessage", + foreign_keys=[ + chat_message_chat_session_id, + chat_message_message_number, + chat_message_edit_number, + ], + backref="feedbacks", + ) + + AllowedAnswerFilters = ( Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"] ) diff --git a/backend/danswer/server/chat_backend.py b/backend/danswer/server/chat_backend.py index 5a57ba240..d832e0d77 100644 --- a/backend/danswer/server/chat_backend.py +++ b/backend/danswer/server/chat_backend.py @@ -21,11 +21,13 @@ from danswer.db.chat import set_latest_chat_message from danswer.db.chat import update_chat_session from danswer.db.chat import verify_parent_exists from danswer.db.engine import get_session +from danswer.db.feedback import create_chat_message_feedback from danswer.db.models import ChatMessage from danswer.db.models import User from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.llm.utils import get_default_llm_tokenizer from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name +from danswer.server.models import ChatFeedbackRequest from danswer.server.models import ChatMessageDetail from danswer.server.models import ChatMessageIdentifier from danswer.server.models import ChatRenameRequest @@ -163,6 +165,25 @@ def delete_chat_session_by_id( delete_chat_session(user_id, session_id, db_session) +@router.post("/create-chat-message-feedback") +def create_chat_feedback( + feedback: ChatFeedbackRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + user_id = user.id if user else None + + create_chat_message_feedback( + chat_session_id=feedback.chat_session_id, + message_number=feedback.message_number, + edit_number=feedback.edit_number, + user_id=user_id, + db_session=db_session, + is_positive=feedback.is_positive, + feedback_text=feedback.feedback_text, + ) + + def _create_chat_chain( chat_session_id: int, db_session: Session, diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 9cb6179f4..d4e3c762c 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -166,6 +166,14 @@ class QAFeedbackRequest(BaseModel): feedback: QAFeedbackType +class ChatFeedbackRequest(BaseModel): + chat_session_id: int + message_number: int + edit_number: int + is_positive: bool | None = None + feedback_text: str | None = None + + class SearchFeedbackRequest(BaseModel): query_id: int document_id: str