mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-28 18:52:31 +01:00
fix usage report pagination (#4183)
* early work in progress * rename utility script * move actual data seeding to a shareable function * add test * make the test pass with the fix * fix comment --------- Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
This commit is contained in:
parent
b6e9e65bb8
commit
a7acc07e79
@ -134,7 +134,9 @@ def fetch_chat_sessions_eagerly_by_time(
|
|||||||
limit: int | None = 500,
|
limit: int | None = 500,
|
||||||
initial_time: datetime | None = None,
|
initial_time: datetime | None = None,
|
||||||
) -> list[ChatSession]:
|
) -> list[ChatSession]:
|
||||||
time_order: UnaryExpression = desc(ChatSession.time_created)
|
"""Sorted by oldest to newest, then by message id"""
|
||||||
|
|
||||||
|
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
|
||||||
message_order: UnaryExpression = asc(ChatMessage.id)
|
message_order: UnaryExpression = asc(ChatMessage.id)
|
||||||
|
|
||||||
filters: list[ColumnElement | BinaryExpression] = [
|
filters: list[ColumnElement | BinaryExpression] = [
|
||||||
@ -147,8 +149,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
|||||||
subquery = (
|
subquery = (
|
||||||
db_session.query(ChatSession.id, ChatSession.time_created)
|
db_session.query(ChatSession.id, ChatSession.time_created)
|
||||||
.filter(*filters)
|
.filter(*filters)
|
||||||
.order_by(ChatSession.id, time_order)
|
.order_by(asc_time_order)
|
||||||
.distinct(ChatSession.id)
|
|
||||||
.limit(limit)
|
.limit(limit)
|
||||||
.subquery()
|
.subquery()
|
||||||
)
|
)
|
||||||
@ -164,7 +165,7 @@ def fetch_chat_sessions_eagerly_by_time(
|
|||||||
ChatMessage.chat_message_feedbacks
|
ChatMessage.chat_message_feedbacks
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
.order_by(time_order, message_order)
|
.order_by(asc_time_order, message_order)
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_sessions = query.all()
|
chat_sessions = query.all()
|
||||||
|
@ -16,13 +16,18 @@ from onyx.db.models import UsageReport
|
|||||||
from onyx.file_store.file_store import get_default_file_store
|
from onyx.file_store.file_store import get_default_file_store
|
||||||
|
|
||||||
|
|
||||||
# Gets skeletons of all message
|
# Gets skeletons of all messages in the given range
|
||||||
def get_empty_chat_messages_entries__paginated(
|
def get_empty_chat_messages_entries__paginated(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
period: tuple[datetime, datetime],
|
period: tuple[datetime, datetime],
|
||||||
limit: int | None = 500,
|
limit: int | None = 500,
|
||||||
initial_time: datetime | None = None,
|
initial_time: datetime | None = None,
|
||||||
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
|
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
|
||||||
|
"""Returns a tuple where:
|
||||||
|
first element is the most recent timestamp out of the sessions iterated
|
||||||
|
- this timestamp can be used to paginate forward in time
|
||||||
|
second element is a list of messages belonging to all the sessions iterated
|
||||||
|
"""
|
||||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||||
start=period[0],
|
start=period[0],
|
||||||
end=period[1],
|
end=period[1],
|
||||||
@ -52,18 +57,17 @@ def get_empty_chat_messages_entries__paginated(
|
|||||||
if len(chat_sessions) == 0:
|
if len(chat_sessions) == 0:
|
||||||
return None, []
|
return None, []
|
||||||
|
|
||||||
return chat_sessions[0].time_created, message_skeletons
|
return chat_sessions[-1].time_created, message_skeletons
|
||||||
|
|
||||||
|
|
||||||
def get_all_empty_chat_message_entries(
|
def get_all_empty_chat_message_entries(
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
period: tuple[datetime, datetime],
|
period: tuple[datetime, datetime],
|
||||||
) -> Generator[list[ChatMessageSkeleton], None, None]:
|
) -> Generator[list[ChatMessageSkeleton], None, None]:
|
||||||
|
"""period is the range of time over which to fetch messages."""
|
||||||
initial_time: Optional[datetime] = period[0]
|
initial_time: Optional[datetime] = period[0]
|
||||||
ind = 0
|
|
||||||
while True:
|
while True:
|
||||||
ind += 1
|
# iterate from oldest to newest
|
||||||
|
|
||||||
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
|
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
|
||||||
db_session,
|
db_session,
|
||||||
period,
|
period,
|
||||||
|
@ -144,6 +144,12 @@ class OnyxConfluence:
|
|||||||
self.static_credentials = credential_json
|
self.static_credentials = credential_json
|
||||||
return credential_json, False
|
return credential_json, False
|
||||||
|
|
||||||
|
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID:
|
||||||
|
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!")
|
||||||
|
|
||||||
|
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET:
|
||||||
|
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!")
|
||||||
|
|
||||||
# check if we should refresh tokens. we're deciding to refresh halfway
|
# check if we should refresh tokens. we're deciding to refresh halfway
|
||||||
# to expiration
|
# to expiration
|
||||||
now = datetime.now(timezone.utc)
|
now = datetime.now(timezone.utc)
|
||||||
|
53
backend/onyx/db/seeding/chat_history_seeding.py
Normal file
53
backend/onyx/db/seeding/chat_history_seeding.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import random
|
||||||
|
from datetime import datetime
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
from onyx.configs.constants import MessageType
|
||||||
|
from onyx.db.chat import create_chat_session
|
||||||
|
from onyx.db.chat import create_new_chat_message
|
||||||
|
from onyx.db.chat import get_or_create_root_message
|
||||||
|
from onyx.db.engine import get_session_with_current_tenant
|
||||||
|
from onyx.db.models import ChatSession
|
||||||
|
|
||||||
|
|
||||||
|
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||||
|
"""Utility function to seed chat history for testing.
|
||||||
|
|
||||||
|
num_sessions: the number of sessions to seed
|
||||||
|
num_messages: the number of messages to seed per sessions
|
||||||
|
days: the number of days looking backwards from the current time over which to randomize
|
||||||
|
the times.
|
||||||
|
"""
|
||||||
|
with get_session_with_current_tenant() as db_session:
|
||||||
|
for y in range(0, num_sessions):
|
||||||
|
create_chat_session(db_session, f"pytest_session_{y}", None, None)
|
||||||
|
|
||||||
|
# randomize all session times
|
||||||
|
rows = db_session.query(ChatSession).all()
|
||||||
|
for row in rows:
|
||||||
|
row.time_created = datetime.utcnow() - timedelta(
|
||||||
|
days=random.randint(0, days)
|
||||||
|
)
|
||||||
|
row.time_updated = row.time_created + timedelta(
|
||||||
|
minutes=random.randint(0, 10)
|
||||||
|
)
|
||||||
|
|
||||||
|
root_message = get_or_create_root_message(row.id, db_session)
|
||||||
|
|
||||||
|
for x in range(0, num_messages):
|
||||||
|
chat_message = create_new_chat_message(
|
||||||
|
row.id,
|
||||||
|
root_message,
|
||||||
|
f"pytest_message_{x}",
|
||||||
|
None,
|
||||||
|
0,
|
||||||
|
MessageType.USER,
|
||||||
|
db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_message.time_sent = row.time_created + timedelta(
|
||||||
|
minutes=random.randint(0, 10)
|
||||||
|
)
|
||||||
|
db_session.commit()
|
||||||
|
|
||||||
|
db_session.commit()
|
45
backend/scripts/chat_history_seeding.py
Normal file
45
backend/scripts/chat_history_seeding.py
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
from onyx.db.seeding.chat_history_seeding import seed_chat_history
|
||||||
|
|
||||||
|
# Configure the logger
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO, # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Log format
|
||||||
|
handlers=[logging.StreamHandler()], # Output logs to console
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def go_main(num_sessions: int, num_messages: int, num_days: int) -> None:
|
||||||
|
seed_chat_history(num_sessions, num_messages, num_days)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Seed chat history")
|
||||||
|
parser.add_argument(
|
||||||
|
"--sessions",
|
||||||
|
type=int,
|
||||||
|
default=2048,
|
||||||
|
help="Number of chat sessions to seed",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--messages",
|
||||||
|
type=int,
|
||||||
|
default=4,
|
||||||
|
help="Number of chat messages to seed per session",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--days",
|
||||||
|
type=int,
|
||||||
|
default=90,
|
||||||
|
help="Number of days looking backwards over which to seed the timestamps with",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
go_main(args.sessions, args.messages, args.days)
|
@ -0,0 +1,46 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from datetime import timedelta
|
||||||
|
from datetime import timezone
|
||||||
|
|
||||||
|
from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
|
||||||
|
from onyx.db.engine import get_session_with_current_tenant
|
||||||
|
from onyx.db.seeding.chat_history_seeding import seed_chat_history
|
||||||
|
|
||||||
|
|
||||||
|
def test_usage_reports(reset: None) -> None:
|
||||||
|
EXPECTED_SESSIONS = 2048
|
||||||
|
MESSAGES_PER_SESSION = 4
|
||||||
|
EXPECTED_MESSAGES = EXPECTED_SESSIONS * MESSAGES_PER_SESSION
|
||||||
|
|
||||||
|
seed_chat_history(EXPECTED_SESSIONS, MESSAGES_PER_SESSION, 90)
|
||||||
|
|
||||||
|
with get_session_with_current_tenant() as db_session:
|
||||||
|
# count of all entries should be exact
|
||||||
|
period = (
|
||||||
|
datetime.fromtimestamp(0, tz=timezone.utc),
|
||||||
|
datetime.now(tz=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for entry_batch in get_all_empty_chat_message_entries(db_session, period):
|
||||||
|
for entry in entry_batch:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
assert count == EXPECTED_MESSAGES
|
||||||
|
|
||||||
|
# count in a one month time range should be within a certain range statistically
|
||||||
|
# this can be improved if we seed the chat history data deterministically
|
||||||
|
period = (
|
||||||
|
datetime.now(tz=timezone.utc) - timedelta(days=30),
|
||||||
|
datetime.now(tz=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
for entry_batch in get_all_empty_chat_message_entries(db_session, period):
|
||||||
|
for entry in entry_batch:
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
lower = EXPECTED_MESSAGES // 3 - (EXPECTED_MESSAGES // (3 * 3))
|
||||||
|
upper = EXPECTED_MESSAGES // 3 + (EXPECTED_MESSAGES // (3 * 3))
|
||||||
|
assert count > lower
|
||||||
|
assert count < upper
|
Loading…
x
Reference in New Issue
Block a user