Customizable personas (#772)

Also includes a small fix to LLM filtering when combined with reranking
This commit is contained in:
Chris Weaver 2023-11-28 00:57:48 -08:00 committed by GitHub
parent 87beb1f4d1
commit 78d1ae0379
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 1846 additions and 408 deletions

View File

@ -0,0 +1,28 @@
"""Add additional retrieval controls to Persona
Revision ID: 50b683a8295c
Revises: 7da0ae5ad583
Create Date: 2023-11-27 17:23:29.668422
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "50b683a8295c"
down_revision = "7da0ae5ad583"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("persona", sa.Column("num_chunks", sa.Integer(), nullable=True))
op.add_column(
"persona",
sa.Column("apply_llm_relevance_filter", sa.Boolean(), nullable=True),
)
def downgrade() -> None:
op.drop_column("persona", "apply_llm_relevance_filter")
op.drop_column("persona", "num_chunks")

View File

@ -0,0 +1,23 @@
"""Add description to persona
Revision ID: 7da0ae5ad583
Revises: e86866a9c78a
Create Date: 2023-11-27 00:16:19.959414
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "7da0ae5ad583"
down_revision = "e86866a9c78a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("persona", sa.Column("description", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("persona", "description")

View File

@ -0,0 +1,36 @@
"""Add chat session to query_event
Revision ID: 80696cf850ae
Revises: 15326fcec57e
Create Date: 2023-11-26 02:38:35.008070
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "80696cf850ae"
down_revision = "15326fcec57e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"query_event",
sa.Column("chat_session_id", sa.Integer(), nullable=True),
)
op.create_foreign_key(
"fk_query_event_chat_session_id",
"query_event",
"chat_session",
["chat_session_id"],
["id"],
)
def downgrade() -> None:
op.drop_constraint(
"fk_query_event_chat_session_id", "query_event", type_="foreignkey"
)
op.drop_column("query_event", "chat_session_id")

View File

@ -0,0 +1,27 @@
"""Add persona to chat_session
Revision ID: e86866a9c78a
Revises: 80696cf850ae
Create Date: 2023-11-26 02:51:47.657357
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "e86866a9c78a"
down_revision = "80696cf850ae"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column("chat_session", sa.Column("persona_id", sa.Integer(), nullable=True))
op.create_foreign_key(
"fk_chat_session_persona_id", "chat_session", "persona", ["persona_id"], ["id"]
)
def downgrade() -> None:
op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey")
op.drop_column("chat_session", "persona_id")

View File

@ -22,12 +22,13 @@ from danswer.danswerbot.slack.models import SlackMessageInfo
from danswer.danswerbot.slack.utils import ChannelIdAdapter
from danswer.danswerbot.slack.utils import fetch_userids_from_emails
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.chat import create_chat_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import SlackBotConfig
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.search.models import BaseFilters
from danswer.server.models import NewMessageRequest
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger
logger_base = setup_logger()
@ -171,12 +172,12 @@ def handle_message(
backoff=2,
logger=logger,
)
def _get_answer(question: QuestionRequest) -> QAResponse:
def _get_answer(new_message_request: NewMessageRequest) -> QAResponse:
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as db_session:
# This also handles creating the query event in postgres
answer = answer_qa_query(
question=question,
new_message_request=new_message_request,
user=None,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
@ -188,6 +189,15 @@ def handle_message(
else:
raise RuntimeError(answer.error_msg)
# create a chat session for this interaction
# TODO: when chat support is added to Slack, this should check
# for an existing chat session associated with this thread
with Session(get_sqlalchemy_engine()) as db_session:
chat_session = create_chat_session(
db_session=db_session, description="", user_id=None
)
chat_session_id = chat_session.id
answer_failed = False
try:
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
@ -200,7 +210,8 @@ def handle_message(
# This includes throwing out answer via reflexion
answer = _get_answer(
QuestionRequest(
NewMessageRequest(
chat_session_id=chat_session_id,
query=msg,
filters=filters,
enable_auto_detect_filters=not disable_auto_detect_filters,

View File

@ -1,3 +1,4 @@
from collections.abc import Sequence
from typing import Any
from uuid import UUID
@ -99,10 +100,14 @@ def verify_parent_exists(
def create_chat_session(
description: str, user_id: UUID | None, db_session: Session
db_session: Session,
description: str,
user_id: UUID | None,
persona_id: int | None = None,
) -> ChatSession:
chat_session = ChatSession(
user_id=user_id,
persona_id=persona_id,
description=description,
)
@ -256,7 +261,11 @@ def set_latest_chat_message(
def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
stmt = select(Persona).where(Persona.id == persona_id)
stmt = (
select(Persona)
.where(Persona.id == persona_id)
.where(Persona.deleted == False) # noqa: E712
)
result = db_session.execute(stmt)
persona = result.scalar_one_or_none()
@ -269,8 +278,12 @@ def fetch_persona_by_id(persona_id: int, db_session: Session) -> Persona:
def fetch_default_persona_by_name(
persona_name: str, db_session: Session
) -> Persona | None:
stmt = select(Persona).where(
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
stmt = (
select(Persona)
.where(
Persona.name == persona_name, Persona.default_persona == True # noqa: E712
)
.where(Persona.deleted == False) # noqa: E712
)
result = db_session.execute(stmt).scalar_one_or_none()
return result
@ -284,7 +297,11 @@ def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | N
if persona is not None:
return persona
stmt = select(Persona).where(Persona.name == persona_name) # noqa: E712
stmt = (
select(Persona)
.where(Persona.name == persona_name)
.where(Persona.deleted == False) # noqa: E712
)
result = db_session.execute(stmt).first()
if result:
return result[0]
@ -292,31 +309,44 @@ def fetch_persona_by_name(persona_name: str, db_session: Session) -> Persona | N
def upsert_persona(
db_session: Session,
name: str,
retrieval_enabled: bool,
datetime_aware: bool,
system_text: str | None,
tools: list[ToolInfo] | None,
hint_text: str | None,
db_session: Session,
description: str | None = None,
system_text: str | None = None,
tools: list[ToolInfo] | None = None,
hint_text: str | None = None,
num_chunks: int | None = None,
apply_llm_relevance_filter: bool | None = None,
persona_id: int | None = None,
default_persona: bool = False,
document_sets: list[DocumentSetDBModel] | None = None,
commit: bool = True,
) -> Persona:
persona = db_session.query(Persona).filter_by(id=persona_id).first()
if persona and persona.deleted:
raise ValueError("Trying to update a deleted persona")
# Default personas are defined via yaml files at deployment time
if persona is None and default_persona:
persona = fetch_default_persona_by_name(name, db_session)
if persona is None:
if default_persona:
persona = fetch_default_persona_by_name(name, db_session)
else:
# only one persona with the same name should exist
if fetch_persona_by_name(name, db_session):
raise ValueError("Trying to create a persona with a duplicate name")
if persona:
persona.name = name
persona.description = description
persona.retrieval_enabled = retrieval_enabled
persona.datetime_aware = datetime_aware
persona.system_text = system_text
persona.tools = tools
persona.hint_text = hint_text
persona.num_chunks = num_chunks
persona.apply_llm_relevance_filter = apply_llm_relevance_filter
persona.default_persona = default_persona
# Do not delete any associations manually added unless
@ -328,11 +358,14 @@ def upsert_persona(
else:
persona = Persona(
name=name,
description=description,
retrieval_enabled=retrieval_enabled,
datetime_aware=datetime_aware,
system_text=system_text,
tools=tools,
hint_text=hint_text,
num_chunks=num_chunks,
apply_llm_relevance_filter=apply_llm_relevance_filter,
default_persona=default_persona,
document_sets=document_sets if document_sets else [],
)
@ -345,3 +378,18 @@ def upsert_persona(
db_session.flush()
return persona
def fetch_personas(
db_session: Session, include_default: bool = False
) -> Sequence[Persona]:
stmt = select(Persona).where(Persona.deleted == False) # noqa: E712
if not include_default:
stmt = stmt.where(Persona.default_persona == False) # noqa: E712
return db_session.scalars(stmt).all()
def mark_persona_as_deleted(db_session: Session, persona_id: int) -> None:
persona = fetch_persona_by_id(persona_id, db_session)
persona.deleted = True
db_session.commit()

View File

@ -100,6 +100,7 @@ def update_document_hidden(
def create_query_event(
db_session: Session,
query: str,
chat_session_id: int,
search_type: SearchType | None,
llm_answer: str | None,
user_id: UUID | None,
@ -107,6 +108,7 @@ def create_query_event(
) -> int:
query_event = QueryEvent(
query=query,
chat_session_id=chat_session_id,
selected_search_flow=search_type,
llm_answer=llm_answer,
retrieved_document_ids=retrieved_document_ids,

View File

@ -4,6 +4,7 @@ from typing import Any
from typing import List
from typing import Literal
from typing import NotRequired
from typing import Optional
from typing import TypedDict
from uuid import UUID
@ -341,6 +342,11 @@ class QueryEvent(Base):
__tablename__ = "query_event"
id: Mapped[int] = mapped_column(primary_key=True)
# TODO: make this non-nullable after migration to consolidate chat /
# QA flows is complete
chat_session_id: Mapped[int | None] = mapped_column(
ForeignKey("chat_session.id"), nullable=True
)
query: Mapped[str] = mapped_column(Text)
# search_flow refers to user selection, None if user used auto
selected_search_flow: Mapped[SearchType | None] = mapped_column(
@ -459,6 +465,9 @@ class ChatSession(Base):
id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
persona_id: Mapped[int | None] = mapped_column(
ForeignKey("persona.id"), default=None
)
description: Mapped[str] = mapped_column(Text)
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
# The following texts help build up the model's ability to use the context effectively
@ -475,6 +484,7 @@ class ChatSession(Base):
messages: Mapped[List["ChatMessage"]] = relationship(
"ChatMessage", back_populates="chat_session", cascade="delete"
)
persona: Mapped[Optional["Persona"]] = relationship("Persona")
class ToolInfo(TypedDict):
@ -488,6 +498,7 @@ class Persona(Base):
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String)
description: Mapped[str | None] = mapped_column(String, nullable=True)
# Danswer retrieval, treated as a special tool
retrieval_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
@ -496,6 +507,13 @@ class Persona(Base):
postgresql.JSONB(), nullable=True
)
hint_text: Mapped[str | None] = mapped_column(Text, nullable=True)
# number of chunks to use for retrieval. If unspecified, uses the default set
# in the env variables
num_chunks: Mapped[int | None] = mapped_column(Integer, nullable=True)
# if unspecified, then uses the default set in the env variables
apply_llm_relevance_filter: Mapped[bool | None] = mapped_column(
Boolean, nullable=True
)
# Default personas are configured via backend during deployment
# Treated specially (cannot be user edited etc.)
default_persona: Mapped[bool] = mapped_column(Boolean, default=False)

View File

@ -5,36 +5,40 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.configs.app_configs import CHUNK_SIZE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import QUERY_EVENT_ID
from danswer.db.chat import fetch_chat_session_by_id
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_llm_answer
from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User
from danswer.direct_qa.factory import get_default_qa_model
from danswer.direct_qa.factory import get_qa_model_for_persona
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_chunks_for_qa
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.danswer_helper import query_intent
from danswer.search.models import QueryFlow
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchType
from danswer.search.request_preprocessing import retrieval_preprocessing
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import danswer_search
from danswer.search.search_runner import danswer_search_generator
from danswer.search.search_runner import full_chunk_search
from danswer.search.search_runner import full_chunk_search_generator
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
from danswer.server.models import LLMRelevanceFilterResponse
from danswer.server.models import NewMessageRequest
from danswer.server.models import QADocsResponse
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
@ -43,7 +47,7 @@ logger = setup_logger()
@log_function_time()
def answer_qa_query(
question: QuestionRequest,
new_message_request: NewMessageRequest,
user: User | None,
db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
@ -55,43 +59,36 @@ def answer_qa_query(
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> QAResponse:
query = question.query
offset_count = question.offset if question.offset is not None else 0
query = new_message_request.query
offset_count = (
new_message_request.offset if new_message_request.offset is not None else 0
)
logger.info(f"Received QA query: {query}")
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
run_source_filters = FunctionCall(
extract_question_source_filters, (question, db_session), {}
)
run_query_intent = FunctionCall(query_intent, (query,), {})
parallel_results = run_functions_in_parallel(
[
run_time_filters,
run_source_filters,
run_query_intent,
]
# create record for this query in Postgres
query_event_id = create_query_event(
query=new_message_request.query,
chat_session_id=new_message_request.chat_session_id,
search_type=new_message_request.search_type,
llm_answer=None,
user_id=user.id if user is not None else None,
db_session=db_session,
)
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
source_filters = parallel_results[run_source_filters.result_id]
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing(
new_message_request=new_message_request,
user=user,
db_session=db_session,
bypass_acl=bypass_acl,
)
# Set flow as search so frontend doesn't ask the user if they want to run QA over more docs
if disable_generative_answer:
predicted_flow = QueryFlow.SEARCH
# Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
question.filters.source_type = source_filters
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
question=question,
user=user,
db_session=db_session,
top_chunks, llm_chunk_selection = full_chunk_search(
query=retrieval_request,
document_index=get_default_document_index(),
bypass_acl=bypass_acl,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
@ -102,11 +99,11 @@ def answer_qa_query(
QAResponse,
top_documents=chunks_to_search_docs(top_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
predicted_search=predicted_search_type,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
source_type=retrieval_request.filters.source_type,
time_cutoff=retrieval_request.filters.time_cutoff,
favor_recent=retrieval_request.favor_recent,
)
if disable_generative_answer or not top_docs:
@ -115,9 +112,20 @@ def answer_qa_query(
quotes=None,
)
# update record for this query to include top docs
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_chunks]
if top_chunks
else [],
query_id=query_event_id,
user_id=None if user is None else user.id,
)
try:
qa_model = get_default_qa_model(
timeout=answer_generation_timeout, real_time_flow=question.real_time
timeout=answer_generation_timeout,
real_time_flow=new_message_request.real_time,
)
except Exception as e:
return partial_response(
@ -131,9 +139,7 @@ def answer_qa_query(
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
)
@ -158,7 +164,7 @@ def answer_qa_query(
)
validity = None
if not question.real_time and enable_reflexion and d_answer is not None:
if not new_message_request.real_time and enable_reflexion and d_answer is not None:
validity = False
if d_answer.answer is not None:
validity = get_answer_validity(query, d_answer.answer)
@ -174,47 +180,61 @@ def answer_qa_query(
@log_generator_function_time()
def answer_qa_query_stream(
question: QuestionRequest,
new_message_request: NewMessageRequest,
user: User | None,
db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Iterator[str]:
logger.debug(
f"Received QA query ({question.search_type.value} search): {question.query}"
f"Received QA query ({new_message_request.search_type.value} search): {new_message_request.query}"
)
logger.debug(f"Query filters: {question.filters}")
logger.debug(f"Query filters: {new_message_request.filters}")
answer_so_far: str = ""
query = question.query
offset_count = question.offset if question.offset is not None else 0
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
run_source_filters = FunctionCall(
extract_question_source_filters, (question, db_session), {}
)
run_query_intent = FunctionCall(query_intent, (query,), {})
parallel_results = run_functions_in_parallel(
[
run_time_filters,
run_source_filters,
run_query_intent,
]
query = new_message_request.query
offset_count = (
new_message_request.offset if new_message_request.offset is not None else 0
)
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
source_filters = parallel_results[run_source_filters.result_id]
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
# create record for this query in Postgres
query_event_id = create_query_event(
query=new_message_request.query,
chat_session_id=new_message_request.chat_session_id,
search_type=new_message_request.search_type,
llm_answer=None,
user_id=user.id if user is not None else None,
db_session=db_session,
)
chat_session = fetch_chat_session_by_id(
chat_session_id=new_message_request.chat_session_id, db_session=db_session
)
persona = chat_session.persona
persona_skip_llm_chunk_filter = (
not persona.apply_llm_relevance_filter if persona else None
)
persona_num_chunks = persona.num_chunks if persona else None
if persona:
logger.info(f"Using persona: {persona.name}")
logger.info(
"Persona retrieval settings: skip_llm_chunk_filter: "
f"{persona_skip_llm_chunk_filter}, "
f"num_chunks: {persona_num_chunks}"
)
# Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
question.filters.source_type = source_filters
search_generator = danswer_search_generator(
question=question,
retrieval_request, predicted_search_type, predicted_flow = retrieval_preprocessing(
new_message_request=new_message_request,
user=user,
db_session=db_session,
skip_llm_chunk_filter=persona_skip_llm_chunk_filter
if persona_skip_llm_chunk_filter is not None
else DISABLE_LLM_CHUNK_FILTER,
)
# if a persona is specified, always respond with an answer
if persona:
predicted_flow = QueryFlow.QUESTION_ANSWER
search_generator = full_chunk_search_generator(
query=retrieval_request,
document_index=get_default_document_index(),
)
@ -228,10 +248,10 @@ def answer_qa_query_stream(
# doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search=predicted_search,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
else cast(QueryFlow, predicted_flow),
predicted_search=cast(SearchType, predicted_search_type),
time_cutoff=retrieval_request.filters.time_cutoff,
favor_recent=retrieval_request.favor_recent,
).dict()
yield get_json_line(initial_response)
@ -239,31 +259,44 @@ def answer_qa_query_stream(
logger.debug("No Documents Found")
return
# next apply the LLM filtering
# update record for this query to include top docs
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_chunks]
if top_chunks
else [],
query_id=query_event_id,
user_id=None if user is None else user.id,
)
# next get the final chunks to be fed into the LLM
llm_chunk_selection = cast(list[bool], next(search_generator))
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
token_limit=persona_num_chunks * CHUNK_SIZE
if persona_num_chunks
else NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
batch_offset=offset_count,
)
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
relevant_chunk_indices=[
index for index, value in enumerate(llm_chunk_selection) if value
]
if not retrieval_request.skip_llm_chunk_filter
else []
).dict()
yield get_json_line(llm_relevance_filtering_response)
# finally get the query ID from the search generator for updating the
# row in Postgres. This is the end of the `search_generator` - any future
# calls to `next` will raise StopIteration
query_event_id = cast(int, next(search_generator))
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_qa_model()
if not persona:
qa_model = get_default_qa_model()
else:
qa_model = get_qa_model_for_persona(persona=persona)
except Exception as e:
logger.exception("Unable to get QA model")
error = StreamingError(error=str(e))

View File

@ -1,6 +1,8 @@
from danswer.configs.app_configs import QA_PROMPT_OVERRIDE
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.db.models import Persona
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_block import PersonaBasedQAHandler
from danswer.direct_qa.qa_block import QABlock
from danswer.direct_qa.qa_block import QAHandler
from danswer.direct_qa.qa_block import SingleMessageQAHandler
@ -44,3 +46,16 @@ def get_default_qa_model(
llm=llm,
qa_handler=qa_handler,
)
def get_qa_model_for_persona(
persona: Persona,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
) -> QAModel:
return QABlock(
llm=get_default_llm(api_key=api_key, timeout=timeout),
qa_handler=PersonaBasedQAHandler(
system_prompt=persona.system_text or "", task_prompt=persona.hint_text or ""
),
)

View File

@ -10,6 +10,7 @@ from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
@ -24,6 +25,7 @@ from danswer.prompts.constants import CODE_BLOCK_PAT
from danswer.prompts.direct_qa_prompts import COT_PROMPT
from danswer.prompts.direct_qa_prompts import JSON_PROMPT
from danswer.prompts.direct_qa_prompts import LANGUAGE_HINT
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
@ -190,6 +192,56 @@ class SingleMessageScratchpadHandler(QAHandler):
)
class PersonaBasedQAHandler(QAHandler):
def __init__(self, system_prompt: str, task_prompt: str) -> None:
self.system_prompt = system_prompt
self.task_prompt = task_prompt
@property
def is_json_output(self) -> bool:
return False
def build_prompt(
self,
query: str,
context_chunks: list[InferenceChunk],
) -> list[BaseMessage]:
context_docs_str = build_context_str(context_chunks)
single_message = PARAMATERIZED_PROMPT.format(
context_docs_str=context_docs_str,
user_query=query,
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
prompt: list[BaseMessage] = [HumanMessage(content=single_message)]
return prompt
def build_dummy_prompt(
self,
) -> str:
return PARAMATERIZED_PROMPT.format(
context_docs_str="<CONTEXT_DOCS>",
user_query="<USER_QUERY>",
system_prompt=self.system_prompt,
task_prompt=self.task_prompt,
).strip()
def process_llm_output(
self, model_output: str, context_chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, DanswerQuotes]:
return DanswerAnswer(answer=model_output), DanswerQuotes(quotes=[])
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
for token in tokens:
yield DanswerAnswerPiece(answer_piece=token)
yield DanswerQuotes(quotes=[])
class QABlock(QAModel):
def __init__(self, llm: LLM, qa_handler: QAHandler) -> None:
self._llm = llm

View File

@ -45,13 +45,14 @@ from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm
from danswer.search.search_nlp_models import warm_up_models
from danswer.server.cc_pair.api import router as cc_pair_router
from danswer.server.chat_backend import router as chat_router
from danswer.server.chat.api import router as chat_router
from danswer.server.connector import router as connector_router
from danswer.server.credential import router as credential_router
from danswer.server.danswer_api import get_danswer_api_key
from danswer.server.danswer_api import router as danswer_api_router
from danswer.server.document_set import router as document_set_router
from danswer.server.manage import router as admin_router
from danswer.server.persona.api import router as persona_router
from danswer.server.search_backend import router as backend_router
from danswer.server.slack_bot_management import router as slack_bot_management_router
from danswer.server.state import router as state_router
@ -97,6 +98,7 @@ def get_application() -> FastAPI:
application.include_router(cc_pair_router)
application.include_router(document_set_router)
application.include_router(slack_bot_management_router)
application.include_router(persona_router)
application.include_router(state_router)
application.include_router(danswer_api_router)

View File

@ -118,6 +118,23 @@ Answer the user query based on the following document:
""".strip()
# Paramaterized prompt which allows the user to specify their
# own system / task prompt
PARAMATERIZED_PROMPT = f"""
{{system_prompt}}
CONTEXT:
{GENERAL_SEP_PAT}
{{context_docs_str}}
{GENERAL_SEP_PAT}
{{task_prompt}}
{QUESTION_PAT.upper()} {{user_query}}
RESPONSE:
""".strip()
# User the following for easy viewing of prompts
if __name__ == "__main__":
print(JSON_PROMPT) # Default prompt used in the Danswer UI flow

View File

@ -62,6 +62,9 @@ class SearchQuery(BaseModel):
# Only used if not skip_llm_chunk_filter
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
class Config:
frozen = True
class RetrievalMetricsContainer(BaseModel):
search_type: SearchType

View File

@ -0,0 +1,121 @@
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.models import User
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.models import IndexFilters
from danswer.search.models import QueryFlow
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.secondary_llm_flows.source_filter import extract_source_filter
from danswer.secondary_llm_flows.time_filter import extract_time_filter
from danswer.server.models import NewMessageRequest
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
def retrieval_preprocessing(
new_message_request: NewMessageRequest,
user: User | None,
db_session: Session,
bypass_acl: bool = False,
include_query_intent: bool = True,
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
skip_rerank_non_realtime: bool = SKIP_RERANKING,
disable_llm_filter_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
) -> tuple[SearchQuery, SearchType | None, QueryFlow | None]:
auto_filters_enabled = (
not disable_llm_filter_extraction
and new_message_request.enable_auto_detect_filters
)
# based on the query figure out if we should apply any hard time filters /
# if we should bias more recent docs even more strongly
run_time_filters = (
FunctionCall(extract_time_filter, (new_message_request.query,), {})
if auto_filters_enabled
else None
)
# based on the query, figure out if we should apply any source filters
should_run_source_filters = (
auto_filters_enabled and not new_message_request.filters.source_type
)
run_source_filters = (
FunctionCall(extract_source_filter, (new_message_request.query, db_session), {})
if should_run_source_filters
else None
)
# NOTE: this isn't really part of building the retrieval request, but is done here
# so it can be simply done in parallel with the filters without multi-level multithreading
run_query_intent = (
FunctionCall(query_intent, (new_message_request.query,), {})
if include_query_intent
else None
)
functions_to_run = [
filter_fn
for filter_fn in [
run_time_filters,
run_source_filters,
run_query_intent,
]
if filter_fn
]
parallel_results = run_functions_in_parallel(functions_to_run)
time_cutoff, favor_recent = (
parallel_results[run_time_filters.result_id]
if run_time_filters
else (None, None)
)
source_filters = (
parallel_results[run_source_filters.result_id] if run_source_filters else None
)
predicted_search_type, predicted_flow = (
parallel_results[run_query_intent.result_id]
if run_query_intent
else (None, None)
)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
source_type=new_message_request.filters.source_type or source_filters,
document_set=new_message_request.filters.document_set,
time_cutoff=new_message_request.filters.time_cutoff or time_cutoff,
access_control_list=user_acl_filters,
)
# figure out if we should skip running Tranformer-based re-ranking of the
# top chunks
skip_reranking = (
skip_rerank_realtime
if new_message_request.real_time
else skip_rerank_non_realtime
)
return (
SearchQuery(
query=new_message_request.query,
search_type=new_message_request.search_type,
filters=final_filters,
# use user specified favor_recent over generated favor_recent
favor_recent=(
new_message_request.favor_recent
if new_message_request.favor_recent is not None
else (favor_recent or False)
),
skip_rerank=skip_reranking,
skip_llm_chunk_filter=skip_llm_chunk_filter,
),
predicted_search_type,
predicted_flow,
)

View File

@ -7,30 +7,21 @@ import numpy
from nltk.corpus import stopwords # type:ignore
from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import HYBRID_ALPHA
from danswer.configs.app_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import InferenceChunk
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
@ -40,7 +31,6 @@ from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.secondary_llm_flows.query_expansion import rephrase_query
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
@ -547,112 +537,9 @@ def full_chunk_search_generator(
else None,
)
if llm_chunk_selection is not None:
yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks]
yield [
chunk.unique_id in llm_chunk_selection
for chunk in reranked_chunks or retrieved_chunks
]
else:
yield [True for _ in reranked_chunks or retrieved_chunks]
def danswer_search_generator(
question: QuestionRequest,
user: User | None,
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
skip_rerank_realtime: bool = not ENABLE_RERANKING_REAL_TIME_FLOW,
skip_rerank_non_realtime: bool = SKIP_RERANKING,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Iterator[list[InferenceChunk] | list[bool] | int]:
"""The main entry point for search. This fetches the relevant documents from Vespa
based on the provided query (applying permissions / filters), does any specified
post-processing, and returns the results. It also creates an entry in the query_event table
for this search event."""
query_event_id = create_query_event(
query=question.query,
search_type=question.search_type,
llm_answer=None,
user_id=user.id if user is not None else None,
db_session=db_session,
)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
source_type=question.filters.source_type,
document_set=question.filters.document_set,
time_cutoff=question.filters.time_cutoff,
access_control_list=user_acl_filters,
)
skip_reranking = (
skip_rerank_realtime if question.real_time else skip_rerank_non_realtime
)
search_query = SearchQuery(
query=question.query,
search_type=question.search_type,
filters=final_filters,
# Still applies time decay but not magnified
favor_recent=question.favor_recent
if question.favor_recent is not None
else False,
skip_rerank=skip_reranking,
skip_llm_chunk_filter=skip_llm_chunk_filter,
)
search_generator = full_chunk_search_generator(
query=search_query,
document_index=document_index,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
yield top_chunks
llm_chunk_selection = cast(list[bool], next(search_generator))
yield llm_chunk_selection
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_chunks]
if top_chunks
else [],
query_id=query_event_id,
user_id=None if user is None else user.id,
)
yield query_event_id
def danswer_search(
question: QuestionRequest,
user: User | None,
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk], list[bool], int]:
"""Returns a tuple of the top chunks, the LLM relevance filter results, and the query event ID.
Presents a simpler interface than the underlying `danswer_search_generator`, as callers no
longer need to worry about the order / have nicer typing. This should be used for flows which
do not require streaming."""
search_generator = danswer_search_generator(
question=question,
user=user,
db_session=db_session,
document_index=document_index,
skip_llm_chunk_filter=skip_llm_chunk_filter,
bypass_acl=bypass_acl,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
top_chunks = cast(list[InferenceChunk], next(search_generator))
llm_chunk_selection = cast(list[bool], next(search_generator))
query_event_id = cast(int, next(search_generator))
return top_chunks, llm_chunk_selection, query_event_id

View File

@ -3,7 +3,6 @@ import random
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.constants import DocumentSource
from danswer.db.connector import fetch_unique_document_sources
from danswer.db.engine import get_sqlalchemy_engine
@ -13,7 +12,6 @@ from danswer.prompts.constants import SOURCES_KEY
from danswer.prompts.secondary_llm_flows import FILE_SOURCE_WARNING
from danswer.prompts.secondary_llm_flows import SOURCE_FILTER_PROMPT
from danswer.prompts.secondary_llm_flows import WEB_SOURCE_WARNING
from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.timing import log_function_time
@ -161,21 +159,6 @@ def extract_source_filter(
return _extract_source_filters_from_llm_out(model_output)
def extract_question_source_filters(
question: QuestionRequest,
db_session: Session,
disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
) -> list[DocumentSource] | None:
# If specified in the question, don't update
if question.filters.source_type:
return question.filters.source_type
if not question.enable_auto_detect_filters or disable_llm_extraction:
return None
return extract_source_filter(question.query, db_session)
if __name__ == "__main__":
# Just for testing purposes
with Session(get_sqlalchemy_engine()) as db_session:

View File

@ -5,12 +5,10 @@ from datetime import timezone
from dateutil.parser import parse
from danswer.configs.app_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.prompt_utils import get_current_llm_day_time
from danswer.prompts.secondary_llm_flows import TIME_FILTER_PROMPT
from danswer.server.models import QuestionRequest
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@ -157,32 +155,6 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
return _extract_time_filter_from_llm_out(model_output)
def extract_question_time_filters(
question: QuestionRequest,
disable_llm_extraction: bool = DISABLE_LLM_FILTER_EXTRACTION,
) -> tuple[datetime | None, bool]:
time_cutoff = question.filters.time_cutoff
favor_recent = question.favor_recent
# Frontend needs to be able to set this flag so that if user deletes the time filter,
# we don't automatically reapply it. The env variable is a global disable of this feature
# for the sake of latency
if not question.enable_auto_detect_filters or disable_llm_extraction:
if favor_recent is None:
favor_recent = False
return time_cutoff, favor_recent
llm_cutoff, llm_favor_recent = extract_time_filter(question.query)
# For all extractable filters, don't overwrite the provided values if any is provided
if time_cutoff is None:
time_cutoff = llm_cutoff
if favor_recent is None:
favor_recent = llm_favor_recent
return time_cutoff, favor_recent
if __name__ == "__main__":
# Just for testing purposes, too tedious to unit test as it relies on an LLM
while True:

View File

@ -26,6 +26,7 @@ from danswer.db.models import User
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.llm.utils import get_default_llm_token_encode
from danswer.secondary_llm_flows.chat_helpers import get_new_chat_name
from danswer.server.chat.models import ChatSessionCreationRequest
from danswer.server.models import ChatFeedbackRequest
from danswer.server.models import ChatMessageDetail
from danswer.server.models import ChatMessageIdentifier
@ -124,15 +125,17 @@ def get_chat_session_messages(
@router.post("/create-chat-session")
def create_new_chat_session(
chat_session_creation_request: ChatSessionCreationRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> CreateChatSessionID:
user_id = user.id if user is not None else None
new_chat_session = create_chat_session(
"",
user_id,
db_session, # Leave the naming till later to prevent delay
db_session=db_session,
description="", # Leave the naming till later to prevent delay
user_id=user_id,
persona_id=chat_session_creation_request.persona_id,
)
return CreateChatSessionID(chat_session_id=new_chat_session.id)

View File

@ -0,0 +1,5 @@
from pydantic import BaseModel
class ChatSessionCreationRequest(BaseModel):
persona_id: int | None = None

View File

@ -174,7 +174,9 @@ class SearchDoc(BaseModel):
return initial_dict
class QuestionRequest(BaseModel):
# TODO: rename/consolidate once the chat / QA flows are merged
class NewMessageRequest(BaseModel):
chat_session_id: int
query: str
filters: BaseFilters
collection: str = DOCUMENT_INDEX_NAME

View File

@ -0,0 +1,136 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.chat import fetch_persona_by_id
from danswer.db.chat import fetch_personas
from danswer.db.chat import mark_persona_as_deleted
from danswer.db.chat import upsert_persona
from danswer.db.document_set import get_document_sets_by_ids
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.direct_qa.qa_block import PersonaBasedQAHandler
from danswer.server.persona.models import CreatePersonaRequest
from danswer.server.persona.models import PersonaSnapshot
from danswer.server.persona.models import PromptTemplateResponse
from danswer.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter()
@router.post("/admin/persona")
def create_persona(
create_persona_request: CreatePersonaRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
document_sets = list(
get_document_sets_by_ids(
db_session=db_session,
document_set_ids=create_persona_request.document_set_ids,
)
if create_persona_request.document_set_ids
else []
)
try:
persona = upsert_persona(
db_session=db_session,
name=create_persona_request.name,
description=create_persona_request.description,
retrieval_enabled=True,
datetime_aware=True,
system_text=create_persona_request.system_prompt,
hint_text=create_persona_request.task_prompt,
num_chunks=create_persona_request.num_chunks,
apply_llm_relevance_filter=create_persona_request.apply_llm_relevance_filter,
document_sets=document_sets,
)
except ValueError as e:
logger.exception("Failed to update persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
@router.patch("/admin/persona/{persona_id}")
def update_persona(
persona_id: int,
update_persona_request: CreatePersonaRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
document_sets = list(
get_document_sets_by_ids(
db_session=db_session,
document_set_ids=update_persona_request.document_set_ids,
)
if update_persona_request.document_set_ids
else []
)
try:
persona = upsert_persona(
db_session=db_session,
name=update_persona_request.name,
description=update_persona_request.description,
retrieval_enabled=True,
datetime_aware=True,
system_text=update_persona_request.system_prompt,
hint_text=update_persona_request.task_prompt,
num_chunks=update_persona_request.num_chunks,
apply_llm_relevance_filter=update_persona_request.apply_llm_relevance_filter,
document_sets=document_sets,
persona_id=persona_id,
)
except ValueError as e:
logger.exception("Failed to update persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
@router.delete("/admin/persona/{persona_id}")
def delete_persona(
persona_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
mark_persona_as_deleted(db_session=db_session, persona_id=persona_id)
@router.get("/persona")
def list_personas(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[PersonaSnapshot]:
return [
PersonaSnapshot.from_model(persona)
for persona in fetch_personas(db_session=db_session)
]
@router.get("/persona/{persona_id}")
def get_persona(
persona_id: int,
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return PersonaSnapshot.from_model(
fetch_persona_by_id(db_session=db_session, persona_id=persona_id)
)
@router.get("/persona-utils/prompt-explorer")
def build_final_template_prompt(
system_prompt: str,
task_prompt: str,
_: User | None = Depends(current_user),
) -> PromptTemplateResponse:
return PromptTemplateResponse(
final_prompt_template=PersonaBasedQAHandler(
system_prompt=system_prompt, task_prompt=task_prompt
).build_dummy_prompt()
)

View File

@ -0,0 +1,41 @@
from pydantic import BaseModel
from danswer.db.models import Persona
from danswer.server.models import DocumentSet
class CreatePersonaRequest(BaseModel):
name: str
description: str
document_set_ids: list[int]
system_prompt: str
task_prompt: str
num_chunks: int | None = None
apply_llm_relevance_filter: bool | None = None
class PersonaSnapshot(BaseModel):
id: int
name: str
description: str
system_prompt: str
task_prompt: str
document_sets: list[DocumentSet]
@classmethod
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
return PersonaSnapshot(
id=persona.id,
name=persona.name,
description=persona.description or "",
system_prompt=persona.system_text or "",
task_prompt=persona.hint_text or "",
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
)
class PromptTemplateResponse(BaseModel):
final_prompt_template: str

View File

@ -8,6 +8,7 @@ from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.db.engine import get_session
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_feedback
from danswer.db.models import User
from danswer.direct_qa.answer_question import answer_qa_query
@ -17,25 +18,22 @@ from danswer.document_index.vespa.index import VespaIndex
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.models import IndexFilters
from danswer.search.request_preprocessing import retrieval_preprocessing
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import danswer_search
from danswer.search.search_runner import full_chunk_search
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
from danswer.server.models import AdminSearchRequest
from danswer.server.models import AdminSearchResponse
from danswer.server.models import HelperResponse
from danswer.server.models import NewMessageRequest
from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
logger = setup_logger()
@ -87,26 +85,26 @@ def admin_search(
@router.post("/search-intent")
def get_search_type(
question: QuestionRequest, _: User = Depends(current_user)
new_message_request: NewMessageRequest, _: User = Depends(current_user)
) -> HelperResponse:
query = question.query
query = new_message_request.query
return recommend_search_flow(query)
@router.post("/query-validation")
def query_validation(
question: QuestionRequest, _: User = Depends(current_user)
new_message_request: NewMessageRequest, _: User = Depends(current_user)
) -> QueryValidationResponse:
query = question.query
query = new_message_request.query
reasoning, answerable = get_query_answerability(query)
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
@router.post("/stream-query-validation")
def stream_query_validation(
question: QuestionRequest, _: User = Depends(current_user)
new_message_request: NewMessageRequest, _: User = Depends(current_user)
) -> StreamingResponse:
query = question.query
query = new_message_request.query
return StreamingResponse(
stream_query_answerability(query), media_type="application/json"
)
@ -114,65 +112,68 @@ def stream_query_validation(
@router.post("/document-search")
def handle_search_request(
question: QuestionRequest,
new_message_request: NewMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchResponse:
query = question.query
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
functions_to_run = [
FunctionCall(extract_question_time_filters, (question,), {}),
FunctionCall(extract_question_source_filters, (question, db_session), {}),
]
parallel_results = run_functions_in_parallel(functions_to_run)
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
source_filters = parallel_results["extract_question_source_filters"]
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
question.filters.source_type = source_filters
top_chunks, _, query_event_id = danswer_search(
question=question,
user=user,
db_session=db_session,
document_index=get_default_document_index(),
skip_llm_chunk_filter=True,
query = new_message_request.query
logger.info(
f"Received {new_message_request.search_type.value} " f"search query: {query}"
)
# create record for this query in Postgres
query_event_id = create_query_event(
query=new_message_request.query,
chat_session_id=new_message_request.chat_session_id,
search_type=new_message_request.search_type,
llm_answer=None,
user_id=user.id if user is not None else None,
db_session=db_session,
)
retrieval_request, _, _ = retrieval_preprocessing(
new_message_request=new_message_request,
user=user,
db_session=db_session,
include_query_intent=False,
)
top_chunks, _ = full_chunk_search(
query=retrieval_request,
document_index=get_default_document_index(),
)
top_docs = chunks_to_search_docs(top_chunks)
return SearchResponse(
top_documents=top_docs,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
source_type=retrieval_request.filters.source_type,
time_cutoff=retrieval_request.filters.time_cutoff,
favor_recent=retrieval_request.favor_recent,
)
@router.post("/direct-qa")
def direct_qa(
question: QuestionRequest,
new_message_request: NewMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> QAResponse:
# Everything handled via answer_qa_query which is also used by default
# for the DanswerBot flow
return answer_qa_query(question=question, user=user, db_session=db_session)
return answer_qa_query(
new_message_request=new_message_request, user=user, db_session=db_session
)
@router.post("/stream-direct-qa")
def stream_direct_qa(
question: QuestionRequest,
new_message_request: NewMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
packets = answer_qa_query_stream(
question=question, user=user, db_session=db_session
new_message_request=new_message_request, user=user, db_session=db_session
)
return StreamingResponse(packets, media_type="application/json")

View File

@ -3,11 +3,15 @@ from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Generic
from typing import TypeVar
from danswer.utils.logger import setup_logger
logger = setup_logger()
R = TypeVar("R")
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
@ -45,19 +49,21 @@ def run_functions_tuples_in_parallel(
return [result for index, result in results]
class FunctionCall:
class FunctionCall(Generic[R]):
"""
Container for run_functions_in_parallel, fetch the results from the output of
run_functions_in_parallel via the FunctionCall.result_id.
"""
def __init__(self, func: Callable, args: tuple = (), kwargs: dict | None = None):
def __init__(
self, func: Callable[..., R], args: tuple = (), kwargs: dict | None = None
):
self.func = func
self.args = args
self.kwargs = kwargs if kwargs is not None else {}
self.result_id = str(uuid.uuid4())
def execute(self) -> Any:
def execute(self) -> R:
return self.func(*self.args, **self.kwargs)

View File

@ -8,13 +8,14 @@ from typing import TextIO
import yaml
from sqlalchemy.orm import Session
from danswer.db.chat import create_chat_session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.search.models import IndexFilters
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.server.models import QuestionRequest
from danswer.server.models import NewMessageRequest
from danswer.utils.callbacks import MetricsHander
@ -81,7 +82,13 @@ def get_answer_for_question(
time_cutoff=None,
access_control_list=None,
)
question = QuestionRequest(
chat_session = create_chat_session(
db_session=db_session,
description="Regression Test Session",
user_id=None,
)
new_message_request = NewMessageRequest(
chat_session_id=chat_session.id,
query=query,
filters=filters,
real_time=False,
@ -93,7 +100,7 @@ def get_answer_for_question(
llm_metrics = MetricsHander[LLMMetricsContainer]()
answer = answer_qa_query(
question=question,
new_message_request=new_message_request,
user=None,
db_session=db_session,
answer_generation_timeout=100,

View File

@ -5,8 +5,6 @@ from contextlib import contextmanager
from typing import Any
from typing import TextIO
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.qa_utils import get_chunks_for_qa
from danswer.document_index.factory import get_default_document_index
@ -14,8 +12,9 @@ from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.search_runner import danswer_search
from danswer.server.models import QuestionRequest
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_runner import full_chunk_search
from danswer.utils.callbacks import MetricsHander
@ -74,7 +73,7 @@ def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str
def get_search_results(
query: str, enable_llm: bool, db_session: Session
query: str,
) -> tuple[
list[InferenceChunk],
RetrievalMetricsContainer | None,
@ -86,22 +85,19 @@ def get_search_results(
time_cutoff=None,
access_control_list=None,
)
question = QuestionRequest(
search_query = SearchQuery(
query=query,
search_type=SearchType.HYBRID,
filters=filters,
enable_auto_detect_filters=False,
favor_recent=False,
)
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
rerank_metrics = MetricsHander[RerankMetricsContainer]()
top_chunks, llm_chunk_selection, query_id = danswer_search(
question=question,
user=None,
db_session=db_session,
top_chunks, llm_chunk_selection = full_chunk_search(
query=search_query,
document_index=get_default_document_index(),
bypass_acl=True,
skip_llm_chunk_filter=not enable_llm,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
)
@ -177,58 +173,49 @@ def main(
with open(output_file, "w") as outfile:
with redirect_print_to_file(outfile):
print("Running Document Retrieval Test\n")
for ind, (question, targets) in enumerate(questions_info.items()):
if ind >= stop_after:
break
with Session(engine, expire_on_commit=False) as db_session:
for ind, (question, targets) in enumerate(questions_info.items()):
if ind >= stop_after:
break
print(f"\n\nQuestion: {question}")
print(f"\n\nQuestion: {question}")
(
top_chunks,
retrieval_metrics,
rerank_metrics,
) = get_search_results(query=question)
(
top_chunks,
retrieval_metrics,
rerank_metrics,
) = get_search_results(
query=question, enable_llm=enable_llm, db_session=db_session
)
assert retrieval_metrics is not None and rerank_metrics is not None
assert retrieval_metrics is not None and rerank_metrics is not None
retrieval_ids = [
metric.document_id for metric in retrieval_metrics.metrics
]
retrieval_score = calculate_score("Retrieval", retrieval_ids, targets)
running_retrieval_score += retrieval_score
print(f"Average: {running_retrieval_score / (ind + 1)}")
retrieval_ids = [
metric.document_id for metric in retrieval_metrics.metrics
]
retrieval_score = calculate_score(
"Retrieval", retrieval_ids, targets
)
running_retrieval_score += retrieval_score
print(f"Average: {running_retrieval_score / (ind + 1)}")
rerank_ids = [metric.document_id for metric in rerank_metrics.metrics]
rerank_score = calculate_score("Rerank", rerank_ids, targets)
running_rerank_score += rerank_score
print(f"Average: {running_rerank_score / (ind + 1)}")
rerank_ids = [
metric.document_id for metric in rerank_metrics.metrics
]
rerank_score = calculate_score("Rerank", rerank_ids, targets)
running_rerank_score += rerank_score
print(f"Average: {running_rerank_score / (ind + 1)}")
llm_ids = [chunk.document_id for chunk in top_chunks]
llm_score = calculate_score("LLM Filter", llm_ids, targets)
running_llm_filter_score += llm_score
print(f"Average: {running_llm_filter_score / (ind + 1)}")
if enable_llm:
llm_ids = [chunk.document_id for chunk in top_chunks]
llm_score = calculate_score("LLM Filter", llm_ids, targets)
running_llm_filter_score += llm_score
print(f"Average: {running_llm_filter_score / (ind + 1)}")
if show_details:
print("\nRetrieval Metrics:")
if retrieval_metrics is None:
print("No Retrieval Metrics Available")
else:
_print_retrieval_metrics(retrieval_metrics)
if show_details:
print("\nRetrieval Metrics:")
if retrieval_metrics is None:
print("No Retrieval Metrics Available")
else:
_print_retrieval_metrics(retrieval_metrics)
print("\nReranking Metrics:")
if rerank_metrics is None:
print("No Reranking Metrics Available")
else:
_print_reranking_metrics(rerank_metrics)
print("\nReranking Metrics:")
if rerank_metrics is None:
print("No Reranking Metrics Available")
else:
_print_reranking_metrics(rerank_metrics)
if __name__ == "__main__":

View File

@ -0,0 +1,411 @@
"use client";
import {
BooleanFormField,
TextArrayField,
TextFormField,
} from "@/components/admin/connectors/Field";
import { DocumentSet } from "@/lib/types";
import { Button, Divider, Text, Title } from "@tremor/react";
import {
ArrayHelpers,
ErrorMessage,
Field,
FieldArray,
Form,
Formik,
} from "formik";
import * as Yup from "yup";
import { buildFinalPrompt, createPersona, updatePersona } from "./lib";
import { useRouter } from "next/navigation";
import { usePopup } from "@/components/admin/connectors/Popup";
import { Persona } from "./interfaces";
import Link from "next/link";
import { useEffect, useState } from "react";
function SectionHeader({ children }: { children: string | JSX.Element }) {
return <div className="mb-4 font-bold text-lg">{children}</div>;
}
function Label({ children }: { children: string | JSX.Element }) {
return (
<div className="block font-medium text-base text-gray-200">{children}</div>
);
}
function SubLabel({ children }: { children: string | JSX.Element }) {
return <div className="text-sm text-gray-300 mb-2">{children}</div>;
}
// TODO: make this the default text input across all forms
function PersonaTextInput({
name,
label,
subtext,
placeholder,
onChange,
type = "text",
isTextArea = false,
disabled = false,
autoCompleteDisabled = true,
}: {
name: string;
label: string;
subtext?: string | JSX.Element;
placeholder?: string;
onChange?: (e: React.ChangeEvent<HTMLInputElement>) => void;
type?: string;
isTextArea?: boolean;
disabled?: boolean;
autoCompleteDisabled?: boolean;
}) {
return (
<div className="mb-4">
<Label>{label}</Label>
{subtext && <SubLabel>{subtext}</SubLabel>}
<Field
as={isTextArea ? "textarea" : "input"}
type={type}
name={name}
id={name}
className={
`
border
text-gray-200
border-gray-600
rounded
w-full
py-2
px-3
mt-1
${isTextArea ? " h-28" : ""}
` + (disabled ? " bg-gray-900" : " bg-gray-800")
}
disabled={disabled}
placeholder={placeholder}
autoComplete={autoCompleteDisabled ? "off" : undefined}
{...(onChange ? { onChange } : {})}
/>
<ErrorMessage
name={name}
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
);
}
function PersonaBooleanInput({
name,
label,
subtext,
}: {
name: string;
label: string;
subtext?: string | JSX.Element;
}) {
return (
<div className="mb-4">
<Label>{label}</Label>
{subtext && <SubLabel>{subtext}</SubLabel>}
<Field
type="checkbox"
name={name}
id={name}
className={`
ml-2
border
text-gray-200
border-gray-600
rounded
py-2
px-3
mt-1
`}
/>
<ErrorMessage
name={name}
component="div"
className="text-red-500 text-sm mt-1"
/>
</div>
);
}
export function PersonaEditor({
existingPersona,
documentSets,
}: {
existingPersona?: Persona | null;
documentSets: DocumentSet[];
}) {
const router = useRouter();
const { popup, setPopup } = usePopup();
const [finalPrompt, setFinalPrompt] = useState<string | null>("");
const triggerFinalPromptUpdate = async (
systemPrompt: string,
taskPrompt: string
) => {
const response = await buildFinalPrompt(systemPrompt, taskPrompt);
if (response.ok) {
setFinalPrompt((await response.json()).final_prompt_template);
}
};
const isUpdate = existingPersona !== undefined && existingPersona !== null;
useEffect(() => {
if (isUpdate) {
triggerFinalPromptUpdate(
existingPersona.system_prompt,
existingPersona.task_prompt
);
}
}, []);
return (
<div className="dark">
{popup}
<Formik
initialValues={{
name: existingPersona?.name ?? "",
description: existingPersona?.description ?? "",
system_prompt: existingPersona?.system_prompt ?? "",
task_prompt: existingPersona?.task_prompt ?? "",
document_set_ids:
existingPersona?.document_sets?.map(
(documentSet) => documentSet.id
) ?? ([] as number[]),
num_chunks: existingPersona?.num_chunks ?? null,
apply_llm_relevance_filter:
existingPersona?.apply_llm_relevance_filter ?? false,
}}
validationSchema={Yup.object().shape({
name: Yup.string().required("Must give the Persona a name!"),
description: Yup.string().required(
"Must give the Persona a description!"
),
system_prompt: Yup.string().required(
"Must give the Persona a system prompt!"
),
task_prompt: Yup.string().required(
"Must give the Persona a task prompt!"
),
document_set_ids: Yup.array().of(Yup.number()),
num_chunks: Yup.number().max(20).nullable(),
apply_llm_relevance_filter: Yup.boolean().required(),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
let response;
if (isUpdate) {
response = await updatePersona({
id: existingPersona.id,
...values,
num_chunks: values.num_chunks || null,
});
} else {
response = await createPersona({
...values,
num_chunks: values.num_chunks || null,
});
}
if (response.ok) {
router.push("/admin/personas");
return;
}
setPopup({
type: "error",
message: `Failed to create Persona - ${await response.text()}`,
});
formikHelpers.setSubmitting(false);
}}
>
{({ isSubmitting, values, setFieldValue }) => (
<Form>
<div className="pb-6">
<SectionHeader>Who am I?</SectionHeader>
<PersonaTextInput
name="name"
label="Name"
disabled={isUpdate}
subtext="Users will be able to select this Persona based on this name."
/>
<PersonaTextInput
name="description"
label="Description"
subtext="Provide a short descriptions which gives users a hint as to what they should use this Persona for."
/>
<Divider />
<SectionHeader>Customize my response style</SectionHeader>
<PersonaTextInput
name="system_prompt"
label="System Prompt"
isTextArea={true}
subtext={
'Give general info about what the Persona is about. For example, "You are an assistant for On-Call engineers. Your goal is to read the provided context documents and give recommendations as to how to resolve the issue."'
}
onChange={(e) => {
setFieldValue("system_prompt", e.target.value);
triggerFinalPromptUpdate(e.target.value, values.task_prompt);
}}
/>
<PersonaTextInput
name="task_prompt"
label="Task Prompt"
isTextArea={true}
subtext={
'Give specific instructions as to what to do with the user query. For example, "Find any relevant sections from the provided documents that can help the user resolve their issue and explain how they are relevant."'
}
onChange={(e) => {
setFieldValue("task_prompt", e.target.value);
triggerFinalPromptUpdate(
values.system_prompt,
e.target.value
);
}}
/>
<Label>Final Prompt</Label>
{finalPrompt ? (
<pre className="text-sm mt-2 whitespace-pre-wrap">
{finalPrompt.replaceAll("\\n", "\n")}
</pre>
) : (
"-"
)}
<Divider />
<SectionHeader>What data should I have access to?</SectionHeader>
<FieldArray
name="document_set_ids"
render={(arrayHelpers: ArrayHelpers) => (
<div>
<div>
<SubLabel>
<>
Select which{" "}
<Link
href="/admin/documents/sets"
className="text-blue-500"
target="_blank"
>
Document Sets
</Link>{" "}
that this Persona should search through. If none are
specified, the Persona will search through all
available documents in order to try and response to
queries.
</>
</SubLabel>
</div>
<div className="mb-3 mt-2 flex gap-2 flex-wrap text-sm">
{documentSets.map((documentSet) => {
const ind = values.document_set_ids.indexOf(
documentSet.id
);
let isSelected = ind !== -1;
return (
<div
key={documentSet.id}
className={
`
px-3
py-1
rounded-lg
border
border-gray-700
w-fit
flex
cursor-pointer ` +
(isSelected
? " bg-gray-600"
: " bg-gray-900 hover:bg-gray-700")
}
onClick={() => {
if (isSelected) {
arrayHelpers.remove(ind);
} else {
arrayHelpers.push(documentSet.id);
}
}}
>
<div className="my-auto">{documentSet.name}</div>
</div>
);
})}
</div>
</div>
)}
/>
<Divider />
<SectionHeader>[Advanced] Retrieval Customization</SectionHeader>
<PersonaTextInput
name="num_chunks"
label="Number of Chunks"
subtext={
<div>
How many chunks should we feed into the LLM when generating
the final response? Each chunk is ~400 words long. If you
are using gpt-3.5-turbo or other similar models, setting
this to a value greater than 5 will result in errors at
query time due to the model&apos;s input length limit.
<br />
<br />
If unspecified, will use 5 chunks.
</div>
}
onChange={(e) => {
const value = e.target.value;
// Allow only integer values
if (value === "" || /^[0-9]+$/.test(value)) {
setFieldValue("num_chunks", value);
}
}}
/>
<PersonaBooleanInput
name="apply_llm_relevance_filter"
label="Apply LLM Relevance Filter"
subtext={
"If enabled, the LLM will filter out chunks that are not relevant to the user query."
}
/>
<Divider />
<div className="flex">
<Button
className="mx-auto"
variant="secondary"
size="md"
type="submit"
disabled={isSubmitting}
>
{isUpdate ? "Update!" : "Create!"}
</Button>
</div>
</div>
</Form>
)}
</Formik>
</div>
);
}

View File

@ -0,0 +1,52 @@
"use client";
import {
Table,
TableHead,
TableRow,
TableHeaderCell,
TableBody,
TableCell,
} from "@tremor/react";
import { Persona } from "./interfaces";
import Link from "next/link";
import { EditButton } from "@/components/EditButton";
import { useRouter } from "next/navigation";
export function PersonasTable({ personas }: { personas: Persona[] }) {
const router = useRouter();
const sortedPersonas = [...personas];
sortedPersonas.sort((a, b) => a.name.localeCompare(b.name));
return (
<div className="dark">
<Table className="overflow-visible">
<TableHead>
<TableRow>
<TableHeaderCell>Name</TableHeaderCell>
<TableHeaderCell>Description</TableHeaderCell>
<TableHeaderCell></TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{sortedPersonas.map((persona) => {
return (
<TableRow key={persona.id}>
<TableCell className="whitespace-normal break-all">
<p className="text font-medium">{persona.name}</p>
</TableCell>
<TableCell>{persona.description}</TableCell>
<TableCell>
<EditButton
onClick={() => router.push(`/admin/personas/${persona.id}`)}
/>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
);
}

View File

@ -0,0 +1,29 @@
"use client";
import { Button } from "@tremor/react";
import { FiTrash } from "react-icons/fi";
import { deletePersona } from "../lib";
import { useRouter } from "next/navigation";
export function DeletePersonaButton({ personaId }: { personaId: number }) {
const router = useRouter();
return (
<Button
variant="secondary"
size="xs"
color="red"
onClick={async () => {
const response = await deletePersona(personaId);
if (response.ok) {
router.push("/admin/personas");
} else {
alert(`Failed to delete persona - ${await response.text()}`);
}
}}
icon={FiTrash}
>
Delete
</Button>
);
}

View File

@ -0,0 +1,62 @@
import { ErrorCallout } from "@/components/ErrorCallout";
import { fetchSS } from "@/lib/utilsSS";
import { FaRobot } from "react-icons/fa";
import { Persona } from "../interfaces";
import { PersonaEditor } from "../PersonaEditor";
import { DocumentSet } from "@/lib/types";
import { RobotIcon } from "@/components/icons/icons";
import { BackButton } from "@/components/BackButton";
import { Card, Title, Text, Divider, Button } from "@tremor/react";
import { FiTrash } from "react-icons/fi";
import { DeletePersonaButton } from "./DeletePersonaButton";
export default async function Page({
params,
}: {
params: { personaId: string };
}) {
const personaResponse = await fetchSS(`/persona/${params.personaId}`);
if (!personaResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch Persona - ${await personaResponse.text()}`}
/>
);
}
const documentSetsResponse = await fetchSS("/manage/document-set");
if (!documentSetsResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
/>
);
}
const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
const persona = (await personaResponse.json()) as Persona;
return (
<div className="dark">
<BackButton />
<div className="pb-2 mb-4 flex">
<h1 className="text-3xl font-bold pl-2">Edit Persona</h1>
</div>
<Card>
<PersonaEditor existingPersona={persona} documentSets={documentSets} />
</Card>
<div className="mt-12">
<Title>Delete Persona</Title>
<div className="flex mt-6">
<DeletePersonaButton personaId={persona.id} />
</div>
</div>
</div>
);
}

View File

@ -0,0 +1,12 @@
import { DocumentSet } from "@/lib/types";
export interface Persona {
id: number;
name: string;
description: string;
system_prompt: string;
task_prompt: string;
document_sets: DocumentSet[];
num_chunks?: number;
apply_llm_relevance_filter?: boolean;
}

View File

@ -0,0 +1,61 @@
interface PersonaCreationRequest {
name: string;
description: string;
system_prompt: string;
task_prompt: string;
document_set_ids: number[];
num_chunks: number | null;
apply_llm_relevance_filter: boolean | null;
}
export function createPersona(personaCreationRequest: PersonaCreationRequest) {
return fetch("/api/admin/persona", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(personaCreationRequest),
});
}
interface PersonaUpdateRequest {
id: number;
description: string;
system_prompt: string;
task_prompt: string;
document_set_ids: number[];
num_chunks: number | null;
apply_llm_relevance_filter: boolean | null;
}
export function updatePersona(personaUpdateRequest: PersonaUpdateRequest) {
const { id, ...requestBody } = personaUpdateRequest;
return fetch(`/api/admin/persona/${id}`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(requestBody),
});
}
export function deletePersona(personaId: number) {
return fetch(`/api/admin/persona/${personaId}`, {
method: "DELETE",
});
}
export function buildFinalPrompt(systemPrompt: string, taskPrompt: string) {
let queryString = Object.entries({
system_prompt: systemPrompt,
task_prompt: taskPrompt,
})
.map(
([key, value]) =>
`${encodeURIComponent(key)}=${encodeURIComponent(value)}`
)
.join("&");
return fetch(`/api/persona-utils/prompt-explorer?${queryString}`);
}

View File

@ -0,0 +1,37 @@
import { FaRobot } from "react-icons/fa";
import { PersonaEditor } from "../PersonaEditor";
import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import { DocumentSet } from "@/lib/types";
import { RobotIcon } from "@/components/icons/icons";
import { BackButton } from "@/components/BackButton";
import { Card } from "@tremor/react";
export default async function Page() {
const documentSetsResponse = await fetchSS("/manage/document-set");
if (!documentSetsResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
/>
);
}
const documentSets = (await documentSetsResponse.json()) as DocumentSet[];
return (
<div className="dark">
<BackButton />
<div className="border-solid border-gray-600 border-b pb-2 mb-4 flex">
<RobotIcon size={32} />
<h1 className="text-3xl font-bold pl-2">Create a New Persona</h1>
</div>
<Card>
<PersonaEditor documentSets={documentSets} />
</Card>
</div>
);
}

View File

@ -0,0 +1,64 @@
import { PersonasTable } from "./PersonaTable";
import { FiPlusSquare } from "react-icons/fi";
import Link from "next/link";
import { Divider, Text, Title } from "@tremor/react";
import { fetchSS } from "@/lib/utilsSS";
import { ErrorCallout } from "@/components/ErrorCallout";
import { Persona } from "./interfaces";
import { RobotIcon } from "@/components/icons/icons";
export default async function Page() {
const personaResponse = await fetchSS("/persona");
if (!personaResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch personas - ${await personaResponse.text()}`}
/>
);
}
const personas = (await personaResponse.json()) as Persona[];
return (
<div>
<div className="border-solid border-gray-600 border-b pb-2 mb-4 flex">
<RobotIcon size={32} />
<h1 className="text-3xl font-bold pl-2">Personas</h1>
</div>
<div className="text-gray-300 text-sm mb-2">
Personas are a way to build custom search/question-answering experiences
for different use cases.
<p className="mt-2">They allow you to customize:</p>
<ul className="list-disc mt-2 ml-4">
<li>
The prompt used by your LLM of choice to respond to the user query
</li>
<li>The documents that are used as context</li>
</ul>
</div>
<div className="dark">
<Divider />
<Title>Create a Persona</Title>
<Link
href="/admin/personas/new"
className="text-gray-100 flex py-2 px-4 mt-2 border border-gray-800 h-fit cursor-pointer hover:bg-gray-800 text-sm w-36"
>
<div className="mx-auto flex">
<FiPlusSquare className="my-auto mr-2" />
New Persona
</div>
</Link>
<Divider />
<Title>Existing Personas</Title>
<PersonasTable personas={personas} />
</div>
</div>
);
}

View File

@ -8,10 +8,11 @@ import {
import { redirect } from "next/navigation";
import { HealthCheckBanner } from "@/components/health/healthcheck";
import { ApiKeyModal } from "@/components/openai/ApiKeyModal";
import { buildUrl } from "@/lib/utilsSS";
import { buildUrl, fetchSS } from "@/lib/utilsSS";
import { Connector, DocumentSet, User } from "@/lib/types";
import { cookies } from "next/headers";
import { SearchType } from "@/lib/search/interfaces";
import { Persona } from "./admin/personas/interfaces";
export default async function Home() {
const tasks = [
@ -29,6 +30,7 @@ export default async function Home() {
cookie: processCookies(cookies()),
},
}),
fetchSS("/persona"),
];
// catch cases where the backend is completely unreachable here
@ -44,6 +46,7 @@ export default async function Home() {
const user = results[1] as User | null;
const connectorsResponse = results[2] as Response | null;
const documentSetsResponse = results[3] as Response | null;
const personaResponse = results[4] as Response | null;
if (!authDisabled && !user) {
return redirect("/auth/login");
@ -65,6 +68,13 @@ export default async function Home() {
);
}
let personas: Persona[] = [];
if (personaResponse?.ok) {
personas = await personaResponse.json();
} else {
console.log(`Failed to fetch personas - ${personaResponse?.status}`);
}
// needs to be done in a non-client side component due to nextjs
const storedSearchType = cookies().get("searchType")?.value as
| string
@ -87,6 +97,7 @@ export default async function Home() {
<SearchSection
connectors={connectors}
documentSets={documentSets}
personas={personas}
defaultSearchType={searchTypeDefault}
/>
</div>

View File

@ -14,11 +14,7 @@ interface DropdownProps {
onSelect: (selected: Option) => void;
}
export const Dropdown: FC<DropdownProps> = ({
options,
selected,
onSelect,
}) => {
export const Dropdown = ({ options, selected, onSelect }: DropdownProps) => {
const [isOpen, setIsOpen] = useState(false);
const dropdownRef = useRef<HTMLDivElement>(null);

View File

@ -0,0 +1,27 @@
"use client";
import { useRouter } from "next/navigation";
import { FiChevronLeft, FiEdit } from "react-icons/fi";
export function EditButton({ onClick }: { onClick: () => void }) {
return (
<div
className={`
my-auto
flex
mb-1
hover:bg-gray-800
w-fit
p-2
cursor-pointer
rounded-lg
border-gray-800
text-sm`}
onClick={onClick}
>
<FiEdit className="mr-1 my-auto" />
Edit
</div>
);
}

View File

@ -28,6 +28,7 @@ import {
GongIcon,
ZoomInIcon,
ZendeskIcon,
RobotIcon,
} from "@/components/icons/icons";
import { getAuthDisabledSS, getCurrentUserSS } from "@/lib/userSS";
import { redirect } from "next/navigation";
@ -314,13 +315,22 @@ export async function Layout({ children }: { children: React.ReactNode }) {
],
},
{
name: "Bots",
name: "Custom Assistants",
items: [
{
name: (
<div className="flex">
<RobotIcon size={18} />
<div className="ml-1">Personas</div>
</div>
),
link: "/admin/personas",
},
{
name: (
<div className="flex">
<CPUIcon size={18} />
<div className="ml-1">Slack Bot</div>
<div className="ml-1">Slack Bots</div>
</div>
),
link: "/admin/bot",

View File

@ -49,6 +49,7 @@ import hubSpotIcon from "../../../public/HubSpot.png";
import document360Icon from "../../../public/Document360.png";
import googleSitesIcon from "../../../public/GoogleSites.png";
import zendeskIcon from "../../../public/Zendesk.svg";
import { FaRobot } from "react-icons/fa";
interface IconProps {
size?: number;
@ -281,6 +282,13 @@ export const CPUIcon = ({
return <FiCpu size={size} className={className} />;
};
export const RobotIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return <FaRobot size={size} className={className} />;
};
//
// COMPANY LOGOS
//

View File

@ -173,7 +173,7 @@ export const DocumentDisplay = ({
ml-auto
mr-2`}
>
{document.score.toFixed(2)}
{Math.abs(document.score).toFixed(2)}
</div>
</div>
)}

View File

@ -0,0 +1,120 @@
import { Persona } from "@/app/admin/personas/interfaces";
import { CustomDropdown } from "../Dropdown";
import { FiCheck, FiChevronDown } from "react-icons/fi";
import { FaRobot } from "react-icons/fa";
function PersonaItem({
id,
name,
onSelect,
isSelected,
isFinal,
}: {
id: number;
name: string;
onSelect: (personaId: number) => void;
isSelected: boolean;
isFinal: boolean;
}) {
return (
<div
key={id}
className={`
flex
px-3
text-sm
text-gray-200
py-2.5
select-none
cursor-pointer
${isFinal ? "" : "border-b border-gray-800"}
${
isSelected
? "bg-dark-tremor-background-muted"
: "hover:bg-dark-tremor-background-muted "
}
`}
onClick={() => {
onSelect(id);
}}
>
{name}
{isSelected && (
<div className="ml-auto mr-1">
<FiCheck />
</div>
)}
</div>
);
}
export function PersonaSelector({
personas,
selectedPersonaId,
onPersonaChange,
}: {
personas: Persona[];
selectedPersonaId: number | null;
onPersonaChange: (persona: Persona | null) => void;
}) {
const currentlySelectedPersona = personas.find(
(persona) => persona.id === selectedPersonaId
);
return (
<CustomDropdown
dropdown={
<div
className={`
border
border-gray-800
rounded-lg
flex
flex-col
w-64
max-h-96
overflow-y-auto
flex
overscroll-contain`}
>
<PersonaItem
key={-1}
id={-1}
name="Default"
onSelect={() => {
onPersonaChange(null);
}}
isSelected={selectedPersonaId === null}
isFinal={false}
/>
{personas.map((persona, ind) => {
const isSelected = persona.id === selectedPersonaId;
return (
<PersonaItem
key={persona.id}
id={persona.id}
name={persona.name}
onSelect={(clickedPersonaId) => {
const clickedPersona = personas.find(
(persona) => persona.id === clickedPersonaId
);
if (clickedPersona) {
onPersonaChange(clickedPersona);
}
}}
isSelected={isSelected}
isFinal={ind === personas.length - 1}
/>
);
})}
</div>
}
>
<div className="select-none text-sm flex text-gray-300 px-1 py-1.5 cursor-pointer w-64">
<FaRobot className="my-auto mr-2" />
{currentlySelectedPersona?.name || "Default"}{" "}
<FiChevronDown className="my-auto ml-2" />
</div>
</CustomDropdown>
);
}

View File

@ -7,11 +7,7 @@ interface SearchBarProps {
onSearch: () => void;
}
export const SearchBar: React.FC<SearchBarProps> = ({
query,
setQuery,
onSearch,
}) => {
export const SearchBar = ({ query, setQuery, onSearch }: SearchBarProps) => {
const handleChange = (event: ChangeEvent<HTMLTextAreaElement>) => {
const target = event.target;
setQuery(target.value);
@ -30,7 +26,7 @@ export const SearchBar: React.FC<SearchBarProps> = ({
};
return (
<div className="flex justify-center py-3">
<div className="flex justify-center">
<div className="flex items-center w-full border-2 border-gray-600 rounded px-4 py-2 focus-within:border-blue-500">
<MagnifyingGlass className="text-gray-400" />
<textarea

View File

@ -20,6 +20,7 @@ import {
} from "@/lib/search/aiThoughtUtils";
import { ThreeDots } from "react-loader-spinner";
import { usePopup } from "../admin/connectors/Popup";
import { AlertIcon } from "../icons/icons";
const removeDuplicateDocs = (documents: DanswerDocument[]) => {
const seen = new Set<string>();
@ -49,14 +50,16 @@ interface SearchResultsDisplayProps {
validQuestionResponse: ValidQuestionResponse;
isFetching: boolean;
defaultOverrides: SearchDefaultOverrides;
personaName?: string | null;
}
export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
export const SearchResultsDisplay = ({
searchResponse,
validQuestionResponse,
isFetching,
defaultOverrides,
}) => {
personaName = null,
}: SearchResultsDisplayProps) => {
const { popup, setPopup } = usePopup();
const [isAIThoughtsOpen, setIsAIThoughtsOpen] = React.useState<boolean>(
getAIThoughtsIsOpenSavedValue()
@ -70,6 +73,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
return null;
}
const isPersona = personaName !== null;
const { answer, quotes, documents, error, queryEventId } = searchResponse;
if (isFetching && !answer && !documents) {
@ -92,6 +96,17 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
}
if (answer === null && documents === null && quotes === null) {
if (error) {
return (
<div className="text-red-500 text-sm">
<div className="flex">
<AlertIcon size={16} className="text-red-500 my-auto mr-1" />
<p className="italic">{error}</p>
</div>
</div>
);
}
return <div className="text-gray-300">No matching documents found.</div>;
}
@ -132,34 +147,38 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
<h2 className="text font-bold my-auto mb-1 w-full">AI Answer</h2>
</div>
<div className="mb-2 w-full">
<ResponseSection
status={questionValidityCheckStatus}
header={
validQuestionResponse.answerable === null ? (
<div className="flex ml-2">Evaluating question...</div>
) : (
<div className="flex ml-2">AI thoughts</div>
)
}
body={<div>{validQuestionResponse.reasoning}</div>}
desiredOpenStatus={isAIThoughtsOpen}
setDesiredOpenStatus={handleAIThoughtToggle}
/>
</div>
{!isPersona && (
<div className="mb-2 w-full">
<ResponseSection
status={questionValidityCheckStatus}
header={
validQuestionResponse.answerable === null ? (
<div className="flex ml-2">Evaluating question...</div>
) : (
<div className="flex ml-2">AI thoughts</div>
)
}
body={<div>{validQuestionResponse.reasoning}</div>}
desiredOpenStatus={isAIThoughtsOpen}
setDesiredOpenStatus={handleAIThoughtToggle}
/>
</div>
)}
<div className="mb-2 pt-1 border-t border-gray-700 w-full">
<AnswerSection
answer={answer}
quotes={quotes}
error={error}
isAnswerable={validQuestionResponse.answerable}
isAnswerable={
validQuestionResponse.answerable || (isPersona ? true : null)
}
isFetching={isFetching}
aiThoughtsIsOpen={isAIThoughtsOpen}
/>
</div>
{quotes !== null && answer && (
{quotes !== null && answer && !isPersona && (
<div className="pt-1 border-t border-gray-700 w-full">
<QuotesSection
quotes={dedupedQuotes}

View File

@ -20,8 +20,11 @@ import { SearchHelper } from "./SearchHelper";
import { CancellationToken, cancellable } from "@/lib/search/cancellable";
import { NEXT_PUBLIC_DISABLE_STREAMING } from "@/lib/constants";
import { searchRequest } from "@/lib/search/qa";
import { useFilters, useObjectState, useTimeRange } from "@/lib/hooks";
import { useFilters, useObjectState } from "@/lib/hooks";
import { questionValidationStreamed } from "@/lib/search/streamingQuestionValidation";
import { createChatSession } from "@/lib/search/chatSessions";
import { Persona } from "@/app/admin/personas/interfaces";
import { PersonaSelector } from "./PersonaSelector";
const SEARCH_DEFAULT_OVERRIDES_START: SearchDefaultOverrides = {
forceDisplayQA: false,
@ -36,14 +39,16 @@ const VALID_QUESTION_RESPONSE_DEFAULT: ValidQuestionResponse = {
interface SearchSectionProps {
connectors: Connector<any>[];
documentSets: DocumentSet[];
personas: Persona[];
defaultSearchType: SearchType;
}
export const SearchSection: React.FC<SearchSectionProps> = ({
export const SearchSection = ({
connectors,
documentSets,
personas,
defaultSearchType,
}) => {
}: SearchSectionProps) => {
// Search Bar
const [query, setQuery] = useState<string>("");
@ -63,6 +68,8 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
const [selectedSearchType, setSelectedSearchType] =
useState<SearchType>(defaultSearchType);
const [selectedPersona, setSelectedPersona] = useState<number | null>(null);
// Overrides for default behavior that only last a single query
const [defaultOverrides, setDefaultOverrides] =
useState<SearchDefaultOverrides>(SEARCH_DEFAULT_OVERRIDES_START);
@ -134,11 +141,23 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
setSearchResponse(initialSearchResponse);
setValidQuestionResponse(VALID_QUESTION_RESPONSE_DEFAULT);
const chatSessionResponse = await createChatSession(selectedPersona);
if (!chatSessionResponse.ok) {
updateError(
`Unable to create chat session - ${await chatSessionResponse.text()}`
);
setIsFetching(false);
return;
}
const chatSessionId = (await chatSessionResponse.json())
.chat_session_id as number;
const searchFn = NEXT_PUBLIC_DISABLE_STREAMING
? searchRequest
: searchRequestStreamed;
const searchFnArgs = {
query,
chatSessionId,
sources: filterManager.selectedSources,
documentSets: filterManager.selectedDocumentSets,
timeRange: filterManager.timeRange,
@ -180,6 +199,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
const questionValidationArgs = {
query,
chatSessionId,
update: setValidQuestionResponse,
};
@ -226,6 +246,20 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
</div>
</div>
<div className="w-[800px] mx-auto">
{personas.length > 0 ? (
<div className="flex mb-2 w-64">
<PersonaSelector
personas={personas}
selectedPersonaId={selectedPersona}
onPersonaChange={(persona) =>
setSelectedPersona(persona ? persona.id : null)
}
/>
</div>
) : (
<div className="pt-3" />
)}
<SearchBar
query={query}
setQuery={setQuery}
@ -241,6 +275,11 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
validQuestionResponse={validQuestionResponse}
isFetching={isFetching}
defaultOverrides={defaultOverrides}
personaName={
selectedPersona
? personas.find((p) => p.id === selectedPersona)?.name
: null
}
/>
</div>
</div>

View File

@ -0,0 +1,12 @@
export async function createChatSession(personaId?: number | null) {
const chatSessionResponse = await fetch("/api/chat/create-chat-session", {
method: "POST",
body: JSON.stringify({
persona_id: personaId,
}),
headers: {
"Content-Type": "application/json",
},
});
return chatSessionResponse;
}

View File

@ -92,6 +92,7 @@ export interface Filters {
export interface SearchRequestArgs {
query: string;
chatSessionId: number;
sources: Source[];
documentSets: string[];
timeRange: DateRangePickerValue | null;

View File

@ -14,6 +14,7 @@ import { buildFilters } from "./utils";
export const searchRequestStreamed = async ({
query,
chatSessionId,
sources,
documentSets,
timeRange,
@ -35,6 +36,7 @@ export const searchRequestStreamed = async ({
const response = await fetch("/api/stream-direct-qa", {
method: "POST",
body: JSON.stringify({
chat_session_id: chatSessionId,
query,
collection: "danswer_index",
filters,

View File

@ -3,11 +3,13 @@ import { processRawChunkString } from "./streamingUtils";
export interface QuestionValidationArgs {
query: string;
chatSessionId: number;
update: (update: Partial<ValidQuestionResponse>) => void;
}
export const questionValidationStreamed = async <T>({
query,
chatSessionId,
update,
}: QuestionValidationArgs) => {
const emptyFilters = {
@ -20,6 +22,7 @@ export const questionValidationStreamed = async <T>({
method: "POST",
body: JSON.stringify({
query,
chat_session_id: chatSessionId,
collection: "danswer_index",
filters: emptyFilters,
enable_auto_detect_filters: false,