Improved chat search ()

* functional + fast

* k

* adapt

* k

* nit

* k

* k

* fix typing

* k
This commit is contained in:
pablonyx 2025-02-26 18:27:45 -08:00 committed by GitHub
parent a3e3d83b7e
commit abb74f2eaa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 151 additions and 106 deletions

@ -0,0 +1,84 @@
"""improved index
Revision ID: 3bd4c84fe72f
Revises: 8f43500ee275
Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None
# NOTE:
# This migration addresses issues with the previous migration (8f43500ee275) which caused
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)
def downgrade() -> None:
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

@ -3,14 +3,13 @@ from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import column
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import literal
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import union_all
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@ -26,127 +25,87 @@ def search_chat_sessions(
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
Fast full-text search on ChatSession + ChatMessage using tsvectors.
Returns a tuple of (chat_sessions, has_more)
If no query is provided, returns the most recent chat sessions.
Otherwise, searches both chat messages and session descriptions.
Returns a tuple of (sessions, has_more) where has_more indicates if
there are additional results beyond the requested page.
"""
offset = (page - 1) * page_size
offset_val = (page - 1) * page_size
# If no search query, we use standard SQLAlchemy pagination
# If no query, just return the most recent sessions
if not query or not query.strip():
stmt = select(ChatSession)
if user_id:
stmt = (
select(ChatSession)
.order_by(desc(ChatSession.time_created))
.offset(offset_val)
.limit(page_size + 1)
)
if user_id is not None:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
# Apply pagination
stmt = stmt.offset(offset).limit(page_size + 1)
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
sessions = result.scalars().all()
has_more = len(chat_sessions) > page_size
has_more = len(sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
sessions = sessions[:page_size]
return list(chat_sessions), has_more
return list(sessions), has_more
words = query.lower().strip().split()
# Otherwise, proceed with full-text search
query = query.strip()
# Message mach subquery
message_matches = []
for word in words:
word_like = f"%{word}%"
message_match: Select = (
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
.where(func.lower(ChatMessage.message).like(word_like))
)
if user_id:
message_match = message_match.where(ChatSession.user_id == user_id)
message_matches.append(message_match)
if message_matches:
message_matches_query = union_all(*message_matches).alias("message_matches")
else:
return [], False
# Description matches
description_match: Select = select(
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
if user_id:
description_match = description_match.where(ChatSession.user_id == user_id)
base_conditions = []
if user_id is not None:
base_conditions.append(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
description_match = description_match.where(ChatSession.deleted.is_(False))
base_conditions.append(ChatSession.deleted.is_(False))
# Combine all match sources
combined_matches = union_all(
message_matches_query.select(), description_match
).alias("combined_matches")
message_tsv: ColumnClause = column("message_tsv")
description_tsv: ColumnClause = column("description_tsv")
# Use CTE to group and get max rank
session_ranks = (
select(
combined_matches.c.chat_session_id,
func.max(combined_matches.c.search_rank).label("rank"),
)
.group_by(combined_matches.c.chat_session_id)
.alias("session_ranks")
ts_query = func.plainto_tsquery("english", query)
description_session_ids = (
select(ChatSession.id)
.where(*base_conditions)
.where(description_tsv.op("@@")(ts_query))
)
# Get ranked sessions with pagination
ranked_query = (
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
.offset(offset)
message_session_ids = (
select(ChatMessage.chat_session_id)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(*base_conditions)
.where(message_tsv.op("@@")(ts_query))
)
combined_ids = description_session_ids.union(message_session_ids).alias(
"combined_ids"
)
final_stmt = (
select(ChatSession)
.join(combined_ids, ChatSession.id == combined_ids.c.id)
.order_by(desc(ChatSession.time_created))
.distinct()
.offset(offset_val)
.limit(page_size + 1)
.options(joinedload(ChatSession.persona))
)
result = ranked_query.all()
session_objs = db_session.execute(final_stmt).scalars().all()
# Extract session IDs and ranks
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
session_ids = list(session_ids_with_ranks.keys())
if not session_ids:
return [], False
# Now, let's query the actual ChatSession objects using the IDs
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
# Full objects with eager loading
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
# Sort based on above ranking
chat_sessions = sorted(
chat_sessions,
key=lambda session: (
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
session.time_created.timestamp() * -1, # Then by time (newest first)
),
)
has_more = len(chat_sessions) > page_size
has_more = len(session_objs) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
session_objs = session_objs[:page_size]
return chat_sessions, has_more
return list(session_objs), has_more

@ -25,6 +25,7 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Text

@ -15,8 +15,8 @@ export function ChatSearchGroup({
}: ChatSearchGroupProps) {
return (
<div className="mb-4">
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-gray-800/90 py-2 px-4 px-4">
<div className="text-xs font-medium leading-4 text-gray-600 dark:text-gray-400">
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-neutral-800/90 py-2 px-4 px-4">
<div className="text-xs font-medium leading-4 text-neutral-600 dark:text-neutral-400">
{title}
</div>
</div>

@ -1,6 +1,7 @@
import React from "react";
import { MessageSquare } from "lucide-react";
import { ChatSessionSummary } from "./interfaces";
import { truncateString } from "@/lib/utils";
interface ChatSearchItemProps {
chat: ChatSessionSummary;
@ -11,12 +12,12 @@ export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) {
return (
<li>
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-800">
<div className="flex items-center">
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-700">
<div className="flex max-w-full mx-2 items-center">
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
<div className="text-sm dark:text-neutral-200">
{chat.name || "Untitled Chat"}
<div className="relative max-w-full grow overflow-hidden whitespace-nowrap pl-4">
<div className="text-sm max-w-full dark:text-neutral-200">
{truncateString(chat.name || "Untitled Chat", 90)}
</div>
</div>
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">