diff --git a/backend/ee/onyx/db/query_history.py b/backend/ee/onyx/db/query_history.py index 5b961fcb7..bddef18a4 100644 --- a/backend/ee/onyx/db/query_history.py +++ b/backend/ee/onyx/db/query_history.py @@ -134,7 +134,9 @@ def fetch_chat_sessions_eagerly_by_time( limit: int | None = 500, initial_time: datetime | None = None, ) -> list[ChatSession]: - time_order: UnaryExpression = desc(ChatSession.time_created) + """Sorted by oldest to newest, then by message id""" + + asc_time_order: UnaryExpression = asc(ChatSession.time_created) message_order: UnaryExpression = asc(ChatMessage.id) filters: list[ColumnElement | BinaryExpression] = [ @@ -147,8 +149,7 @@ def fetch_chat_sessions_eagerly_by_time( subquery = ( db_session.query(ChatSession.id, ChatSession.time_created) .filter(*filters) - .order_by(ChatSession.id, time_order) - .distinct(ChatSession.id) + .order_by(asc_time_order) .limit(limit) .subquery() ) @@ -164,7 +165,7 @@ def fetch_chat_sessions_eagerly_by_time( ChatMessage.chat_message_feedbacks ), ) - .order_by(time_order, message_order) + .order_by(asc_time_order, message_order) ) chat_sessions = query.all() diff --git a/backend/ee/onyx/db/usage_export.py b/backend/ee/onyx/db/usage_export.py index affd61e0a..9cd9756ce 100644 --- a/backend/ee/onyx/db/usage_export.py +++ b/backend/ee/onyx/db/usage_export.py @@ -16,13 +16,18 @@ from onyx.db.models import UsageReport from onyx.file_store.file_store import get_default_file_store -# Gets skeletons of all message +# Gets skeletons of all messages in the given range def get_empty_chat_messages_entries__paginated( db_session: Session, period: tuple[datetime, datetime], limit: int | None = 500, initial_time: datetime | None = None, ) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]: + """Returns a tuple where: + first element is the most recent timestamp out of the sessions iterated + - this timestamp can be used to paginate forward in time + second element is a list of messages belonging to all the sessions iterated + """ chat_sessions = fetch_chat_sessions_eagerly_by_time( start=period[0], end=period[1], @@ -52,18 +57,17 @@ def get_empty_chat_messages_entries__paginated( if len(chat_sessions) == 0: return None, [] - return chat_sessions[0].time_created, message_skeletons + return chat_sessions[-1].time_created, message_skeletons def get_all_empty_chat_message_entries( db_session: Session, period: tuple[datetime, datetime], ) -> Generator[list[ChatMessageSkeleton], None, None]: + """period is the range of time over which to fetch messages.""" initial_time: Optional[datetime] = period[0] - ind = 0 while True: - ind += 1 - + # iterate from oldest to newest time_created, message_skeletons = get_empty_chat_messages_entries__paginated( db_session, period, diff --git a/backend/onyx/connectors/confluence/onyx_confluence.py b/backend/onyx/connectors/confluence/onyx_confluence.py index 147ed82c6..427866a78 100644 --- a/backend/onyx/connectors/confluence/onyx_confluence.py +++ b/backend/onyx/connectors/confluence/onyx_confluence.py @@ -144,6 +144,12 @@ class OnyxConfluence: self.static_credentials = credential_json return credential_json, False + if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID: + raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!") + + if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET: + raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!") + # check if we should refresh tokens. we're deciding to refresh halfway # to expiration now = datetime.now(timezone.utc) diff --git a/backend/onyx/db/seeding/chat_history_seeding.py b/backend/onyx/db/seeding/chat_history_seeding.py new file mode 100644 index 000000000..ee0c558ae --- /dev/null +++ b/backend/onyx/db/seeding/chat_history_seeding.py @@ -0,0 +1,53 @@ +import random +from datetime import datetime +from datetime import timedelta + +from onyx.configs.constants import MessageType +from onyx.db.chat import create_chat_session +from onyx.db.chat import create_new_chat_message +from onyx.db.chat import get_or_create_root_message +from onyx.db.engine import get_session_with_current_tenant +from onyx.db.models import ChatSession + + +def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None: + """Utility function to seed chat history for testing. + + num_sessions: the number of sessions to seed + num_messages: the number of messages to seed per sessions + days: the number of days looking backwards from the current time over which to randomize + the times. + """ + with get_session_with_current_tenant() as db_session: + for y in range(0, num_sessions): + create_chat_session(db_session, f"pytest_session_{y}", None, None) + + # randomize all session times + rows = db_session.query(ChatSession).all() + for row in rows: + row.time_created = datetime.utcnow() - timedelta( + days=random.randint(0, days) + ) + row.time_updated = row.time_created + timedelta( + minutes=random.randint(0, 10) + ) + + root_message = get_or_create_root_message(row.id, db_session) + + for x in range(0, num_messages): + chat_message = create_new_chat_message( + row.id, + root_message, + f"pytest_message_{x}", + None, + 0, + MessageType.USER, + db_session, + ) + + chat_message.time_sent = row.time_created + timedelta( + minutes=random.randint(0, 10) + ) + db_session.commit() + + db_session.commit() diff --git a/backend/scripts/chat_history_seeding.py b/backend/scripts/chat_history_seeding.py new file mode 100644 index 000000000..694ccff86 --- /dev/null +++ b/backend/scripts/chat_history_seeding.py @@ -0,0 +1,45 @@ +import argparse +import logging +from logging import getLogger + +from onyx.db.seeding.chat_history_seeding import seed_chat_history + +# Configure the logger +logging.basicConfig( + level=logging.INFO, # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL) + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Log format + handlers=[logging.StreamHandler()], # Output logs to console +) + +logger = getLogger(__name__) + + +def go_main(num_sessions: int, num_messages: int, num_days: int) -> None: + seed_chat_history(num_sessions, num_messages, num_days) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Seed chat history") + parser.add_argument( + "--sessions", + type=int, + default=2048, + help="Number of chat sessions to seed", + ) + + parser.add_argument( + "--messages", + type=int, + default=4, + help="Number of chat messages to seed per session", + ) + + parser.add_argument( + "--days", + type=int, + default=90, + help="Number of days looking backwards over which to seed the timestamps with", + ) + + args = parser.parse_args() + go_main(args.sessions, args.messages, args.days) diff --git a/backend/tests/integration/tests/query_history/test_usage_reports.py b/backend/tests/integration/tests/query_history/test_usage_reports.py new file mode 100644 index 000000000..3fbe70e9c --- /dev/null +++ b/backend/tests/integration/tests/query_history/test_usage_reports.py @@ -0,0 +1,46 @@ +from datetime import datetime +from datetime import timedelta +from datetime import timezone + +from ee.onyx.db.usage_export import get_all_empty_chat_message_entries +from onyx.db.engine import get_session_with_current_tenant +from onyx.db.seeding.chat_history_seeding import seed_chat_history + + +def test_usage_reports(reset: None) -> None: + EXPECTED_SESSIONS = 2048 + MESSAGES_PER_SESSION = 4 + EXPECTED_MESSAGES = EXPECTED_SESSIONS * MESSAGES_PER_SESSION + + seed_chat_history(EXPECTED_SESSIONS, MESSAGES_PER_SESSION, 90) + + with get_session_with_current_tenant() as db_session: + # count of all entries should be exact + period = ( + datetime.fromtimestamp(0, tz=timezone.utc), + datetime.now(tz=timezone.utc), + ) + + count = 0 + for entry_batch in get_all_empty_chat_message_entries(db_session, period): + for entry in entry_batch: + count += 1 + + assert count == EXPECTED_MESSAGES + + # count in a one month time range should be within a certain range statistically + # this can be improved if we seed the chat history data deterministically + period = ( + datetime.now(tz=timezone.utc) - timedelta(days=30), + datetime.now(tz=timezone.utc), + ) + + count = 0 + for entry_batch in get_all_empty_chat_message_entries(db_session, period): + for entry in entry_batch: + count += 1 + + lower = EXPECTED_MESSAGES // 3 - (EXPECTED_MESSAGES // (3 * 3)) + upper = EXPECTED_MESSAGES // 3 + (EXPECTED_MESSAGES // (3 * 3)) + assert count > lower + assert count < upper