LLM Chunk Filtering (#735)

This commit is contained in:
Yuhong Sun 2023-11-18 17:12:24 -08:00 committed by GitHub
parent d5916e420c
commit fa0d19cc8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 551 additions and 292 deletions

View File

@ -46,28 +46,32 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
## 💃 Features ## 💃 Features
* Direct QA powered by Generative AI models with answers backed by quotes and source links. * Direct QA powered by Generative AI models with answers backed by quotes and source links.
* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs. * Intelligent Document Retrieval (Hybrid Search + Reranking) using the latest NLP models.
* An AI Helper backed by a custom Deep Learning model to interpret user intent. * Automatic time/source filter extraction from natural language + custom model to identify user intent.
* User authentication and document level access management. * User authentication and document level access management.
* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.) * Support for LLMs of your choice (GPT-4, Llama2, Orca, etc.)
* Management Dashboard to manage connectors and set up features such as live update fetching. * Management Dashboards to manage connectors and set up features such as live update fetching.
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere. * One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
## 🔌 Connectors ## 🔌 Connectors
Danswer currently syncs documents (every 10 minutes) from: Efficiently pulls the latest changes from:
* Slack * Slack
* GitHub * GitHub
* Google Drive * Google Drive
* Confluence * Confluence
* Jira * Jira
* Notion * Notion
* Gong
* Slab * Slab
* Linear * Linear
* Productboard * Productboard
* Guru * Guru
* Zulip * Zulip
* Bookstack * Bookstack
* Document360
* Request Tracker
* Hubspot
* Local Files * Local Files
* Websites * Websites
* With more to come... * With more to come...
@ -75,7 +79,9 @@ Danswer currently syncs documents (every 10 minutes) from:
## 🚧 Roadmap ## 🚧 Roadmap
* Chat/Conversation support. * Chat/Conversation support.
* Organizational understanding. * Organizational understanding.
* Ability to locate and suggest experts. * Code Search
* Structured Query Languages (SQL, Excel formulas, etc.)
* Ability to locate and suggest experts from your team.
## 💡 Contributing ## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details. Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@ -140,18 +140,15 @@ def danswer_chat_retrieval(
) )
# Good Debug/Breakpoint # Good Debug/Breakpoint
ranked_chunks, unranked_chunks = search_chunks( top_chunks, _ = search_chunks(
query=search_query, document_index=get_default_document_index() query=search_query, document_index=get_default_document_index()
) )
if not ranked_chunks: if not top_chunks:
return [] return []
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [ filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) chunk for chunk in top_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
] ]
# get all chunks that fit into the token limit # get all chunks that fit into the token limit

View File

@ -178,8 +178,12 @@ MINI_CHUNK_SIZE = 150
NUM_RETURNED_HITS = 50 NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15 NUM_RERANKED_RESULTS = 15
# We feed in document chunks until we reach this token limit. # We feed in document chunks until we reach this token limit.
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks # Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
# may be smaller which could result in passing in more total chunks # significantly smaller which could result in passing in more total chunks.
# There is also a slight bit of overhead, not accounted for here such as separator patterns
# between the docs, metadata for the docs, etc.
# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
# model token limit
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int( NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5) os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
) )
@ -198,12 +202,14 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2
DISABLE_LLM_FILTER_EXTRACTION = ( DISABLE_LLM_FILTER_EXTRACTION = (
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true" os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
) )
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow DISABLE_LLM_CHUNK_FILTER = (
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
)
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05 QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI # Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False INCLUDE_METADATA = False
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
# Keyword Search Drop Stopwords # Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual # If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords # model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords

View File

@ -1,3 +1,4 @@
import os import os
FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true" FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true"
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"

View File

@ -35,6 +35,8 @@ SCORE = "score"
ID_SEPARATOR = ":;:" ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0 DEFAULT_BOOST = 0
SESSION_KEY = "session" SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
class DocumentSource(str, Enum): class DocumentSource(str, Enum):

View File

@ -232,7 +232,7 @@ def handle_message(
logger.debug(answer.answer) logger.debug(answer.answer)
return True return True
if not answer.top_ranked_docs: if not answer.top_documents:
logger.error(f"Unable to answer question: '{msg}' - no documents found") logger.error(f"Unable to answer question: '{msg}' - no documents found")
# Optionally, respond in thread with the error message, Used primarily # Optionally, respond in thread with the error message, Used primarily
# for debugging purposes # for debugging purposes
@ -265,8 +265,17 @@ def handle_message(
favor_recent=answer.favor_recent, favor_recent=answer.favor_recent,
) )
# Get the chunks fed to the LLM only, then fill with other docs
top_docs = answer.top_documents
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = build_documents_blocks( document_blocks = build_documents_blocks(
documents=answer.top_ranked_docs, query_event_id=answer.query_event_id documents=priority_ordered_docs,
query_event_id=answer.query_event_id,
) )
try: try:

View File

@ -9,7 +9,7 @@ from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import HARD_DELETE_CHATS from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession from danswer.db.models import ChatSession

View File

@ -1,19 +1,19 @@
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from functools import partial
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import QUERY_EVENT_ID
from danswer.db.feedback import update_query_event_llm_answer from danswer.db.feedback import update_query_event_llm_answer
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa.factory import get_default_qa_model from danswer.direct_qa.factory import get_default_qa_model
from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks from danswer.direct_qa.qa_utils import get_chunks_for_qa
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.models import QueryFlow from danswer.search.models import QueryFlow
@ -24,11 +24,12 @@ from danswer.search.search_runner import danswer_search
from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
from danswer.server.models import QADocsResponse
from danswer.server.models import QAResponse from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.utils import get_json_line from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time from danswer.utils.timing import log_generator_function_time
@ -54,24 +55,34 @@ def answer_qa_query(
offset_count = question.offset if question.offset is not None else 0 offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}") logger.info(f"Received QA query: {query}")
functions_to_run: dict[Callable, tuple] = { run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
extract_question_time_filters: (question,), run_source_filters = FunctionCall(
extract_question_source_filters: (question, db_session), extract_question_source_filters, (question, db_session), {}
query_intent: (query,), )
} run_query_intent = FunctionCall(query_intent, (query,), {})
parallel_results = run_functions_in_parallel(functions_to_run) parallel_results = run_functions_in_parallel(
[
run_time_filters,
run_source_filters,
run_query_intent,
]
)
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"] time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
source_filters = parallel_results["extract_question_source_filters"] source_filters = parallel_results[run_source_filters.result_id]
predicted_search, predicted_flow = parallel_results["query_intent"] predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
# Set flow as search so frontend doesn't ask the user if they want to run QA over more docs
if disable_generative_answer:
predicted_flow = QueryFlow.SEARCH
# Modifies the question object but nothing upstream uses it # Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent question.favor_recent = favor_recent
question.filters.source_type = source_filters question.filters.source_type = source_filters
ranked_chunks, unranked_chunks, query_event_id = danswer_search( top_chunks, llm_chunk_selection, query_event_id = danswer_search(
question=question, question=question,
user=user, user=user,
db_session=db_session, db_session=db_session,
@ -80,38 +91,23 @@ def answer_qa_query(
rerank_metrics_callback=rerank_metrics_callback, rerank_metrics_callback=rerank_metrics_callback,
) )
if not ranked_chunks: top_docs = chunks_to_search_docs(top_chunks)
return QAResponse(
partial_response = partial(
QAResponse,
top_documents=chunks_to_search_docs(top_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
if disable_generative_answer or not top_docs:
return partial_response(
answer=None, answer=None,
quotes=None, quotes=None,
top_ranked_docs=None,
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
# set flow as search so frontend doesn't ask the user if they want
# to run QA over more documents
predicted_flow=QueryFlow.SEARCH,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
) )
try: try:
@ -119,41 +115,28 @@ def answer_qa_query(
timeout=answer_generation_timeout, real_time_flow=real_time_flow timeout=answer_generation_timeout, real_time_flow=real_time_flow
) )
except Exception as e: except Exception as e:
return QAResponse( return partial_response(
answer=None, answer=None,
quotes=None, quotes=None,
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=str(e), error_msg=str(e),
) )
# remove chunks marked as not applicable for QA (e.g. Google Drive file llm_chunks_indices = get_chunks_for_qa(
# types which can't be parsed). These chunks are useful to show in the chunks=top_chunks,
# search results, but not for QA. llm_chunk_selection=llm_chunk_selection,
filtered_ranked_chunks = [ batch_offset=offset_count,
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
) )
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
logger.debug( logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}" f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
) )
error_msg = None error_msg = None
try: try:
d_answer, quotes = qa_model.answer_question( d_answer, quotes = qa_model.answer_question(
query, usable_chunks, metrics_callback=llm_metrics_callback query, llm_chunks, metrics_callback=llm_metrics_callback
) )
except Exception as e: except Exception as e:
# exception is logged in the answer_question method, no need to re-log # exception is logged in the answer_question method, no need to re-log
@ -169,37 +152,17 @@ def answer_qa_query(
user_id=None if user is None else user.id, user_id=None if user is None else user.id,
) )
validity = None
if not real_time_flow and enable_reflexion and d_answer is not None: if not real_time_flow and enable_reflexion and d_answer is not None:
valid = False validity = False
if d_answer.answer is not None: if d_answer.answer is not None:
valid = get_answer_validity(query, d_answer.answer) validity = get_answer_validity(query, d_answer.answer)
return QAResponse( return partial_response(
answer=d_answer.answer if d_answer else None,
quotes=quotes.quotes if quotes else None,
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
eval_res_valid=True if valid else False,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=error_msg,
)
return QAResponse(
answer=d_answer.answer if d_answer else None, answer=d_answer.answer if d_answer else None,
quotes=quotes.quotes if quotes else None, quotes=quotes.quotes if quotes else None,
top_ranked_docs=top_docs, eval_res_valid=validity,
lower_ranked_docs=unranked_top_docs, llm_chunks_indices=llm_chunks_indices,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=error_msg, error_msg=error_msg,
) )
@ -220,36 +183,47 @@ def answer_qa_query_stream(
query = question.query query = question.query
offset_count = question.offset if question.offset is not None else 0 offset_count = question.offset if question.offset is not None else 0
functions_to_run: dict[Callable, tuple] = { run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
extract_question_time_filters: (question,), run_source_filters = FunctionCall(
extract_question_source_filters: (question, db_session), extract_question_source_filters, (question, db_session), {}
query_intent: (query,), )
} run_query_intent = FunctionCall(query_intent, (query,), {})
parallel_results = run_functions_in_parallel(functions_to_run) parallel_results = run_functions_in_parallel(
[
run_time_filters,
run_source_filters,
run_query_intent,
]
)
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"] time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
source_filters = parallel_results["extract_question_source_filters"] source_filters = parallel_results[run_source_filters.result_id]
predicted_search, predicted_flow = parallel_results["query_intent"] predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
# Modifies the question object but nothing upstream uses it # Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent question.favor_recent = favor_recent
question.filters.source_type = source_filters question.filters.source_type = source_filters
ranked_chunks, unranked_chunks, query_event_id = danswer_search( top_chunks, llm_chunk_selection, query_event_id = danswer_search(
question=question, question=question,
user=user, user=user,
db_session=db_session, db_session=db_session,
document_index=get_default_document_index(), document_index=get_default_document_index(),
) )
top_docs = chunks_to_search_docs(ranked_chunks) top_docs = chunks_to_search_docs(top_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
initial_response = RerankedRetrievalDocs( llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
initial_response = QADocsResponse(
top_documents=top_docs, top_documents=top_docs,
unranked_top_documents=unranked_top_docs, llm_chunks_indices=llm_chunks_indices,
# if generative AI is disabled, set flow as search so frontend # if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents # doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH predicted_flow=QueryFlow.SEARCH
@ -260,10 +234,9 @@ def answer_qa_query_stream(
favor_recent=favor_recent, favor_recent=favor_recent,
).dict() ).dict()
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
yield get_json_line(initial_response) yield get_json_line(initial_response)
if not ranked_chunks: if not top_chunks:
logger.debug("No Documents Found") logger.debug("No Documents Found")
return return
@ -279,25 +252,13 @@ def answer_qa_query_stream(
yield get_json_line(error.dict()) yield get_json_line(error.dict())
return return
# remove chunks marked as not applicable for QA (e.g. Google Drive file llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
# types which can't be parsed). These chunks are useful to show in the
# search results, but not for QA.
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug( logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}" f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
) )
try: try:
for response_packet in qa_model.answer_question_stream(query, usable_chunks): for response_packet in qa_model.answer_question_stream(query, llm_chunks):
if response_packet is None: if response_packet is None:
continue continue
if ( if (
@ -321,4 +282,4 @@ def answer_qa_query_stream(
user_id=None if user is None else user.id, user_id=None if user is None else user.id,
) )
yield get_json_line({"query_event_id": query_event_id}) yield get_json_line({QUERY_EVENT_ID: query_event_id})

View File

@ -11,6 +11,7 @@ import regex
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
@ -316,3 +317,57 @@ def get_usable_chunks(
offset_into_chunks += len(usable_chunks) offset_into_chunks += len(usable_chunks)
return usable_chunks return usable_chunks
def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
batch_offset: int = 0,
) -> list[int]:
"""
Gives back indices of chunks to pass into the LLM for Q&A.
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
by the LLM in a separate flow (this can be turned off)
Note, the batch_offset calculation has to count the batches from the beginning each time as
there's no way to know which chunks were included in the prior batches without recounting atm,
this is somewhat slow as it requires tokenizing all the chunks again
"""
batch_index = 0
latest_batch_indices: list[int] = []
token_count = 0
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
for selection_target in [True, False]:
for ind, chunk in enumerate(chunks):
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
IGNORE_FOR_QA
):
continue
# We calculate it live in case the user uses a different LLM + tokenizer
chunk_token = check_number_of_tokens(chunk.content)
token_count += chunk_token
# Always use at least 1 chunk
if token_count <= token_limit or not latest_batch_indices:
latest_batch_indices.append(ind)
current_chunk_unused = False
else:
current_chunk_unused = True
if token_count >= token_limit:
if batch_index < batch_offset:
batch_index += 1
if current_chunk_unused:
latest_batch_indices = [ind]
token_count = chunk_token
else:
latest_batch_indices = []
token_count = 0
else:
return latest_batch_indices
return latest_batch_indices

View File

@ -33,11 +33,6 @@ class LangChainChatLLM(LLM, abc.ABC):
def llm(self) -> BaseChatModel: def llm(self) -> BaseChatModel:
raise NotImplementedError raise NotImplementedError
def _log_model_config(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
@staticmethod @staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None: def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list): if isinstance(prompt, list):
@ -46,8 +41,12 @@ class LangChainChatLLM(LLM, abc.ABC):
if isinstance(prompt, str): if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}") logger.debug(f"Prompt:\n{prompt}")
def log_model_configs(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
def invoke(self, prompt: LanguageModelInput) -> str: def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS: if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt) self._log_prompt(prompt)
@ -58,7 +57,6 @@ class LangChainChatLLM(LLM, abc.ABC):
return model_raw return model_raw
def stream(self, prompt: LanguageModelInput) -> Iterator[str]: def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS: if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt) self._log_prompt(prompt)

View File

@ -9,6 +9,10 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import convert_lm_input_to_basic_string from danswer.llm.utils import convert_lm_input_to_basic_string
from danswer.utils.logger import setup_logger
logger = setup_logger()
class CustomModelServer(LLM): class CustomModelServer(LLM):
@ -65,6 +69,9 @@ class CustomModelServer(LLM):
response.raise_for_status() response.raise_for_status()
return json.loads(response.content).get("generated_text", "") return json.loads(response.content).get("generated_text", "")
def log_model_configs(self) -> None:
logger.debug(f"Custom model at: {self._endpoint}")
def invoke(self, prompt: LanguageModelInput) -> str: def invoke(self, prompt: LanguageModelInput) -> str:
return self._execute(prompt) return self._execute(prompt)

View File

@ -61,6 +61,11 @@ class DanswerGPT4All(LLM):
self.temperature = temperature self.temperature = temperature
self.gpt4all_model = GPT4All(model_version) self.gpt4all_model = GPT4All(model_version)
def log_model_configs(self) -> None:
logger.debug(
f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}"
)
def invoke(self, prompt: LanguageModelInput) -> str: def invoke(self, prompt: LanguageModelInput) -> str:
prompt_basic = convert_lm_input_to_basic_string(prompt) prompt_basic = convert_lm_input_to_basic_string(prompt)
return self.gpt4all_model.generate(prompt_basic) return self.gpt4all_model.generate(prompt_basic)

View File

@ -22,6 +22,10 @@ class LLM(abc.ABC):
def requires_api_key(self) -> bool: def requires_api_key(self) -> bool:
return True return True
@abc.abstractmethod
def log_model_configs(self) -> None:
raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def invoke(self, prompt: LanguageModelInput) -> str: def invoke(self, prompt: LanguageModelInput) -> str:
raise NotImplementedError raise NotImplementedError

View File

@ -36,6 +36,7 @@ from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.credentials import create_initial_public_credential from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.factory import get_default_qa_model from danswer.direct_qa.factory import get_default_qa_model
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm
from danswer.server.cc_pair.api import router as cc_pair_router from danswer.server.cc_pair.api import router as cc_pair_router
from danswer.server.chat_backend import router as chat_router from danswer.server.chat_backend import router as chat_router
from danswer.server.connector import router as connector_router from danswer.server.connector import router as connector_router
@ -197,7 +198,7 @@ def get_application() -> FastAPI:
warm_up_models() warm_up_models()
# This is for the LLM, most LLMs will not need warming up # This is for the LLM, most LLMs will not need warming up
# It logs for itself get_default_llm().log_model_configs()
get_default_qa_model().warm_up_model() get_default_qa_model().warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded") logger.info("Verifying query preprocessing (NLTK) data is downloaded")

View File

@ -132,6 +132,29 @@ Note: The "file" source only applies to when the user refers to uploaded files i
""".strip() """.strip()
USEFUL_PAT = "Yes useful"
NONUSEFUL_PAT = "Not useful"
CHUNK_FILTER_PROMPT = f"""
Determine if the reference section is USEFUL for answering the user query.
It is NOT enough for the section to be related to the query, \
it must contain information that is USEFUL for answering the query.
If the section contains ANY useful information, that is good enough, \
it does not need to fully answer the every part of the user query.
Reference Section:
```
{{chunk_text}}
```
User Query:
```
{{user_query}}
```
Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
""".strip()
# User the following for easy viewing of prompts # User the following for easy viewing of prompts
if __name__ == "__main__": if __name__ == "__main__":
print(ANSWERABLE_PROMPT) print(ANSWERABLE_PROMPT)

View File

@ -3,6 +3,7 @@ from enum import Enum
from pydantic import BaseModel from pydantic import BaseModel
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
@ -57,6 +58,9 @@ class SearchQuery(BaseModel):
skip_rerank: bool = SKIP_RERANKING skip_rerank: bool = SKIP_RERANKING
# Only used if not skip_rerank # Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS num_rerank: int | None = NUM_RERANKED_RESULTS
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER
# Only used if not skip_llm_chunk_filter
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
class RetrievalMetricsContainer(BaseModel): class RetrievalMetricsContainer(BaseModel):

View File

@ -8,7 +8,9 @@ from nltk.tokenize import word_tokenize # type:ignore
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import HYBRID_ALPHA from danswer.configs.app_configs import HYBRID_ALPHA
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
@ -33,9 +35,12 @@ from danswer.search.models import SearchQuery
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import EmbeddingModel from danswer.search.search_nlp_models import EmbeddingModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc from danswer.server.models import SearchDoc
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
@ -147,7 +152,12 @@ def semantic_reranking(
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
model_min: int = CROSS_ENCODER_RANGE_MIN, model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX, model_max: int = CROSS_ENCODER_RANGE_MAX,
) -> list[InferenceChunk]: ) -> tuple[list[InferenceChunk], list[int]]:
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
of the chunks in their new sorted order.
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
cross_encoders = CrossEncoderEnsembleModel() cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks] passages = [chunk.content for chunk in chunks]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages) sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
@ -168,16 +178,20 @@ def semantic_reranking(
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
model_max - model_min model_max - model_min
) )
scored_results = list(zip(normalized_b_s_scores, raw_sim_scores, chunks)) orig_indices = [i for i in range(len(normalized_b_s_scores))]
scored_results = list(
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
)
scored_results.sort(key=lambda x: x[0], reverse=True) scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_raw_scores, ranked_chunks = zip(*scored_results) ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
*scored_results
)
logger.debug( logger.debug(
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
) )
# Assign new chunk scores based on reranking # Assign new chunk scores based on reranking
# TODO if pagination is added, the scores won't make sense with respect to the non-reranked hits
for ind, chunk in enumerate(ranked_chunks): for ind, chunk in enumerate(ranked_chunks):
chunk.score = ranked_sim_scores[ind] chunk.score = ranked_sim_scores[ind]
@ -198,7 +212,7 @@ def semantic_reranking(
) )
) )
return list(ranked_chunks) return list(ranked_chunks), list(ranked_indices)
def apply_boost_legacy( def apply_boost_legacy(
@ -257,6 +271,9 @@ def apply_boost_legacy(
def apply_boost( def apply_boost(
chunks: list[InferenceChunk], chunks: list[InferenceChunk],
# Need the range of values to not be too spread out for applying boost
# therefore norm across only the top few results
norm_cutoff: int = NUM_RERANKED_RESULTS,
norm_min: float = SIM_SCORE_RANGE_LOW, norm_min: float = SIM_SCORE_RANGE_LOW,
norm_max: float = SIM_SCORE_RANGE_HIGH, norm_max: float = SIM_SCORE_RANGE_HIGH,
) -> list[InferenceChunk]: ) -> list[InferenceChunk]:
@ -266,13 +283,13 @@ def apply_boost(
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
recency_multiplier = [chunk.recency_bias for chunk in chunks] recency_multiplier = [chunk.recency_bias for chunk in chunks]
norm_min = min(norm_min, min(scores)) norm_min = min(norm_min, min(scores[:norm_cutoff]))
norm_max = max(norm_max, max(scores)) norm_max = max(norm_max, max(scores[:norm_cutoff]))
# This should never be 0 unless user has done some weird/wrong settings # This should never be 0 unless user has done some weird/wrong settings
norm_range = norm_max - norm_min norm_range = norm_max - norm_min
boosted_scores = [ boosted_scores = [
(score - norm_min) * boost * recency / norm_range max(0, (score - norm_min) * boost * recency / norm_range)
for score, boost, recency in zip(scores, boosts, recency_multiplier) for score, boost, recency in zip(scores, boosts, recency_multiplier)
] ]
@ -299,7 +316,14 @@ def search_chunks(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]: ) -> tuple[list[InferenceChunk], list[bool]]:
"""Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM.
For sake of speed, the system cannot rerank all retrieved chunks
Also pass the chunks through LLM to determine if they are relevant (binary for speed)
Only the first max_llm_filter_chunks
"""
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [ top_links = [
c.source_links[0] if c.source_links is not None else "No Link" c.source_links[0] if c.source_links is not None else "No Link"
@ -316,7 +340,7 @@ def search_chunks(
f"{query.search_type.value.capitalize()} search returned no results " f"{query.search_type.value.capitalize()} search returned no results "
f"with filters: {query.filters}" f"with filters: {query.filters}"
) )
return None, None return [], []
if retrieval_metrics_callback is not None: if retrieval_metrics_callback is not None:
chunk_metrics = [ chunk_metrics = [
@ -332,27 +356,62 @@ def search_chunks(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics) RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
) )
# Keyword Search should never do reranking, no transformers involved in this flow functions_to_run: list[FunctionCall] = []
if query.search_type == SearchType.KEYWORD:
# Keyword Search should not do reranking
if query.search_type == SearchType.KEYWORD or query.skip_rerank:
_log_top_chunk_links(query.search_type.value, top_chunks) _log_top_chunk_links(query.search_type.value, top_chunks)
return top_chunks, None run_rerank_id: str | None = None
else:
run_rerank = FunctionCall(
semantic_reranking,
(query.query, top_chunks[: query.num_rerank]),
{"rerank_metrics_callback": rerank_metrics_callback},
)
functions_to_run.append(run_rerank)
run_rerank_id = run_rerank.result_id
if query.skip_rerank: run_llm_filter_id = None
# Need the range of values to not be too spread out for applying boost if not query.skip_llm_chunk_filter:
# Therefore pass in smaller set of chunks to limit the range for norm-ing run_llm_filter = FunctionCall(
boosted_chunks = apply_boost(top_chunks[: query.num_rerank]) llm_batch_eval_chunks,
_log_top_chunk_links(query.search_type.value, boosted_chunks) (
return boosted_chunks, top_chunks[query.num_rerank :] query.query,
[chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]],
),
{},
)
functions_to_run.append(run_llm_filter)
run_llm_filter_id = run_llm_filter.result_id
ranked_chunks = semantic_reranking( parallel_results = run_functions_in_parallel(functions_to_run)
query.query,
top_chunks[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
ranked_results = parallel_results.get(str(run_rerank_id))
if ranked_results is None:
ranked_chunks = top_chunks
sorted_indices = [i for i in range(len(top_chunks))]
else:
ranked_chunks, orig_indices = ranked_results
sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks)))
lower_chunks = top_chunks[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
llm_chunk_selection = parallel_results.get(str(run_llm_filter_id))
if llm_chunk_selection is None:
reranked_llm_chunk_selection = [True for _ in top_chunks]
else:
llm_chunk_selection.extend(
[False for _ in top_chunks[query.max_llm_filter_chunks :]]
)
reranked_llm_chunk_selection = [
llm_chunk_selection[ind] for ind in sorted_indices
]
_log_top_chunk_links(query.search_type.value, ranked_chunks) _log_top_chunk_links(query.search_type.value, ranked_chunks)
return ranked_chunks, top_chunks[query.num_rerank :] return ranked_chunks, reranked_llm_chunk_selection
def danswer_search( def danswer_search(
@ -360,10 +419,11 @@ def danswer_search(
user: User | None, user: User | None,
db_session: Session, db_session: Session,
document_index: DocumentIndex, document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None, int]: ) -> tuple[list[InferenceChunk], list[bool], int]:
query_event_id = create_query_event( query_event_id = create_query_event(
query=question.query, query=question.query,
search_type=question.search_type, search_type=question.search_type,
@ -384,17 +444,21 @@ def danswer_search(
query=question.query, query=question.query,
search_type=question.search_type, search_type=question.search_type,
filters=final_filters, filters=final_filters,
favor_recent=True if question.favor_recent is None else question.favor_recent, # Still applies time decay but not magnified
favor_recent=question.favor_recent
if question.favor_recent is not None
else False,
skip_llm_chunk_filter=skip_llm_chunk_filter,
) )
ranked_chunks, unranked_chunks = search_chunks( top_chunks, llm_chunk_selection = search_chunks(
query=search_query, query=search_query,
document_index=document_index, document_index=document_index,
retrieval_metrics_callback=retrieval_metrics_callback, retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback, rerank_metrics_callback=rerank_metrics_callback,
) )
retrieved_ids = [doc.document_id for doc in ranked_chunks] if ranked_chunks else [] retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else []
update_query_event_retrieved_documents( update_query_event_retrieved_documents(
db_session=db_session, db_session=db_session,
@ -403,4 +467,4 @@ def danswer_search(
user_id=None if user is None else user.id, user_id=None if user is None else user.id,
) )
return ranked_chunks, unranked_chunks, query_event_id return top_chunks, llm_chunk_selection, query_event_id

View File

@ -0,0 +1,65 @@
from collections.abc import Callable
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.secondary_llm_flows import CHUNK_FILTER_PROMPT
from danswer.prompts.secondary_llm_flows import NONUSEFUL_PAT
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def llm_eval_chunk(query: str, chunk_content: str) -> bool:
def _get_usefulness_messages() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": CHUNK_FILTER_PROMPT.format(
chunk_text=chunk_content, user_query=query
),
},
]
return messages
def _extract_usefulness(model_output: str) -> bool:
"""Default useful if the LLM doesn't match pattern exactly
This is because it's better to trust the (re)ranking if LLM fails"""
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
return False
return True
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
# When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds
model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt)
logger.debug(model_output)
return _extract_usefulness(model_output)
def llm_batch_eval_chunks(
query: str, chunk_contents: list[str], use_threads: bool = True
) -> list[bool]:
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents
]
logger.debug(
"Running LLM usefulness eval in parallel (following logging may be out of order)"
)
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
# In case of failure/timeout, don't throw out the chunk
return [True if item is None else item for item in parallel_results]
else:
return [
llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents
]

View File

@ -171,12 +171,58 @@ class SearchDoc(BaseModel):
return initial_dict return initial_dict
class QuestionRequest(BaseModel):
query: str
collection: str
filters: BaseFilters
offset: int | None
enable_auto_detect_filters: bool
favor_recent: bool | None = None
search_type: SearchType = SearchType.HYBRID
class QAFeedbackRequest(BaseModel):
query_id: int
feedback: QAFeedbackType
class SearchFeedbackRequest(BaseModel):
query_id: int
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool
class RetrievalDocs(BaseModel): class RetrievalDocs(BaseModel):
top_documents: list[SearchDoc] top_documents: list[SearchDoc]
class RerankedRetrievalDocs(RetrievalDocs): class SearchResponse(RetrievalDocs):
unranked_top_documents: list[SearchDoc] query_event_id: int
source_type: list[DocumentSource] | None
time_cutoff: datetime | None
favor_recent: bool
class QAResponse(SearchResponse):
answer: str | None # DanswerAnswer
quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_chunks_indices: list[int] | None = None
error_msg: str | None = None
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs):
llm_chunks_indices: list[int]
predicted_flow: QueryFlow predicted_flow: QueryFlow
predicted_search: SearchType predicted_search: SearchType
time_cutoff: datetime | None time_cutoff: datetime | None
@ -194,21 +240,6 @@ class CreateChatSessionID(BaseModel):
chat_session_id: int chat_session_id: int
class QuestionRequest(BaseModel):
query: str
collection: str
filters: BaseFilters
offset: int | None
enable_auto_detect_filters: bool
favor_recent: bool | None = None
search_type: SearchType = SearchType.HYBRID
class QAFeedbackRequest(BaseModel):
query_id: int
feedback: QAFeedbackType
class ChatFeedbackRequest(BaseModel): class ChatFeedbackRequest(BaseModel):
chat_session_id: int chat_session_id: int
message_number: int message_number: int
@ -217,14 +248,6 @@ class ChatFeedbackRequest(BaseModel):
feedback_text: str | None = None feedback_text: str | None = None
class SearchFeedbackRequest(BaseModel):
query_id: int
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType
class CreateChatMessageRequest(BaseModel): class CreateChatMessageRequest(BaseModel):
chat_session_id: int chat_session_id: int
message_number: int message_number: int
@ -280,30 +303,6 @@ class ChatSessionDetailResponse(BaseModel):
messages: list[ChatMessageDetail] messages: list[ChatMessageDetail]
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool
class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None
lower_ranked_docs: list[SearchDoc] | None
query_event_id: int
source_type: list[DocumentSource] | None
time_cutoff: datetime | None
favor_recent: bool
class QAResponse(SearchResponse):
answer: str | None # DanswerAnswer
quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
error_msg: str | None = None
class UserByEmail(BaseModel): class UserByEmail(BaseModel):
user_email: str user_email: str

View File

@ -1,5 +1,3 @@
from collections.abc import Callable
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
from fastapi import HTTPException from fastapi import HTTPException
@ -36,6 +34,7 @@ from danswer.server.models import SearchDoc
from danswer.server.models import SearchFeedbackRequest from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse from danswer.server.models import SearchResponse
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.threadpool_concurrency import run_functions_in_parallel
logger = setup_logger() logger = setup_logger()
@ -131,10 +130,10 @@ def handle_search_request(
query = question.query query = question.query
logger.info(f"Received {question.search_type.value} " f"search query: {query}") logger.info(f"Received {question.search_type.value} " f"search query: {query}")
functions_to_run: dict[Callable, tuple] = { functions_to_run = [
extract_question_time_filters: (question,), FunctionCall(extract_question_time_filters, (question,), {}),
extract_question_source_filters: (question, db_session), FunctionCall(extract_question_source_filters, (question, db_session), {}),
} ]
parallel_results = run_functions_in_parallel(functions_to_run) parallel_results = run_functions_in_parallel(functions_to_run)
@ -145,29 +144,18 @@ def handle_search_request(
question.favor_recent = favor_recent question.favor_recent = favor_recent
question.filters.source_type = source_filters question.filters.source_type = source_filters
ranked_chunks, unranked_chunks, query_event_id = danswer_search( top_chunks, _, query_event_id = danswer_search(
question=question, question=question,
user=user, user=user,
db_session=db_session, db_session=db_session,
document_index=get_default_document_index(), document_index=get_default_document_index(),
skip_llm_chunk_filter=True,
) )
if not ranked_chunks: top_docs = chunks_to_search_docs(top_chunks)
return SearchResponse(
top_ranked_docs=None,
lower_ranked_docs=None,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
top_docs = chunks_to_search_docs(ranked_chunks)
lower_top_docs = chunks_to_search_docs(unranked_chunks)
return SearchResponse( return SearchResponse(
top_ranked_docs=top_docs, top_documents=top_docs,
lower_ranked_docs=lower_top_docs or None,
query_event_id=query_event_id, query_event_id=query_event_id,
source_type=source_filters, source_type=source_filters,
time_cutoff=time_cutoff, time_cutoff=time_cutoff,

View File

@ -37,16 +37,20 @@ def optional_telemetry(record_type: RecordType, data: dict) -> None:
try: try:
def telemetry_logic() -> None: def telemetry_logic() -> None:
payload = { try:
"data": data, payload = {
"record": record_type, "data": data,
"customer_uuid": get_or_generate_uuid(), "record": record_type,
} "customer_uuid": get_or_generate_uuid(),
requests.post( }
DANSWER_TELEMETRY_ENDPOINT, requests.post(
headers={"Content-Type": "application/json"}, DANSWER_TELEMETRY_ENDPOINT,
json=payload, headers={"Content-Type": "application/json"},
) json=payload,
)
except Exception:
# This way it silences all thread level logging as well
pass
# Run in separate thread to have minimal overhead in main flows # Run in separate thread to have minimal overhead in main flows
thread = threading.Thread(target=telemetry_logic, daemon=True) thread = threading.Thread(target=telemetry_logic, daemon=True)

View File

@ -1,3 +1,4 @@
import uuid
from collections.abc import Callable from collections.abc import Callable
from concurrent.futures import as_completed from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
@ -8,31 +9,82 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
def run_functions_in_parallel( def run_functions_tuples_in_parallel(
functions_with_args: dict[Callable, tuple] functions_with_args: list[tuple[Callable, tuple]],
) -> dict[str, Any]: allow_failures: bool = False,
) -> list[Any]:
""" """
Executes multiple functions in parallel and returns a dictionary with the results. Executes multiple functions in parallel and returns a list of the results for each function.
Args: Args:
functions_with_args (dict): A dictionary mapping functions to a tuple of arguments. functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
allow_failures: if set to True, then the function result will just be None
Returns: Returns:
dict: A dictionary mapping function names to their results or error messages. dict: A dictionary mapping function names to their results or error messages.
""" """
results = {} results = []
with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor: with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor:
future_to_function = { future_to_index = {
executor.submit(func, *args): func.__name__ executor.submit(func, *args): i
for func, args in functions_with_args.items() for i, (func, args) in enumerate(functions_with_args)
} }
for future in as_completed(future_to_function): for future in as_completed(future_to_index):
function_name = future_to_function[future] index = future_to_index[future]
try: try:
results[function_name] = future.result() results.append((index, future.result()))
except Exception as e: except Exception as e:
logger.exception(f"Function {function_name} failed due to {e}") logger.exception(f"Function at index {index} failed due to {e}")
raise results.append((index, None))
if not allow_failures:
raise
results.sort(key=lambda x: x[0])
return [result for index, result in results]
class FunctionCall:
"""
Container for run_functions_in_parallel, fetch the results from the output of
run_functions_in_parallel via the FunctionCall.result_id.
"""
def __init__(self, func: Callable, args: tuple = (), kwargs: dict | None = None):
self.func = func
self.args = args
self.kwargs = kwargs if kwargs is not None else {}
self.result_id = str(uuid.uuid4())
def execute(self) -> Any:
return self.func(*self.args, **self.kwargs)
def run_functions_in_parallel(
function_calls: list[FunctionCall],
allow_failures: bool = False,
) -> dict[str, Any]:
"""
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
are the result_id of the FunctionCall and the values are the results of the call.
"""
results = {}
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(func_call.execute): func_call.result_id
for func_call in function_calls
}
for future in as_completed(future_to_id):
result_id = future_to_id[future]
try:
results[result_id] = future.result()
except Exception as e:
logger.exception(f"Function with ID {result_id} failed due to {e}")
results[result_id] = None
if not allow_failures:
raise
return results return results

View File

@ -32,6 +32,7 @@ services:
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
# Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-} - DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
# Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)

View File

@ -116,6 +116,11 @@ export const DocumentDisplay = ({
}: DocumentDisplayProps) => { }: DocumentDisplayProps) => {
const [isHovered, setIsHovered] = useState(false); const [isHovered, setIsHovered] = useState(false);
// Consider reintroducing null scored docs in the future
if (document.score === null) {
return null;
}
return ( return (
<div <div
key={document.semantic_identifier} key={document.semantic_identifier}
@ -126,23 +131,25 @@ export const DocumentDisplay = ({
onMouseLeave={() => setIsHovered(false)} onMouseLeave={() => setIsHovered(false)}
> >
<div className="flex relative"> <div className="flex relative">
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex"> {document.score !== null && (
<div <div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex">
className={` <div
text-xs className={`
text-gray-200 text-xs
bg-gray-800 text-gray-200
rounded bg-gray-800
p-0.5 rounded
w-fit p-0.5
my-auto w-fit
select-none my-auto
ml-auto select-none
mr-2`} ml-auto
> mr-2`}
{document.score.toFixed(2)} >
{document.score.toFixed(2)}
</div>
</div> </div>
</div> )}
<a <a
className={ className={
"rounded-lg flex font-bold " + "rounded-lg flex font-bold " +