mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-27 08:21:00 +02:00
Improved chat search (#4137)
* functional + fast * k * adapt * k * nit * k * k * fix typing * k
This commit is contained in:
parent
a3e3d83b7e
commit
abb74f2eaa
84
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
84
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
"""improved index
|
||||||
|
|
||||||
|
Revision ID: 3bd4c84fe72f
|
||||||
|
Revises: 8f43500ee275
|
||||||
|
Create Date: 2025-02-26 13:07:56.217791
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "3bd4c84fe72f"
|
||||||
|
down_revision = "8f43500ee275"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE:
|
||||||
|
# This migration addresses issues with the previous migration (8f43500ee275) which caused
|
||||||
|
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||||
|
#
|
||||||
|
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||||
|
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||||
|
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||||
|
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||||
|
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||||
|
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Create a GIN index for full-text search on chat_message.message
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE chat_message
|
||||||
|
ADD COLUMN message_tsv tsvector
|
||||||
|
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Commit the current transaction before creating concurrent indexes
|
||||||
|
op.execute("COMMIT")
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||||
|
ON chat_message
|
||||||
|
USING GIN (message_tsv)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Also add a stored tsvector column for chat_session.description
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
ALTER TABLE chat_session
|
||||||
|
ADD COLUMN description_tsv tsvector
|
||||||
|
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Commit again before creating the second concurrent index
|
||||||
|
op.execute("COMMIT")
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||||
|
ON chat_session
|
||||||
|
USING GIN (description_tsv)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||||
|
op.execute("COMMIT")
|
||||||
|
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||||
|
|
||||||
|
op.execute("COMMIT")
|
||||||
|
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||||
|
|
||||||
|
# Then drop the columns
|
||||||
|
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||||
|
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||||
|
|
||||||
|
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
@ -3,14 +3,13 @@ from typing import Optional
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
|
from sqlalchemy import column
|
||||||
from sqlalchemy import desc
|
from sqlalchemy import desc
|
||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy import literal
|
|
||||||
from sqlalchemy import Select
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy import union_all
|
|
||||||
from sqlalchemy.orm import joinedload
|
from sqlalchemy.orm import joinedload
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from sqlalchemy.sql.expression import ColumnClause
|
||||||
|
|
||||||
from onyx.db.models import ChatMessage
|
from onyx.db.models import ChatMessage
|
||||||
from onyx.db.models import ChatSession
|
from onyx.db.models import ChatSession
|
||||||
@ -26,127 +25,87 @@ def search_chat_sessions(
|
|||||||
include_onyxbot_flows: bool = False,
|
include_onyxbot_flows: bool = False,
|
||||||
) -> Tuple[List[ChatSession], bool]:
|
) -> Tuple[List[ChatSession], bool]:
|
||||||
"""
|
"""
|
||||||
Search for chat sessions based on the provided query.
|
Fast full-text search on ChatSession + ChatMessage using tsvectors.
|
||||||
If no query is provided, returns recent chat sessions.
|
|
||||||
|
|
||||||
Returns a tuple of (chat_sessions, has_more)
|
If no query is provided, returns the most recent chat sessions.
|
||||||
|
Otherwise, searches both chat messages and session descriptions.
|
||||||
|
|
||||||
|
Returns a tuple of (sessions, has_more) where has_more indicates if
|
||||||
|
there are additional results beyond the requested page.
|
||||||
"""
|
"""
|
||||||
offset = (page - 1) * page_size
|
offset_val = (page - 1) * page_size
|
||||||
|
|
||||||
# If no search query, we use standard SQLAlchemy pagination
|
# If no query, just return the most recent sessions
|
||||||
if not query or not query.strip():
|
if not query or not query.strip():
|
||||||
stmt = select(ChatSession)
|
stmt = (
|
||||||
if user_id:
|
select(ChatSession)
|
||||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
.order_by(desc(ChatSession.time_created))
|
||||||
if not include_onyxbot_flows:
|
.offset(offset_val)
|
||||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
|
||||||
if not include_deleted:
|
|
||||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
|
||||||
|
|
||||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
|
||||||
|
|
||||||
# Apply pagination
|
|
||||||
stmt = stmt.offset(offset).limit(page_size + 1)
|
|
||||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
|
||||||
chat_sessions = result.scalars().all()
|
|
||||||
|
|
||||||
has_more = len(chat_sessions) > page_size
|
|
||||||
if has_more:
|
|
||||||
chat_sessions = chat_sessions[:page_size]
|
|
||||||
|
|
||||||
return list(chat_sessions), has_more
|
|
||||||
|
|
||||||
words = query.lower().strip().split()
|
|
||||||
|
|
||||||
# Message mach subquery
|
|
||||||
message_matches = []
|
|
||||||
for word in words:
|
|
||||||
word_like = f"%{word}%"
|
|
||||||
message_match: Select = (
|
|
||||||
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
|
|
||||||
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
|
|
||||||
.where(func.lower(ChatMessage.message).like(word_like))
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
message_match = message_match.where(ChatSession.user_id == user_id)
|
|
||||||
|
|
||||||
message_matches.append(message_match)
|
|
||||||
|
|
||||||
if message_matches:
|
|
||||||
message_matches_query = union_all(*message_matches).alias("message_matches")
|
|
||||||
else:
|
|
||||||
return [], False
|
|
||||||
|
|
||||||
# Description matches
|
|
||||||
description_match: Select = select(
|
|
||||||
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
|
|
||||||
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
description_match = description_match.where(ChatSession.user_id == user_id)
|
|
||||||
if not include_onyxbot_flows:
|
|
||||||
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
|
|
||||||
if not include_deleted:
|
|
||||||
description_match = description_match.where(ChatSession.deleted.is_(False))
|
|
||||||
|
|
||||||
# Combine all match sources
|
|
||||||
combined_matches = union_all(
|
|
||||||
message_matches_query.select(), description_match
|
|
||||||
).alias("combined_matches")
|
|
||||||
|
|
||||||
# Use CTE to group and get max rank
|
|
||||||
session_ranks = (
|
|
||||||
select(
|
|
||||||
combined_matches.c.chat_session_id,
|
|
||||||
func.max(combined_matches.c.search_rank).label("rank"),
|
|
||||||
)
|
|
||||||
.group_by(combined_matches.c.chat_session_id)
|
|
||||||
.alias("session_ranks")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Get ranked sessions with pagination
|
|
||||||
ranked_query = (
|
|
||||||
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
|
|
||||||
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
|
|
||||||
.offset(offset)
|
|
||||||
.limit(page_size + 1)
|
.limit(page_size + 1)
|
||||||
)
|
)
|
||||||
|
if user_id is not None:
|
||||||
result = ranked_query.all()
|
|
||||||
|
|
||||||
# Extract session IDs and ranks
|
|
||||||
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
|
|
||||||
session_ids = list(session_ids_with_ranks.keys())
|
|
||||||
|
|
||||||
if not session_ids:
|
|
||||||
return [], False
|
|
||||||
|
|
||||||
# Now, let's query the actual ChatSession objects using the IDs
|
|
||||||
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
|
|
||||||
|
|
||||||
if user_id:
|
|
||||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||||
if not include_onyxbot_flows:
|
if not include_onyxbot_flows:
|
||||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||||
if not include_deleted:
|
if not include_deleted:
|
||||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||||
|
|
||||||
# Full objects with eager loading
|
|
||||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||||
chat_sessions = result.scalars().all()
|
sessions = result.scalars().all()
|
||||||
|
|
||||||
# Sort based on above ranking
|
has_more = len(sessions) > page_size
|
||||||
chat_sessions = sorted(
|
if has_more:
|
||||||
chat_sessions,
|
sessions = sessions[:page_size]
|
||||||
key=lambda session: (
|
|
||||||
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
|
return list(sessions), has_more
|
||||||
session.time_created.timestamp() * -1, # Then by time (newest first)
|
|
||||||
),
|
# Otherwise, proceed with full-text search
|
||||||
|
query = query.strip()
|
||||||
|
|
||||||
|
base_conditions = []
|
||||||
|
if user_id is not None:
|
||||||
|
base_conditions.append(ChatSession.user_id == user_id)
|
||||||
|
if not include_onyxbot_flows:
|
||||||
|
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
|
||||||
|
if not include_deleted:
|
||||||
|
base_conditions.append(ChatSession.deleted.is_(False))
|
||||||
|
|
||||||
|
message_tsv: ColumnClause = column("message_tsv")
|
||||||
|
description_tsv: ColumnClause = column("description_tsv")
|
||||||
|
|
||||||
|
ts_query = func.plainto_tsquery("english", query)
|
||||||
|
|
||||||
|
description_session_ids = (
|
||||||
|
select(ChatSession.id)
|
||||||
|
.where(*base_conditions)
|
||||||
|
.where(description_tsv.op("@@")(ts_query))
|
||||||
)
|
)
|
||||||
|
|
||||||
has_more = len(chat_sessions) > page_size
|
message_session_ids = (
|
||||||
if has_more:
|
select(ChatMessage.chat_session_id)
|
||||||
chat_sessions = chat_sessions[:page_size]
|
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||||
|
.where(*base_conditions)
|
||||||
|
.where(message_tsv.op("@@")(ts_query))
|
||||||
|
)
|
||||||
|
|
||||||
return chat_sessions, has_more
|
combined_ids = description_session_ids.union(message_session_ids).alias(
|
||||||
|
"combined_ids"
|
||||||
|
)
|
||||||
|
|
||||||
|
final_stmt = (
|
||||||
|
select(ChatSession)
|
||||||
|
.join(combined_ids, ChatSession.id == combined_ids.c.id)
|
||||||
|
.order_by(desc(ChatSession.time_created))
|
||||||
|
.distinct()
|
||||||
|
.offset(offset_val)
|
||||||
|
.limit(page_size + 1)
|
||||||
|
.options(joinedload(ChatSession.persona))
|
||||||
|
)
|
||||||
|
|
||||||
|
session_objs = db_session.execute(final_stmt).scalars().all()
|
||||||
|
|
||||||
|
has_more = len(session_objs) > page_size
|
||||||
|
if has_more:
|
||||||
|
session_objs = session_objs[:page_size]
|
||||||
|
|
||||||
|
return list(session_objs), has_more
|
||||||
|
@ -25,6 +25,7 @@ from sqlalchemy import ForeignKey
|
|||||||
from sqlalchemy import func
|
from sqlalchemy import func
|
||||||
from sqlalchemy import Index
|
from sqlalchemy import Index
|
||||||
from sqlalchemy import Integer
|
from sqlalchemy import Integer
|
||||||
|
|
||||||
from sqlalchemy import Sequence
|
from sqlalchemy import Sequence
|
||||||
from sqlalchemy import String
|
from sqlalchemy import String
|
||||||
from sqlalchemy import Text
|
from sqlalchemy import Text
|
||||||
|
@ -15,8 +15,8 @@ export function ChatSearchGroup({
|
|||||||
}: ChatSearchGroupProps) {
|
}: ChatSearchGroupProps) {
|
||||||
return (
|
return (
|
||||||
<div className="mb-4">
|
<div className="mb-4">
|
||||||
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-gray-800/90 py-2 px-4 px-4">
|
<div className="sticky -top-1 mt-1 z-10 bg-[#fff]/90 dark:bg-neutral-800/90 py-2 px-4 px-4">
|
||||||
<div className="text-xs font-medium leading-4 text-gray-600 dark:text-gray-400">
|
<div className="text-xs font-medium leading-4 text-neutral-600 dark:text-neutral-400">
|
||||||
{title}
|
{title}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
import React from "react";
|
import React from "react";
|
||||||
import { MessageSquare } from "lucide-react";
|
import { MessageSquare } from "lucide-react";
|
||||||
import { ChatSessionSummary } from "./interfaces";
|
import { ChatSessionSummary } from "./interfaces";
|
||||||
|
import { truncateString } from "@/lib/utils";
|
||||||
|
|
||||||
interface ChatSearchItemProps {
|
interface ChatSearchItemProps {
|
||||||
chat: ChatSessionSummary;
|
chat: ChatSessionSummary;
|
||||||
@ -11,12 +12,12 @@ export function ChatSearchItem({ chat, onSelect }: ChatSearchItemProps) {
|
|||||||
return (
|
return (
|
||||||
<li>
|
<li>
|
||||||
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
|
<div className="cursor-pointer" onClick={() => onSelect(chat.id)}>
|
||||||
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-800">
|
<div className="group relative flex flex-col rounded-lg px-4 py-3 hover:bg-neutral-100 dark:hover:bg-neutral-700">
|
||||||
<div className="flex items-center">
|
<div className="flex max-w-full mx-2 items-center">
|
||||||
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
|
<MessageSquare className="h-5 w-5 text-neutral-600 dark:text-neutral-400" />
|
||||||
<div className="relative grow overflow-hidden whitespace-nowrap pl-4">
|
<div className="relative max-w-full grow overflow-hidden whitespace-nowrap pl-4">
|
||||||
<div className="text-sm dark:text-neutral-200">
|
<div className="text-sm max-w-full dark:text-neutral-200">
|
||||||
{chat.name || "Untitled Chat"}
|
{truncateString(chat.name || "Untitled Chat", 90)}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">
|
<div className="opacity-0 group-hover:opacity-100 transition-opacity text-xs text-neutral-500 dark:text-neutral-400">
|
||||||
|
Loading…
x
Reference in New Issue
Block a user