COMPLETE USER EXPERIENCE OVERHAUL (#1822)

This commit is contained in:
pablodanswer
2024-07-17 19:44:21 -07:00
committed by GitHub
parent 2b07c102f9
commit 87fadb07ea
122 changed files with 6814 additions and 2204 deletions

View File

@ -0,0 +1,32 @@
"""add search doc relevance details
Revision ID: 05c07bf07c00
Revises: 3a7802814195
Create Date: 2024-07-10 17:48:15.886653
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "05c07bf07c00"
down_revision = "b896bbd0d5a7"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_doc",
sa.Column("is_relevant", sa.Boolean(), nullable=True),
)
op.add_column(
"search_doc",
sa.Column("relevance_explanation", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("search_doc", "relevance_explanation")
op.drop_column("search_doc", "is_relevant")

View File

@ -42,11 +42,19 @@ class QADocsResponse(RetrievalDocs):
return initial_dict
# Second chunk of info for streaming QA
class LLMRelevanceFilterResponse(BaseModel):
relevant_chunk_indices: list[int]
class RelevanceChunk(BaseModel):
relevant: bool | None = None
content: str | None = None
class LLMRelevanceSummaryResponse(BaseModel):
relevance_summaries: dict[str, RelevanceChunk]
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer

View File

@ -75,6 +75,17 @@ LANGUAGE_CHAT_NAMING_HINT = (
or "The name of the conversation must be in the same language as the user query."
)
# Agentic search takes significantly more tokens and therefore has much higher cost.
# This configuration allows users to get a search-only experience with instant results
# and no involvement from the LLM.
# Additionally, some LLM providers have strict rate limits which may prohibit
# sending many API requests at once (as is done in agentic search).
DISABLE_AGENTIC_SEARCH = (
os.environ.get("DISABLE_AGENTIC_SEARCH") or "false"
).lower() == "true"
# Stops streaming answers back to the UI if this pattern is seen:
STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None

View File

@ -3,15 +3,20 @@ from datetime import datetime
from datetime import timedelta
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from danswer.auth.schemas import UserRole
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
@ -34,6 +39,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -81,17 +87,46 @@ def get_chat_sessions_by_slack_thread_id(
return db_session.scalars(stmt).all()
def get_first_messages_for_chat_sessions(
chat_session_ids: list[int], db_session: Session
) -> dict[int, str]:
subquery = (
select(ChatMessage.chat_session_id, func.min(ChatMessage.id).label("min_id"))
.where(
and_(
ChatMessage.chat_session_id.in_(chat_session_ids),
ChatMessage.message_type == MessageType.USER, # Select USER messages
)
)
.group_by(ChatMessage.chat_session_id)
.subquery()
)
query = select(ChatMessage.chat_session_id, ChatMessage.message).join(
subquery,
(ChatMessage.chat_session_id == subquery.c.chat_session_id)
& (ChatMessage.id == subquery.c.min_id),
)
first_messages = db_session.execute(query).all()
return dict([(row.chat_session_id, row.message) for row in first_messages])
def get_chat_sessions_by_user(
user_id: UUID | None,
deleted: bool | None,
db_session: Session,
include_one_shot: bool = False,
only_one_shot: bool = False,
) -> list[ChatSession]:
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
if not include_one_shot:
if only_one_shot:
stmt = stmt.where(ChatSession.one_shot.is_(True))
else:
stmt = stmt.where(ChatSession.one_shot.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@ -275,6 +310,20 @@ def get_chat_messages_by_sessions(
return db_session.execute(stmt).scalars().all()
def get_search_docs_for_chat_message(
chat_message_id: int, db_session: Session
) -> list[SearchDoc]:
stmt = (
select(SearchDoc)
.join(
ChatMessage__SearchDoc, ChatMessage__SearchDoc.search_doc_id == SearchDoc.id
)
.where(ChatMessage__SearchDoc.chat_message_id == chat_message_id)
)
return list(db_session.scalars(stmt).all())
def get_chat_messages_by_session(
chat_session_id: int,
user_id: UUID | None,
@ -295,8 +344,6 @@ def get_chat_messages_by_session(
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
if prefetch_tool_calls:
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@ -484,6 +531,27 @@ def get_doc_query_identifiers_from_model(
return doc_query_identifiers
def update_search_docs_table_with_relevance(
db_session: Session,
reference_db_search_docs: list[SearchDoc],
relevance_summary: LLMRelevanceSummaryResponse,
) -> None:
for search_doc in reference_db_search_docs:
relevance_data = relevance_summary.relevance_summaries.get(
f"{search_doc.document_id}-{search_doc.chunk_ind}"
)
if relevance_data is not None:
db_session.execute(
update(SearchDoc)
.where(SearchDoc.id == search_doc.id)
.values(
is_relevant=relevance_data.relevant,
relevance_explanation=relevance_data.content,
)
)
db_session.commit()
def create_db_search_doc(
server_search_doc: ServerSearchDoc,
db_session: Session,
@ -498,6 +566,8 @@ def create_db_search_doc(
boost=server_search_doc.boost,
hidden=server_search_doc.hidden,
doc_metadata=server_search_doc.metadata,
is_relevant=server_search_doc.is_relevant,
relevance_explanation=server_search_doc.relevance_explanation,
# For docs further down that aren't reranked, we can't use the retrieval score
score=server_search_doc.score or 0.0,
match_highlights=server_search_doc.match_highlights,
@ -509,7 +579,6 @@ def create_db_search_doc(
db_session.add(db_search_doc)
db_session.commit()
return db_search_doc
@ -538,6 +607,8 @@ def translate_db_search_doc_to_server_search_doc(
match_highlights=(
db_search_doc.match_highlights if not remove_doc_content else []
),
relevance_explanation=db_search_doc.relevance_explanation,
is_relevant=db_search_doc.is_relevant,
updated_at=db_search_doc.updated_at if not remove_doc_content else None,
primary_owners=db_search_doc.primary_owners if not remove_doc_content else [],
secondary_owners=(
@ -561,9 +632,11 @@ def get_retrieval_docs_from_chat_message(
def translate_db_message_to_chat_message_detail(
chat_message: ChatMessage, remove_doc_content: bool = False
chat_message: ChatMessage,
remove_doc_content: bool = False,
) -> ChatMessageDetail:
chat_msg_detail = ChatMessageDetail(
chat_session_id=chat_message.chat_session_id,
message_id=chat_message.id,
parent_message=chat_message.parent_message,
latest_child_message=chat_message.latest_child_message,

View File

@ -671,6 +671,9 @@ class SearchDoc(Base):
)
is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True)
is_relevant: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
relevance_explanation: Mapped[str | None] = mapped_column(String, nullable=True)
chat_messages = relationship(
"ChatMessage",
secondary="chat_message__search_doc",

View File

@ -89,6 +89,9 @@ def _get_answer_stream_processor(
AnswerStream = Iterator[AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse]
logger = setup_logger()
class Answer:
def __init__(
self,
@ -112,6 +115,7 @@ class Answer:
skip_explicit_tool_calling: bool = False,
# Returns the full document sections text from the search tool
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False,
) -> None:
if single_message_history and message_history:
raise ValueError(
@ -140,11 +144,12 @@ class Answer:
self._final_prompt: list[BaseMessage] | None = None
self._streamed_output: list[str] | None = None
self._processed_stream: list[
AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff
] | None = None
self._processed_stream: (
list[AnswerQuestionPossibleReturn | ToolResponse | ToolCallKickoff] | None
) = None
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
@ -403,8 +408,9 @@ class Answer:
)
)
)
final = tool_runner.tool_final_result()
yield tool_runner.tool_final_result()
yield final
prompt = prompt_builder.build()
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
@ -467,22 +473,23 @@ class Answer:
# assumes all tool responses will come first, then the final answer
break
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
def _stream() -> Iterator[str]:
if message:
yield cast(str, message)
yield from cast(Iterator[str], stream)
def _stream() -> Iterator[str]:
if message:
yield cast(str, message)
yield from cast(Iterator[str], stream)
yield from process_answer_stream_fn(_stream())
yield from process_answer_stream_fn(_stream())
processed_stream = []
for processed_packet in _process_stream(output_generator):

View File

@ -265,6 +265,7 @@ def prune_sections(
max_tokens=document_pruning_config.max_tokens,
tool_token_count=document_pruning_config.tool_num_tokens,
)
return _apply_pruning(
sections=sections,
section_relevance_list=section_relevance_list,

View File

@ -10,6 +10,7 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import LLMRelevanceSummaryResponse
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@ -21,6 +22,7 @@ from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.chat import update_search_docs_table_with_relevance
from danswer.db.engine import get_session_context_manager
from danswer.db.models import User
from danswer.db.persona import get_prompt_by_id
@ -48,6 +50,7 @@ from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.utils import get_json_line
from danswer.tools.force import ForceUseTool
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_EVALUATION_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
@ -57,6 +60,7 @@ from danswer.tools.tool_runner import ToolCallKickoff
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
AnswerObjectIterator = Iterator[
@ -70,6 +74,7 @@ AnswerObjectIterator = Iterator[
| ChatMessageDetail
| CitationInfo
| ToolCallKickoff
| LLMRelevanceSummaryResponse
]
@ -88,8 +93,9 @@ def stream_answer_objects(
bypass_acl: bool = False,
use_citations: bool = False,
danswerbot_flow: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> AnswerObjectIterator:
"""Streams in order:
@ -127,6 +133,7 @@ def stream_answer_objects(
user_query=query_msg.message,
history_str=history_str,
)
# Given back ahead of the documents for latency reasons
# In chat flow it's given back along with the documents
yield QueryRephrase(rephrased_query=rephrased_query)
@ -182,6 +189,7 @@ def stream_answer_objects(
chunks_below=query_req.chunks_below,
full_doc=query_req.full_doc,
bypass_acl=bypass_acl,
evaluate_response=query_req.evaluate_response,
)
answer_config = AnswerStyleConfig(
@ -189,6 +197,7 @@ def stream_answer_objects(
quotes_config=QuotesConfig() if not use_citations else None,
document_pruning_config=document_pruning_config,
)
answer = Answer(
question=query_msg.message,
answer_style_config=answer_config,
@ -204,12 +213,16 @@ def stream_answer_objects(
# tested quotes with tool calling too much yet
skip_explicit_tool_calling=True,
return_contexts=query_req.return_contexts,
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
)
# won't be any ImageGenerationDisplay responses since that tool is never passed in
dropped_inds: list[int] = []
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
# for one-shot flow, don't currently do anything with these
if isinstance(packet, ToolResponse):
# (likely fine that it comes after the initial creation of the search docs)
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, packet.response)
@ -242,6 +255,7 @@ def stream_answer_objects(
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
)
yield initial_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
chunk_indices = packet.response
@ -253,8 +267,21 @@ def stream_answer_objects(
)
yield LLMRelevanceFilterResponse(relevant_chunk_indices=packet.response)
elif packet.id == SEARCH_DOC_CONTENT_ID:
yield packet.response
elif packet.id == SEARCH_EVALUATION_ID:
evaluation_response = LLMRelevanceSummaryResponse(
relevance_summaries=packet.response
)
if reference_db_search_docs is not None:
update_search_docs_table_with_relevance(
db_session=db_session,
reference_db_search_docs=reference_db_search_docs,
relevance_summary=evaluation_response,
)
yield evaluation_response
else:
yield packet
@ -275,7 +302,6 @@ def stream_answer_objects(
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
@ -309,8 +335,9 @@ def get_search_answer(
bypass_acl: bool = False,
use_citations: bool = False,
danswerbot_flow: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> OneShotQAResponse:
"""Collects the streamed one shot answer responses into a single object"""

View File

@ -27,12 +27,19 @@ class DirectQARequest(ChunkContext):
messages: list[ThreadMessage]
prompt_id: int | None
persona_id: int
agentic: bool | None = None
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
# This is to forcibly skip (or run) the step, if None it uses the system defaults
skip_rerank: bool | None = None
skip_llm_chunk_filter: bool | None = None
chain_of_thought: bool = False
return_contexts: bool = False
# This is to toggle agentic evaluation:
# 1. Evaluates whether each response is relevant or not
# 2. Provides a summary of the document's relevance in the resulsts
evaluate_response: bool = False
# If True, skips generative an AI response to the search query
skip_gen_ai_answer_generation: bool = False
@root_validator
def check_chain_of_thought_and_prompt_id(

View File

@ -24,6 +24,25 @@ Query:
""".strip()
AGENTIC_SEARCH_EVALUATION_PROMPT = """
1. Chain of Thought Analysis:
Provide a chain of thought analysis considering:
- The main purpose and content of the document
- What the user is searching for
- How the document's topic relates to the query
- Potential uses of the document for the given query
Be thorough, but avoid unnecessary repetition. Think step by step.
2. Useful Analysis:
[ANALYSIS_START]
State the most important point from the chain of thought.
DO NOT refer to "the document" (describe it as "this")- ONLY state the core point in a description.
[ANALYSIS_END]
3. Relevance Determination:
RESULT: True (if potentially relevant)
RESULT: False (if not relevant)
""".strip()
# Use the following for easy viewing of prompts
if __name__ == "__main__":
print(LANGUAGE_REPHRASE_PROMPT)

View File

@ -130,11 +130,14 @@ class InferenceChunk(BaseChunk):
recency_bias: float
score: float | None
hidden: bool
is_relevant: bool | None = None
relevance_explanation: str | None = None
metadata: dict[str, str | list[str]]
# Matched sections in the chunk. Uses Vespa syntax e.g. <hi>TEXT</hi>
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
# when the doc was last updated
updated_at: datetime | None
primary_owners: list[str] | None = None
@ -227,6 +230,8 @@ class SearchDoc(BaseModel):
hidden: bool
metadata: dict[str, str | list[str]]
score: float | None
is_relevant: bool | None = None
relevance_explanation: str | None = None
# Matched sections in the doc. Uses Vespa syntax e.g. <hi>TEXT</hi>
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "the answer is <hi>42</hi>""]

View File

@ -5,10 +5,14 @@ from typing import cast
from sqlalchemy.orm import Session
from danswer.chat.models import RelevanceChunk
from danswer.configs.chat_configs import DISABLE_AGENTIC_SEARCH
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.prune_and_merge import ChunkRange
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
from danswer.llm.interfaces import LLM
@ -25,7 +29,10 @@ from danswer.search.postprocessing.postprocessing import search_postprocessing
from danswer.search.preprocessing.preprocessing import retrieval_preprocessing
from danswer.search.retrieval.search_runner import retrieve_chunks
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
@ -40,9 +47,12 @@ class SearchPipeline:
fast_llm: LLM,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
retrieval_metrics_callback: (
Callable[[RetrievalMetricsContainer], None] | None
) = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
pruning_config: DocumentPruningConfig | None = None,
):
self.search_request = search_request
self.user = user
@ -58,6 +68,8 @@ class SearchPipeline:
primary_index_name=self.embedding_model.index_name,
secondary_index_name=None,
)
self.prompt_config: PromptConfig | None = prompt_config
self.pruning_config: DocumentPruningConfig | None = pruning_config
# Preprocessing steps generate this
self._search_query: SearchQuery | None = None
@ -74,9 +86,9 @@ class SearchPipeline:
self._relevant_section_indices: list[int] | None = None
# Generates reranked chunks and LLM selections
self._postprocessing_generator: Iterator[
list[InferenceSection] | list[int]
] | None = None
self._postprocessing_generator: (
Iterator[list[InferenceSection] | list[int]] | None
) = None
"""Pre-processing"""
@ -323,6 +335,32 @@ class SearchPipeline:
)
return self._relevant_section_indices
@property
def relevance_summaries(self) -> dict[str, RelevanceChunk]:
if DISABLE_AGENTIC_SEARCH:
raise ValueError(
"Agentic saerch operation called while DISABLE_AGENTIC_SEARCH is toggled"
)
if len(self.reranked_sections) == 0:
logger.warning(
"No sections found in agentic search evalution. Returning empty dict."
)
return {}
sections = self.reranked_sections
functions = [
FunctionCall(
evaluate_inference_section, (section, self.search_query.query, self.llm)
)
for section in sections
]
results = run_functions_in_parallel(function_calls=functions)
return {
next(iter(value)): value[next(iter(value))] for value in results.values()
}
@property
def section_relevance_list(self) -> list[bool]:
return [

View File

@ -0,0 +1,71 @@
from danswer.chat.models import RelevanceChunk
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_to_string
from danswer.prompts.miscellaneous_prompts import AGENTIC_SEARCH_EVALUATION_PROMPT
from danswer.search.models import InferenceSection
from danswer.utils.logger import setup_logger
logger = setup_logger()
def evaluate_inference_section(
document: InferenceSection, query: str, llm: LLM
) -> dict[str, RelevanceChunk]:
relevance: RelevanceChunk = RelevanceChunk()
results = {}
# At least for now, is the same doucment ID across chunks
document_id = document.center_chunk.document_id
chunk_id = document.center_chunk.chunk_id
prompt = f"""
Analyze the relevance of this document to the search query:
Title: {document_id.split("/")[-1]}
Blurb: {document.combined_content}
Query: {query}
{AGENTIC_SEARCH_EVALUATION_PROMPT}
"""
content = message_to_string(llm.invoke(prompt=prompt))
analysis = ""
relevant = False
chain_of_thought = ""
parts = content.split("[ANALYSIS_START]", 1)
if len(parts) == 2:
chain_of_thought, rest = parts
else:
logger.warning(f"Missing [ANALYSIS_START] tag for document {document_id}")
rest = content
parts = rest.split("[ANALYSIS_END]", 1)
if len(parts) == 2:
analysis, result = parts
else:
logger.warning(f"Missing [ANALYSIS_END] tag for document {document_id}")
result = rest
chain_of_thought = chain_of_thought.strip()
analysis = analysis.strip()
result = result.strip().lower()
# Determine relevance
if "result: true" in result:
relevant = True
elif "result: false" in result:
relevant = False
else:
logger.warning(f"Invalid result format for document {document_id}")
if not analysis:
logger.warning(
f"Couldn't extract proper analysis for document {document_id}. Using full content."
)
analysis = content
relevance.content = analysis
relevance.relevant = relevant
results[f"{document_id}-{chunk_id}"] = relevance
return results

View File

@ -303,7 +303,6 @@ def handle_new_chat_message(
request.headers
),
)
return StreamingResponse(packets, media_type="application/json")

View File

@ -171,7 +171,6 @@ class SearchFeedbackRequest(BaseModel):
if click is False and feedback is None:
raise ValueError("Empty feedback received.")
return values
@ -186,6 +185,7 @@ class ChatMessageDetail(BaseModel):
time_sent: datetime
alternate_assistant_id: str | None
# Dict mapping citation number to db_doc_id
chat_session_id: int | None = None
citations: dict[int, int] | None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
@ -196,6 +196,13 @@ class ChatMessageDetail(BaseModel):
return initial_dict
class SearchSessionDetailResponse(BaseModel):
search_session_id: int
description: str
documents: list[SearchDoc]
messages: list[ChatMessageDetail]
class ChatSessionDetailResponse(BaseModel):
chat_session_id: int
description: str

View File

@ -7,6 +7,14 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import MessageType
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_chat_sessions_by_user
from danswer.db.chat import get_first_messages_for_chat_sessions
from danswer.db.chat import get_search_docs_for_chat_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
@ -24,8 +32,11 @@ from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
from danswer.server.query_and_chat.models import AdminSearchRequest
from danswer.server.query_and_chat.models import AdminSearchResponse
from danswer.server.query_and_chat.models import ChatSessionDetails
from danswer.server.query_and_chat.models import ChatSessionsResponse
from danswer.server.query_and_chat.models import HelperResponse
from danswer.server.query_and_chat.models import QueryValidationResponse
from danswer.server.query_and_chat.models import SearchSessionDetailResponse
from danswer.server.query_and_chat.models import SimpleQueryRequest
from danswer.server.query_and_chat.models import SourceTag
from danswer.server.query_and_chat.models import TagResponse
@ -46,7 +57,6 @@ def admin_search(
) -> AdminSearchResponse:
query = question.query
logger.info(f"Received admin search query: {query}")
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
source_type=question.filters.source_type,
@ -55,19 +65,15 @@ def admin_search(
tags=question.filters.tags,
access_control_list=user_acl_filters,
)
embedding_model = get_current_db_embedding_model(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
)
if not isinstance(document_index, VespaIndex):
raise HTTPException(
status_code=400,
detail="Cannot use admin-search when using a non-Vespa document index",
)
matching_chunks = document_index.admin_retrieval(query=query, filters=final_filters)
documents = chunks_or_sections_to_search_docs(matching_chunks)
@ -136,6 +142,103 @@ def query_validation(
return QueryValidationResponse(reasoning=reasoning, answerable=answerable)
@basic_router.get("/user-searches")
def get_user_search_sessions(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
user_id = user.id if user is not None else None
try:
search_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session, only_one_shot=True
)
except ValueError:
raise HTTPException(
status_code=404, detail="Chat session does not exist or has been deleted"
)
search_session_ids = [chat.id for chat in search_sessions]
first_messages = get_first_messages_for_chat_sessions(
search_session_ids, db_session
)
first_messages_dict = dict(first_messages)
response = ChatSessionsResponse(
sessions=[
ChatSessionDetails(
id=search.id,
name=first_messages_dict.get(search.id, search.description),
persona_id=search.persona_id,
time_created=search.time_created.isoformat(),
shared_status=search.shared_status,
folder_id=search.folder_id,
current_alternate_model=search.current_alternate_model,
)
for search in search_sessions
]
)
return response
@basic_router.get("/search-session/{session_id}")
def get_search_session(
session_id: int,
is_shared: bool = False,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchSessionDetailResponse:
user_id = user.id if user is not None else None
try:
search_session = get_chat_session_by_id(
chat_session_id=session_id,
user_id=user_id,
db_session=db_session,
is_shared=is_shared,
)
except ValueError:
raise ValueError("Search session does not exist or has been deleted")
session_messages = get_chat_messages_by_session(
chat_session_id=session_id,
user_id=user_id,
db_session=db_session,
# we already did a permission check above with the call to
# `get_chat_session_by_id`, so we can skip it here
skip_permission_check=True,
# we need the tool call objs anyways, so just fetch them in a single call
prefetch_tool_calls=True,
)
docs_response: list[SearchDoc] = []
for message in session_messages:
if (
message.message_type == MessageType.ASSISTANT
or message.message_type == MessageType.SYSTEM
):
docs = get_search_docs_for_chat_message(
db_session=db_session, chat_message_id=message.id
)
for doc in docs:
server_doc = translate_db_search_doc_to_server_search_doc(doc)
docs_response.append(server_doc)
response = SearchSessionDetailResponse(
search_session_id=session_id,
description=search_session.description,
documents=docs_response,
messages=[
translate_db_message_to_chat_message_detail(
msg, remove_doc_content=is_shared # if shared, don't leak doc content
)
for msg in session_messages
],
)
return response
# NOTE No longer used, after search/chat redesign.
# No search responses are answered with a conversational generative AI response
@basic_router.post("/stream-query-validation")
def stream_query_validation(
simple_query: SimpleQueryRequest, _: User = Depends(current_user)
@ -156,6 +259,7 @@ def get_answer_with_quote(
_: None = Depends(check_token_rate_limits),
) -> StreamingResponse:
query = query_request.messages[0].message
logger.info(f"Received query for one shot answer with quotes: {query}")
packets = stream_search_answer(
query_req=query_request,

View File

@ -10,6 +10,7 @@ from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import DanswerContext
from danswer.chat.models import DanswerContexts
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import DISABLE_AGENTIC_SEARCH
from danswer.db.models import Persona
from danswer.db.models import User
from danswer.dynamic_configs.interface import JSON_ro
@ -30,11 +31,15 @@ from danswer.secondary_llm_flows.query_expansion import history_based_query_reph
from danswer.tools.search.search_utils import llm_doc_to_dict
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
logger = setup_logger()
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
FINAL_CONTEXT_DOCUMENTS = "final_context_documents"
SEARCH_EVALUATION_ID = "evaluate_response"
class SearchResponseSummary(BaseModel):
@ -80,6 +85,7 @@ class SearchTool(Tool):
chunks_below: int = 0,
full_doc: bool = False,
bypass_acl: bool = False,
evaluate_response: bool = False,
) -> None:
self.user = user
self.persona = persona
@ -96,6 +102,7 @@ class SearchTool(Tool):
self.full_doc = full_doc
self.bypass_acl = bypass_acl
self.db_session = db_session
self.evaluate_response = evaluate_response
@property
def name(self) -> str:
@ -218,23 +225,28 @@ class SearchTool(Tool):
self.retrieval_options.filters if self.retrieval_options else None
),
persona=self.persona,
offset=self.retrieval_options.offset
if self.retrieval_options
else None,
offset=(
self.retrieval_options.offset if self.retrieval_options else None
),
limit=self.retrieval_options.limit if self.retrieval_options else None,
chunks_above=self.chunks_above,
chunks_below=self.chunks_below,
full_doc=self.full_doc,
enable_auto_detect_filters=self.retrieval_options.enable_auto_detect_filters
if self.retrieval_options
else None,
enable_auto_detect_filters=(
self.retrieval_options.enable_auto_detect_filters
if self.retrieval_options
else None
),
),
user=self.user,
llm=self.llm,
fast_llm=self.fast_llm,
bypass_acl=self.bypass_acl,
db_session=self.db_session,
prompt_config=self.prompt_config,
pruning_config=self.pruning_config,
)
yield ToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
@ -246,6 +258,7 @@ class SearchTool(Tool):
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
),
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=DanswerContexts(
@ -260,6 +273,7 @@ class SearchTool(Tool):
]
),
)
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=search_pipeline.relevant_section_indices,
@ -281,6 +295,11 @@ class SearchTool(Tool):
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS, response=llm_docs)
if self.evaluate_response and not DISABLE_AGENTIC_SEARCH:
yield ToolResponse(
id=SEARCH_EVALUATION_ID, response=search_pipeline.relevance_summaries
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
final_docs = cast(
list[LlmDoc],

View File

@ -87,6 +87,7 @@ def run_functions_in_parallel(
are the result_id of the FunctionCall and the values are the results of the call.
"""
results = {}
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(func_call.execute): func_call.result_id

View File

@ -1 +1 @@
f1f2 1 1718910083.03085 wikipedia:en
f1f2 2 1721064549.902656 wikipedia:en