Chat Folders Backend (#1419)

This commit is contained in:
Yuhong Sun 2024-05-03 16:37:18 -07:00 committed by GitHub
parent 6cbfe1bcdb
commit 745f68241d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 428 additions and 5 deletions

View File

@ -0,0 +1,51 @@
"""Chat Folders
Revision ID: 7547d982db8f
Revises: ef7da92f7213
Create Date: 2024-05-02 15:18:56.573347
"""
from alembic import op
import sqlalchemy as sa
import fastapi_users_db_sqlalchemy
# revision identifiers, used by Alembic.
revision = "7547d982db8f"
down_revision = "ef7da92f7213"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chat_folder",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column("name", sa.String(), nullable=True),
sa.Column("display_priority", sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("chat_session", sa.Column("folder_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"chat_session_chat_folder_fk",
"chat_session",
"chat_folder",
["folder_id"],
["id"],
)
def downgrade() -> None:
op.drop_constraint(
"chat_session_chat_folder_fk", "chat_session", type_="foreignkey"
)
op.drop_column("chat_session", "folder_id")
op.drop_table("chat_folder")

View File

@ -0,0 +1,132 @@
from uuid import UUID
from sqlalchemy.orm import Session
from danswer.db.chat import delete_chat_session
from danswer.db.models import ChatFolder
from danswer.db.models import ChatSession
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_user_folders(
user_id: UUID | None,
db_session: Session,
) -> list[ChatFolder]:
return db_session.query(ChatFolder).filter(ChatFolder.user_id == user_id).all()
def update_folder_display_priority(
user_id: UUID | None,
display_priority_map: dict[int, int],
db_session: Session,
) -> None:
folders = get_user_folders(user_id=user_id, db_session=db_session)
folder_ids = {folder.id for folder in folders}
if folder_ids != set(display_priority_map.keys()):
raise ValueError("Invalid Folder IDs provided")
for folder in folders:
folder.display_priority = display_priority_map[folder.id]
db_session.commit()
def get_folder_by_id(
user_id: UUID | None,
folder_id: int,
db_session: Session,
) -> ChatFolder:
folder = (
db_session.query(ChatFolder).filter(ChatFolder.id == folder_id).one_or_none()
)
if not folder:
raise ValueError("Folder by specified id does not exist")
if folder.user_id != user_id:
raise PermissionError(f"Folder does not belong to user: {user_id}")
return folder
def create_folder(
user_id: UUID | None, folder_name: str | None, db_session: Session
) -> int:
new_folder = ChatFolder(
user_id=user_id,
name=folder_name,
)
db_session.add(new_folder)
db_session.commit()
return new_folder.id
def rename_folder(
user_id: UUID | None, folder_id: int, folder_name: str | None, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
folder.name = folder_name
db_session.commit()
def add_chat_to_folder(
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
chat_session.folder_id = folder.id
db_session.commit()
def remove_chat_from_folder(
user_id: UUID | None, folder_id: int, chat_session: ChatSession, db_session: Session
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
if chat_session.folder_id != folder.id:
raise ValueError("The chat session is not in the specified folder.")
if folder.user_id != user_id:
raise ValueError(
f"Tried to remove a chat session from a folder that does not below to "
f"this user, user id: {user_id}"
)
chat_session.folder_id = None
if chat_session in folder.chat_sessions:
folder.chat_sessions.remove(chat_session)
db_session.commit()
def delete_folder(
user_id: UUID | None,
folder_id: int,
including_chats: bool,
db_session: Session,
) -> None:
folder = get_folder_by_id(
user_id=user_id, folder_id=folder_id, db_session=db_session
)
# Assuming there will not be a massive number of chats in any given folder
if including_chats:
for chat_session in folder.chat_sessions:
delete_chat_session(
user_id=user_id,
chat_session_id=chat_session.id,
db_session=db_session,
)
db_session.delete(folder)
db_session.commit()

View File

@ -76,6 +76,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
chat_sessions: Mapped[List["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
)
chat_folders: Mapped[List["ChatFolder"]] = relationship(
"ChatFolder", back_populates="user"
)
prompts: Mapped[List["Prompt"]] = relationship("Prompt", back_populates="user")
# Personas owned by this user
personas: Mapped[List["Persona"]] = relationship("Persona", back_populates="user")
@ -572,6 +575,9 @@ class ChatSession(Base):
Enum(ChatSessionSharedStatus, native_enum=False),
default=ChatSessionSharedStatus.PRIVATE,
)
folder_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_folder.id"), nullable=True
)
# the latest "overrides" specified by the user. These take precedence over
# the attached persona. However, overrides specified directly in the
@ -596,6 +602,9 @@ class ChatSession(Base):
)
user: Mapped[User] = relationship("User", back_populates="chat_sessions")
folder: Mapped["ChatFolder"] = relationship(
"ChatFolder", back_populates="chat_sessions"
)
messages: Mapped[List["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="delete"
)
@ -656,6 +665,31 @@ class ChatMessage(Base):
)
class ChatFolder(Base):
"""For organizing chat sessions"""
__tablename__ = "chat_folder"
id: Mapped[int] = mapped_column(primary_key=True)
# Only null if auth is off
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
name: Mapped[str | None] = mapped_column(String, nullable=True)
display_priority: Mapped[int] = mapped_column(Integer, nullable=True, default=0)
user: Mapped[User] = relationship("User", back_populates="chat_folders")
chat_sessions: Mapped[List["ChatSession"]] = relationship(
"ChatSession", back_populates="folder"
)
def __lt__(self, other: Any) -> bool:
if not isinstance(other, ChatFolder):
return NotImplemented
if self.display_priority == other.display_priority:
# Bigger ID (created later) show earlier
return self.id > other.id
return self.display_priority < other.display_priority
"""
Feedback, Logging, Metrics Tables
"""

View File

@ -58,6 +58,7 @@ from danswer.server.documents.connector import router as connector_router
from danswer.server.documents.credential import router as credential_router
from danswer.server.documents.document import router as document_router
from danswer.server.features.document_set.api import router as document_set_router
from danswer.server.features.folder.api import router as folder_router
from danswer.server.features.persona.api import admin_router as admin_persona_router
from danswer.server.features.persona.api import basic_router as persona_router
from danswer.server.features.prompt.api import basic_router as prompt_router
@ -261,6 +262,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, connector_router)
include_router_with_global_prefix_prepended(application, credential_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, folder_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, secondary_index_router)
include_router_with_global_prefix_prepended(

View File

@ -0,0 +1,171 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Path
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.chat import get_chat_session_by_id
from danswer.db.engine import get_session
from danswer.db.folder import add_chat_to_folder
from danswer.db.folder import create_folder
from danswer.db.folder import delete_folder
from danswer.db.folder import get_user_folders
from danswer.db.folder import remove_chat_from_folder
from danswer.db.folder import rename_folder
from danswer.db.folder import update_folder_display_priority
from danswer.db.models import User
from danswer.server.features.folder.models import DeleteFolderOptions
from danswer.server.features.folder.models import FolderChatMinimalInfo
from danswer.server.features.folder.models import FolderChatSessionRequest
from danswer.server.features.folder.models import FolderCreationRequest
from danswer.server.features.folder.models import FolderResponse
from danswer.server.features.folder.models import FolderUpdateRequest
from danswer.server.features.folder.models import GetUserFoldersResponse
from danswer.server.models import DisplayPriorityRequest
router = APIRouter(prefix="/folder")
@router.get("")
def get_folders(
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> GetUserFoldersResponse:
folders = get_user_folders(
user_id=user.id if user else None,
db_session=db_session,
)
folders.sort()
return GetUserFoldersResponse(
folders=[
FolderResponse(
folder_id=folder.id,
folder_name=folder.name,
display_priority=folder.display_priority,
chat_sessions=[
FolderChatMinimalInfo(
chat_session_id=chat_session.id,
chat_session_name=chat_session.description,
)
for chat_session in folder.chat_sessions
],
)
for folder in folders
]
)
@router.put("/reorder")
def put_folder_display_priority(
display_priority_request: DisplayPriorityRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_folder_display_priority(
user_id=user.id if user else None,
display_priority_map=display_priority_request.display_priority_map,
db_session=db_session,
)
@router.post("")
def create_folder_endpoint(
request: FolderCreationRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> int:
return create_folder(
user_id=user.id if user else None,
folder_name=request.folder_name,
db_session=db_session,
)
@router.patch("/{folder_id}")
def patch_folder_endpoint(
request: FolderUpdateRequest,
folder_id: int = Path(..., description="The ID of the folder to rename"),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
rename_folder(
user_id=user.id if user else None,
folder_id=folder_id,
folder_name=request.folder_name,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.delete("/{folder_id}")
def delete_folder_endpoint(
request: DeleteFolderOptions,
folder_id: int = Path(..., description="The ID of the folder to delete"),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None
try:
delete_folder(
user_id=user_id,
folder_id=folder_id,
including_chats=request.including_chats,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/{folder_id}/add-chat-session")
def add_chat_to_folder_endpoint(
request: FolderChatSessionRequest,
folder_id: int = Path(
..., description="The ID of the folder in which to add the chat session"
),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=request.chat_session_id,
user_id=user_id,
db_session=db_session,
)
add_chat_to_folder(
user_id=user.id if user else None,
folder_id=folder_id,
chat_session=chat_session,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@router.post("/{folder_id}/remove-chat-session/")
def remove_chat_from_folder_endpoint(
request: FolderChatSessionRequest,
folder_id: int = Path(
..., description="The ID of the folder from which to remove the chat session"
),
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=request.chat_session_id,
user_id=user_id,
db_session=db_session,
)
remove_chat_from_folder(
user_id=user_id,
folder_id=folder_id,
chat_session=chat_session,
db_session=db_session,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@ -0,0 +1,33 @@
from pydantic import BaseModel
class FolderChatMinimalInfo(BaseModel):
chat_session_id: int
chat_session_name: str
class FolderResponse(BaseModel):
folder_id: int
folder_name: str | None
display_priority: int
chat_sessions: list[FolderChatMinimalInfo]
class GetUserFoldersResponse(BaseModel):
folders: list[FolderResponse]
class FolderCreationRequest(BaseModel):
folder_name: str | None = None
class FolderUpdateRequest(BaseModel):
folder_name: str | None
class FolderChatSessionRequest(BaseModel):
chat_session_id: int
class DeleteFolderOptions(BaseModel):
including_chats: bool = False

View File

@ -17,6 +17,7 @@ from danswer.llm.answering.prompts.utils import build_dummy_prompt
from danswer.server.features.persona.models import CreatePersonaRequest
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.features.persona.models import PromptTemplateResponse
from danswer.server.models import DisplayPriorityRequest
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -44,11 +45,6 @@ def patch_persona_visibility(
)
class DisplayPriorityRequest(BaseModel):
# maps persona id to display priority
display_priority_map: dict[int, int]
@admin_router.put("/display-priority")
def patch_persona_display_priority(
display_priority_request: DisplayPriorityRequest,

View File

@ -27,3 +27,7 @@ class IdReturn(BaseModel):
class MinimalUserSnapshot(BaseModel):
id: UUID
email: str
class DisplayPriorityRequest(BaseModel):
display_priority_map: dict[int, int]