diff --git a/backend/ee/danswer/db/query_history.py b/backend/ee/danswer/db/query_history.py index b6a79cb77..8fb77f0a2 100644 --- a/backend/ee/danswer/db/query_history.py +++ b/backend/ee/danswer/db/query_history.py @@ -8,6 +8,7 @@ from sqlalchemy import desc from sqlalchemy.orm import contains_eager from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import UnaryExpression from danswer.db.models import ChatMessage from danswer.db.models import ChatSession @@ -20,22 +21,22 @@ def fetch_chat_sessions_eagerly_by_time( end: datetime.datetime, db_session: Session, limit: int | None = 500, - initial_id: int | None = None, + initial_time: datetime.datetime | None = None, ) -> list[ChatSession]: - id_order = desc(ChatSession.id) # type: ignore - time_order = desc(ChatSession.time_created) # type: ignore - message_order = asc(ChatMessage.id) # type: ignore + time_order: UnaryExpression = desc(ChatSession.time_created) + message_order: UnaryExpression = asc(ChatMessage.id) filters: list[ColumnElement | BinaryExpression] = [ ChatSession.time_created.between(start, end) ] - if initial_id: - filters.append(ChatSession.id < initial_id) + if initial_time: + filters.append(ChatSession.time_created > initial_time) + subquery = ( db_session.query(ChatSession.id, ChatSession.time_created) .filter(*filters) - .order_by(id_order, time_order) + .order_by(ChatSession.id, time_order) .distinct(ChatSession.id) .limit(limit) .subquery() @@ -43,7 +44,7 @@ def fetch_chat_sessions_eagerly_by_time( query = ( db_session.query(ChatSession) - .join(subquery, ChatSession.id == subquery.c.id) # type: ignore + .join(subquery, ChatSession.id == subquery.c.id) .outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id) .options( joinedload(ChatSession.user), diff --git a/backend/ee/danswer/db/usage_export.py b/backend/ee/danswer/db/usage_export.py index 6642def13..074e1ae7d 100644 --- a/backend/ee/danswer/db/usage_export.py +++ b/backend/ee/danswer/db/usage_export.py @@ -2,6 +2,7 @@ import uuid from collections.abc import Generator from datetime import datetime from typing import IO +from typing import Optional from fastapi_users_db_sqlalchemy import UUID_ID from sqlalchemy.orm import Session @@ -19,11 +20,15 @@ from ee.danswer.server.reporting.usage_export_models import UsageReportMetadata def get_empty_chat_messages_entries__paginated( db_session: Session, period: tuple[datetime, datetime], - limit: int | None = 1, - initial_id: int | None = None, -) -> list[ChatMessageSkeleton]: + limit: int | None = 500, + initial_time: datetime | None = None, +) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]: chat_sessions = fetch_chat_sessions_eagerly_by_time( - period[0], period[1], db_session, limit=limit, initial_id=initial_id + start=period[0], + end=period[1], + db_session=db_session, + limit=limit, + initial_time=initial_time, ) message_skeletons: list[ChatMessageSkeleton] = [] @@ -36,7 +41,7 @@ def get_empty_chat_messages_entries__paginated( flow_type = FlowType.CHAT for message in chat_session.messages: - # only count user messages + # Only count user messages if message.message_type != MessageType.USER: continue @@ -49,24 +54,34 @@ def get_empty_chat_messages_entries__paginated( time_sent=message.time_sent, ) ) + if len(chat_sessions) == 0: + return None, [] - return message_skeletons + return chat_sessions[0].time_created, message_skeletons def get_all_empty_chat_message_entries( db_session: Session, period: tuple[datetime, datetime], ) -> Generator[list[ChatMessageSkeleton], None, None]: - initial_id = None + initial_time: Optional[datetime] = period[0] + ind = 0 while True: - message_skeletons = get_empty_chat_messages_entries__paginated( - db_session, period, initial_id=initial_id + ind += 1 + + time_created, message_skeletons = get_empty_chat_messages_entries__paginated( + db_session, + period, + initial_time=initial_time, ) + if not message_skeletons: return yield message_skeletons - initial_id = message_skeletons[-1].chat_session_id + + # Update initial_time for the next iteration + initial_time = time_created def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]: