mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-04 03:59:25 +02:00
* Fix pagination for index attempts table * fixed index attempts pagination * fixed query history table * query clearnup * fixed test * fixed weird tests???
173 lines
5.4 KiB
Python
173 lines
5.4 KiB
Python
from collections.abc import Sequence
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import asc
|
|
from sqlalchemy import BinaryExpression
|
|
from sqlalchemy import ColumnElement
|
|
from sqlalchemy import desc
|
|
from sqlalchemy import distinct
|
|
from sqlalchemy.orm import contains_eager
|
|
from sqlalchemy.orm import joinedload
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql import case
|
|
from sqlalchemy.sql import func
|
|
from sqlalchemy.sql import select
|
|
from sqlalchemy.sql.expression import literal
|
|
from sqlalchemy.sql.expression import UnaryExpression
|
|
|
|
from onyx.configs.constants import QAFeedbackType
|
|
from onyx.db.models import ChatMessage
|
|
from onyx.db.models import ChatMessageFeedback
|
|
from onyx.db.models import ChatSession
|
|
|
|
|
|
def _build_filter_conditions(
|
|
start_time: datetime | None,
|
|
end_time: datetime | None,
|
|
feedback_filter: QAFeedbackType | None,
|
|
) -> list[ColumnElement]:
|
|
"""
|
|
Helper function to build all filter conditions for chat sessions.
|
|
Filters by start and end time, feedback type, and any sessions without messages.
|
|
start_time: Date from which to filter
|
|
end_time: Date to which to filter
|
|
feedback_filter: Feedback type to filter by
|
|
Returns: List of filter conditions
|
|
"""
|
|
conditions = []
|
|
|
|
if start_time is not None:
|
|
conditions.append(ChatSession.time_created >= start_time)
|
|
if end_time is not None:
|
|
conditions.append(ChatSession.time_created <= end_time)
|
|
|
|
if feedback_filter is not None:
|
|
feedback_subq = (
|
|
select(ChatMessage.chat_session_id)
|
|
.join(ChatMessageFeedback)
|
|
.group_by(ChatMessage.chat_session_id)
|
|
.having(
|
|
case(
|
|
(
|
|
case(
|
|
{literal(feedback_filter == QAFeedbackType.LIKE): True},
|
|
else_=False,
|
|
),
|
|
func.bool_and(ChatMessageFeedback.is_positive),
|
|
),
|
|
(
|
|
case(
|
|
{literal(feedback_filter == QAFeedbackType.DISLIKE): True},
|
|
else_=False,
|
|
),
|
|
func.bool_and(func.not_(ChatMessageFeedback.is_positive)),
|
|
),
|
|
else_=func.bool_or(ChatMessageFeedback.is_positive)
|
|
& func.bool_or(func.not_(ChatMessageFeedback.is_positive)),
|
|
)
|
|
)
|
|
)
|
|
conditions.append(ChatSession.id.in_(feedback_subq))
|
|
|
|
return conditions
|
|
|
|
|
|
def get_total_filtered_chat_sessions_count(
|
|
db_session: Session,
|
|
start_time: datetime | None,
|
|
end_time: datetime | None,
|
|
feedback_filter: QAFeedbackType | None,
|
|
) -> int:
|
|
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
|
stmt = (
|
|
select(func.count(distinct(ChatSession.id)))
|
|
.select_from(ChatSession)
|
|
.filter(*conditions)
|
|
)
|
|
return db_session.scalar(stmt) or 0
|
|
|
|
|
|
def get_page_of_chat_sessions(
|
|
start_time: datetime | None,
|
|
end_time: datetime | None,
|
|
db_session: Session,
|
|
page_num: int,
|
|
page_size: int,
|
|
feedback_filter: QAFeedbackType | None = None,
|
|
) -> Sequence[ChatSession]:
|
|
conditions = _build_filter_conditions(start_time, end_time, feedback_filter)
|
|
|
|
subquery = (
|
|
select(ChatSession.id)
|
|
.filter(*conditions)
|
|
.order_by(desc(ChatSession.time_created), ChatSession.id)
|
|
.limit(page_size)
|
|
.offset(page_num * page_size)
|
|
.subquery()
|
|
)
|
|
|
|
stmt = (
|
|
select(ChatSession)
|
|
.join(subquery, ChatSession.id == subquery.c.id)
|
|
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
|
|
.options(
|
|
joinedload(ChatSession.user),
|
|
joinedload(ChatSession.persona),
|
|
contains_eager(ChatSession.messages).joinedload(
|
|
ChatMessage.chat_message_feedbacks
|
|
),
|
|
)
|
|
.order_by(
|
|
desc(ChatSession.time_created),
|
|
ChatSession.id,
|
|
asc(ChatMessage.id), # Ensure chronological message order
|
|
)
|
|
)
|
|
|
|
return db_session.scalars(stmt).unique().all()
|
|
|
|
|
|
def fetch_chat_sessions_eagerly_by_time(
|
|
start: datetime,
|
|
end: datetime,
|
|
db_session: Session,
|
|
limit: int | None = 500,
|
|
initial_time: datetime | None = None,
|
|
) -> list[ChatSession]:
|
|
time_order: UnaryExpression = desc(ChatSession.time_created)
|
|
message_order: UnaryExpression = asc(ChatMessage.id)
|
|
|
|
filters: list[ColumnElement | BinaryExpression] = [
|
|
ChatSession.time_created.between(start, end)
|
|
]
|
|
|
|
if initial_time:
|
|
filters.append(ChatSession.time_created > initial_time)
|
|
|
|
subquery = (
|
|
db_session.query(ChatSession.id, ChatSession.time_created)
|
|
.filter(*filters)
|
|
.order_by(ChatSession.id, time_order)
|
|
.distinct(ChatSession.id)
|
|
.limit(limit)
|
|
.subquery()
|
|
)
|
|
|
|
query = (
|
|
db_session.query(ChatSession)
|
|
.join(subquery, ChatSession.id == subquery.c.id)
|
|
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
|
|
.options(
|
|
joinedload(ChatSession.user),
|
|
joinedload(ChatSession.persona),
|
|
contains_eager(ChatSession.messages).joinedload(
|
|
ChatMessage.chat_message_feedbacks
|
|
),
|
|
)
|
|
.order_by(time_order, message_order)
|
|
)
|
|
|
|
chat_sessions = query.all()
|
|
|
|
return chat_sessions
|