mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 13:15:18 +02:00
411 lines
12 KiB
Python
411 lines
12 KiB
Python
from collections.abc import Sequence
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import and_
|
|
from sqlalchemy import delete
|
|
from sqlalchemy import func
|
|
from sqlalchemy import not_
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import NoResultFound
|
|
from sqlalchemy.orm import selectinload
|
|
from sqlalchemy.orm import Session
|
|
|
|
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
|
from danswer.configs.constants import MessageType
|
|
from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
|
from danswer.db.models import ChatMessage
|
|
from danswer.db.models import ChatSession
|
|
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
|
from danswer.db.models import Persona
|
|
from danswer.db.models import ToolInfo
|
|
|
|
|
|
def fetch_chat_sessions_by_user(
|
|
user_id: UUID | None,
|
|
deleted: bool | None,
|
|
db_session: Session,
|
|
) -> list[ChatSession]:
|
|
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
|
|
|
if deleted is not None:
|
|
stmt = stmt.where(ChatSession.deleted == deleted)
|
|
|
|
result = db_session.execute(stmt)
|
|
chat_sessions = result.scalars().all()
|
|
|
|
return list(chat_sessions)
|
|
|
|
|
|
def fetch_chat_messages_by_session(
|
|
chat_session_id: int, db_session: Session
|
|
) -> list[ChatMessage]:
|
|
stmt = (
|
|
select(ChatMessage)
|
|
.where(ChatMessage.chat_session_id == chat_session_id)
|
|
.order_by(ChatMessage.message_number.asc(), ChatMessage.edit_number.asc())
|
|
)
|
|
result = db_session.execute(stmt).scalars().all()
|
|
return list(result)
|
|
|
|
|
|
def fetch_chat_message(
|
|
chat_session_id: int, message_number: int, edit_number: int, db_session: Session
|
|
) -> ChatMessage:
|
|
stmt = (
|
|
select(ChatMessage)
|
|
.where(
|
|
(ChatMessage.chat_session_id == chat_session_id)
|
|
& (ChatMessage.message_number == message_number)
|
|
& (ChatMessage.edit_number == edit_number)
|
|
)
|
|
.options(selectinload(ChatMessage.chat_session))
|
|
)
|
|
|
|
chat_message = db_session.execute(stmt).scalar_one_or_none()
|
|
|
|
if not chat_message:
|
|
raise ValueError("Invalid Chat Message specified")
|
|
|
|
return chat_message
|
|
|
|
|
|
def fetch_chat_session_by_id(chat_session_id: int, db_session: Session) -> ChatSession:
|
|
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
|
|
result = db_session.execute(stmt)
|
|
chat_session = result.scalar_one_or_none()
|
|
|
|
if not chat_session:
|
|
raise ValueError("Invalid Chat Session ID provided")
|
|
|
|
return chat_session
|
|
|
|
|
|
def verify_parent_exists(
|
|
chat_session_id: int,
|
|
message_number: int,
|
|
parent_edit_number: int | None,
|
|
db_session: Session,
|
|
) -> ChatMessage:
|
|
stmt = select(ChatMessage).where(
|
|
(ChatMessage.chat_session_id == chat_session_id)
|
|
& (ChatMessage.message_number == message_number - 1)
|
|
& (ChatMessage.edit_number == parent_edit_number)
|
|
)
|
|
|
|
result = db_session.execute(stmt)
|
|
|
|
try:
|
|
return result.scalar_one()
|
|
except NoResultFound:
|
|
raise ValueError("Invalid message, parent message not found")
|
|
|
|
|
|
def create_chat_session(
|
|
db_session: Session,
|
|
description: str,
|
|
user_id: UUID | None,
|
|
persona_id: int | None = None,
|
|
) -> ChatSession:
|
|
chat_session = ChatSession(
|
|
user_id=user_id,
|
|
persona_id=persona_id,
|
|
description=description,
|
|
)
|
|
|
|
db_session.add(chat_session)
|
|
db_session.commit()
|
|
|
|
return chat_session
|
|
|
|
|
|
def update_chat_session(
|
|
user_id: UUID | None, chat_session_id: int, description: str, db_session: Session
|
|
) -> ChatSession:
|
|
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
|
|
|
if chat_session.deleted:
|
|
raise ValueError("Trying to rename a deleted chat session")
|
|
|
|
if user_id != chat_session.user_id:
|
|
raise ValueError("User trying to update chat of another user.")
|
|
|
|
chat_session.description = description
|
|
|
|
db_session.commit()
|
|
|
|
return chat_session
|
|
|
|
|
|
def delete_chat_session(
|
|
user_id: UUID | None,
|
|
chat_session_id: int,
|
|
db_session: Session,
|
|
hard_delete: bool = HARD_DELETE_CHATS,
|
|
) -> None:
|
|
chat_session = fetch_chat_session_by_id(chat_session_id, db_session)
|
|
|
|
if user_id != chat_session.user_id:
|
|
raise ValueError("User trying to delete chat of another user.")
|
|
|
|
if hard_delete:
|
|
stmt_messages = delete(ChatMessage).where(
|
|
ChatMessage.chat_session_id == chat_session_id
|
|
)
|
|
db_session.execute(stmt_messages)
|
|
|
|
stmt = delete(ChatSession).where(ChatSession.id == chat_session_id)
|
|
db_session.execute(stmt)
|
|
|
|
else:
|
|
chat_session.deleted = True
|
|
|
|
db_session.commit()
|
|
|
|
|
|
def _set_latest_chat_message_no_commit(
|
|
chat_session_id: int,
|
|
message_number: int,
|
|
parent_edit_number: int | None,
|
|
edit_number: int,
|
|
db_session: Session,
|
|
) -> None:
|
|
if message_number != 0 and parent_edit_number is None:
|
|
raise ValueError(
|
|
"Only initial message in a chat is allowed to not have a parent"
|
|
)
|
|
|
|
db_session.query(ChatMessage).filter(
|
|
and_(
|
|
ChatMessage.chat_session_id == chat_session_id,
|
|
ChatMessage.message_number == message_number,
|
|
ChatMessage.parent_edit_number == parent_edit_number,
|
|
)
|
|
).update({ChatMessage.latest: False})
|
|
|
|
db_session.query(ChatMessage).filter(
|
|
and_(
|
|
ChatMessage.chat_session_id == chat_session_id,
|
|
ChatMessage.message_number == message_number,
|
|
ChatMessage.edit_number == edit_number,
|
|
)
|
|
).update({ChatMessage.latest: True})
|
|
|
|
|
|
def create_new_chat_message(
|
|
chat_session_id: int,
|
|
message_number: int,
|
|
message: str,
|
|
token_count: int,
|
|
parent_edit_number: int | None,
|
|
message_type: MessageType,
|
|
db_session: Session,
|
|
retrieval_docs: dict[str, Any] | None = None,
|
|
) -> ChatMessage:
|
|
"""Creates a new chat message and sets it to the latest message of its parent message"""
|
|
# Get the count of existing edits at the provided message number
|
|
latest_edit_number = (
|
|
db_session.query(func.max(ChatMessage.edit_number))
|
|
.filter_by(
|
|
chat_session_id=chat_session_id,
|
|
message_number=message_number,
|
|
)
|
|
.scalar()
|
|
)
|
|
|
|
# The new message is a new edit at the provided message number
|
|
new_edit_number = latest_edit_number + 1 if latest_edit_number is not None else 0
|
|
|
|
# Create a new message and set it to be the latest for its parent message
|
|
new_chat_message = ChatMessage(
|
|
chat_session_id=chat_session_id,
|
|
message_number=message_number,
|
|
parent_edit_number=parent_edit_number,
|
|
edit_number=new_edit_number,
|
|
message=message,
|
|
reference_docs=retrieval_docs,
|
|
token_count=token_count,
|
|
message_type=message_type,
|
|
)
|
|
|
|
db_session.add(new_chat_message)
|
|
|
|
# Set the previous latest message of the same parent, as no longer the latest
|
|
_set_latest_chat_message_no_commit(
|
|
chat_session_id=chat_session_id,
|
|
message_number=message_number,
|
|
parent_edit_number=parent_edit_number,
|
|
edit_number=new_edit_number,
|
|
db_session=db_session,
|
|
)
|
|
|
|
db_session.commit()
|
|
|
|
return new_chat_message
|
|
|
|
|
|
def set_latest_chat_message(
|
|
chat_session_id: int,
|
|
message_number: int,
|
|
parent_edit_number: int | None,
|
|
edit_number: int,
|
|
db_session: Session,
|
|
) -> None:
|
|
_set_latest_chat_message_no_commit(
|
|
chat_session_id=chat_session_id,
|
|
message_number=message_number,
|
|
parent_edit_number=parent_edit_number,
|
|
edit_number=edit_number,
|
|
db_session=db_session,
|
|
)
|
|
|
|
db_session.commit()
|
|
|
|
|
|
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
|
stmt = (
|
|
select(Persona)
|
|
.where(Persona.id == persona_id)
|
|
.where(Persona.deleted == False) # noqa: E712
|
|
)
|
|
result = db_session.execute(stmt)
|
|
persona = result.scalar_one_or_none()
|
|
|
|
if persona is None:
|
|
raise ValueError(f"Persona with ID {persona_id} does not exist")
|
|
|
|
return persona
|
|
|
|
|
|
def fetch_default_persona_by_name(
|
|
persona_name: str, db_session: Session
|
|
) -> Persona | None:
|
|
stmt = (
|
|
select(Persona)
|
|
.where(
|
|
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
|
|
)
|
|
.where(Persona.deleted == False) # noqa: E712
|
|
)
|
|
result = db_session.execute(stmt).scalar_one_or_none()
|
|
return result
|
|
|
|
|
|
def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | None:
|
|
"""Try to fetch a default persona by name first,
|
|
if not exist, try to find any persona with the name
|
|
Note that name is not guaranteed unique unless default is true"""
|
|
persona = fetch_default_persona_by_name(persona_name, db_session)
|
|
if persona is not None:
|
|
return persona
|
|
|
|
stmt = (
|
|
select(Persona)
|
|
.where(Persona.name == persona_name)
|
|
.where(Persona.deleted == False) # noqa: E712
|
|
)
|
|
result = db_session.execute(stmt).first()
|
|
if result:
|
|
return result[0]
|
|
return None
|
|
|
|
|
|
def upsert_persona(
|
|
db_session: Session,
|
|
name: str,
|
|
retrieval_enabled: bool,
|
|
datetime_aware: bool,
|
|
description: str | None = None,
|
|
system_text: str | None = None,
|
|
tools: list[ToolInfo] | None = None,
|
|
hint_text: str | None = None,
|
|
num_chunks: int | None = None,
|
|
apply_llm_relevance_filter: bool | None = None,
|
|
persona_id: int | None = None,
|
|
default_persona: bool = False,
|
|
document_sets: list[DocumentSetDBModel] | None = None,
|
|
llm_model_version_override: str | None = None,
|
|
commit: bool = True,
|
|
overwrite_duplicate_named_persona: bool = False,
|
|
) -> Persona:
|
|
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
|
if persona and persona.deleted:
|
|
raise ValueError("Trying to update a deleted persona")
|
|
|
|
# Default personas are defined via yaml files at deployment time
|
|
if persona is None:
|
|
if default_persona:
|
|
persona = fetch_default_persona_by_name(name, db_session)
|
|
else:
|
|
# only one persona with the same name should exist
|
|
persona_with_same_name = fetch_persona_by_name(name, db_session)
|
|
if persona_with_same_name and not overwrite_duplicate_named_persona:
|
|
raise ValueError("Trying to create a persona with a duplicate name")
|
|
|
|
# set "existing" persona to the one with the same name so we can override it
|
|
persona = persona_with_same_name
|
|
|
|
if persona:
|
|
persona.name = name
|
|
persona.description = description
|
|
persona.retrieval_enabled = retrieval_enabled
|
|
persona.datetime_aware = datetime_aware
|
|
persona.system_text = system_text
|
|
persona.tools = tools
|
|
persona.hint_text = hint_text
|
|
persona.num_chunks = num_chunks
|
|
persona.apply_llm_relevance_filter = apply_llm_relevance_filter
|
|
persona.default_persona = default_persona
|
|
persona.llm_model_version_override = llm_model_version_override
|
|
|
|
# Do not delete any associations manually added unless
|
|
# a new updated list is provided
|
|
if document_sets is not None:
|
|
persona.document_sets.clear()
|
|
persona.document_sets = document_sets
|
|
|
|
else:
|
|
persona = Persona(
|
|
name=name,
|
|
description=description,
|
|
retrieval_enabled=retrieval_enabled,
|
|
datetime_aware=datetime_aware,
|
|
system_text=system_text,
|
|
tools=tools,
|
|
hint_text=hint_text,
|
|
num_chunks=num_chunks,
|
|
apply_llm_relevance_filter=apply_llm_relevance_filter,
|
|
default_persona=default_persona,
|
|
document_sets=document_sets if document_sets else [],
|
|
llm_model_version_override=llm_model_version_override,
|
|
)
|
|
db_session.add(persona)
|
|
|
|
if commit:
|
|
db_session.commit()
|
|
else:
|
|
# flush the session so that the persona has an ID
|
|
db_session.flush()
|
|
|
|
return persona
|
|
|
|
|
|
def fetch_personas(
|
|
db_session: Session,
|
|
include_default: bool = False,
|
|
include_slack_bot_personas: bool = False,
|
|
) -> Sequence[Persona]:
|
|
stmt = select(Persona).where(Persona.deleted == False) # noqa: E712
|
|
if not include_default:
|
|
stmt = stmt.where(Persona.default_persona == False) # noqa: E712
|
|
if not include_slack_bot_personas:
|
|
stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX)))
|
|
|
|
return db_session.scalars(stmt).all()
|
|
|
|
|
|
def mark_persona_as_deleted(db_session: Session, persona_id: int) -> None:
|
|
persona = fetch_persona_by_id(persona_id, db_session)
|
|
persona.deleted = True
|
|
db_session.commit()
|