From abb74f2eaa433ff3267fded868166afd090e1230 Mon Sep 17 00:00:00 2001 From: pablonyx Date: Wed, 26 Feb 2025 18:27:45 -0800 Subject: [PATCH] Improved chat search (#4137) * functional + fast * k * adapt * k * nit * k * k * fix typing * k --- .../versions/3bd4c84fe72f_improved_index.py | 84 ++++++++++ backend/onyx/db/chat_search.py | 157 +++++++----------- backend/onyx/db/models.py | 1 + .../app/chat/chat_search/ChatSearchGroup.tsx | 4 +- .../app/chat/chat_search/ChatSearchItem.tsx | 11 +- 5 files changed, 151 insertions(+), 106 deletions(-) create mode 100644 backend/alembic/versions/3bd4c84fe72f_improved_index.py diff --git a/backend/alembic/versions/3bd4c84fe72f_improved_index.py b/backend/alembic/versions/3bd4c84fe72f_improved_index.py new file mode 100644 index 000000000..ab9497619 --- /dev/null +++ b/backend/alembic/versions/3bd4c84fe72f_improved_index.py @@ -0,0 +1,84 @@ +"""improved index + +Revision ID: 3bd4c84fe72f +Revises: 8f43500ee275 +Create Date: 2025-02-26 13:07:56.217791 + +""" +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "3bd4c84fe72f" +down_revision = "8f43500ee275" +branch_labels = None +depends_on = None + + +# NOTE: +# This migration addresses issues with the previous migration (8f43500ee275) which caused +# an outage by creating an index without using CONCURRENTLY. This migration: +# +# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes +# 2. Uses CONCURRENTLY for all index creation to prevent table locking +# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work +# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY) +# (see: https://github.com/sqlalchemy/alembic/issues/277) +# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search + + +def upgrade() -> None: + # Create a GIN index for full-text search on chat_message.message + op.execute( + """ + ALTER TABLE chat_message + ADD COLUMN message_tsv tsvector + GENERATED ALWAYS AS (to_tsvector('english', message)) STORED; + """ + ) + + # Commit the current transaction before creating concurrent indexes + op.execute("COMMIT") + + op.execute( + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv + ON chat_message + USING GIN (message_tsv) + """ + ) + + # Also add a stored tsvector column for chat_session.description + op.execute( + """ + ALTER TABLE chat_session + ADD COLUMN description_tsv tsvector + GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED; + """ + ) + + # Commit again before creating the second concurrent index + op.execute("COMMIT") + + op.execute( + """ + CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv + ON chat_session + USING GIN (description_tsv) + """ + ) + + +def downgrade() -> None: + # Drop the indexes first (use CONCURRENTLY for dropping too) + op.execute("COMMIT") + op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;") + + op.execute("COMMIT") + op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;") + + # Then drop the columns + op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;") + op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;") + + op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;") diff --git a/backend/onyx/db/chat_search.py b/backend/onyx/db/chat_search.py index fd9a69b22..8fb511680 100644 --- a/backend/onyx/db/chat_search.py +++ b/backend/onyx/db/chat_search.py @@ -3,14 +3,13 @@ from typing import Optional from typing import Tuple from uuid import UUID +from sqlalchemy import column from sqlalchemy import desc from sqlalchemy import func -from sqlalchemy import literal -from sqlalchemy import Select from sqlalchemy import select -from sqlalchemy import union_all from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import ColumnClause from onyx.db.models import ChatMessage from onyx.db.models import ChatSession @@ -26,127 +25,87 @@ def search_chat_sessions( include_onyxbot_flows: bool = False, ) -> Tuple[List[ChatSession], bool]: """ - Search for chat sessions based on the provided query. - If no query is provided, returns recent chat sessions. + Fast full-text search on ChatSession + ChatMessage using tsvectors. - Returns a tuple of (chat_sessions, has_more) + If no query is provided, returns the most recent chat sessions. + Otherwise, searches both chat messages and session descriptions. + + Returns a tuple of (sessions, has_more) where has_more indicates if + there are additional results beyond the requested page. """ - offset = (page - 1) * page_size + offset_val = (page - 1) * page_size - # If no search query, we use standard SQLAlchemy pagination + # If no query, just return the most recent sessions if not query or not query.strip(): - stmt = select(ChatSession) - if user_id: + stmt = ( + select(ChatSession) + .order_by(desc(ChatSession.time_created)) + .offset(offset_val) + .limit(page_size + 1) + ) + if user_id is not None: stmt = stmt.where(ChatSession.user_id == user_id) if not include_onyxbot_flows: stmt = stmt.where(ChatSession.onyxbot_flow.is_(False)) if not include_deleted: stmt = stmt.where(ChatSession.deleted.is_(False)) - stmt = stmt.order_by(desc(ChatSession.time_created)) - - # Apply pagination - stmt = stmt.offset(offset).limit(page_size + 1) result = db_session.execute(stmt.options(joinedload(ChatSession.persona))) - chat_sessions = result.scalars().all() + sessions = result.scalars().all() - has_more = len(chat_sessions) > page_size + has_more = len(sessions) > page_size if has_more: - chat_sessions = chat_sessions[:page_size] + sessions = sessions[:page_size] - return list(chat_sessions), has_more + return list(sessions), has_more - words = query.lower().strip().split() + # Otherwise, proceed with full-text search + query = query.strip() - # Message mach subquery - message_matches = [] - for word in words: - word_like = f"%{word}%" - message_match: Select = ( - select(ChatMessage.chat_session_id, literal(1.0).label("search_rank")) - .join(ChatSession, ChatSession.id == ChatMessage.chat_session_id) - .where(func.lower(ChatMessage.message).like(word_like)) - ) - - if user_id: - message_match = message_match.where(ChatSession.user_id == user_id) - - message_matches.append(message_match) - - if message_matches: - message_matches_query = union_all(*message_matches).alias("message_matches") - else: - return [], False - - # Description matches - description_match: Select = select( - ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank") - ).where(func.lower(ChatSession.description).like(f"%{query.lower()}%")) - - if user_id: - description_match = description_match.where(ChatSession.user_id == user_id) + base_conditions = [] + if user_id is not None: + base_conditions.append(ChatSession.user_id == user_id) if not include_onyxbot_flows: - description_match = description_match.where(ChatSession.onyxbot_flow.is_(False)) + base_conditions.append(ChatSession.onyxbot_flow.is_(False)) if not include_deleted: - description_match = description_match.where(ChatSession.deleted.is_(False)) + base_conditions.append(ChatSession.deleted.is_(False)) - # Combine all match sources - combined_matches = union_all( - message_matches_query.select(), description_match - ).alias("combined_matches") + message_tsv: ColumnClause = column("message_tsv") + description_tsv: ColumnClause = column("description_tsv") - # Use CTE to group and get max rank - session_ranks = ( - select( - combined_matches.c.chat_session_id, - func.max(combined_matches.c.search_rank).label("rank"), - ) - .group_by(combined_matches.c.chat_session_id) - .alias("session_ranks") + ts_query = func.plainto_tsquery("english", query) + + description_session_ids = ( + select(ChatSession.id) + .where(*base_conditions) + .where(description_tsv.op("@@")(ts_query)) ) - # Get ranked sessions with pagination - ranked_query = ( - db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank) - .order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id) - .offset(offset) + message_session_ids = ( + select(ChatMessage.chat_session_id) + .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) + .where(*base_conditions) + .where(message_tsv.op("@@")(ts_query)) + ) + + combined_ids = description_session_ids.union(message_session_ids).alias( + "combined_ids" + ) + + final_stmt = ( + select(ChatSession) + .join(combined_ids, ChatSession.id == combined_ids.c.id) + .order_by(desc(ChatSession.time_created)) + .distinct() + .offset(offset_val) .limit(page_size + 1) + .options(joinedload(ChatSession.persona)) ) - result = ranked_query.all() + session_objs = db_session.execute(final_stmt).scalars().all() - # Extract session IDs and ranks - session_ids_with_ranks = {row.chat_session_id: row.rank for row in result} - session_ids = list(session_ids_with_ranks.keys()) - - if not session_ids: - return [], False - - # Now, let's query the actual ChatSession objects using the IDs - stmt = select(ChatSession).where(ChatSession.id.in_(session_ids)) - - if user_id: - stmt = stmt.where(ChatSession.user_id == user_id) - if not include_onyxbot_flows: - stmt = stmt.where(ChatSession.onyxbot_flow.is_(False)) - if not include_deleted: - stmt = stmt.where(ChatSession.deleted.is_(False)) - - # Full objects with eager loading - result = db_session.execute(stmt.options(joinedload(ChatSession.persona))) - chat_sessions = result.scalars().all() - - # Sort based on above ranking - chat_sessions = sorted( - chat_sessions, - key=lambda session: ( - -session_ids_with_ranks.get(session.id, 0), # Rank (higher first) - session.time_created.timestamp() * -1, # Then by time (newest first) - ), - ) - - has_more = len(chat_sessions) > page_size + has_more = len(session_objs) > page_size if has_more: - chat_sessions = chat_sessions[:page_size] + session_objs = session_objs[:page_size] - return chat_sessions, has_more + return list(session_objs), has_more diff --git a/backend/onyx/db/models.py b/backend/onyx/db/models.py index 132b2d63f..0001ec318 100644 --- a/backend/onyx/db/models.py +++ b/backend/onyx/db/models.py @@ -25,6 +25,7 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import Integer + from sqlalchemy import Sequence from sqlalchemy import String from sqlalchemy import Text diff --git a/web/src/app/chat/chat_search/ChatSearchGroup.tsx b/web/src/app/chat/chat_search/ChatSearchGroup.tsx index 6d3dee74c..93be2fa93 100644 --- a/web/src/app/chat/chat_search/ChatSearchGroup.tsx +++ b/web/src/app/chat/chat_search/ChatSearchGroup.tsx @@ -15,8 +15,8 @@ export function ChatSearchGroup({ }: ChatSearchGroupProps) { return (
-
-
+
+
{title}
diff --git a/web/src/app/chat/chat_search/ChatSearchItem.tsx b/web/src/app/chat/chat_search/ChatSearchItem.tsx index 3b0320181..6b72ea305 100644 --- a/web/src/app/chat/chat_search/ChatSearchItem.tsx +++ b/web/src/app/chat/chat_search/ChatSearchItem.tsx @@ -1,6 +1,7 @@ import React from "react"; import { MessageSquare } from "lucide-react"; import { ChatSessionSummary } from "./interfaces"; +import { truncateString } from "@/lib/utils"; interface ChatSearchItemProps { chat: ChatSessionSummary; @@ -11,12 +12,12 @@ export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) { return (
  • onSelect(chat.id)}> -
    -
    +
    +
    -
    -
    - {chat.name || "Untitled Chat"} +
    +
    + {truncateString(chat.name || "Untitled Chat", 90)}