rkuo-danswer 61e8f371b9
fix blowing up the entire task on exception and trying to reuse an in… (#4179)
* fix blowing up the entire task on exception and trying to reuse an invalid db session

* list comprehension

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-04 00:57:27 +00:00

1092 lines
36 KiB
Python

from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from typing import Tuple
from uuid import UUID
from fastapi import HTTPException
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import Row
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.agents.agent_search.shared_graph_utils.models import CombinedAgentMetrics
from onyx.agents.agent_search.shared_graph_utils.models import (
SubQuestionAnswerResults,
)
from onyx.auth.schemas import UserRole
from onyx.chat.models import DocumentRelevance
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import MessageType
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDocs
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc as ServerSearchDoc
from onyx.context.search.utils import chunks_or_sections_to_search_docs
from onyx.db.models import AgentSearchMetrics
from onyx.db.models import AgentSubQuery
from onyx.db.models import AgentSubQuestion
from onyx.db.models import ChatMessage
from onyx.db.models import ChatMessage__SearchDoc
from onyx.db.models import ChatSession
from onyx.db.models import ChatSessionSharedStatus
from onyx.db.models import Prompt
from onyx.db.models import SearchDoc
from onyx.db.models import SearchDoc as DBSearchDoc
from onyx.db.models import ToolCall
from onyx.db.models import User
from onyx.db.persona import get_best_persona_id_for_user
from onyx.db.pg_file_store import delete_lobj_by_name
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.server.query_and_chat.models import ChatMessageDetail
from onyx.server.query_and_chat.models import SubQueryDetail
from onyx.server.query_and_chat.models import SubQuestionDetail
from onyx.tools.tool_runner import ToolCallFinalResult
from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
logger = setup_logger()
def get_chat_session_by_id(
chat_session_id: UUID,
user_id: UUID | None,
db_session: Session,
include_deleted: bool = False,
is_shared: bool = False,
) -> ChatSession:
stmt = select(ChatSession).where(ChatSession.id == chat_session_id)
if is_shared:
stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC)
else:
# if user_id is None, assume this is an admin who should be able
# to view all chat sessions
if user_id is not None:
stmt = stmt.where(
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
)
result = db_session.execute(stmt)
chat_session = result.scalar_one_or_none()
if not chat_session:
raise ValueError("Invalid Chat Session ID provided")
if not include_deleted and chat_session.deleted:
raise ValueError("Chat session has been deleted")
return chat_session
def get_chat_sessions_by_slack_thread_id(
slack_thread_id: str,
user_id: UUID | None,
db_session: Session,
) -> Sequence[ChatSession]:
stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id)
if user_id is not None:
stmt = stmt.where(
or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None))
)
return db_session.scalars(stmt).all()
def get_valid_messages_from_query_sessions(
chat_session_ids: list[UUID],
db_session: Session,
) -> dict[UUID, str]:
user_message_subquery = (
select(
ChatMessage.chat_session_id, func.min(ChatMessage.id).label("user_msg_id")
)
.where(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.USER,
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
assistant_message_subquery = (
select(
ChatMessage.chat_session_id,
func.min(ChatMessage.id).label("assistant_msg_id"),
)
.where(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.ASSISTANT,
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
query = (
select(ChatMessage.chat_session_id, ChatMessage.message)
.join(
user_message_subquery,
ChatMessage.chat_session_id == user_message_subquery.c.chat_session_id,
)
.join(
assistant_message_subquery,
ChatMessage.chat_session_id == assistant_message_subquery.c.chat_session_id,
)
.join(
ChatMessage__SearchDoc,
ChatMessage__SearchDoc.chat_message_id
== assistant_message_subquery.c.assistant_msg_id,
)
.where(ChatMessage.id == user_message_subquery.c.user_msg_id)
)
first_messages = db_session.execute(query).all()
logger.info(f"Retrieved {len(first_messages)} first messages with documents")
return {row.chat_session_id: row.message for row in first_messages}
# Retrieves chat sessions by user
# Chat sessions do not include onyxbot flows
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_onyxbot_flows: bool = False,
limit: int = 50,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_updated))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
if limit:
stmt = stmt.limit(limit)
result = db_session.execute(stmt)
chat_sessions = result.scalars().all()
return list(chat_sessions)
def delete_search_doc_message_relationship(
message_id: int, db_session: Session
) -> None:
db_session.query(ChatMessage__SearchDoc).filter(
ChatMessage__SearchDoc.chat_message_id == message_id
).delete(synchronize_session=False)
db_session.commit()
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
db_session.commit()
def delete_orphaned_search_docs(db_session: Session) -> None:
orphaned_docs = (
db_session.query(SearchDoc)
.outerjoin(ChatMessage__SearchDoc)
.filter(ChatMessage__SearchDoc.chat_message_id.is_(None))
.all()
)
for doc in orphaned_docs:
db_session.delete(doc)
db_session.commit()
def delete_messages_and_files_from_chat_session(
chat_session_id: UUID, db_session: Session
) -> None:
# Select messages older than cutoff_time with files
messages_with_files = db_session.execute(
select(ChatMessage.id, ChatMessage.files).where(
ChatMessage.chat_session_id == chat_session_id,
)
).fetchall()
for id, files in messages_with_files:
delete_tool_call_for_message_id(message_id=id, db_session=db_session)
delete_search_doc_message_relationship(message_id=id, db_session=db_session)
for file_info in files or {}:
lobj_name = file_info.get("id")
if lobj_name:
logger.info(f"Deleting file with name: {lobj_name}")
delete_lobj_by_name(lobj_name, db_session)
db_session.execute(
delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
)
db_session.commit()
delete_orphaned_search_docs(db_session)
def create_chat_session(
db_session: Session,
description: str | None,
user_id: UUID | None,
persona_id: int | None, # Can be none if temporary persona is used
llm_override: LLMOverride | None = None,
prompt_override: PromptOverride | None = None,
onyxbot_flow: bool = False,
slack_thread_id: str | None = None,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
persona_id=persona_id,
description=description,
llm_override=llm_override,
prompt_override=prompt_override,
onyxbot_flow=onyxbot_flow,
slack_thread_id=slack_thread_id,
)
db_session.add(chat_session)
db_session.commit()
return chat_session
def duplicate_chat_session_for_user_from_slack(
db_session: Session,
user: User | None,
chat_session_id: UUID,
) -> ChatSession:
"""
This takes a chat session id for a session in Slack and:
- Creates a new chat session in the DB
- Tries to copy the persona from the original chat session
(if it is available to the user clicking the button)
- Sets the user to the given user (if provided)
"""
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=None, # Ignore user permissions for this
db_session=db_session,
)
if not chat_session:
raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided")
# This enforces permissions and sets a default
new_persona_id = get_best_persona_id_for_user(
db_session=db_session,
user=user,
persona_id=chat_session.persona_id,
)
return create_chat_session(
db_session=db_session,
user_id=user.id if user else None,
persona_id=new_persona_id,
# Set this to empty string so the frontend will force a rename
description="",
llm_override=chat_session.llm_override,
prompt_override=chat_session.prompt_override,
# Chat is in UI now so this is false
onyxbot_flow=False,
# Maybe we want this in the future to track if it was created from Slack
slack_thread_id=None,
)
def update_chat_session(
db_session: Session,
user_id: UUID | None,
chat_session_id: UUID,
description: str | None = None,
sharing_status: ChatSessionSharedStatus | None = None,
) -> ChatSession:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
if chat_session.deleted:
raise ValueError("Trying to rename a deleted chat session")
if description is not None:
chat_session.description = description
if sharing_status is not None:
chat_session.shared_status = sharing_status
db_session.commit()
return chat_session
def delete_all_chat_sessions_for_user(
user: User | None, db_session: Session, hard_delete: bool = HARD_DELETE_CHATS
) -> None:
user_id = user.id if user is not None else None
query = db_session.query(ChatSession).filter(
ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)
)
if hard_delete:
query.delete(synchronize_session=False)
else:
query.update({ChatSession.deleted: True}, synchronize_session=False)
db_session.commit()
def delete_chat_session(
user_id: UUID | None,
chat_session_id: UUID,
db_session: Session,
include_deleted: bool = False,
hard_delete: bool = HARD_DELETE_CHATS,
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
user_id=user_id,
db_session=db_session,
include_deleted=include_deleted,
)
if chat_session.deleted and not include_deleted:
raise ValueError("Cannot delete an already deleted chat session")
if hard_delete:
delete_messages_and_files_from_chat_session(chat_session_id, db_session)
db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id))
else:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
chat_session.deleted = True
db_session.commit()
def get_chat_sessions_older_than(
days_old: int, db_session: Session
) -> list[tuple[UUID | None, UUID]]:
"""
Retrieves chat sessions older than a specified number of days.
Args:
days_old: The number of days to consider as "old".
db_session: The database session.
Returns:
A list of tuples, where each tuple contains the user_id (can be None) and the chat_session_id of an old chat session.
"""
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
old_sessions: Sequence[Row[Tuple[UUID | None, UUID]]] = db_session.execute(
select(ChatSession.user_id, ChatSession.id).where(
ChatSession.time_created < cutoff_time
)
).fetchall()
# convert old_sessions to a conventional list of tuples
returned_sessions: list[tuple[UUID | None, UUID]] = [
(user_id, session_id) for user_id, session_id in old_sessions
]
return returned_sessions
def get_chat_message(
chat_message_id: int,
user_id: UUID | None,
db_session: Session,
) -> ChatMessage:
stmt = select(ChatMessage).where(ChatMessage.id == chat_message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if not chat_message:
raise ValueError("Invalid Chat Message specified")
chat_user = chat_message.chat_session.user
expected_user_id = chat_user.id if chat_user is not None else None
if expected_user_id != user_id:
logger.error(
f"User {user_id} tried to fetch a chat message that does not belong to them"
)
raise ValueError("Chat message does not belong to user")
return chat_message
def get_chat_session_by_message_id(
db_session: Session,
message_id: int,
) -> ChatSession:
"""
Should only be used for Slack
Get the chat session associated with a specific message ID
Note: this ignores permission checks.
"""
stmt = select(ChatMessage).where(ChatMessage.id == message_id)
result = db_session.execute(stmt)
chat_message = result.scalar_one_or_none()
if chat_message is None:
raise ValueError(
f"Unable to find chat session associated with message ID: {message_id}"
)
return chat_message.chat_session
def get_chat_messages_by_sessions(
chat_session_ids: list[UUID],
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
) -> Sequence[ChatMessage]:
if not skip_permission_check:
for chat_session_id in chat_session_ids:
get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id.in_(chat_session_ids))
.order_by(nullsfirst(ChatMessage.parent_message))
)
return db_session.execute(stmt).scalars().all()
def add_chats_to_session_from_slack_thread(
db_session: Session,
slack_chat_session_id: UUID,
new_chat_session_id: UUID,
) -> None:
new_root_message = get_or_create_root_message(
chat_session_id=new_chat_session_id,
db_session=db_session,
)
for chat_message in get_chat_messages_by_sessions(
chat_session_ids=[slack_chat_session_id],
user_id=None, # Ignore user permissions for this
db_session=db_session,
skip_permission_check=True,
):
if chat_message.message_type == MessageType.SYSTEM:
continue
# Duplicate the message
new_root_message = create_new_chat_message(
db_session=db_session,
chat_session_id=new_chat_session_id,
parent_message=new_root_message,
message=chat_message.message,
files=chat_message.files,
rephrased_query=chat_message.rephrased_query,
error=chat_message.error,
citations=chat_message.citations,
reference_docs=chat_message.search_docs,
tool_call=chat_message.tool_call,
prompt_id=chat_message.prompt_id,
token_count=chat_message.token_count,
message_type=chat_message.message_type,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:
stmt = (
select(SearchDoc)
.join(
ChatMessage__SearchDoc, ChatMessage__SearchDoc.search_doc_id == SearchDoc.id
)
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
)
return list(db_session.scalars(stmt).all())
def get_chat_messages_by_session(
chat_session_id: UUID,
user_id: UUID | None,
db_session: Session,
skip_permission_check: bool = False,
prefetch_tool_calls: bool = False,
) -> list[ChatMessage]:
if not skip_permission_check:
# bug if we ever call this expecting the permission check to not be skipped
get_chat_session_by_id(
chat_session_id=chat_session_id, user_id=user_id, db_session=db_session
)
stmt = (
select(ChatMessage)
.where(ChatMessage.chat_session_id == chat_session_id)
.order_by(nullsfirst(ChatMessage.parent_message))
)
if prefetch_tool_calls:
stmt = stmt.options(
joinedload(ChatMessage.tool_call),
joinedload(ChatMessage.sub_questions).joinedload(
AgentSubQuestion.sub_queries
),
)
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
return list(result)
def get_or_create_root_message(
chat_session_id: UUID,
db_session: Session,
) -> ChatMessage:
try:
root_message: ChatMessage | None = (
db_session.query(ChatMessage)
.filter(
ChatMessage.chat_session_id == chat_session_id,
ChatMessage.parent_message.is_(None),
)
.one_or_none()
)
except MultipleResultsFound:
raise Exception(
"Multiple root messages found for chat session. Data inconsistency detected."
)
if root_message is not None:
return root_message
else:
new_root_message = ChatMessage(
chat_session_id=chat_session_id,
prompt_id=None,
parent_message=None,
latest_child_message=None,
message="",
token_count=0,
message_type=MessageType.SYSTEM,
)
db_session.add(new_root_message)
db_session.commit()
return new_root_message
def reserve_message_id(
db_session: Session,
chat_session_id: UUID,
parent_message: int,
message_type: MessageType,
) -> int:
# Create an empty chat message
empty_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message,
latest_child_message=None,
message="",
token_count=0,
message_type=message_type,
)
# Add the empty message to the session
db_session.add(empty_message)
# Flush the session to get an ID for the new chat message
db_session.flush()
# Get the ID of the newly created message
new_id = empty_message.id
return new_id
def create_new_chat_message(
chat_session_id: UUID,
parent_message: ChatMessage,
message: str,
prompt_id: int | None,
token_count: int,
message_type: MessageType,
db_session: Session,
files: list[FileDescriptor] | None = None,
rephrased_query: str | None = None,
error: str | None = None,
reference_docs: list[DBSearchDoc] | None = None,
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_call: ToolCall | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
refined_answer_improvement: bool | None = None,
is_agentic: bool = False,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
if existing_message is None:
raise ValueError(f"No message found with id {reserved_message_id}")
existing_message.chat_session_id = chat_session_id
existing_message.parent_message = parent_message.id
existing_message.message = message
existing_message.rephrased_query = rephrased_query
existing_message.prompt_id = prompt_id
existing_message.token_count = token_count
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_call = tool_call
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
existing_message.refined_answer_improvement = refined_answer_improvement
existing_message.is_agentic = is_agentic
new_chat_message = existing_message
else:
# Create new message
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
tool_call=tool_call,
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement,
is_agentic=is_agentic,
)
db_session.add(new_chat_message)
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
if reference_docs:
new_chat_message.search_docs = reference_docs
# Flush the session to get an ID for the new chat message
db_session.flush()
parent_message.latest_child_message = new_chat_message.id
if commit:
db_session.commit()
return new_chat_message
def set_as_latest_chat_message(
chat_message: ChatMessage,
user_id: UUID | None,
db_session: Session,
) -> None:
parent_message_id = chat_message.parent_message
if parent_message_id is None:
raise RuntimeError(
f"Trying to set a latest message without parent, message id: {chat_message.id}"
)
parent_message = get_chat_message(
chat_message_id=parent_message_id, user_id=user_id, db_session=db_session
)
parent_message.latest_child_message = chat_message.id
db_session.commit()
def attach_files_to_chat_message(
chat_message: ChatMessage,
files: list[FileDescriptor],
db_session: Session,
commit: bool = True,
) -> None:
chat_message.files = files
if commit:
db_session.commit()
def get_prompt_by_id(
prompt_id: int,
user: User | None,
db_session: Session,
include_deleted: bool = False,
) -> Prompt:
stmt = select(Prompt).where(Prompt.id == prompt_id)
# if user is not specified OR they are an admin, they should
# have access to all prompts, so this where clause is not needed
if user and user.role != UserRole.ADMIN:
stmt = stmt.where(or_(Prompt.user_id == user.id, Prompt.user_id.is_(None)))
if not include_deleted:
stmt = stmt.where(Prompt.deleted.is_(False))
result = db_session.execute(stmt)
prompt = result.scalar_one_or_none()
if prompt is None:
raise ValueError(
f"Prompt with ID {prompt_id} does not exist or does not belong to user"
)
return prompt
def get_doc_query_identifiers_from_model(
search_doc_ids: list[int],
chat_session: ChatSession,
user_id: UUID | None,
db_session: Session,
enforce_chat_session_id_for_search_docs: bool,
) -> list[tuple[str, int]]:
"""Given a list of search_doc_ids"""
search_docs = (
db_session.query(SearchDoc).filter(SearchDoc.id.in_(search_doc_ids)).all()
)
if user_id != chat_session.user_id:
logger.error(
f"Docs referenced are from a chat session not belonging to user {user_id}"
)
raise ValueError("Docs references do not belong to user")
try:
if any(
[
doc.chat_messages[0].chat_session_id != chat_session.id
for doc in search_docs
]
):
if enforce_chat_session_id_for_search_docs:
raise ValueError("Invalid reference doc, not from this chat session.")
except IndexError:
# This happens when the doc has no chat_messages associated with it.
# which happens as an edge case where the chat message failed to save
# This usually happens when the LLM fails either immediately or partially through.
raise RuntimeError("Chat session failed, please start a new session.")
doc_query_identifiers = [(doc.document_id, doc.chunk_ind) for doc in search_docs]
return doc_query_identifiers
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[SearchDoc],
relevance_summary: DocumentRelevance,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
search_doc.document_id
)
if relevance_data is not None:
db_session.execute(
update(SearchDoc)
.where(SearchDoc.id == search_doc.id)
.values(
is_relevant=relevance_data.relevant,
relevance_explanation=relevance_data.content,
)
)
db_session.commit()
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
) -> SearchDoc:
db_search_doc = SearchDoc(
document_id=server_search_doc.document_id,
chunk_ind=server_search_doc.chunk_ind,
semantic_id=server_search_doc.semantic_identifier,
link=server_search_doc.link,
blurb=server_search_doc.blurb,
source_type=server_search_doc.source_type,
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
is_relevant=server_search_doc.is_relevant,
relevance_explanation=server_search_doc.relevance_explanation,
# For docs further down that aren't reranked, we can't use the retrieval score
score=server_search_doc.score or 0.0,
match_highlights=server_search_doc.match_highlights,
updated_at=server_search_doc.updated_at,
primary_owners=server_search_doc.primary_owners,
secondary_owners=server_search_doc.secondary_owners,
is_internet=server_search_doc.is_internet,
)
db_session.add(db_search_doc)
db_session.commit()
return db_search_doc
def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | None:
"""There are no safety checks here like user permission etc., use with caution"""
search_doc = db_session.query(SearchDoc).filter(SearchDoc.id == doc_id).first()
return search_doc
def translate_db_search_doc_to_server_search_doc(
db_search_doc: SearchDoc,
remove_doc_content: bool = False,
) -> SavedSearchDoc:
return SavedSearchDoc(
db_doc_id=db_search_doc.id,
document_id=db_search_doc.document_id,
chunk_ind=db_search_doc.chunk_ind,
semantic_identifier=db_search_doc.semantic_id,
link=db_search_doc.link,
blurb=db_search_doc.blurb if not remove_doc_content else "",
source_type=db_search_doc.source_type,
boost=db_search_doc.boost,
hidden=db_search_doc.hidden,
metadata=db_search_doc.doc_metadata if not remove_doc_content else {},
score=db_search_doc.score,
match_highlights=(
db_search_doc.match_highlights if not remove_doc_content else []
),
relevance_explanation=db_search_doc.relevance_explanation,
is_relevant=db_search_doc.is_relevant,
updated_at=db_search_doc.updated_at if not remove_doc_content else None,
primary_owners=db_search_doc.primary_owners if not remove_doc_content else [],
secondary_owners=(
db_search_doc.secondary_owners if not remove_doc_content else []
),
is_internet=db_search_doc.is_internet,
)
def translate_db_sub_questions_to_server_objects(
db_sub_questions: list[AgentSubQuestion],
) -> list[SubQuestionDetail]:
sub_questions = []
for sub_question in db_sub_questions:
sub_queries = []
docs: dict[str, SearchDoc] = {}
doc_results = cast(
list[dict[str, JSON_ro]], sub_question.sub_question_doc_results
)
verified_doc_ids = [x["document_id"] for x in doc_results]
for sub_query in sub_question.sub_queries:
doc_ids = [doc.id for doc in sub_query.search_docs]
sub_queries.append(
SubQueryDetail(
query=sub_query.sub_query,
query_id=sub_query.id,
doc_ids=doc_ids,
)
)
for doc in sub_query.search_docs:
docs[doc.document_id] = doc
verified_docs = [
docs[cast(str, doc_id)] for doc_id in verified_doc_ids if doc_id in docs
]
sub_questions.append(
SubQuestionDetail(
level=sub_question.level,
level_question_num=sub_question.level_question_num,
question=sub_question.sub_question,
answer=sub_question.sub_answer,
sub_queries=sub_queries,
context_docs=get_retrieval_docs_from_search_docs(
verified_docs, sort_by_score=False
),
)
)
return sub_questions
def get_retrieval_docs_from_search_docs(
search_docs: list[SearchDoc],
remove_doc_content: bool = False,
sort_by_score: bool = True,
) -> RetrievalDocs:
top_documents = [
translate_db_search_doc_to_server_search_doc(
db_doc, remove_doc_content=remove_doc_content
)
for db_doc in search_docs
]
if sort_by_score:
top_documents = sorted(top_documents, key=lambda doc: doc.score, reverse=True) # type: ignore
return RetrievalDocs(top_documents=top_documents)
def translate_db_message_to_chat_message_detail(
chat_message: ChatMessage,
remove_doc_content: bool = False,
) -> ChatMessageDetail:
chat_msg_detail = ChatMessageDetail(
chat_session_id=chat_message.chat_session_id,
message_id=chat_message.id,
parent_message=chat_message.parent_message,
latest_child_message=chat_message.latest_child_message,
message=chat_message.message,
rephrased_query=chat_message.rephrased_query,
context_docs=get_retrieval_docs_from_search_docs(
chat_message.search_docs, remove_doc_content=remove_doc_content
),
message_type=chat_message.message_type,
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
sub_questions=translate_db_sub_questions_to_server_objects(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
is_agentic=chat_message.is_agentic,
error=chat_message.error,
)
return chat_msg_detail
def log_agent_metrics(
db_session: Session,
user_id: UUID | None,
persona_id: int | None, # Can be none if temporary persona is used
agent_type: str,
start_time: datetime | None,
agent_metrics: CombinedAgentMetrics,
) -> AgentSearchMetrics:
agent_timings = agent_metrics.timings
agent_base_metrics = agent_metrics.base_metrics
agent_refined_metrics = agent_metrics.refined_metrics
agent_additional_metrics = agent_metrics.additional_metrics
agent_metric_tracking = AgentSearchMetrics(
user_id=user_id,
persona_id=persona_id,
agent_type=agent_type,
start_time=start_time,
base_duration_s=agent_timings.base_duration_s,
full_duration_s=agent_timings.full_duration_s,
base_metrics=vars(agent_base_metrics) if agent_base_metrics else None,
refined_metrics=vars(agent_refined_metrics) if agent_refined_metrics else None,
all_metrics=vars(agent_additional_metrics)
if agent_additional_metrics
else None,
)
db_session.add(agent_metric_tracking)
db_session.flush()
return agent_metric_tracking
def log_agent_sub_question_results(
db_session: Session,
chat_session_id: UUID | None,
primary_message_id: int | None,
sub_question_answer_results: list[SubQuestionAnswerResults],
) -> None:
def _create_citation_format_list(
document_citations: list[InferenceSection],
) -> list[dict[str, Any]]:
citation_list: list[dict[str, Any]] = []
for document_citation in document_citations:
document_citation_dict = {
"link": "",
"blurb": document_citation.center_chunk.blurb,
"content": document_citation.center_chunk.content,
"metadata": document_citation.center_chunk.metadata,
"updated_at": str(document_citation.center_chunk.updated_at),
"document_id": document_citation.center_chunk.document_id,
"source_type": "file",
"source_links": document_citation.center_chunk.source_links,
"match_highlights": document_citation.center_chunk.match_highlights,
"semantic_identifier": document_citation.center_chunk.semantic_identifier,
}
citation_list.append(document_citation_dict)
return citation_list
now = datetime.now()
for sub_question_answer_result in sub_question_answer_results:
level, level_question_num = [
int(x) for x in sub_question_answer_result.question_id.split("_")
]
sub_question = sub_question_answer_result.question
sub_answer = sub_question_answer_result.answer
sub_document_results = _create_citation_format_list(
sub_question_answer_result.context_documents
)
sub_question_object = AgentSubQuestion(
chat_session_id=chat_session_id,
primary_question_id=primary_message_id,
level=level,
level_question_num=level_question_num,
sub_question=sub_question,
sub_answer=sub_answer,
sub_question_doc_results=sub_document_results,
)
db_session.add(sub_question_object)
db_session.commit()
sub_question_id = sub_question_object.id
for sub_query in sub_question_answer_result.sub_query_retrieval_results:
sub_query_object = AgentSubQuery(
parent_question_id=sub_question_id,
chat_session_id=chat_session_id,
sub_query=sub_query.query,
time_created=now,
)
db_session.add(sub_query_object)
db_session.commit()
search_docs = chunks_or_sections_to_search_docs(
sub_query.retrieved_documents
)
for doc in search_docs:
db_doc = create_db_search_doc(doc, db_session)
db_session.add(db_doc)
sub_query_object.search_docs.append(db_doc)
db_session.commit()
return None