Bugfix/usage report (#3075)

* fix pagination

* update side

* fixed query history

* minor update

* minor update

* typing
This commit is contained in:
pablodanswer 2024-11-23 12:11:39 -08:00 committed by GitHub
parent d9b87bbbc2
commit 8ae6b1960b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 34 additions and 18 deletions

View File

@ -8,6 +8,7 @@ from sqlalchemy import desc
from sqlalchemy.orm import contains_eager from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import UnaryExpression
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession from danswer.db.models import ChatSession
@ -20,22 +21,22 @@ def fetch_chat_sessions_eagerly_by_time(
end: datetime.datetime, end: datetime.datetime,
db_session: Session, db_session: Session,
limit: int | None = 500, limit: int | None = 500,
initial_id: int | None = None, initial_time: datetime.datetime | None = None,
) -> list[ChatSession]: ) -> list[ChatSession]:
id_order = desc(ChatSession.id) # type: ignore time_order: UnaryExpression = desc(ChatSession.time_created)
time_order = desc(ChatSession.time_created) # type: ignore message_order: UnaryExpression = asc(ChatMessage.id)
message_order = asc(ChatMessage.id) # type: ignore
filters: list[ColumnElement | BinaryExpression] = [ filters: list[ColumnElement | BinaryExpression] = [
ChatSession.time_created.between(start, end) ChatSession.time_created.between(start, end)
] ]
if initial_id: if initial_time:
filters.append(ChatSession.id < initial_id) filters.append(ChatSession.time_created > initial_time)
subquery = ( subquery = (
db_session.query(ChatSession.id, ChatSession.time_created) db_session.query(ChatSession.id, ChatSession.time_created)
.filter(*filters) .filter(*filters)
.order_by(id_order, time_order) .order_by(ChatSession.id, time_order)
.distinct(ChatSession.id) .distinct(ChatSession.id)
.limit(limit) .limit(limit)
.subquery() .subquery()
@ -43,7 +44,7 @@ def fetch_chat_sessions_eagerly_by_time(
query = ( query = (
db_session.query(ChatSession) db_session.query(ChatSession)
.join(subquery, ChatSession.id == subquery.c.id) # type: ignore .join(subquery, ChatSession.id == subquery.c.id)
.outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id) .outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id)
.options( .options(
joinedload(ChatSession.user), joinedload(ChatSession.user),

View File

@ -2,6 +2,7 @@ import uuid
from collections.abc import Generator from collections.abc import Generator
from datetime import datetime from datetime import datetime
from typing import IO from typing import IO
from typing import Optional
from fastapi_users_db_sqlalchemy import UUID_ID from fastapi_users_db_sqlalchemy import UUID_ID
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@ -19,11 +20,15 @@ from ee.danswer.server.reporting.usage_export_models import UsageReportMetadata
def get_empty_chat_messages_entries__paginated( def get_empty_chat_messages_entries__paginated(
db_session: Session, db_session: Session,
period: tuple[datetime, datetime], period: tuple[datetime, datetime],
limit: int | None = 1, limit: int | None = 500,
initial_id: int | None = None, initial_time: datetime | None = None,
) -> list[ChatMessageSkeleton]: ) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
chat_sessions = fetch_chat_sessions_eagerly_by_time( chat_sessions = fetch_chat_sessions_eagerly_by_time(
period[0], period[1], db_session, limit=limit, initial_id=initial_id start=period[0],
end=period[1],
db_session=db_session,
limit=limit,
initial_time=initial_time,
) )
message_skeletons: list[ChatMessageSkeleton] = [] message_skeletons: list[ChatMessageSkeleton] = []
@ -36,7 +41,7 @@ def get_empty_chat_messages_entries__paginated(
flow_type = FlowType.CHAT flow_type = FlowType.CHAT
for message in chat_session.messages: for message in chat_session.messages:
# only count user messages # Only count user messages
if message.message_type != MessageType.USER: if message.message_type != MessageType.USER:
continue continue
@ -49,24 +54,34 @@ def get_empty_chat_messages_entries__paginated(
time_sent=message.time_sent, time_sent=message.time_sent,
) )
) )
if len(chat_sessions) == 0:
return None, []
return message_skeletons return chat_sessions[0].time_created, message_skeletons
def get_all_empty_chat_message_entries( def get_all_empty_chat_message_entries(
db_session: Session, db_session: Session,
period: tuple[datetime, datetime], period: tuple[datetime, datetime],
) -> Generator[list[ChatMessageSkeleton], None, None]: ) -> Generator[list[ChatMessageSkeleton], None, None]:
initial_id = None initial_time: Optional[datetime] = period[0]
ind = 0
while True: while True:
message_skeletons = get_empty_chat_messages_entries__paginated( ind += 1
db_session, period, initial_id=initial_id
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
db_session,
period,
initial_time=initial_time,
) )
if not message_skeletons: if not message_skeletons:
return return
yield message_skeletons yield message_skeletons
initial_id = message_skeletons[-1].chat_session_id
# Update initial_time for the next iteration
initial_time = time_created
def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]: def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]: