mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 12:30:49 +02:00
Customizable personas (#772)
Also includes a small fix to LLM filtering when combined with reranking
This commit is contained in:
parent
87beb1f4d1
commit
78d1ae0379
@ -0,0 +1,28 @@
|
||||
"""Add additional retrieval controls to Persona
|
||||
|
||||
Revision ID: 50b683a8295c
|
||||
Revises: 7da0ae5ad583
|
||||
Create Date: 2023-11-27 17:23:29.668422
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "50b683a8295c"
|
||||
down_revision = "7da0ae5ad583"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Integer(), nullable=True))
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("apply_llm_relevance_filter", sa.Boolean(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "apply_llm_relevance_filter")
|
||||
op.drop_column("persona", "num_chunks")
|
@ -0,0 +1,23 @@
|
||||
"""Add description to persona
|
||||
|
||||
Revision ID: 7da0ae5ad583
|
||||
Revises: e86866a9c78a
|
||||
Create Date: 2023-11-27 00:16:19.959414
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7da0ae5ad583"
|
||||
down_revision = "e86866a9c78a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("persona", sa.Column("description", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("persona", "description")
|
@ -0,0 +1,36 @@
|
||||
"""Add chat session to query_event
|
||||
|
||||
Revision ID: 80696cf850ae
|
||||
Revises: 15326fcec57e
|
||||
Create Date: 2023-11-26 02:38:35.008070
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "80696cf850ae"
|
||||
down_revision = "15326fcec57e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"query_event",
|
||||
sa.Column("chat_session_id", sa.Integer(), nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_query_event_chat_session_id",
|
||||
"query_event",
|
||||
"chat_session",
|
||||
["chat_session_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"fk_query_event_chat_session_id", "query_event", type_="foreignkey"
|
||||
)
|
||||
op.drop_column("query_event", "chat_session_id")
|
@ -0,0 +1,27 @@
|
||||
"""Add persona to chat_session
|
||||
|
||||
Revision ID: e86866a9c78a
|
||||
Revises: 80696cf850ae
|
||||
Create Date: 2023-11-26 02:51:47.657357
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "e86866a9c78a"
|
||||
down_revision = "80696cf850ae"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("chat_session", sa.Column("persona_id", sa.Integer(), nullable=True))
|
||||
op.create_foreign_key(
|
||||
"fk_chat_session_persona_id", "chat_session", "persona", ["persona_id"], ["id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
|
||||
op.drop_column("chat_session", "persona_id")
|
@ -22,12 +22,13 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import ChannelIdAdapter
|
||||
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.server.models import NewMessageRequest
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger_base = setup_logger()
|
||||
@ -171,12 +172,12 @@ def handle_message(
|
||||
backoff=2,
|
||||
logger=logger,
|
||||
)
|
||||
def _get_answer(question: QuestionRequest) -> QAResponse:
|
||||
def _get_answer(new_message_request: NewMessageRequest) -> QAResponse:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
# This also handles creating the query event in postgres
|
||||
answer = answer_qa_query(
|
||||
question=question,
|
||||
new_message_request=new_message_request,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
@ -188,6 +189,15 @@ def handle_message(
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
# create a chat session for this interaction
|
||||
# TODO: when chat support is added to Slack, this should check
|
||||
# for an existing chat session associated with this thread
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session, description="", user_id=None
|
||||
)
|
||||
chat_session_id = chat_session.id
|
||||
|
||||
answer_failed = False
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
@ -200,7 +210,8 @@ def handle_message(
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
QuestionRequest(
|
||||
NewMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
query=msg,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=not disable_auto_detect_filters,
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@ -99,10 +100,14 @@ def verify_parent_exists(
|
||||
|
||||
|
||||
def create_chat_session(
|
||||
description: str, user_id: UUID | None, db_session: Session
|
||||
db_session: Session,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None = None,
|
||||
) -> ChatSession:
|
||||
chat_session = ChatSession(
|
||||
user_id=user_id,
|
||||
persona_id=persona_id,
|
||||
description=description,
|
||||
)
|
||||
|
||||
@ -256,7 +261,11 @@ def set_latest_chat_message(
|
||||
|
||||
|
||||
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||
stmt = select(Persona).where(Persona.id == persona_id)
|
||||
stmt = (
|
||||
select(Persona)
|
||||
.where(Persona.id == persona_id)
|
||||
.where(Persona.deleted == False) # noqa: E712
|
||||
)
|
||||
result = db_session.execute(stmt)
|
||||
persona = result.scalar_one_or_none()
|
||||
|
||||
@ -269,8 +278,12 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
|
||||
def fetch_default_persona_by_name(
|
||||
persona_name: str, db_session: Session
|
||||
) -> Persona | None:
|
||||
stmt = select(Persona).where(
|
||||
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
|
||||
stmt = (
|
||||
select(Persona)
|
||||
.where(
|
||||
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
|
||||
)
|
||||
.where(Persona.deleted == False) # noqa: E712
|
||||
)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
@ -284,7 +297,11 @@ def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | N
|
||||
if persona is not None:
|
||||
return persona
|
||||
|
||||
stmt = select(Persona).where(Persona.name == persona_name) # noqa: E712
|
||||
stmt = (
|
||||
select(Persona)
|
||||
.where(Persona.name == persona_name)
|
||||
.where(Persona.deleted == False) # noqa: E712
|
||||
)
|
||||
result = db_session.execute(stmt).first()
|
||||
if result:
|
||||
return result[0]
|
||||
@ -292,31 +309,44 @@ def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | N
|
||||
|
||||
|
||||
def upsert_persona(
|
||||
db_session: Session,
|
||||
name: str,
|
||||
retrieval_enabled: bool,
|
||||
datetime_aware: bool,
|
||||
system_text: str | None,
|
||||
tools: list[ToolInfo] | None,
|
||||
hint_text: str | None,
|
||||
db_session: Session,
|
||||
description: str | None = None,
|
||||
system_text: str | None = None,
|
||||
tools: list[ToolInfo] | None = None,
|
||||
hint_text: str | None = None,
|
||||
num_chunks: int | None = None,
|
||||
apply_llm_relevance_filter: bool | None = None,
|
||||
persona_id: int | None = None,
|
||||
default_persona: bool = False,
|
||||
document_sets: list[DocumentSetDBModel] | None = None,
|
||||
commit: bool = True,
|
||||
) -> Persona:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
if persona and persona.deleted:
|
||||
raise ValueError("Trying to update a deleted persona")
|
||||
|
||||
# Default personas are defined via yaml files at deployment time
|
||||
if persona is None and default_persona:
|
||||
persona = fetch_default_persona_by_name(name, db_session)
|
||||
if persona is None:
|
||||
if default_persona:
|
||||
persona = fetch_default_persona_by_name(name, db_session)
|
||||
else:
|
||||
# only one persona with the same name should exist
|
||||
if fetch_persona_by_name(name, db_session):
|
||||
raise ValueError("Trying to create a persona with a duplicate name")
|
||||
|
||||
if persona:
|
||||
persona.name = name
|
||||
persona.description = description
|
||||
persona.retrieval_enabled = retrieval_enabled
|
||||
persona.datetime_aware = datetime_aware
|
||||
persona.system_text = system_text
|
||||
persona.tools = tools
|
||||
persona.hint_text = hint_text
|
||||
persona.num_chunks = num_chunks
|
||||
persona.apply_llm_relevance_filter = apply_llm_relevance_filter
|
||||
persona.default_persona = default_persona
|
||||
|
||||
# Do not delete any associations manually added unless
|
||||
@ -328,11 +358,14 @@ def upsert_persona(
|
||||
else:
|
||||
persona = Persona(
|
||||
name=name,
|
||||
description=description,
|
||||
retrieval_enabled=retrieval_enabled,
|
||||
datetime_aware=datetime_aware,
|
||||
system_text=system_text,
|
||||
tools=tools,
|
||||
hint_text=hint_text,
|
||||
num_chunks=num_chunks,
|
||||
apply_llm_relevance_filter=apply_llm_relevance_filter,
|
||||
default_persona=default_persona,
|
||||
document_sets=document_sets if document_sets else [],
|
||||
)
|
||||
@ -345,3 +378,18 @@ def upsert_persona(
|
||||
db_session.flush()
|
||||
|
||||
return persona
|
||||
|
||||
|
||||
def fetch_personas(
|
||||
db_session: Session, include_default: bool = False
|
||||
) -> Sequence[Persona]:
|
||||
stmt = select(Persona).where(Persona.deleted == False) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(Persona.default_persona == False) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def mark_persona_as_deleted(db_session: Session, persona_id: int) -> None:
|
||||
persona = fetch_persona_by_id(persona_id, db_session)
|
||||
persona.deleted = True
|
||||
db_session.commit()
|
||||
|
@ -100,6 +100,7 @@ def update_document_hidden(
|
||||
def create_query_event(
|
||||
db_session: Session,
|
||||
query: str,
|
||||
chat_session_id: int,
|
||||
search_type: SearchType | None,
|
||||
llm_answer: str | None,
|
||||
user_id: UUID | None,
|
||||
@ -107,6 +108,7 @@ def create_query_event(
|
||||
) -> int:
|
||||
query_event = QueryEvent(
|
||||
query=query,
|
||||
chat_session_id=chat_session_id,
|
||||
selected_search_flow=search_type,
|
||||
llm_answer=llm_answer,
|
||||
retrieved_document_ids=retrieved_document_ids,
|
||||
|
@ -4,6 +4,7 @@ from typing import Any
|
||||
from typing import List
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
from typing import Optional
|
||||
from typing import TypedDict
|
||||
from uuid import UUID
|
||||
|
||||
@ -341,6 +342,11 @@ class QueryEvent(Base):
|
||||
__tablename__ = "query_event"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
# TODO: make this non-nullable after migration to consolidate chat /
|
||||
# QA flows is complete
|
||||
chat_session_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("chat_session.id"), nullable=True
|
||||
)
|
||||
query: Mapped[str] = mapped_column(Text)
|
||||
# search_flow refers to user selection, None if user used auto
|
||||
selected_search_flow: Mapped[SearchType | None] = mapped_column(
|
||||
@ -459,6 +465,9 @@ class ChatSession(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), default=None
|
||||
)
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# The following texts help build up the model's ability to use the context effectively
|
||||
@ -475,6 +484,7 @@ class ChatSession(Base):
|
||||
messages: Mapped[List["ChatMessage"]] = relationship(
|
||||
"ChatMessage", back_populates="chat_session", cascade="delete"
|
||||
)
|
||||
persona: Mapped[Optional["Persona"]] = relationship("Persona")
|
||||
|
||||
|
||||
class ToolInfo(TypedDict):
|
||||
@ -488,6 +498,7 @@ class Persona(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# Danswer retrieval, treated as a special tool
|
||||
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
@ -496,6 +507,13 @@ class Persona(Base):
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# number of chunks to use for retrieval. If unspecified, uses the default set
|
||||
# in the env variables
|
||||
num_chunks: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
# if unspecified, then uses the default set in the env variables
|
||||
apply_llm_relevance_filter: Mapped[bool | None] = mapped_column(
|
||||
Boolean, nullable=True
|
||||
)
|
||||
# Default personas are configured via backend during deployment
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
@ -5,36 +5,40 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import CHUNK_SIZE
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import QUERY_EVENT_ID
|
||||
from danswer.db.chat import fetch_chat_session_by_id
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.feedback import update_query_event_llm_answer
|
||||
from danswer.db.feedback import update_query_event_retrieved_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.factory import get_default_qa_model
|
||||
from danswer.direct_qa.factory import get_qa_model_for_persona
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_chunks_for_qa
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import danswer_search
|
||||
from danswer.search.search_runner import danswer_search_generator
|
||||
from danswer.search.search_runner import full_chunk_search
|
||||
from danswer.search.search_runner import full_chunk_search_generator
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||
from danswer.server.models import LLMRelevanceFilterResponse
|
||||
from danswer.server.models import NewMessageRequest
|
||||
from danswer.server.models import QADocsResponse
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
@ -43,7 +47,7 @@ logger = setup_logger()
|
||||
|
||||
@log_function_time()
|
||||
def answer_qa_query(
|
||||
question: QuestionRequest,
|
||||
new_message_request: NewMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
@ -55,43 +59,36 @@ def answer_qa_query(
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> QAResponse:
|
||||
query = question.query
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
query = new_message_request.query
|
||||
offset_count = (
|
||||
new_message_request.offset if new_message_request.offset is not None else 0
|
||||
)
|
||||
logger.info(f"Received QA query: {query}")
|
||||
|
||||
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
|
||||
run_source_filters = FunctionCall(
|
||||
extract_question_source_filters, (question, db_session), {}
|
||||
)
|
||||
run_query_intent = FunctionCall(query_intent, (query,), {})
|
||||
|
||||
parallel_results = run_functions_in_parallel(
|
||||
[
|
||||
run_time_filters,
|
||||
run_source_filters,
|
||||
run_query_intent,
|
||||
]
|
||||
# create record for this query in Postgres
|
||||
query_event_id = create_query_event(
|
||||
query=new_message_request.query,
|
||||
chat_session_id=new_message_request.chat_session_id,
|
||||
search_type=new_message_request.search_type,
|
||||
llm_answer=None,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
|
||||
source_filters = parallel_results[run_source_filters.result_id]
|
||||
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
|
||||
retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing(
|
||||
new_message_request=new_message_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
# Set flow as search so frontend doesn't ask the user if they want to run QA over more docs
|
||||
if disable_generative_answer:
|
||||
predicted_flow = QueryFlow.SEARCH
|
||||
|
||||
# Modifies the question object but nothing upstream uses it
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
top_chunks, llm_chunk_selection = full_chunk_search(
|
||||
query=retrieval_request,
|
||||
document_index=get_default_document_index(),
|
||||
bypass_acl=bypass_acl,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
@ -102,11 +99,11 @@ def answer_qa_query(
|
||||
QAResponse,
|
||||
top_documents=chunks_to_search_docs(top_chunks),
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
predicted_search=predicted_search_type,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
source_type=retrieval_request.filters.source_type,
|
||||
time_cutoff=retrieval_request.filters.time_cutoff,
|
||||
favor_recent=retrieval_request.favor_recent,
|
||||
)
|
||||
|
||||
if disable_generative_answer or not top_docs:
|
||||
@ -115,9 +112,20 @@ def answer_qa_query(
|
||||
quotes=None,
|
||||
)
|
||||
|
||||
# update record for this query to include top docs
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_chunks]
|
||||
if top_chunks
|
||||
else [],
|
||||
query_id=query_event_id,
|
||||
user_id=None if user is None else user.id,
|
||||
)
|
||||
|
||||
try:
|
||||
qa_model = get_default_qa_model(
|
||||
timeout=answer_generation_timeout, real_time_flow=question.real_time
|
||||
timeout=answer_generation_timeout,
|
||||
real_time_flow=new_message_request.real_time,
|
||||
)
|
||||
except Exception as e:
|
||||
return partial_response(
|
||||
@ -131,9 +139,7 @@ def answer_qa_query(
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
batch_offset=offset_count,
|
||||
)
|
||||
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
|
||||
logger.debug(
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
|
||||
)
|
||||
@ -158,7 +164,7 @@ def answer_qa_query(
|
||||
)
|
||||
|
||||
validity = None
|
||||
if not question.real_time and enable_reflexion and d_answer is not None:
|
||||
if not new_message_request.real_time and enable_reflexion and d_answer is not None:
|
||||
validity = False
|
||||
if d_answer.answer is not None:
|
||||
validity = get_answer_validity(query, d_answer.answer)
|
||||
@ -174,47 +180,61 @@ def answer_qa_query(
|
||||
|
||||
@log_generator_function_time()
|
||||
def answer_qa_query_stream(
|
||||
question: QuestionRequest,
|
||||
new_message_request: NewMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
) -> Iterator[str]:
|
||||
logger.debug(
|
||||
f"Received QA query ({question.search_type.value} search): {question.query}"
|
||||
f"Received QA query ({new_message_request.search_type.value} search): {new_message_request.query}"
|
||||
)
|
||||
logger.debug(f"Query filters: {question.filters}")
|
||||
logger.debug(f"Query filters: {new_message_request.filters}")
|
||||
|
||||
answer_so_far: str = ""
|
||||
query = question.query
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
|
||||
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
|
||||
run_source_filters = FunctionCall(
|
||||
extract_question_source_filters, (question, db_session), {}
|
||||
)
|
||||
run_query_intent = FunctionCall(query_intent, (query,), {})
|
||||
|
||||
parallel_results = run_functions_in_parallel(
|
||||
[
|
||||
run_time_filters,
|
||||
run_source_filters,
|
||||
run_query_intent,
|
||||
]
|
||||
query = new_message_request.query
|
||||
offset_count = (
|
||||
new_message_request.offset if new_message_request.offset is not None else 0
|
||||
)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
|
||||
source_filters = parallel_results[run_source_filters.result_id]
|
||||
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
|
||||
# create record for this query in Postgres
|
||||
query_event_id = create_query_event(
|
||||
query=new_message_request.query,
|
||||
chat_session_id=new_message_request.chat_session_id,
|
||||
search_type=new_message_request.search_type,
|
||||
llm_answer=None,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
chat_session = fetch_chat_session_by_id(
|
||||
chat_session_id=new_message_request.chat_session_id, db_session=db_session
|
||||
)
|
||||
persona = chat_session.persona
|
||||
persona_skip_llm_chunk_filter = (
|
||||
not persona.apply_llm_relevance_filter if persona else None
|
||||
)
|
||||
persona_num_chunks = persona.num_chunks if persona else None
|
||||
if persona:
|
||||
logger.info(f"Using persona: {persona.name}")
|
||||
logger.info(
|
||||
"Persona retrieval settings: skip_llm_chunk_filter: "
|
||||
f"{persona_skip_llm_chunk_filter}, "
|
||||
f"num_chunks: {persona_num_chunks}"
|
||||
)
|
||||
|
||||
# Modifies the question object but nothing upstream uses it
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
search_generator = danswer_search_generator(
|
||||
question=question,
|
||||
retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing(
|
||||
new_message_request=new_message_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
skip_llm_chunk_filter=persona_skip_llm_chunk_filter
|
||||
if persona_skip_llm_chunk_filter is not None
|
||||
else DISABLE_LLM_CHUNK_FILTER,
|
||||
)
|
||||
# if a persona is specified, always respond with an answer
|
||||
if persona:
|
||||
predicted_flow = QueryFlow.QUESTION_ANSWER
|
||||
|
||||
search_generator = full_chunk_search_generator(
|
||||
query=retrieval_request,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
|
||||
@ -228,10 +248,10 @@ def answer_qa_query_stream(
|
||||
# doesn't ask the user if they want to run QA over more documents
|
||||
predicted_flow=QueryFlow.SEARCH
|
||||
if disable_generative_answer
|
||||
else predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
else cast(QueryFlow, predicted_flow),
|
||||
predicted_search=cast(SearchType, predicted_search_type),
|
||||
time_cutoff=retrieval_request.filters.time_cutoff,
|
||||
favor_recent=retrieval_request.favor_recent,
|
||||
).dict()
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
@ -239,31 +259,44 @@ def answer_qa_query_stream(
|
||||
logger.debug("No Documents Found")
|
||||
return
|
||||
|
||||
# next apply the LLM filtering
|
||||
# update record for this query to include top docs
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_chunks]
|
||||
if top_chunks
|
||||
else [],
|
||||
query_id=query_event_id,
|
||||
user_id=None if user is None else user.id,
|
||||
)
|
||||
|
||||
# next get the final chunks to be fed into the LLM
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
token_limit=persona_num_chunks * CHUNK_SIZE
|
||||
if persona_num_chunks
|
||||
else NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
batch_offset=offset_count,
|
||||
)
|
||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||
relevant_chunk_indices=[
|
||||
index for index, value in enumerate(llm_chunk_selection) if value
|
||||
]
|
||||
if not retrieval_request.skip_llm_chunk_filter
|
||||
else []
|
||||
).dict()
|
||||
yield get_json_line(llm_relevance_filtering_response)
|
||||
|
||||
# finally get the query ID from the search generator for updating the
|
||||
# row in Postgres. This is the end of the `search_generator` - any future
|
||||
# calls to `next` will raise StopIteration
|
||||
query_event_id = cast(int, next(search_generator))
|
||||
|
||||
if disable_generative_answer:
|
||||
logger.debug("Skipping QA because generative AI is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
qa_model = get_default_qa_model()
|
||||
if not persona:
|
||||
qa_model = get_default_qa_model()
|
||||
else:
|
||||
qa_model = get_qa_model_for_persona(persona=persona)
|
||||
except Exception as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
error = StreamingError(error=str(e))
|
||||
|
@ -1,6 +1,8 @@
|
||||
from danswer.configs.app_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.db.models import Persona
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_block import PersonaBasedQAHandler
|
||||
from danswer.direct_qa.qa_block import QABlock
|
||||
from danswer.direct_qa.qa_block import QAHandler
|
||||
from danswer.direct_qa.qa_block import SingleMessageQAHandler
|
||||
@ -44,3 +46,16 @@ def get_default_qa_model(
|
||||
llm=llm,
|
||||
qa_handler=qa_handler,
|
||||
)
|
||||
|
||||
|
||||
def get_qa_model_for_persona(
|
||||
persona: Persona,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
) -> QAModel:
|
||||
return QABlock(
|
||||
llm=get_default_llm(api_key=api_key, timeout=timeout),
|
||||
qa_handler=PersonaBasedQAHandler(
|
||||
system_prompt=persona.system_text or "", task_prompt=persona.hint_text or ""
|
||||
),
|
||||
)
|
||||
|
@ -10,6 +10,7 @@ from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
@ -24,6 +25,7 @@ from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.prompts.direct_qa_prompts import COT_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
@ -190,6 +192,56 @@ class SingleMessageScratchpadHandler(QAHandler):
|
||||
)
|
||||
|
||||
|
||||
class PersonaBasedQAHandler(QAHandler):
|
||||
def __init__(self, system_prompt: str, task_prompt: str) -> None:
|
||||
self.system_prompt = system_prompt
|
||||
self.task_prompt = task_prompt
|
||||
|
||||
@property
|
||||
def is_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
query: str,
|
||||
context_chunks: list[InferenceChunk],
|
||||
) -> list[BaseMessage]:
|
||||
context_docs_str = build_context_str(context_chunks)
|
||||
|
||||
single_message = PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str=context_docs_str,
|
||||
user_query=query,
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
|
||||
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
|
||||
return prompt
|
||||
|
||||
def build_dummy_prompt(
|
||||
self,
|
||||
) -> str:
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=self.system_prompt,
|
||||
task_prompt=self.task_prompt,
|
||||
).strip()
|
||||
|
||||
def process_llm_output(
|
||||
self, model_output: str, context_chunks: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||
return DanswerAnswer(answer=model_output), DanswerQuotes(quotes=[])
|
||||
|
||||
def process_llm_token_stream(
|
||||
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
for token in tokens:
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
|
||||
yield DanswerQuotes(quotes=[])
|
||||
|
||||
|
||||
class QABlock(QAModel):
|
||||
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
|
||||
self._llm = llm
|
||||
|
@ -45,13 +45,14 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.server.cc_pair.api import router as cc_pair_router
|
||||
from danswer.server.chat_backend import router as chat_router
|
||||
from danswer.server.chat.api import router as chat_router
|
||||
from danswer.server.connector import router as connector_router
|
||||
from danswer.server.credential import router as credential_router
|
||||
from danswer.server.danswer_api import get_danswer_api_key
|
||||
from danswer.server.danswer_api import router as danswer_api_router
|
||||
from danswer.server.document_set import router as document_set_router
|
||||
from danswer.server.manage import router as admin_router
|
||||
from danswer.server.persona.api import router as persona_router
|
||||
from danswer.server.search_backend import router as backend_router
|
||||
from danswer.server.slack_bot_management import router as slack_bot_management_router
|
||||
from danswer.server.state import router as state_router
|
||||
@ -97,6 +98,7 @@ def get_application() -> FastAPI:
|
||||
application.include_router(cc_pair_router)
|
||||
application.include_router(document_set_router)
|
||||
application.include_router(slack_bot_management_router)
|
||||
application.include_router(persona_router)
|
||||
application.include_router(state_router)
|
||||
application.include_router(danswer_api_router)
|
||||
|
||||
|
@ -118,6 +118,23 @@ Answer the user query based on the following document:
|
||||
""".strip()
|
||||
|
||||
|
||||
# Paramaterized prompt which allows the user to specify their
|
||||
# own system / task prompt
|
||||
PARAMATERIZED_PROMPT = f"""
|
||||
{{system_prompt}}
|
||||
|
||||
CONTEXT:
|
||||
{GENERAL_SEP_PAT}
|
||||
{{context_docs_str}}
|
||||
{GENERAL_SEP_PAT}
|
||||
|
||||
{{task_prompt}}
|
||||
|
||||
{QUESTION_PAT.upper()} {{user_query}}
|
||||
RESPONSE:
|
||||
""".strip()
|
||||
|
||||
|
||||
# User the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(JSON_PROMPT) # Default prompt used in the Danswer UI flow
|
||||
|
@ -62,6 +62,9 @@ class SearchQuery(BaseModel):
|
||||
# Only used if not skip_llm_chunk_filter
|
||||
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
|
||||
|
||||
class Config:
|
||||
frozen = True
|
||||
|
||||
|
||||
class RetrievalMetricsContainer(BaseModel):
|
||||
search_type: SearchType
|
||||
|
121
backend/danswer/search/request_preprocessing.py
Normal file
121
backend/danswer/search/request_preprocessing.py
Normal file
@ -0,0 +1,121 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.db.models import User
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from danswer.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from danswer.server.models import NewMessageRequest
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
|
||||
def retrieval_preprocessing(
|
||||
new_message_request: NewMessageRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False,
|
||||
include_query_intent: bool = True,
|
||||
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
skip_rerank_non_realtime: bool = SKIP_RERANKING,
|
||||
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
) -> tuple[SearchQuery, SearchType | None, QueryFlow | None]:
|
||||
auto_filters_enabled = (
|
||||
not disable_llm_filter_extraction
|
||||
and new_message_request.enable_auto_detect_filters
|
||||
)
|
||||
|
||||
# based on the query figure out if we should apply any hard time filters /
|
||||
# if we should bias more recent docs even more strongly
|
||||
run_time_filters = (
|
||||
FunctionCall(extract_time_filter, (new_message_request.query,), {})
|
||||
if auto_filters_enabled
|
||||
else None
|
||||
)
|
||||
|
||||
# based on the query, figure out if we should apply any source filters
|
||||
should_run_source_filters = (
|
||||
auto_filters_enabled and not new_message_request.filters.source_type
|
||||
)
|
||||
run_source_filters = (
|
||||
FunctionCall(extract_source_filter, (new_message_request.query, db_session), {})
|
||||
if should_run_source_filters
|
||||
else None
|
||||
)
|
||||
# NOTE: this isn't really part of building the retrieval request, but is done here
|
||||
# so it can be simply done in parallel with the filters without multi-level multithreading
|
||||
run_query_intent = (
|
||||
FunctionCall(query_intent, (new_message_request.query,), {})
|
||||
if include_query_intent
|
||||
else None
|
||||
)
|
||||
|
||||
functions_to_run = [
|
||||
filter_fn
|
||||
for filter_fn in [
|
||||
run_time_filters,
|
||||
run_source_filters,
|
||||
run_query_intent,
|
||||
]
|
||||
if filter_fn
|
||||
]
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
time_cutoff, favor_recent = (
|
||||
parallel_results[run_time_filters.result_id]
|
||||
if run_time_filters
|
||||
else (None, None)
|
||||
)
|
||||
source_filters = (
|
||||
parallel_results[run_source_filters.result_id] if run_source_filters else None
|
||||
)
|
||||
predicted_search_type, predicted_flow = (
|
||||
parallel_results[run_query_intent.result_id]
|
||||
if run_query_intent
|
||||
else (None, None)
|
||||
)
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
final_filters = IndexFilters(
|
||||
source_type=new_message_request.filters.source_type or source_filters,
|
||||
document_set=new_message_request.filters.document_set,
|
||||
time_cutoff=new_message_request.filters.time_cutoff or time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
# figure out if we should skip running Tranformer-based re-ranking of the
|
||||
# top chunks
|
||||
skip_reranking = (
|
||||
skip_rerank_realtime
|
||||
if new_message_request.real_time
|
||||
else skip_rerank_non_realtime
|
||||
)
|
||||
|
||||
return (
|
||||
SearchQuery(
|
||||
query=new_message_request.query,
|
||||
search_type=new_message_request.search_type,
|
||||
filters=final_filters,
|
||||
# use user specified favor_recent over generated favor_recent
|
||||
favor_recent=(
|
||||
new_message_request.favor_recent
|
||||
if new_message_request.favor_recent is not None
|
||||
else (favor_recent or False)
|
||||
),
|
||||
skip_rerank=skip_reranking,
|
||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||
),
|
||||
predicted_search_type,
|
||||
predicted_flow,
|
||||
)
|
@ -7,30 +7,21 @@ import numpy
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.app_configs import HYBRID_ALPHA
|
||||
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.feedback import update_query_event_retrieved_documents
|
||||
from danswer.db.models import User
|
||||
from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
@ -40,7 +31,6 @@ from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import rephrase_query
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
@ -547,112 +537,9 @@ def full_chunk_search_generator(
|
||||
else None,
|
||||
)
|
||||
if llm_chunk_selection is not None:
|
||||
yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks]
|
||||
yield [
|
||||
chunk.unique_id in llm_chunk_selection
|
||||
for chunk in reranked_chunks or retrieved_chunks
|
||||
]
|
||||
else:
|
||||
yield [True for _ in reranked_chunks or retrieved_chunks]
|
||||
|
||||
|
||||
def danswer_search_generator(
|
||||
question: QuestionRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
skip_rerank_non_realtime: bool = SKIP_RERANKING,
|
||||
bypass_acl: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Iterator[list[InferenceChunk] | list[bool] | int]:
|
||||
"""The main entry point for search. This fetches the relevant documents from Vespa
|
||||
based on the provided query (applying permissions / filters), does any specified
|
||||
post-processing, and returns the results. It also creates an entry in the query_event table
|
||||
for this search event."""
|
||||
query_event_id = create_query_event(
|
||||
query=question.query,
|
||||
search_type=question.search_type,
|
||||
llm_answer=None,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
final_filters = IndexFilters(
|
||||
source_type=question.filters.source_type,
|
||||
document_set=question.filters.document_set,
|
||||
time_cutoff=question.filters.time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
skip_reranking = (
|
||||
skip_rerank_realtime if question.real_time else skip_rerank_non_realtime
|
||||
)
|
||||
|
||||
search_query = SearchQuery(
|
||||
query=question.query,
|
||||
search_type=question.search_type,
|
||||
filters=final_filters,
|
||||
# Still applies time decay but not magnified
|
||||
favor_recent=question.favor_recent
|
||||
if question.favor_recent is not None
|
||||
else False,
|
||||
skip_rerank=skip_reranking,
|
||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||
)
|
||||
|
||||
search_generator = full_chunk_search_generator(
|
||||
query=search_query,
|
||||
document_index=document_index,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
||||
yield top_chunks
|
||||
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
yield llm_chunk_selection
|
||||
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_chunks]
|
||||
if top_chunks
|
||||
else [],
|
||||
query_id=query_event_id,
|
||||
user_id=None if user is None else user.id,
|
||||
)
|
||||
yield query_event_id
|
||||
|
||||
|
||||
def danswer_search(
|
||||
question: QuestionRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
bypass_acl: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk], list[bool], int]:
|
||||
"""Returns a tuple of the top chunks, the LLM relevance filter results, and the query event ID.
|
||||
|
||||
Presents a simpler interface than the underlying `danswer_search_generator`, as callers no
|
||||
longer need to worry about the order / have nicer typing. This should be used for flows which
|
||||
do not require streaming."""
|
||||
search_generator = danswer_search_generator(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||
bypass_acl=bypass_acl,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
top_chunks = cast(list[InferenceChunk], next(search_generator))
|
||||
llm_chunk_selection = cast(list[bool], next(search_generator))
|
||||
query_event_id = cast(int, next(search_generator))
|
||||
return top_chunks, llm_chunk_selection, query_event_id
|
||||
|
@ -3,7 +3,6 @@ import random
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import fetch_unique_document_sources
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
@ -13,7 +12,6 @@ from danswer.prompts.constants import SOURCES_KEY
|
||||
from danswer.prompts.secondary_llm_flows import FILE_SOURCE_WARNING
|
||||
from danswer.prompts.secondary_llm_flows import SOURCE_FILTER_PROMPT
|
||||
from danswer.prompts.secondary_llm_flows import WEB_SOURCE_WARNING
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import extract_embedded_json
|
||||
from danswer.utils.timing import log_function_time
|
||||
@ -161,21 +159,6 @@ def extract_source_filter(
|
||||
return _extract_source_filters_from_llm_out(model_output)
|
||||
|
||||
|
||||
def extract_question_source_filters(
|
||||
question: QuestionRequest,
|
||||
db_session: Session,
|
||||
disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
) -> list[DocumentSource] | None:
|
||||
# If specified in the question, don't update
|
||||
if question.filters.source_type:
|
||||
return question.filters.source_type
|
||||
|
||||
if not question.enable_auto_detect_filters or disable_llm_extraction:
|
||||
return None
|
||||
|
||||
return extract_source_filter(question.query, db_session)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Just for testing purposes
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
|
@ -5,12 +5,10 @@ from datetime import timezone
|
||||
|
||||
from dateutil.parser import parse
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.prompt_utils import get_current_llm_day_time
|
||||
from danswer.prompts.secondary_llm_flows import TIME_FILTER_PROMPT
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
@ -157,32 +155,6 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
return _extract_time_filter_from_llm_out(model_output)
|
||||
|
||||
|
||||
def extract_question_time_filters(
|
||||
question: QuestionRequest,
|
||||
disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
|
||||
) -> tuple[datetime | None, bool]:
|
||||
time_cutoff = question.filters.time_cutoff
|
||||
favor_recent = question.favor_recent
|
||||
# Frontend needs to be able to set this flag so that if user deletes the time filter,
|
||||
# we don't automatically reapply it. The env variable is a global disable of this feature
|
||||
# for the sake of latency
|
||||
if not question.enable_auto_detect_filters or disable_llm_extraction:
|
||||
if favor_recent is None:
|
||||
favor_recent = False
|
||||
return time_cutoff, favor_recent
|
||||
|
||||
llm_cutoff, llm_favor_recent = extract_time_filter(question.query)
|
||||
|
||||
# For all extractable filters, don't overwrite the provided values if any is provided
|
||||
if time_cutoff is None:
|
||||
time_cutoff = llm_cutoff
|
||||
|
||||
if favor_recent is None:
|
||||
favor_recent = llm_favor_recent
|
||||
|
||||
return time_cutoff, favor_recent
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Just for testing purposes, too tedious to unit test as it relies on an LLM
|
||||
while True:
|
||||
|
@ -26,6 +26,7 @@ from danswer.db.models import User
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name
|
||||
from danswer.server.chat.models import ChatSessionCreationRequest
|
||||
from danswer.server.models import ChatFeedbackRequest
|
||||
from danswer.server.models import ChatMessageDetail
|
||||
from danswer.server.models import ChatMessageIdentifier
|
||||
@ -124,15 +125,17 @@ def get_chat_session_messages(
|
||||
|
||||
@router.post("/create-chat-session")
|
||||
def create_new_chat_session(
|
||||
chat_session_creation_request: ChatSessionCreationRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatSessionID:
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
new_chat_session = create_chat_session(
|
||||
"",
|
||||
user_id,
|
||||
db_session, # Leave the naming till later to prevent delay
|
||||
db_session=db_session,
|
||||
description="", # Leave the naming till later to prevent delay
|
||||
user_id=user_id,
|
||||
persona_id=chat_session_creation_request.persona_id,
|
||||
)
|
||||
|
||||
return CreateChatSessionID(chat_session_id=new_chat_session.id)
|
5
backend/danswer/server/chat/models.py
Normal file
5
backend/danswer/server/chat/models.py
Normal file
@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ChatSessionCreationRequest(BaseModel):
|
||||
persona_id: int | None = None
|
@ -174,7 +174,9 @@ class SearchDoc(BaseModel):
|
||||
return initial_dict
|
||||
|
||||
|
||||
class QuestionRequest(BaseModel):
|
||||
# TODO: rename/consolidate once the chat / QA flows are merged
|
||||
class NewMessageRequest(BaseModel):
|
||||
chat_session_id: int
|
||||
query: str
|
||||
filters: BaseFilters
|
||||
collection: str = DOCUMENT_INDEX_NAME
|
||||
|
136
backend/danswer/server/persona/api.py
Normal file
136
backend/danswer/server/persona/api.py
Normal file
@ -0,0 +1,136 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.chat import fetch_persona_by_id
|
||||
from danswer.db.chat import fetch_personas
|
||||
from danswer.db.chat import mark_persona_as_deleted
|
||||
from danswer.db.chat import upsert_persona
|
||||
from danswer.db.document_set import get_document_sets_by_ids
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.qa_block import PersonaBasedQAHandler
|
||||
from danswer.server.persona.models import CreatePersonaRequest
|
||||
from danswer.server.persona.models import PersonaSnapshot
|
||||
from danswer.server.persona.models import PromptTemplateResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/admin/persona")
|
||||
def create_persona(
|
||||
create_persona_request: CreatePersonaRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PersonaSnapshot:
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
db_session=db_session,
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
)
|
||||
if create_persona_request.document_set_ids
|
||||
else []
|
||||
)
|
||||
try:
|
||||
persona = upsert_persona(
|
||||
db_session=db_session,
|
||||
name=create_persona_request.name,
|
||||
description=create_persona_request.description,
|
||||
retrieval_enabled=True,
|
||||
datetime_aware=True,
|
||||
system_text=create_persona_request.system_prompt,
|
||||
hint_text=create_persona_request.task_prompt,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
apply_llm_relevance_filter=create_persona_request.apply_llm_relevance_filter,
|
||||
document_sets=document_sets,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return PersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
@router.patch("/admin/persona/{persona_id}")
|
||||
def update_persona(
|
||||
persona_id: int,
|
||||
update_persona_request: CreatePersonaRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PersonaSnapshot:
|
||||
document_sets = list(
|
||||
get_document_sets_by_ids(
|
||||
db_session=db_session,
|
||||
document_set_ids=update_persona_request.document_set_ids,
|
||||
)
|
||||
if update_persona_request.document_set_ids
|
||||
else []
|
||||
)
|
||||
try:
|
||||
persona = upsert_persona(
|
||||
db_session=db_session,
|
||||
name=update_persona_request.name,
|
||||
description=update_persona_request.description,
|
||||
retrieval_enabled=True,
|
||||
datetime_aware=True,
|
||||
system_text=update_persona_request.system_prompt,
|
||||
hint_text=update_persona_request.task_prompt,
|
||||
num_chunks=update_persona_request.num_chunks,
|
||||
apply_llm_relevance_filter=update_persona_request.apply_llm_relevance_filter,
|
||||
document_sets=document_sets,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return PersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
@router.delete("/admin/persona/{persona_id}")
|
||||
def delete_persona(
|
||||
persona_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
mark_persona_as_deleted(db_session=db_session, persona_id=persona_id)
|
||||
|
||||
|
||||
@router.get("/persona")
|
||||
def list_personas(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaSnapshot]:
|
||||
return [
|
||||
PersonaSnapshot.from_model(persona)
|
||||
for persona in fetch_personas(db_session=db_session)
|
||||
]
|
||||
|
||||
|
||||
@router.get("/persona/{persona_id}")
|
||||
def get_persona(
|
||||
persona_id: int,
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PersonaSnapshot:
|
||||
return PersonaSnapshot.from_model(
|
||||
fetch_persona_by_id(db_session=db_session, persona_id=persona_id)
|
||||
)
|
||||
|
||||
|
||||
@router.get("/persona-utils/prompt-explorer")
|
||||
def build_final_template_prompt(
|
||||
system_prompt: str,
|
||||
task_prompt: str,
|
||||
_: User | None = Depends(current_user),
|
||||
) -> PromptTemplateResponse:
|
||||
return PromptTemplateResponse(
|
||||
final_prompt_template=PersonaBasedQAHandler(
|
||||
system_prompt=system_prompt, task_prompt=task_prompt
|
||||
).build_dummy_prompt()
|
||||
)
|
41
backend/danswer/server/persona/models.py
Normal file
41
backend/danswer/server/persona/models.py
Normal file
@ -0,0 +1,41 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import Persona
|
||||
from danswer.server.models import DocumentSet
|
||||
|
||||
|
||||
class CreatePersonaRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
document_set_ids: list[int]
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
num_chunks: int | None = None
|
||||
apply_llm_relevance_filter: bool | None = None
|
||||
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
document_sets: list[DocumentSet]
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
|
||||
return PersonaSnapshot(
|
||||
id=persona.id,
|
||||
name=persona.name,
|
||||
description=persona.description or "",
|
||||
system_prompt=persona.system_text or "",
|
||||
task_prompt=persona.hint_text or "",
|
||||
document_sets=[
|
||||
DocumentSet.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class PromptTemplateResponse(BaseModel):
|
||||
final_prompt_template: str
|
@ -8,6 +8,7 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.feedback import create_doc_retrieval_feedback
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.feedback import update_query_event_feedback
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
@ -17,25 +18,22 @@ from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.request_preprocessing import retrieval_preprocessing
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import danswer_search
|
||||
from danswer.search.search_runner import full_chunk_search
|
||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||
from danswer.server.models import AdminSearchRequest
|
||||
from danswer.server.models import AdminSearchResponse
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.server.models import NewMessageRequest
|
||||
from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.server.models import SearchFeedbackRequest
|
||||
from danswer.server.models import SearchResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -87,26 +85,26 @@ def admin_search(
|
||||
|
||||
@router.post("/search-intent")
|
||||
def get_search_type(
|
||||
question: QuestionRequest, _: User = Depends(current_user)
|
||||
new_message_request: NewMessageRequest, _: User = Depends(current_user)
|
||||
) -> HelperResponse:
|
||||
query = question.query
|
||||
query = new_message_request.query
|
||||
return recommend_search_flow(query)
|
||||
|
||||
|
||||
@router.post("/query-validation")
|
||||
def query_validation(
|
||||
question: QuestionRequest, _: User = Depends(current_user)
|
||||
new_message_request: NewMessageRequest, _: User = Depends(current_user)
|
||||
) -> QueryValidationResponse:
|
||||
query = question.query
|
||||
query = new_message_request.query
|
||||
reasoning, answerable = get_query_answerability(query)
|
||||
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
|
||||
|
||||
|
||||
@router.post("/stream-query-validation")
|
||||
def stream_query_validation(
|
||||
question: QuestionRequest, _: User = Depends(current_user)
|
||||
new_message_request: NewMessageRequest, _: User = Depends(current_user)
|
||||
) -> StreamingResponse:
|
||||
query = question.query
|
||||
query = new_message_request.query
|
||||
return StreamingResponse(
|
||||
stream_query_answerability(query), media_type="application/json"
|
||||
)
|
||||
@ -114,65 +112,68 @@ def stream_query_validation(
|
||||
|
||||
@router.post("/document-search")
|
||||
def handle_search_request(
|
||||
question: QuestionRequest,
|
||||
new_message_request: NewMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
query = question.query
|
||||
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
|
||||
|
||||
functions_to_run = [
|
||||
FunctionCall(extract_question_time_filters, (question,), {}),
|
||||
FunctionCall(extract_question_source_filters, (question, db_session), {}),
|
||||
]
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
||||
source_filters = parallel_results["extract_question_source_filters"]
|
||||
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
top_chunks, _, query_event_id = danswer_search(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
skip_llm_chunk_filter=True,
|
||||
query = new_message_request.query
|
||||
logger.info(
|
||||
f"Received {new_message_request.search_type.value} " f"search query: {query}"
|
||||
)
|
||||
|
||||
# create record for this query in Postgres
|
||||
query_event_id = create_query_event(
|
||||
query=new_message_request.query,
|
||||
chat_session_id=new_message_request.chat_session_id,
|
||||
search_type=new_message_request.search_type,
|
||||
llm_answer=None,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
retrieval_request, _, _ = retrieval_preprocessing(
|
||||
new_message_request=new_message_request,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
include_query_intent=False,
|
||||
)
|
||||
|
||||
top_chunks, _ = full_chunk_search(
|
||||
query=retrieval_request,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
return SearchResponse(
|
||||
top_documents=top_docs,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
source_type=retrieval_request.filters.source_type,
|
||||
time_cutoff=retrieval_request.filters.time_cutoff,
|
||||
favor_recent=retrieval_request.favor_recent,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/direct-qa")
|
||||
def direct_qa(
|
||||
question: QuestionRequest,
|
||||
new_message_request: NewMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> QAResponse:
|
||||
# Everything handled via answer_qa_query which is also used by default
|
||||
# for the DanswerBot flow
|
||||
return answer_qa_query(question=question, user=user, db_session=db_session)
|
||||
return answer_qa_query(
|
||||
new_message_request=new_message_request, user=user, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@router.post("/stream-direct-qa")
|
||||
def stream_direct_qa(
|
||||
question: QuestionRequest,
|
||||
new_message_request: NewMessageRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
packets = answer_qa_query_stream(
|
||||
question=question, user=user, db_session=db_session
|
||||
new_message_request=new_message_request, user=user, db_session=db_session
|
||||
)
|
||||
return StreamingResponse(packets, media_type="application/json")
|
||||
|
||||
|
@ -3,11 +3,15 @@ from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
@ -45,19 +49,21 @@ def run_functions_tuples_in_parallel(
|
||||
return [result for index, result in results]
|
||||
|
||||
|
||||
class FunctionCall:
|
||||
class FunctionCall(Generic[R]):
|
||||
"""
|
||||
Container for run_functions_in_parallel, fetch the results from the output of
|
||||
run_functions_in_parallel via the FunctionCall.result_id.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable, args: tuple = (), kwargs: dict | None = None):
|
||||
def __init__(
|
||||
self, func: Callable[..., R], args: tuple = (), kwargs: dict | None = None
|
||||
):
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs if kwargs is not None else {}
|
||||
self.result_id = str(uuid.uuid4())
|
||||
|
||||
def execute(self) -> Any:
|
||||
def execute(self) -> R:
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
|
@ -8,13 +8,14 @@ from typing import TextIO
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import NewMessageRequest
|
||||
from danswer.utils.callbacks import MetricsHander
|
||||
|
||||
|
||||
@ -81,7 +82,13 @@ def get_answer_for_question(
|
||||
time_cutoff=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
question = QuestionRequest(
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="Regression Test Session",
|
||||
user_id=None,
|
||||
)
|
||||
new_message_request = NewMessageRequest(
|
||||
chat_session_id=chat_session.id,
|
||||
query=query,
|
||||
filters=filters,
|
||||
real_time=False,
|
||||
@ -93,7 +100,7 @@ def get_answer_for_question(
|
||||
llm_metrics = MetricsHander[LLMMetricsContainer]()
|
||||
|
||||
answer = answer_qa_query(
|
||||
question=question,
|
||||
new_message_request=new_message_request,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=100,
|
||||
|
@ -5,8 +5,6 @@ from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import TextIO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.direct_qa.qa_utils import get_chunks_for_qa
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
@ -14,8 +12,9 @@ from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.search_runner import danswer_search
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_runner import full_chunk_search
|
||||
from danswer.utils.callbacks import MetricsHander
|
||||
|
||||
|
||||
@ -74,7 +73,7 @@ def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str
|
||||
|
||||
|
||||
def get_search_results(
|
||||
query: str, enable_llm: bool, db_session: Session
|
||||
query: str,
|
||||
) -> tuple[
|
||||
list[InferenceChunk],
|
||||
RetrievalMetricsContainer | None,
|
||||
@ -86,22 +85,19 @@ def get_search_results(
|
||||
time_cutoff=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
question = QuestionRequest(
|
||||
search_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=SearchType.HYBRID,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
favor_recent=False,
|
||||
)
|
||||
|
||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||
|
||||
top_chunks, llm_chunk_selection, query_id = danswer_search(
|
||||
question=question,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
top_chunks, llm_chunk_selection = full_chunk_search(
|
||||
query=search_query,
|
||||
document_index=get_default_document_index(),
|
||||
bypass_acl=True,
|
||||
skip_llm_chunk_filter=not enable_llm,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
)
|
||||
@ -177,58 +173,49 @@ def main(
|
||||
with open(output_file, "w") as outfile:
|
||||
with redirect_print_to_file(outfile):
|
||||
print("Running Document Retrieval Test\n")
|
||||
for ind, (question, targets) in enumerate(questions_info.items()):
|
||||
if ind >= stop_after:
|
||||
break
|
||||
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
for ind, (question, targets) in enumerate(questions_info.items()):
|
||||
if ind >= stop_after:
|
||||
break
|
||||
print(f"\n\nQuestion: {question}")
|
||||
|
||||
print(f"\n\nQuestion: {question}")
|
||||
(
|
||||
top_chunks,
|
||||
retrieval_metrics,
|
||||
rerank_metrics,
|
||||
) = get_search_results(query=question)
|
||||
|
||||
(
|
||||
top_chunks,
|
||||
retrieval_metrics,
|
||||
rerank_metrics,
|
||||
) = get_search_results(
|
||||
query=question, enable_llm=enable_llm, db_session=db_session
|
||||
)
|
||||
assert retrieval_metrics is not None and rerank_metrics is not None
|
||||
|
||||
assert retrieval_metrics is not None and rerank_metrics is not None
|
||||
retrieval_ids = [
|
||||
metric.document_id for metric in retrieval_metrics.metrics
|
||||
]
|
||||
retrieval_score = calculate_score("Retrieval", retrieval_ids, targets)
|
||||
running_retrieval_score += retrieval_score
|
||||
print(f"Average: {running_retrieval_score / (ind + 1)}")
|
||||
|
||||
retrieval_ids = [
|
||||
metric.document_id for metric in retrieval_metrics.metrics
|
||||
]
|
||||
retrieval_score = calculate_score(
|
||||
"Retrieval", retrieval_ids, targets
|
||||
)
|
||||
running_retrieval_score += retrieval_score
|
||||
print(f"Average: {running_retrieval_score / (ind + 1)}")
|
||||
rerank_ids = [metric.document_id for metric in rerank_metrics.metrics]
|
||||
rerank_score = calculate_score("Rerank", rerank_ids, targets)
|
||||
running_rerank_score += rerank_score
|
||||
print(f"Average: {running_rerank_score / (ind + 1)}")
|
||||
|
||||
rerank_ids = [
|
||||
metric.document_id for metric in rerank_metrics.metrics
|
||||
]
|
||||
rerank_score = calculate_score("Rerank", rerank_ids, targets)
|
||||
running_rerank_score += rerank_score
|
||||
print(f"Average: {running_rerank_score / (ind + 1)}")
|
||||
llm_ids = [chunk.document_id for chunk in top_chunks]
|
||||
llm_score = calculate_score("LLM Filter", llm_ids, targets)
|
||||
running_llm_filter_score += llm_score
|
||||
print(f"Average: {running_llm_filter_score / (ind + 1)}")
|
||||
|
||||
if enable_llm:
|
||||
llm_ids = [chunk.document_id for chunk in top_chunks]
|
||||
llm_score = calculate_score("LLM Filter", llm_ids, targets)
|
||||
running_llm_filter_score += llm_score
|
||||
print(f"Average: {running_llm_filter_score / (ind + 1)}")
|
||||
if show_details:
|
||||
print("\nRetrieval Metrics:")
|
||||
if retrieval_metrics is None:
|
||||
print("No Retrieval Metrics Available")
|
||||
else:
|
||||
_print_retrieval_metrics(retrieval_metrics)
|
||||
|
||||
if show_details:
|
||||
print("\nRetrieval Metrics:")
|
||||
if retrieval_metrics is None:
|
||||
print("No Retrieval Metrics Available")
|
||||
else:
|
||||
_print_retrieval_metrics(retrieval_metrics)
|
||||
|
||||
print("\nReranking Metrics:")
|
||||
if rerank_metrics is None:
|
||||
print("No Reranking Metrics Available")
|
||||
else:
|
||||
_print_reranking_metrics(rerank_metrics)
|
||||
print("\nReranking Metrics:")
|
||||
if rerank_metrics is None:
|
||||
print("No Reranking Metrics Available")
|
||||
else:
|
||||
_print_reranking_metrics(rerank_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
411
web/src/app/admin/personas/PersonaEditor.tsx
Normal file
411
web/src/app/admin/personas/PersonaEditor.tsx
Normal file
@ -0,0 +1,411 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
BooleanFormField,
|
||||
TextArrayField,
|
||||
TextFormField,
|
||||
} from "@/components/admin/connectors/Field";
|
||||
import { DocumentSet } from "@/lib/types";
|
||||
import { Button, Divider, Text, Title } from "@tremor/react";
|
||||
import {
|
||||
ArrayHelpers,
|
||||
ErrorMessage,
|
||||
Field,
|
||||
FieldArray,
|
||||
Form,
|
||||
Formik,
|
||||
} from "formik";
|
||||
|
||||
import * as Yup from "yup";
|
||||
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { Persona } from "./interfaces";
|
||||
import Link from "next/link";
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
function SectionHeader({ children }: { children: string | JSX.Element }) {
|
||||
return <div className="mb-4 font-bold text-lg">{children}</div>;
|
||||
}
|
||||
|
||||
function Label({ children }: { children: string | JSX.Element }) {
|
||||
return (
|
||||
<div className="block font-medium text-base text-gray-200">{children}</div>
|
||||
);
|
||||
}
|
||||
|
||||
function SubLabel({ children }: { children: string | JSX.Element }) {
|
||||
return <div className="text-sm text-gray-300 mb-2">{children}</div>;
|
||||
}
|
||||
|
||||
// TODO: make this the default text input across all forms
|
||||
function PersonaTextInput({
|
||||
name,
|
||||
label,
|
||||
subtext,
|
||||
placeholder,
|
||||
onChange,
|
||||
type = "text",
|
||||
isTextArea = false,
|
||||
disabled = false,
|
||||
autoCompleteDisabled = true,
|
||||
}: {
|
||||
name: string;
|
||||
label: string;
|
||||
subtext?: string | JSX.Element;
|
||||
placeholder?: string;
|
||||
onChange?: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
type?: string;
|
||||
isTextArea?: boolean;
|
||||
disabled?: boolean;
|
||||
autoCompleteDisabled?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<Label>{label}</Label>
|
||||
{subtext && <SubLabel>{subtext}</SubLabel>}
|
||||
<Field
|
||||
as={isTextArea ? "textarea" : "input"}
|
||||
type={type}
|
||||
name={name}
|
||||
id={name}
|
||||
className={
|
||||
`
|
||||
border
|
||||
text-gray-200
|
||||
border-gray-600
|
||||
rounded
|
||||
w-full
|
||||
py-2
|
||||
px-3
|
||||
mt-1
|
||||
${isTextArea ? " h-28" : ""}
|
||||
` + (disabled ? " bg-gray-900" : " bg-gray-800")
|
||||
}
|
||||
disabled={disabled}
|
||||
placeholder={placeholder}
|
||||
autoComplete={autoCompleteDisabled ? "off" : undefined}
|
||||
{...(onChange ? { onChange } : {})}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={name}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
function PersonaBooleanInput({
|
||||
name,
|
||||
label,
|
||||
subtext,
|
||||
}: {
|
||||
name: string;
|
||||
label: string;
|
||||
subtext?: string | JSX.Element;
|
||||
}) {
|
||||
return (
|
||||
<div className="mb-4">
|
||||
<Label>{label}</Label>
|
||||
{subtext && <SubLabel>{subtext}</SubLabel>}
|
||||
<Field
|
||||
type="checkbox"
|
||||
name={name}
|
||||
id={name}
|
||||
className={`
|
||||
ml-2
|
||||
border
|
||||
text-gray-200
|
||||
border-gray-600
|
||||
rounded
|
||||
py-2
|
||||
px-3
|
||||
mt-1
|
||||
`}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name={name}
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function PersonaEditor({
|
||||
existingPersona,
|
||||
documentSets,
|
||||
}: {
|
||||
existingPersona?: Persona | null;
|
||||
documentSets: DocumentSet[];
|
||||
}) {
|
||||
const router = useRouter();
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const [finalPrompt, setFinalPrompt] = useState<string | null>("");
|
||||
|
||||
const triggerFinalPromptUpdate = async (
|
||||
systemPrompt: string,
|
||||
taskPrompt: string
|
||||
) => {
|
||||
const response = await buildFinalPrompt(systemPrompt, taskPrompt);
|
||||
if (response.ok) {
|
||||
setFinalPrompt((await response.json()).final_prompt_template);
|
||||
}
|
||||
};
|
||||
|
||||
const isUpdate = existingPersona !== undefined && existingPersona !== null;
|
||||
|
||||
useEffect(() => {
|
||||
if (isUpdate) {
|
||||
triggerFinalPromptUpdate(
|
||||
existingPersona.system_prompt,
|
||||
existingPersona.task_prompt
|
||||
);
|
||||
}
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="dark">
|
||||
{popup}
|
||||
<Formik
|
||||
initialValues={{
|
||||
name: existingPersona?.name ?? "",
|
||||
description: existingPersona?.description ?? "",
|
||||
system_prompt: existingPersona?.system_prompt ?? "",
|
||||
task_prompt: existingPersona?.task_prompt ?? "",
|
||||
document_set_ids:
|
||||
existingPersona?.document_sets?.map(
|
||||
(documentSet) => documentSet.id
|
||||
) ?? ([] as number[]),
|
||||
num_chunks: existingPersona?.num_chunks ?? null,
|
||||
apply_llm_relevance_filter:
|
||||
existingPersona?.apply_llm_relevance_filter ?? false,
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
name: Yup.string().required("Must give the Persona a name!"),
|
||||
description: Yup.string().required(
|
||||
"Must give the Persona a description!"
|
||||
),
|
||||
system_prompt: Yup.string().required(
|
||||
"Must give the Persona a system prompt!"
|
||||
),
|
||||
task_prompt: Yup.string().required(
|
||||
"Must give the Persona a task prompt!"
|
||||
),
|
||||
document_set_ids: Yup.array().of(Yup.number()),
|
||||
num_chunks: Yup.number().max(20).nullable(),
|
||||
apply_llm_relevance_filter: Yup.boolean().required(),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
|
||||
let response;
|
||||
if (isUpdate) {
|
||||
response = await updatePersona({
|
||||
id: existingPersona.id,
|
||||
...values,
|
||||
num_chunks: values.num_chunks || null,
|
||||
});
|
||||
} else {
|
||||
response = await createPersona({
|
||||
...values,
|
||||
num_chunks: values.num_chunks || null,
|
||||
});
|
||||
}
|
||||
if (response.ok) {
|
||||
router.push("/admin/personas");
|
||||
return;
|
||||
}
|
||||
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to create Persona - ${await response.text()}`,
|
||||
});
|
||||
formikHelpers.setSubmitting(false);
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => (
|
||||
<Form>
|
||||
<div className="pb-6">
|
||||
<SectionHeader>Who am I?</SectionHeader>
|
||||
|
||||
<PersonaTextInput
|
||||
name="name"
|
||||
label="Name"
|
||||
disabled={isUpdate}
|
||||
subtext="Users will be able to select this Persona based on this name."
|
||||
/>
|
||||
|
||||
<PersonaTextInput
|
||||
name="description"
|
||||
label="Description"
|
||||
subtext="Provide a short descriptions which gives users a hint as to what they should use this Persona for."
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<SectionHeader>Customize my response style</SectionHeader>
|
||||
|
||||
<PersonaTextInput
|
||||
name="system_prompt"
|
||||
label="System Prompt"
|
||||
isTextArea={true}
|
||||
subtext={
|
||||
'Give general info about what the Persona is about. For example, "You are an assistant for On-Call engineers. Your goal is to read the provided context documents and give recommendations as to how to resolve the issue."'
|
||||
}
|
||||
onChange={(e) => {
|
||||
setFieldValue("system_prompt", e.target.value);
|
||||
triggerFinalPromptUpdate(e.target.value, values.task_prompt);
|
||||
}}
|
||||
/>
|
||||
|
||||
<PersonaTextInput
|
||||
name="task_prompt"
|
||||
label="Task Prompt"
|
||||
isTextArea={true}
|
||||
subtext={
|
||||
'Give specific instructions as to what to do with the user query. For example, "Find any relevant sections from the provided documents that can help the user resolve their issue and explain how they are relevant."'
|
||||
}
|
||||
onChange={(e) => {
|
||||
setFieldValue("task_prompt", e.target.value);
|
||||
triggerFinalPromptUpdate(
|
||||
values.system_prompt,
|
||||
e.target.value
|
||||
);
|
||||
}}
|
||||
/>
|
||||
|
||||
<Label>Final Prompt</Label>
|
||||
|
||||
{finalPrompt ? (
|
||||
<pre className="text-sm mt-2 whitespace-pre-wrap">
|
||||
{finalPrompt.replaceAll("\\n", "\n")}
|
||||
</pre>
|
||||
) : (
|
||||
"-"
|
||||
)}
|
||||
|
||||
<Divider />
|
||||
|
||||
<SectionHeader>What data should I have access to?</SectionHeader>
|
||||
|
||||
<FieldArray
|
||||
name="document_set_ids"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<div>
|
||||
<div>
|
||||
<SubLabel>
|
||||
<>
|
||||
Select which{" "}
|
||||
<Link
|
||||
href="/admin/documents/sets"
|
||||
className="text-blue-500"
|
||||
target="_blank"
|
||||
>
|
||||
Document Sets
|
||||
</Link>{" "}
|
||||
that this Persona should search through. If none are
|
||||
specified, the Persona will search through all
|
||||
available documents in order to try and response to
|
||||
queries.
|
||||
</>
|
||||
</SubLabel>
|
||||
</div>
|
||||
<div className="mb-3 mt-2 flex gap-2 flex-wrap text-sm">
|
||||
{documentSets.map((documentSet) => {
|
||||
const ind = values.document_set_ids.indexOf(
|
||||
documentSet.id
|
||||
);
|
||||
let isSelected = ind !== -1;
|
||||
return (
|
||||
<div
|
||||
key={documentSet.id}
|
||||
className={
|
||||
`
|
||||
px-3
|
||||
py-1
|
||||
rounded-lg
|
||||
border
|
||||
border-gray-700
|
||||
w-fit
|
||||
flex
|
||||
cursor-pointer ` +
|
||||
(isSelected
|
||||
? " bg-gray-600"
|
||||
: " bg-gray-900 hover:bg-gray-700")
|
||||
}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
arrayHelpers.remove(ind);
|
||||
} else {
|
||||
arrayHelpers.push(documentSet.id);
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="my-auto">{documentSet.name}</div>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<SectionHeader>[Advanced] Retrieval Customization</SectionHeader>
|
||||
|
||||
<PersonaTextInput
|
||||
name="num_chunks"
|
||||
label="Number of Chunks"
|
||||
subtext={
|
||||
<div>
|
||||
How many chunks should we feed into the LLM when generating
|
||||
the final response? Each chunk is ~400 words long. If you
|
||||
are using gpt-3.5-turbo or other similar models, setting
|
||||
this to a value greater than 5 will result in errors at
|
||||
query time due to the model's input length limit.
|
||||
<br />
|
||||
<br />
|
||||
If unspecified, will use 5 chunks.
|
||||
</div>
|
||||
}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value;
|
||||
// Allow only integer values
|
||||
if (value === "" || /^[0-9]+$/.test(value)) {
|
||||
setFieldValue("num_chunks", value);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
|
||||
<PersonaBooleanInput
|
||||
name="apply_llm_relevance_filter"
|
||||
label="Apply LLM Relevance Filter"
|
||||
subtext={
|
||||
"If enabled, the LLM will filter out chunks that are not relevant to the user query."
|
||||
}
|
||||
/>
|
||||
|
||||
<Divider />
|
||||
|
||||
<div className="flex">
|
||||
<Button
|
||||
className="mx-auto"
|
||||
variant="secondary"
|
||||
size="md"
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
{isUpdate ? "Update!" : "Create!"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
}
|
52
web/src/app/admin/personas/PersonaTable.tsx
Normal file
52
web/src/app/admin/personas/PersonaTable.tsx
Normal file
@ -0,0 +1,52 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Table,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableHeaderCell,
|
||||
TableBody,
|
||||
TableCell,
|
||||
} from "@tremor/react";
|
||||
import { Persona } from "./interfaces";
|
||||
import Link from "next/link";
|
||||
import { EditButton } from "@/components/EditButton";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function PersonasTable({ personas }: { personas: Persona[] }) {
|
||||
const router = useRouter();
|
||||
|
||||
const sortedPersonas = [...personas];
|
||||
sortedPersonas.sort((a, b) => a.name.localeCompare(b.name));
|
||||
|
||||
return (
|
||||
<div className="dark">
|
||||
<Table className="overflow-visible">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
<TableHeaderCell>Name</TableHeaderCell>
|
||||
<TableHeaderCell>Description</TableHeaderCell>
|
||||
<TableHeaderCell></TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{sortedPersonas.map((persona) => {
|
||||
return (
|
||||
<TableRow key={persona.id}>
|
||||
<TableCell className="whitespace-normal break-all">
|
||||
<p className="text font-medium">{persona.name}</p>
|
||||
</TableCell>
|
||||
<TableCell>{persona.description}</TableCell>
|
||||
<TableCell>
|
||||
<EditButton
|
||||
onClick={() => router.push(`/admin/personas/${persona.id}`)}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -0,0 +1,29 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@tremor/react";
|
||||
import { FiTrash } from "react-icons/fi";
|
||||
import { deletePersona } from "../lib";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
export function DeletePersonaButton({ personaId }: { personaId: number }) {
|
||||
const router = useRouter();
|
||||
|
||||
return (
|
||||
<Button
|
||||
variant="secondary"
|
||||
size="xs"
|
||||
color="red"
|
||||
onClick={async () => {
|
||||
const response = await deletePersona(personaId);
|
||||
if (response.ok) {
|
||||
router.push("/admin/personas");
|
||||
} else {
|
||||
alert(`Failed to delete persona - ${await response.text()}`);
|
||||
}
|
||||
}}
|
||||
icon={FiTrash}
|
||||
>
|
||||
Delete
|
||||
</Button>
|
||||
);
|
||||
}
|
62
web/src/app/admin/personas/[personaId]/page.tsx
Normal file
62
web/src/app/admin/personas/[personaId]/page.tsx
Normal file
@ -0,0 +1,62 @@
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import { Persona } from "../interfaces";
|
||||
import { PersonaEditor } from "../PersonaEditor";
|
||||
import { DocumentSet } from "@/lib/types";
|
||||
import { RobotIcon } from "@/components/icons/icons";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { Card, Title, Text, Divider, Button } from "@tremor/react";
|
||||
import { FiTrash } from "react-icons/fi";
|
||||
import { DeletePersonaButton } from "./DeletePersonaButton";
|
||||
|
||||
export default async function Page({
|
||||
params,
|
||||
}: {
|
||||
params: { personaId: string };
|
||||
}) {
|
||||
const personaResponse = await fetchSS(`/persona/${params.personaId}`);
|
||||
|
||||
if (!personaResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch Persona - ${await personaResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const documentSetsResponse = await fetchSS("/manage/document-set");
|
||||
|
||||
if (!documentSetsResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
|
||||
const persona = (await personaResponse.json()) as Persona;
|
||||
|
||||
return (
|
||||
<div className="dark">
|
||||
<BackButton />
|
||||
<div className="pb-2 mb-4 flex">
|
||||
<h1 className="text-3xl font-bold pl-2">Edit Persona</h1>
|
||||
</div>
|
||||
|
||||
<Card>
|
||||
<PersonaEditor existingPersona={persona} documentSets={documentSets} />
|
||||
</Card>
|
||||
|
||||
<div className="mt-12">
|
||||
<Title>Delete Persona</Title>
|
||||
<div className="flex mt-6">
|
||||
<DeletePersonaButton personaId={persona.id} />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
12
web/src/app/admin/personas/interfaces.ts
Normal file
12
web/src/app/admin/personas/interfaces.ts
Normal file
@ -0,0 +1,12 @@
|
||||
import { DocumentSet } from "@/lib/types";
|
||||
|
||||
export interface Persona {
|
||||
id: number;
|
||||
name: string;
|
||||
description: string;
|
||||
system_prompt: string;
|
||||
task_prompt: string;
|
||||
document_sets: DocumentSet[];
|
||||
num_chunks?: number;
|
||||
apply_llm_relevance_filter?: boolean;
|
||||
}
|
61
web/src/app/admin/personas/lib.ts
Normal file
61
web/src/app/admin/personas/lib.ts
Normal file
@ -0,0 +1,61 @@
|
||||
interface PersonaCreationRequest {
|
||||
name: string;
|
||||
description: string;
|
||||
system_prompt: string;
|
||||
task_prompt: string;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
apply_llm_relevance_filter: boolean | null;
|
||||
}
|
||||
|
||||
export function createPersona(personaCreationRequest: PersonaCreationRequest) {
|
||||
return fetch("/api/admin/persona", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(personaCreationRequest),
|
||||
});
|
||||
}
|
||||
|
||||
interface PersonaUpdateRequest {
|
||||
id: number;
|
||||
description: string;
|
||||
system_prompt: string;
|
||||
task_prompt: string;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
apply_llm_relevance_filter: boolean | null;
|
||||
}
|
||||
|
||||
export function updatePersona(personaUpdateRequest: PersonaUpdateRequest) {
|
||||
const { id, ...requestBody } = personaUpdateRequest;
|
||||
|
||||
return fetch(`/api/admin/persona/${id}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(requestBody),
|
||||
});
|
||||
}
|
||||
|
||||
export function deletePersona(personaId: number) {
|
||||
return fetch(`/api/admin/persona/${personaId}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
}
|
||||
|
||||
export function buildFinalPrompt(systemPrompt: string, taskPrompt: string) {
|
||||
let queryString = Object.entries({
|
||||
system_prompt: systemPrompt,
|
||||
task_prompt: taskPrompt,
|
||||
})
|
||||
.map(
|
||||
([key, value]) =>
|
||||
`${encodeURIComponent(key)}=${encodeURIComponent(value)}`
|
||||
)
|
||||
.join("&");
|
||||
|
||||
return fetch(`/api/persona-utils/prompt-explorer?${queryString}`);
|
||||
}
|
37
web/src/app/admin/personas/new/page.tsx
Normal file
37
web/src/app/admin/personas/new/page.tsx
Normal file
@ -0,0 +1,37 @@
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
import { PersonaEditor } from "../PersonaEditor";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { DocumentSet } from "@/lib/types";
|
||||
import { RobotIcon } from "@/components/icons/icons";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { Card } from "@tremor/react";
|
||||
|
||||
export default async function Page() {
|
||||
const documentSetsResponse = await fetchSS("/manage/document-set");
|
||||
|
||||
if (!documentSetsResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
|
||||
|
||||
return (
|
||||
<div className="dark">
|
||||
<BackButton />
|
||||
<div className="border-solid border-gray-600 border-b pb-2 mb-4 flex">
|
||||
<RobotIcon size={32} />
|
||||
<h1 className="text-3xl font-bold pl-2">Create a New Persona</h1>
|
||||
</div>
|
||||
|
||||
<Card>
|
||||
<PersonaEditor documentSets={documentSets} />
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
64
web/src/app/admin/personas/page.tsx
Normal file
64
web/src/app/admin/personas/page.tsx
Normal file
@ -0,0 +1,64 @@
|
||||
import { PersonasTable } from "./PersonaTable";
|
||||
import { FiPlusSquare } from "react-icons/fi";
|
||||
import Link from "next/link";
|
||||
import { Divider, Text, Title } from "@tremor/react";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { Persona } from "./interfaces";
|
||||
import { RobotIcon } from "@/components/icons/icons";
|
||||
|
||||
export default async function Page() {
|
||||
const personaResponse = await fetchSS("/persona");
|
||||
|
||||
if (!personaResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch personas - ${await personaResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const personas = (await personaResponse.json()) as Persona[];
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="border-solid border-gray-600 border-b pb-2 mb-4 flex">
|
||||
<RobotIcon size={32} />
|
||||
<h1 className="text-3xl font-bold pl-2">Personas</h1>
|
||||
</div>
|
||||
|
||||
<div className="text-gray-300 text-sm mb-2">
|
||||
Personas are a way to build custom search/question-answering experiences
|
||||
for different use cases.
|
||||
<p className="mt-2">They allow you to customize:</p>
|
||||
<ul className="list-disc mt-2 ml-4">
|
||||
<li>
|
||||
The prompt used by your LLM of choice to respond to the user query
|
||||
</li>
|
||||
<li>The documents that are used as context</li>
|
||||
</ul>
|
||||
</div>
|
||||
|
||||
<div className="dark">
|
||||
<Divider />
|
||||
|
||||
<Title>Create a Persona</Title>
|
||||
<Link
|
||||
href="/admin/personas/new"
|
||||
className="text-gray-100 flex py-2 px-4 mt-2 border border-gray-800 h-fit cursor-pointer hover:bg-gray-800 text-sm w-36"
|
||||
>
|
||||
<div className="mx-auto flex">
|
||||
<FiPlusSquare className="my-auto mr-2" />
|
||||
New Persona
|
||||
</div>
|
||||
</Link>
|
||||
|
||||
<Divider />
|
||||
|
||||
<Title>Existing Personas</Title>
|
||||
<PersonasTable personas={personas} />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
@ -8,10 +8,11 @@ import {
|
||||
import { redirect } from "next/navigation";
|
||||
import { HealthCheckBanner } from "@/components/health/healthcheck";
|
||||
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
|
||||
import { buildUrl } from "@/lib/utilsSS";
|
||||
import { buildUrl, fetchSS } from "@/lib/utilsSS";
|
||||
import { Connector, DocumentSet, User } from "@/lib/types";
|
||||
import { cookies } from "next/headers";
|
||||
import { SearchType } from "@/lib/search/interfaces";
|
||||
import { Persona } from "./admin/personas/interfaces";
|
||||
|
||||
export default async function Home() {
|
||||
const tasks = [
|
||||
@ -29,6 +30,7 @@ export default async function Home() {
|
||||
cookie: processCookies(cookies()),
|
||||
},
|
||||
}),
|
||||
fetchSS("/persona"),
|
||||
];
|
||||
|
||||
// catch cases where the backend is completely unreachable here
|
||||
@ -44,6 +46,7 @@ export default async function Home() {
|
||||
const user = results[1] as User | null;
|
||||
const connectorsResponse = results[2] as Response | null;
|
||||
const documentSetsResponse = results[3] as Response | null;
|
||||
const personaResponse = results[4] as Response | null;
|
||||
|
||||
if (!authDisabled && !user) {
|
||||
return redirect("/auth/login");
|
||||
@ -65,6 +68,13 @@ export default async function Home() {
|
||||
);
|
||||
}
|
||||
|
||||
let personas: Persona[] = [];
|
||||
if (personaResponse?.ok) {
|
||||
personas = await personaResponse.json();
|
||||
} else {
|
||||
console.log(`Failed to fetch personas - ${personaResponse?.status}`);
|
||||
}
|
||||
|
||||
// needs to be done in a non-client side component due to nextjs
|
||||
const storedSearchType = cookies().get("searchType")?.value as
|
||||
| string
|
||||
@ -87,6 +97,7 @@ export default async function Home() {
|
||||
<SearchSection
|
||||
connectors={connectors}
|
||||
documentSets={documentSets}
|
||||
personas={personas}
|
||||
defaultSearchType={searchTypeDefault}
|
||||
/>
|
||||
</div>
|
||||
|
@ -14,11 +14,7 @@ interface DropdownProps {
|
||||
onSelect: (selected: Option) => void;
|
||||
}
|
||||
|
||||
export const Dropdown: FC<DropdownProps> = ({
|
||||
options,
|
||||
selected,
|
||||
onSelect,
|
||||
}) => {
|
||||
export const Dropdown = ({ options, selected, onSelect }: DropdownProps) => {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
|
27
web/src/components/EditButton.tsx
Normal file
27
web/src/components/EditButton.tsx
Normal file
@ -0,0 +1,27 @@
|
||||
"use client";
|
||||
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
import { FiChevronLeft, FiEdit } from "react-icons/fi";
|
||||
|
||||
export function EditButton({ onClick }: { onClick: () => void }) {
|
||||
return (
|
||||
<div
|
||||
className={`
|
||||
my-auto
|
||||
flex
|
||||
mb-1
|
||||
hover:bg-gray-800
|
||||
w-fit
|
||||
p-2
|
||||
cursor-pointer
|
||||
rounded-lg
|
||||
border-gray-800
|
||||
text-sm`}
|
||||
onClick={onClick}
|
||||
>
|
||||
<FiEdit className="mr-1 my-auto" />
|
||||
Edit
|
||||
</div>
|
||||
);
|
||||
}
|
@ -28,6 +28,7 @@ import {
|
||||
GongIcon,
|
||||
ZoomInIcon,
|
||||
ZendeskIcon,
|
||||
RobotIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { getAuthDisabledSS, getCurrentUserSS } from "@/lib/userSS";
|
||||
import { redirect } from "next/navigation";
|
||||
@ -314,13 +315,22 @@ export async function Layout({ children }: { children: React.ReactNode }) {
|
||||
],
|
||||
},
|
||||
{
|
||||
name: "Bots",
|
||||
name: "Custom Assistants",
|
||||
items: [
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<RobotIcon size={18} />
|
||||
<div className="ml-1">Personas</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/personas",
|
||||
},
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<CPUIcon size={18} />
|
||||
<div className="ml-1">Slack Bot</div>
|
||||
<div className="ml-1">Slack Bots</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/bot",
|
||||
|
@ -49,6 +49,7 @@ import hubSpotIcon from "../../../public/HubSpot.png";
|
||||
import document360Icon from "../../../public/Document360.png";
|
||||
import googleSitesIcon from "../../../public/GoogleSites.png";
|
||||
import zendeskIcon from "../../../public/Zendesk.svg";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
|
||||
interface IconProps {
|
||||
size?: number;
|
||||
@ -281,6 +282,13 @@ export const CPUIcon = ({
|
||||
return <FiCpu size={size} className={className} />;
|
||||
};
|
||||
|
||||
export const RobotIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return <FaRobot size={size} className={className} />;
|
||||
};
|
||||
|
||||
//
|
||||
// COMPANY LOGOS
|
||||
//
|
||||
|
@ -173,7 +173,7 @@ export const DocumentDisplay = ({
|
||||
ml-auto
|
||||
mr-2`}
|
||||
>
|
||||
{document.score.toFixed(2)}
|
||||
{Math.abs(document.score).toFixed(2)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
120
web/src/components/search/PersonaSelector.tsx
Normal file
120
web/src/components/search/PersonaSelector.tsx
Normal file
@ -0,0 +1,120 @@
|
||||
import { Persona } from "@/app/admin/personas/interfaces";
|
||||
import { CustomDropdown } from "../Dropdown";
|
||||
import { FiCheck, FiChevronDown } from "react-icons/fi";
|
||||
import { FaRobot } from "react-icons/fa";
|
||||
|
||||
function PersonaItem({
|
||||
id,
|
||||
name,
|
||||
onSelect,
|
||||
isSelected,
|
||||
isFinal,
|
||||
}: {
|
||||
id: number;
|
||||
name: string;
|
||||
onSelect: (personaId: number) => void;
|
||||
isSelected: boolean;
|
||||
isFinal: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
key={id}
|
||||
className={`
|
||||
flex
|
||||
px-3
|
||||
text-sm
|
||||
text-gray-200
|
||||
py-2.5
|
||||
select-none
|
||||
cursor-pointer
|
||||
${isFinal ? "" : "border-b border-gray-800"}
|
||||
${
|
||||
isSelected
|
||||
? "bg-dark-tremor-background-muted"
|
||||
: "hover:bg-dark-tremor-background-muted "
|
||||
}
|
||||
`}
|
||||
onClick={() => {
|
||||
onSelect(id);
|
||||
}}
|
||||
>
|
||||
{name}
|
||||
{isSelected && (
|
||||
<div className="ml-auto mr-1">
|
||||
<FiCheck />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function PersonaSelector({
|
||||
personas,
|
||||
selectedPersonaId,
|
||||
onPersonaChange,
|
||||
}: {
|
||||
personas: Persona[];
|
||||
selectedPersonaId: number | null;
|
||||
onPersonaChange: (persona: Persona | null) => void;
|
||||
}) {
|
||||
const currentlySelectedPersona = personas.find(
|
||||
(persona) => persona.id === selectedPersonaId
|
||||
);
|
||||
|
||||
return (
|
||||
<CustomDropdown
|
||||
dropdown={
|
||||
<div
|
||||
className={`
|
||||
border
|
||||
border-gray-800
|
||||
rounded-lg
|
||||
flex
|
||||
flex-col
|
||||
w-64
|
||||
max-h-96
|
||||
overflow-y-auto
|
||||
flex
|
||||
overscroll-contain`}
|
||||
>
|
||||
<PersonaItem
|
||||
key={-1}
|
||||
id={-1}
|
||||
name="Default"
|
||||
onSelect={() => {
|
||||
onPersonaChange(null);
|
||||
}}
|
||||
isSelected={selectedPersonaId === null}
|
||||
isFinal={false}
|
||||
/>
|
||||
{personas.map((persona, ind) => {
|
||||
const isSelected = persona.id === selectedPersonaId;
|
||||
return (
|
||||
<PersonaItem
|
||||
key={persona.id}
|
||||
id={persona.id}
|
||||
name={persona.name}
|
||||
onSelect={(clickedPersonaId) => {
|
||||
const clickedPersona = personas.find(
|
||||
(persona) => persona.id === clickedPersonaId
|
||||
);
|
||||
if (clickedPersona) {
|
||||
onPersonaChange(clickedPersona);
|
||||
}
|
||||
}}
|
||||
isSelected={isSelected}
|
||||
isFinal={ind === personas.length - 1}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div className="select-none text-sm flex text-gray-300 px-1 py-1.5 cursor-pointer w-64">
|
||||
<FaRobot className="my-auto mr-2" />
|
||||
{currentlySelectedPersona?.name || "Default"}{" "}
|
||||
<FiChevronDown className="my-auto ml-2" />
|
||||
</div>
|
||||
</CustomDropdown>
|
||||
);
|
||||
}
|
@ -7,11 +7,7 @@ interface SearchBarProps {
|
||||
onSearch: () => void;
|
||||
}
|
||||
|
||||
export const SearchBar: React.FC<SearchBarProps> = ({
|
||||
query,
|
||||
setQuery,
|
||||
onSearch,
|
||||
}) => {
|
||||
export const SearchBar = ({ query, setQuery, onSearch }: SearchBarProps) => {
|
||||
const handleChange = (event: ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const target = event.target;
|
||||
setQuery(target.value);
|
||||
@ -30,7 +26,7 @@ export const SearchBar: React.FC<SearchBarProps> = ({
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex justify-center py-3">
|
||||
<div className="flex justify-center">
|
||||
<div className="flex items-center w-full border-2 border-gray-600 rounded px-4 py-2 focus-within:border-blue-500">
|
||||
<MagnifyingGlass className="text-gray-400" />
|
||||
<textarea
|
||||
|
@ -20,6 +20,7 @@ import {
|
||||
} from "@/lib/search/aiThoughtUtils";
|
||||
import { ThreeDots } from "react-loader-spinner";
|
||||
import { usePopup } from "../admin/connectors/Popup";
|
||||
import { AlertIcon } from "../icons/icons";
|
||||
|
||||
const removeDuplicateDocs = (documents: DanswerDocument[]) => {
|
||||
const seen = new Set<string>();
|
||||
@ -49,14 +50,16 @@ interface SearchResultsDisplayProps {
|
||||
validQuestionResponse: ValidQuestionResponse;
|
||||
isFetching: boolean;
|
||||
defaultOverrides: SearchDefaultOverrides;
|
||||
personaName?: string | null;
|
||||
}
|
||||
|
||||
export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
|
||||
export const SearchResultsDisplay = ({
|
||||
searchResponse,
|
||||
validQuestionResponse,
|
||||
isFetching,
|
||||
defaultOverrides,
|
||||
}) => {
|
||||
personaName = null,
|
||||
}: SearchResultsDisplayProps) => {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [isAIThoughtsOpen, setIsAIThoughtsOpen] = React.useState<boolean>(
|
||||
getAIThoughtsIsOpenSavedValue()
|
||||
@ -70,6 +73,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
|
||||
return null;
|
||||
}
|
||||
|
||||
const isPersona = personaName !== null;
|
||||
const { answer, quotes, documents, error, queryEventId } = searchResponse;
|
||||
|
||||
if (isFetching && !answer && !documents) {
|
||||
@ -92,6 +96,17 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
|
||||
}
|
||||
|
||||
if (answer === null && documents === null && quotes === null) {
|
||||
if (error) {
|
||||
return (
|
||||
<div className="text-red-500 text-sm">
|
||||
<div className="flex">
|
||||
<AlertIcon size={16} className="text-red-500 my-auto mr-1" />
|
||||
<p className="italic">{error}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return <div className="text-gray-300">No matching documents found.</div>;
|
||||
}
|
||||
|
||||
@ -132,34 +147,38 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
|
||||
<h2 className="text font-bold my-auto mb-1 w-full">AI Answer</h2>
|
||||
</div>
|
||||
|
||||
<div className="mb-2 w-full">
|
||||
<ResponseSection
|
||||
status={questionValidityCheckStatus}
|
||||
header={
|
||||
validQuestionResponse.answerable === null ? (
|
||||
<div className="flex ml-2">Evaluating question...</div>
|
||||
) : (
|
||||
<div className="flex ml-2">AI thoughts</div>
|
||||
)
|
||||
}
|
||||
body={<div>{validQuestionResponse.reasoning}</div>}
|
||||
desiredOpenStatus={isAIThoughtsOpen}
|
||||
setDesiredOpenStatus={handleAIThoughtToggle}
|
||||
/>
|
||||
</div>
|
||||
{!isPersona && (
|
||||
<div className="mb-2 w-full">
|
||||
<ResponseSection
|
||||
status={questionValidityCheckStatus}
|
||||
header={
|
||||
validQuestionResponse.answerable === null ? (
|
||||
<div className="flex ml-2">Evaluating question...</div>
|
||||
) : (
|
||||
<div className="flex ml-2">AI thoughts</div>
|
||||
)
|
||||
}
|
||||
body={<div>{validQuestionResponse.reasoning}</div>}
|
||||
desiredOpenStatus={isAIThoughtsOpen}
|
||||
setDesiredOpenStatus={handleAIThoughtToggle}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-2 pt-1 border-t border-gray-700 w-full">
|
||||
<AnswerSection
|
||||
answer={answer}
|
||||
quotes={quotes}
|
||||
error={error}
|
||||
isAnswerable={validQuestionResponse.answerable}
|
||||
isAnswerable={
|
||||
validQuestionResponse.answerable || (isPersona ? true : null)
|
||||
}
|
||||
isFetching={isFetching}
|
||||
aiThoughtsIsOpen={isAIThoughtsOpen}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{quotes !== null && answer && (
|
||||
{quotes !== null && answer && !isPersona && (
|
||||
<div className="pt-1 border-t border-gray-700 w-full">
|
||||
<QuotesSection
|
||||
quotes={dedupedQuotes}
|
||||
|
@ -20,8 +20,11 @@ import { SearchHelper } from "./SearchHelper";
|
||||
import { CancellationToken, cancellable } from "@/lib/search/cancellable";
|
||||
import { NEXT_PUBLIC_DISABLE_STREAMING } from "@/lib/constants";
|
||||
import { searchRequest } from "@/lib/search/qa";
|
||||
import { useFilters, useObjectState, useTimeRange } from "@/lib/hooks";
|
||||
import { useFilters, useObjectState } from "@/lib/hooks";
|
||||
import { questionValidationStreamed } from "@/lib/search/streamingQuestionValidation";
|
||||
import { createChatSession } from "@/lib/search/chatSessions";
|
||||
import { Persona } from "@/app/admin/personas/interfaces";
|
||||
import { PersonaSelector } from "./PersonaSelector";
|
||||
|
||||
const SEARCH_DEFAULT_OVERRIDES_START: SearchDefaultOverrides = {
|
||||
forceDisplayQA: false,
|
||||
@ -36,14 +39,16 @@ const VALID_QUESTION_RESPONSE_DEFAULT: ValidQuestionResponse = {
|
||||
interface SearchSectionProps {
|
||||
connectors: Connector<any>[];
|
||||
documentSets: DocumentSet[];
|
||||
personas: Persona[];
|
||||
defaultSearchType: SearchType;
|
||||
}
|
||||
|
||||
export const SearchSection: React.FC<SearchSectionProps> = ({
|
||||
export const SearchSection = ({
|
||||
connectors,
|
||||
documentSets,
|
||||
personas,
|
||||
defaultSearchType,
|
||||
}) => {
|
||||
}: SearchSectionProps) => {
|
||||
// Search Bar
|
||||
const [query, setQuery] = useState<string>("");
|
||||
|
||||
@ -63,6 +68,8 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
|
||||
const [selectedSearchType, setSelectedSearchType] =
|
||||
useState<SearchType>(defaultSearchType);
|
||||
|
||||
const [selectedPersona, setSelectedPersona] = useState<number | null>(null);
|
||||
|
||||
// Overrides for default behavior that only last a single query
|
||||
const [defaultOverrides, setDefaultOverrides] =
|
||||
useState<SearchDefaultOverrides>(SEARCH_DEFAULT_OVERRIDES_START);
|
||||
@ -134,11 +141,23 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
|
||||
setSearchResponse(initialSearchResponse);
|
||||
setValidQuestionResponse(VALID_QUESTION_RESPONSE_DEFAULT);
|
||||
|
||||
const chatSessionResponse = await createChatSession(selectedPersona);
|
||||
if (!chatSessionResponse.ok) {
|
||||
updateError(
|
||||
`Unable to create chat session - ${await chatSessionResponse.text()}`
|
||||
);
|
||||
setIsFetching(false);
|
||||
return;
|
||||
}
|
||||
const chatSessionId = (await chatSessionResponse.json())
|
||||
.chat_session_id as number;
|
||||
|
||||
const searchFn = NEXT_PUBLIC_DISABLE_STREAMING
|
||||
? searchRequest
|
||||
: searchRequestStreamed;
|
||||
const searchFnArgs = {
|
||||
query,
|
||||
chatSessionId,
|
||||
sources: filterManager.selectedSources,
|
||||
documentSets: filterManager.selectedDocumentSets,
|
||||
timeRange: filterManager.timeRange,
|
||||
@ -180,6 +199,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
|
||||
|
||||
const questionValidationArgs = {
|
||||
query,
|
||||
chatSessionId,
|
||||
update: setValidQuestionResponse,
|
||||
};
|
||||
|
||||
@ -226,6 +246,20 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-[800px] mx-auto">
|
||||
{personas.length > 0 ? (
|
||||
<div className="flex mb-2 w-64">
|
||||
<PersonaSelector
|
||||
personas={personas}
|
||||
selectedPersonaId={selectedPersona}
|
||||
onPersonaChange={(persona) =>
|
||||
setSelectedPersona(persona ? persona.id : null)
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="pt-3" />
|
||||
)}
|
||||
|
||||
<SearchBar
|
||||
query={query}
|
||||
setQuery={setQuery}
|
||||
@ -241,6 +275,11 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
|
||||
validQuestionResponse={validQuestionResponse}
|
||||
isFetching={isFetching}
|
||||
defaultOverrides={defaultOverrides}
|
||||
personaName={
|
||||
selectedPersona
|
||||
? personas.find((p) => p.id === selectedPersona)?.name
|
||||
: null
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
12
web/src/lib/search/chatSessions.ts
Normal file
12
web/src/lib/search/chatSessions.ts
Normal file
@ -0,0 +1,12 @@
|
||||
export async function createChatSession(personaId?: number | null) {
|
||||
const chatSessionResponse = await fetch("/api/chat/create-chat-session", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
persona_id: personaId,
|
||||
}),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
return chatSessionResponse;
|
||||
}
|
@ -92,6 +92,7 @@ export interface Filters {
|
||||
|
||||
export interface SearchRequestArgs {
|
||||
query: string;
|
||||
chatSessionId: number;
|
||||
sources: Source[];
|
||||
documentSets: string[];
|
||||
timeRange: DateRangePickerValue | null;
|
||||
|
@ -14,6 +14,7 @@ import { buildFilters } from "./utils";
|
||||
|
||||
export const searchRequestStreamed = async ({
|
||||
query,
|
||||
chatSessionId,
|
||||
sources,
|
||||
documentSets,
|
||||
timeRange,
|
||||
@ -35,6 +36,7 @@ export const searchRequestStreamed = async ({
|
||||
const response = await fetch("/api/stream-direct-qa", {
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
chat_session_id: chatSessionId,
|
||||
query,
|
||||
collection: "danswer_index",
|
||||
filters,
|
||||
|
@ -3,11 +3,13 @@ import { processRawChunkString } from "./streamingUtils";
|
||||
|
||||
export interface QuestionValidationArgs {
|
||||
query: string;
|
||||
chatSessionId: number;
|
||||
update: (update: Partial<ValidQuestionResponse>) => void;
|
||||
}
|
||||
|
||||
export const questionValidationStreamed = async <T>({
|
||||
query,
|
||||
chatSessionId,
|
||||
update,
|
||||
}: QuestionValidationArgs) => {
|
||||
const emptyFilters = {
|
||||
@ -20,6 +22,7 @@ export const questionValidationStreamed = async <T>({
|
||||
method: "POST",
|
||||
body: JSON.stringify({
|
||||
query,
|
||||
chat_session_id: chatSessionId,
|
||||
collection: "danswer_index",
|
||||
filters: emptyFilters,
|
||||
enable_auto_detect_filters: false,
|
||||
|
Loading…
x
Reference in New Issue
Block a user