mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
112 lines
3.4 KiB
Python
112 lines
3.4 KiB
Python
from typing import List
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import column
|
|
from sqlalchemy import desc
|
|
from sqlalchemy import func
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import joinedload
|
|
from sqlalchemy.orm import Session
|
|
from sqlalchemy.sql.expression import ColumnClause
|
|
|
|
from onyx.db.models import ChatMessage
|
|
from onyx.db.models import ChatSession
|
|
|
|
|
|
def search_chat_sessions(
|
|
user_id: UUID | None,
|
|
db_session: Session,
|
|
query: Optional[str] = None,
|
|
page: int = 1,
|
|
page_size: int = 10,
|
|
include_deleted: bool = False,
|
|
include_onyxbot_flows: bool = False,
|
|
) -> Tuple[List[ChatSession], bool]:
|
|
"""
|
|
Fast full-text search on ChatSession + ChatMessage using tsvectors.
|
|
|
|
If no query is provided, returns the most recent chat sessions.
|
|
Otherwise, searches both chat messages and session descriptions.
|
|
|
|
Returns a tuple of (sessions, has_more) where has_more indicates if
|
|
there are additional results beyond the requested page.
|
|
"""
|
|
offset_val = (page - 1) * page_size
|
|
|
|
# If no query, just return the most recent sessions
|
|
if not query or not query.strip():
|
|
stmt = (
|
|
select(ChatSession)
|
|
.order_by(desc(ChatSession.time_created))
|
|
.offset(offset_val)
|
|
.limit(page_size + 1)
|
|
)
|
|
if user_id is not None:
|
|
stmt = stmt.where(ChatSession.user_id == user_id)
|
|
if not include_onyxbot_flows:
|
|
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
|
if not include_deleted:
|
|
stmt = stmt.where(ChatSession.deleted.is_(False))
|
|
|
|
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
|
sessions = result.scalars().all()
|
|
|
|
has_more = len(sessions) > page_size
|
|
if has_more:
|
|
sessions = sessions[:page_size]
|
|
|
|
return list(sessions), has_more
|
|
|
|
# Otherwise, proceed with full-text search
|
|
query = query.strip()
|
|
|
|
base_conditions = []
|
|
if user_id is not None:
|
|
base_conditions.append(ChatSession.user_id == user_id)
|
|
if not include_onyxbot_flows:
|
|
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
|
|
if not include_deleted:
|
|
base_conditions.append(ChatSession.deleted.is_(False))
|
|
|
|
message_tsv: ColumnClause = column("message_tsv")
|
|
description_tsv: ColumnClause = column("description_tsv")
|
|
|
|
ts_query = func.plainto_tsquery("english", query)
|
|
|
|
description_session_ids = (
|
|
select(ChatSession.id)
|
|
.where(*base_conditions)
|
|
.where(description_tsv.op("@@")(ts_query))
|
|
)
|
|
|
|
message_session_ids = (
|
|
select(ChatMessage.chat_session_id)
|
|
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
|
.where(*base_conditions)
|
|
.where(message_tsv.op("@@")(ts_query))
|
|
)
|
|
|
|
combined_ids = description_session_ids.union(message_session_ids).alias(
|
|
"combined_ids"
|
|
)
|
|
|
|
final_stmt = (
|
|
select(ChatSession)
|
|
.join(combined_ids, ChatSession.id == combined_ids.c.id)
|
|
.order_by(desc(ChatSession.time_created))
|
|
.distinct()
|
|
.offset(offset_val)
|
|
.limit(page_size + 1)
|
|
.options(joinedload(ChatSession.persona))
|
|
)
|
|
|
|
session_objs = db_session.execute(final_stmt).scalars().all()
|
|
|
|
has_more = len(session_objs) > page_size
|
|
if has_more:
|
|
session_objs = session_objs[:page_size]
|
|
|
|
return list(session_objs), has_more
|