danswer/backend/onyx/db/chat_search.py
pablonyx abb74f2eaa
Improved chat search (#4137)
* functional + fast

* k

* adapt

* k

* nit

* k

* k

* fix typing

* k
2025-02-27 02:27:45 +00:00

112 lines
3.4 KiB
Python

from typing import List
from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import column
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
def search_chat_sessions(
user_id: UUID | None,
db_session: Session,
query: Optional[str] = None,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False,
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Fast full-text search on ChatSession + ChatMessage using tsvectors.
If no query is provided, returns the most recent chat sessions.
Otherwise, searches both chat messages and session descriptions.
Returns a tuple of (sessions, has_more) where has_more indicates if
there are additional results beyond the requested page.
"""
offset_val = (page - 1) * page_size
# If no query, just return the most recent sessions
if not query or not query.strip():
stmt = (
select(ChatSession)
.order_by(desc(ChatSession.time_created))
.offset(offset_val)
.limit(page_size + 1)
)
if user_id is not None:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
sessions = result.scalars().all()
has_more = len(sessions) > page_size
if has_more:
sessions = sessions[:page_size]
return list(sessions), has_more
# Otherwise, proceed with full-text search
query = query.strip()
base_conditions = []
if user_id is not None:
base_conditions.append(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
base_conditions.append(ChatSession.deleted.is_(False))
message_tsv: ColumnClause = column("message_tsv")
description_tsv: ColumnClause = column("description_tsv")
ts_query = func.plainto_tsquery("english", query)
description_session_ids = (
select(ChatSession.id)
.where(*base_conditions)
.where(description_tsv.op("@@")(ts_query))
)
message_session_ids = (
select(ChatMessage.chat_session_id)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(*base_conditions)
.where(message_tsv.op("@@")(ts_query))
)
combined_ids = description_session_ids.union(message_session_ids).alias(
"combined_ids"
)
final_stmt = (
select(ChatSession)
.join(combined_ids, ChatSession.id == combined_ids.c.id)
.order_by(desc(ChatSession.time_created))
.distinct()
.offset(offset_val)
.limit(page_size + 1)
.options(joinedload(ChatSession.persona))
)
session_objs = db_session.execute(final_stmt).scalars().all()
has_more = len(session_objs) > page_size
if has_more:
session_objs = session_objs[:page_size]
return list(session_objs), has_more