From 745f68241d248388c22e787ac60bded202820df0 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 3 May 2024 16:37:18 -0700 Subject: [PATCH] Chat Folders Backend (#1419) --- .../versions/7547d982db8f_chat_folders.py | 51 ++++++ backend/danswer/db/folder.py | 132 ++++++++++++++ backend/danswer/db/models.py | 34 ++++ backend/danswer/main.py | 2 + .../server/features/folder/__init__.py | 0 backend/danswer/server/features/folder/api.py | 171 ++++++++++++++++++ .../danswer/server/features/folder/models.py | 33 ++++ .../danswer/server/features/persona/api.py | 6 +- backend/danswer/server/models.py | 4 + 9 files changed, 428 insertions(+), 5 deletions(-) create mode 100644 backend/alembic/versions/7547d982db8f_chat_folders.py create mode 100644 backend/danswer/db/folder.py create mode 100644 backend/danswer/server/features/folder/__init__.py create mode 100644 backend/danswer/server/features/folder/api.py create mode 100644 backend/danswer/server/features/folder/models.py diff --git a/backend/alembic/versions/7547d982db8f_chat_folders.py b/backend/alembic/versions/7547d982db8f_chat_folders.py new file mode 100644 index 000000000..b0eb65785 --- /dev/null +++ b/backend/alembic/versions/7547d982db8f_chat_folders.py @@ -0,0 +1,51 @@ +"""Chat Folders + +Revision ID: 7547d982db8f +Revises: ef7da92f7213 +Create Date: 2024-05-02 15:18:56.573347 + +""" +from alembic import op +import sqlalchemy as sa +import fastapi_users_db_sqlalchemy + +# revision identifiers, used by Alembic. +revision = "7547d982db8f" +down_revision = "ef7da92f7213" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "chat_folder", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + sa.Column("name", sa.String(), nullable=True), + sa.Column("display_priority", sa.Integer(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column("chat_session", sa.Column("folder_id", sa.Integer(), nullable=True)) + op.create_foreign_key( + "chat_session_chat_folder_fk", + "chat_session", + "chat_folder", + ["folder_id"], + ["id"], + ) + + +def downgrade() -> None: + op.drop_constraint( + "chat_session_chat_folder_fk", "chat_session", type_="foreignkey" + ) + op.drop_column("chat_session", "folder_id") + op.drop_table("chat_folder") diff --git a/backend/danswer/db/folder.py b/backend/danswer/db/folder.py new file mode 100644 index 000000000..77e543a8d --- /dev/null +++ b/backend/danswer/db/folder.py @@ -0,0 +1,132 @@ +from uuid import UUID + +from sqlalchemy.orm import Session + +from danswer.db.chat import delete_chat_session +from danswer.db.models import ChatFolder +from danswer.db.models import ChatSession +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def get_user_folders( + user_id: UUID | None, + db_session: Session, +) -> list[ChatFolder]: + return db_session.query(ChatFolder).filter(ChatFolder.user_id == user_id).all() + + +def update_folder_display_priority( + user_id: UUID | None, + display_priority_map: dict[int, int], + db_session: Session, +) -> None: + folders = get_user_folders(user_id=user_id, db_session=db_session) + folder_ids = {folder.id for folder in folders} + if folder_ids != set(display_priority_map.keys()): + raise ValueError("Invalid Folder IDs provided") + + for folder in folders: + folder.display_priority = display_priority_map[folder.id] + + db_session.commit() + + +def get_folder_by_id( + user_id: UUID | None, + folder_id: int, + db_session: Session, +) -> ChatFolder: + folder = ( + db_session.query(ChatFolder).filter(ChatFolder.id == folder_id).one_or_none() + ) + if not folder: + raise ValueError("Folder by specified id does not exist") + + if folder.user_id != user_id: + raise PermissionError(f"Folder does not belong to user: {user_id}") + + return folder + + +def create_folder( + user_id: UUID | None, folder_name: str | None, db_session: Session +) -> int: + new_folder = ChatFolder( + user_id=user_id, + name=folder_name, + ) + db_session.add(new_folder) + db_session.commit() + + return new_folder.id + + +def rename_folder( + user_id: UUID | None, folder_id: int, folder_name: str | None, db_session: Session +) -> None: + folder = get_folder_by_id( + user_id=user_id, folder_id=folder_id, db_session=db_session + ) + + folder.name = folder_name + db_session.commit() + + +def add_chat_to_folder( + user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session +) -> None: + folder = get_folder_by_id( + user_id=user_id, folder_id=folder_id, db_session=db_session + ) + + chat_session.folder_id = folder.id + + db_session.commit() + + +def remove_chat_from_folder( + user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session +) -> None: + folder = get_folder_by_id( + user_id=user_id, folder_id=folder_id, db_session=db_session + ) + + if chat_session.folder_id != folder.id: + raise ValueError("The chat session is not in the specified folder.") + + if folder.user_id != user_id: + raise ValueError( + f"Tried to remove a chat session from a folder that does not below to " + f"this user, user id: {user_id}" + ) + + chat_session.folder_id = None + if chat_session in folder.chat_sessions: + folder.chat_sessions.remove(chat_session) + + db_session.commit() + + +def delete_folder( + user_id: UUID | None, + folder_id: int, + including_chats: bool, + db_session: Session, +) -> None: + folder = get_folder_by_id( + user_id=user_id, folder_id=folder_id, db_session=db_session + ) + + # Assuming there will not be a massive number of chats in any given folder + if including_chats: + for chat_session in folder.chat_sessions: + delete_chat_session( + user_id=user_id, + chat_session_id=chat_session.id, + db_session=db_session, + ) + + db_session.delete(folder) + db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index beb0c9a69..75cff3c25 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -76,6 +76,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base): chat_sessions: Mapped[List["ChatSession"]] = relationship( "ChatSession", back_populates="user" ) + chat_folders: Mapped[List["ChatFolder"]] = relationship( + "ChatFolder", back_populates="user" + ) prompts: Mapped[List["Prompt"]] = relationship("Prompt", back_populates="user") # Personas owned by this user personas: Mapped[List["Persona"]] = relationship("Persona", back_populates="user") @@ -572,6 +575,9 @@ class ChatSession(Base): Enum(ChatSessionSharedStatus, native_enum=False), default=ChatSessionSharedStatus.PRIVATE, ) + folder_id: Mapped[int | None] = mapped_column( + ForeignKey("chat_folder.id"), nullable=True + ) # the latest "overrides" specified by the user. These take precedence over # the attached persona. However, overrides specified directly in the @@ -596,6 +602,9 @@ class ChatSession(Base): ) user: Mapped[User] = relationship("User", back_populates="chat_sessions") + folder: Mapped["ChatFolder"] = relationship( + "ChatFolder", back_populates="chat_sessions" + ) messages: Mapped[List["ChatMessage"]] = relationship( "ChatMessage", back_populates="chat_session", cascade="delete" ) @@ -656,6 +665,31 @@ class ChatMessage(Base): ) +class ChatFolder(Base): + """For organizing chat sessions""" + + __tablename__ = "chat_folder" + + id: Mapped[int] = mapped_column(primary_key=True) + # Only null if auth is off + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + name: Mapped[str | None] = mapped_column(String, nullable=True) + display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0) + + user: Mapped[User] = relationship("User", back_populates="chat_folders") + chat_sessions: Mapped[List["ChatSession"]] = relationship( + "ChatSession", back_populates="folder" + ) + + def __lt__(self, other: Any) -> bool: + if not isinstance(other, ChatFolder): + return NotImplemented + if self.display_priority == other.display_priority: + # Bigger ID (created later) show earlier + return self.id > other.id + return self.display_priority < other.display_priority + + """ Feedback, Logging, Metrics Tables """ diff --git a/backend/danswer/main.py b/backend/danswer/main.py index c43833e56..badc435ed 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -58,6 +58,7 @@ from danswer.server.documents.connector import router as connector_router from danswer.server.documents.credential import router as credential_router from danswer.server.documents.document import router as document_router from danswer.server.features.document_set.api import router as document_set_router +from danswer.server.features.folder.api import router as folder_router from danswer.server.features.persona.api import admin_router as admin_persona_router from danswer.server.features.persona.api import basic_router as persona_router from danswer.server.features.prompt.api import basic_router as prompt_router @@ -261,6 +262,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, connector_router) include_router_with_global_prefix_prepended(application, credential_router) include_router_with_global_prefix_prepended(application, cc_pair_router) + include_router_with_global_prefix_prepended(application, folder_router) include_router_with_global_prefix_prepended(application, document_set_router) include_router_with_global_prefix_prepended(application, secondary_index_router) include_router_with_global_prefix_prepended( diff --git a/backend/danswer/server/features/folder/__init__.py b/backend/danswer/server/features/folder/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/danswer/server/features/folder/api.py b/backend/danswer/server/features/folder/api.py new file mode 100644 index 000000000..a1204e342 --- /dev/null +++ b/backend/danswer/server/features/folder/api.py @@ -0,0 +1,171 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException +from fastapi import Path +from sqlalchemy.orm import Session + +from danswer.auth.users import current_user +from danswer.db.chat import get_chat_session_by_id +from danswer.db.engine import get_session +from danswer.db.folder import add_chat_to_folder +from danswer.db.folder import create_folder +from danswer.db.folder import delete_folder +from danswer.db.folder import get_user_folders +from danswer.db.folder import remove_chat_from_folder +from danswer.db.folder import rename_folder +from danswer.db.folder import update_folder_display_priority +from danswer.db.models import User +from danswer.server.features.folder.models import DeleteFolderOptions +from danswer.server.features.folder.models import FolderChatMinimalInfo +from danswer.server.features.folder.models import FolderChatSessionRequest +from danswer.server.features.folder.models import FolderCreationRequest +from danswer.server.features.folder.models import FolderResponse +from danswer.server.features.folder.models import FolderUpdateRequest +from danswer.server.features.folder.models import GetUserFoldersResponse +from danswer.server.models import DisplayPriorityRequest + +router = APIRouter(prefix="/folder") + + +@router.get("") +def get_folders( + user: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> GetUserFoldersResponse: + folders = get_user_folders( + user_id=user.id if user else None, + db_session=db_session, + ) + folders.sort() + return GetUserFoldersResponse( + folders=[ + FolderResponse( + folder_id=folder.id, + folder_name=folder.name, + display_priority=folder.display_priority, + chat_sessions=[ + FolderChatMinimalInfo( + chat_session_id=chat_session.id, + chat_session_name=chat_session.description, + ) + for chat_session in folder.chat_sessions + ], + ) + for folder in folders + ] + ) + + +@router.put("/reorder") +def put_folder_display_priority( + display_priority_request: DisplayPriorityRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + update_folder_display_priority( + user_id=user.id if user else None, + display_priority_map=display_priority_request.display_priority_map, + db_session=db_session, + ) + + +@router.post("") +def create_folder_endpoint( + request: FolderCreationRequest, + user: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> int: + return create_folder( + user_id=user.id if user else None, + folder_name=request.folder_name, + db_session=db_session, + ) + + +@router.patch("/{folder_id}") +def patch_folder_endpoint( + request: FolderUpdateRequest, + folder_id: int = Path(..., description="The ID of the folder to rename"), + user: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + try: + rename_folder( + user_id=user.id if user else None, + folder_id=folder_id, + folder_name=request.folder_name, + db_session=db_session, + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.delete("/{folder_id}") +def delete_folder_endpoint( + request: DeleteFolderOptions, + folder_id: int = Path(..., description="The ID of the folder to delete"), + user: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + user_id = user.id if user else None + try: + delete_folder( + user_id=user_id, + folder_id=folder_id, + including_chats=request.including_chats, + db_session=db_session, + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.post("/{folder_id}/add-chat-session") +def add_chat_to_folder_endpoint( + request: FolderChatSessionRequest, + folder_id: int = Path( + ..., description="The ID of the folder in which to add the chat session" + ), + user: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + user_id = user.id if user else None + try: + chat_session = get_chat_session_by_id( + chat_session_id=request.chat_session_id, + user_id=user_id, + db_session=db_session, + ) + add_chat_to_folder( + user_id=user.id if user else None, + folder_id=folder_id, + chat_session=chat_session, + db_session=db_session, + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) + + +@router.post("/{folder_id}/remove-chat-session/") +def remove_chat_from_folder_endpoint( + request: FolderChatSessionRequest, + folder_id: int = Path( + ..., description="The ID of the folder from which to remove the chat session" + ), + user: User = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + user_id = user.id if user else None + try: + chat_session = get_chat_session_by_id( + chat_session_id=request.chat_session_id, + user_id=user_id, + db_session=db_session, + ) + remove_chat_from_folder( + user_id=user_id, + folder_id=folder_id, + chat_session=chat_session, + db_session=db_session, + ) + except Exception as e: + raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/danswer/server/features/folder/models.py b/backend/danswer/server/features/folder/models.py new file mode 100644 index 000000000..87b65072a --- /dev/null +++ b/backend/danswer/server/features/folder/models.py @@ -0,0 +1,33 @@ +from pydantic import BaseModel + + +class FolderChatMinimalInfo(BaseModel): + chat_session_id: int + chat_session_name: str + + +class FolderResponse(BaseModel): + folder_id: int + folder_name: str | None + display_priority: int + chat_sessions: list[FolderChatMinimalInfo] + + +class GetUserFoldersResponse(BaseModel): + folders: list[FolderResponse] + + +class FolderCreationRequest(BaseModel): + folder_name: str | None = None + + +class FolderUpdateRequest(BaseModel): + folder_name: str | None + + +class FolderChatSessionRequest(BaseModel): + chat_session_id: int + + +class DeleteFolderOptions(BaseModel): + including_chats: bool = False diff --git a/backend/danswer/server/features/persona/api.py b/backend/danswer/server/features/persona/api.py index f316560d0..e68570976 100644 --- a/backend/danswer/server/features/persona/api.py +++ b/backend/danswer/server/features/persona/api.py @@ -17,6 +17,7 @@ from danswer.llm.answering.prompts.utils import build_dummy_prompt from danswer.server.features.persona.models import CreatePersonaRequest from danswer.server.features.persona.models import PersonaSnapshot from danswer.server.features.persona.models import PromptTemplateResponse +from danswer.server.models import DisplayPriorityRequest from danswer.utils.logger import setup_logger logger = setup_logger() @@ -44,11 +45,6 @@ def patch_persona_visibility( ) -class DisplayPriorityRequest(BaseModel): - # maps persona id to display priority - display_priority_map: dict[int, int] - - @admin_router.put("/display-priority") def patch_persona_display_priority( display_priority_request: DisplayPriorityRequest, diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index ca23f0a15..21349ae07 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -27,3 +27,7 @@ class IdReturn(BaseModel): class MinimalUserSnapshot(BaseModel): id: UUID email: str + + +class DisplayPriorityRequest(BaseModel): + display_priority_map: dict[int, int]