mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-27 02:02:18 +01:00
100 lines
2.9 KiB
Python
100 lines
2.9 KiB
Python
from sqlalchemy import select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.orm.attributes import flag_modified
|
|
|
|
from onyx.configs.constants import MilestoneRecordType
|
|
from onyx.db.models import Milestone
|
|
from onyx.db.models import User
|
|
|
|
|
|
USER_ASSISTANT_PREFIX = "user_assistants_used_"
|
|
MULTI_ASSISTANT_USED = "multi_assistant_used"
|
|
|
|
|
|
def create_milestone(
|
|
user: User | None,
|
|
event_type: MilestoneRecordType,
|
|
db_session: Session,
|
|
) -> Milestone:
|
|
milestone = Milestone(
|
|
event_type=event_type,
|
|
user_id=user.id if user else None,
|
|
)
|
|
db_session.add(milestone)
|
|
db_session.commit()
|
|
|
|
return milestone
|
|
|
|
|
|
def create_milestone_if_not_exists(
|
|
user: User | None, event_type: MilestoneRecordType, db_session: Session
|
|
) -> tuple[Milestone, bool]:
|
|
# Check if it exists
|
|
milestone = db_session.execute(
|
|
select(Milestone).where(Milestone.event_type == event_type)
|
|
).scalar_one_or_none()
|
|
|
|
if milestone is not None:
|
|
return milestone, False
|
|
|
|
# If it doesn't exist, try to create it.
|
|
try:
|
|
milestone = create_milestone(user, event_type, db_session)
|
|
return milestone, True
|
|
except IntegrityError:
|
|
# Another thread or process inserted it in the meantime
|
|
db_session.rollback()
|
|
# Fetch again to return the existing record
|
|
milestone = db_session.execute(
|
|
select(Milestone).where(Milestone.event_type == event_type)
|
|
).scalar_one() # Now should exist
|
|
return milestone, False
|
|
|
|
|
|
def update_user_assistant_milestone(
|
|
milestone: Milestone,
|
|
user_id: str | None,
|
|
assistant_id: int,
|
|
db_session: Session,
|
|
) -> None:
|
|
event_tracker = milestone.event_tracker
|
|
if event_tracker is None:
|
|
milestone.event_tracker = event_tracker = {}
|
|
|
|
if event_tracker.get(MULTI_ASSISTANT_USED):
|
|
# No need to keep tracking and populating if the milestone has already been hit
|
|
return
|
|
|
|
user_key = f"{USER_ASSISTANT_PREFIX}{user_id}"
|
|
|
|
if event_tracker.get(user_key) is None:
|
|
event_tracker[user_key] = [assistant_id]
|
|
elif assistant_id not in event_tracker[user_key]:
|
|
event_tracker[user_key].append(assistant_id)
|
|
|
|
flag_modified(milestone, "event_tracker")
|
|
db_session.commit()
|
|
|
|
|
|
def check_multi_assistant_milestone(
|
|
milestone: Milestone,
|
|
db_session: Session,
|
|
) -> tuple[bool, bool]:
|
|
"""Returns if the milestone was hit and if it was just hit for the first time"""
|
|
event_tracker = milestone.event_tracker
|
|
if event_tracker is None:
|
|
return False, False
|
|
|
|
if event_tracker.get(MULTI_ASSISTANT_USED):
|
|
return True, False
|
|
|
|
for key, value in event_tracker.items():
|
|
if key.startswith(USER_ASSISTANT_PREFIX) and len(value) > 1:
|
|
event_tracker[MULTI_ASSISTANT_USED] = True
|
|
flag_modified(milestone, "event_tracker")
|
|
db_session.commit()
|
|
return True, True
|
|
|
|
return False, False
|