LLM Chunk Filtering (#735)

This commit is contained in:
Yuhong Sun 2023-11-18 17:12:24 -08:00 committed by GitHub
parent d5916e420c
commit fa0d19cc8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 551 additions and 292 deletions

View File

@ -46,28 +46,32 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
## 💃 Features
* Direct QA powered by Generative AI models with answers backed by quotes and source links.
* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs.
* An AI Helper backed by a custom Deep Learning model to interpret user intent.
* Intelligent Document Retrieval (Hybrid Search + Reranking) using the latest NLP models.
* Automatic time/source filter extraction from natural language + custom model to identify user intent.
* User authentication and document level access management.
* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.)
* Management Dashboard to manage connectors and set up features such as live update fetching.
* Support for LLMs of your choice (GPT-4, Llama2, Orca, etc.)
* Management Dashboards to manage connectors and set up features such as live update fetching.
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
## 🔌 Connectors
Danswer currently syncs documents (every 10 minutes) from:
Efficiently pulls the latest changes from:
* Slack
* GitHub
* Google Drive
* Confluence
* Jira
* Notion
* Gong
* Slab
* Linear
* Productboard
* Guru
* Zulip
* Bookstack
* Document360
* Request Tracker
* Hubspot
* Local Files
* Websites
* With more to come...
@ -75,7 +79,9 @@ Danswer currently syncs documents (every 10 minutes) from:
## 🚧 Roadmap
* Chat/Conversation support.
* Organizational understanding.
* Ability to locate and suggest experts.
* Code Search
* Structured Query Languages (SQL, Excel formulas, etc.)
* Ability to locate and suggest experts from your team.
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@ -140,18 +140,15 @@ def danswer_chat_retrieval(
)
# Good Debug/Breakpoint
ranked_chunks, unranked_chunks = search_chunks(
top_chunks, _ = search_chunks(
query=search_query, document_index=get_default_document_index()
)
if not ranked_chunks:
if not top_chunks:
return []
if unranked_chunks:
ranked_chunks.extend(unranked_chunks)
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
chunk for chunk in top_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit

View File

@ -178,8 +178,12 @@ MINI_CHUNK_SIZE = 150
NUM_RETURNED_HITS = 50
NUM_RERANKED_RESULTS = 15
# We feed in document chunks until we reach this token limit.
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
# may be smaller which could result in passing in more total chunks
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
# significantly smaller which could result in passing in more total chunks.
# There is also a slight bit of overhead, not accounted for here such as separator patterns
# between the docs, metadata for the docs, etc.
# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
# model token limit
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
)
@ -198,12 +202,14 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2
DISABLE_LLM_FILTER_EXTRACTION = (
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
)
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
DISABLE_LLM_CHUNK_FILTER = (
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
)
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
# Include additional document/chunk metadata in prompt to GenerativeAI
INCLUDE_METADATA = False
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
# Keyword Search Drop Stopwords
# If user has changed the default model, would most likely be to use a multilingual
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords

View File

@ -1,3 +1,4 @@
import os
FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true"
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"

View File

@ -35,6 +35,8 @@ SCORE = "score"
ID_SEPARATOR = ":;:"
DEFAULT_BOOST = 0
SESSION_KEY = "session"
QUERY_EVENT_ID = "query_event_id"
LLM_CHUNKS = "llm_chunks"
class DocumentSource(str, Enum):

View File

@ -232,7 +232,7 @@ def handle_message(
logger.debug(answer.answer)
return True
if not answer.top_ranked_docs:
if not answer.top_documents:
logger.error(f"Unable to answer question: '{msg}' - no documents found")
# Optionally, respond in thread with the error message, Used primarily
# for debugging purposes
@ -265,8 +265,17 @@ def handle_message(
favor_recent=answer.favor_recent,
)
# Get the chunks fed to the LLM only, then fill with other docs
top_docs = answer.top_documents
llm_doc_inds = answer.llm_chunks_indices or []
llm_docs = [top_docs[i] for i in llm_doc_inds]
remaining_docs = [
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
]
priority_ordered_docs = llm_docs + remaining_docs
document_blocks = build_documents_blocks(
documents=answer.top_ranked_docs, query_event_id=answer.query_event_id
documents=priority_ordered_docs,
query_event_id=answer.query_event_id,
)
try:

View File

@ -9,7 +9,7 @@ from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from danswer.configs.app_configs import HARD_DELETE_CHATS
from danswer.configs.chat_configs import HARD_DELETE_CHATS
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession

View File

@ -1,19 +1,19 @@
from collections.abc import Callable
from collections.abc import Iterator
from functools import partial
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.constants import QUERY_EVENT_ID
from danswer.db.feedback import update_query_event_llm_answer
from danswer.db.models import User
from danswer.direct_qa.factory import get_default_qa_model
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.direct_qa.qa_utils import get_chunks_for_qa
from danswer.document_index.factory import get_default_document_index
from danswer.search.danswer_helper import query_intent
from danswer.search.models import QueryFlow
@ -24,11 +24,12 @@ from danswer.search.search_runner import danswer_search
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
from danswer.server.models import QADocsResponse
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
@ -54,24 +55,34 @@ def answer_qa_query(
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
functions_to_run: dict[Callable, tuple] = {
extract_question_time_filters: (question,),
extract_question_source_filters: (question, db_session),
query_intent: (query,),
}
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
run_source_filters = FunctionCall(
extract_question_source_filters, (question, db_session), {}
)
run_query_intent = FunctionCall(query_intent, (query,), {})
parallel_results = run_functions_in_parallel(functions_to_run)
parallel_results = run_functions_in_parallel(
[
run_time_filters,
run_source_filters,
run_query_intent,
]
)
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
source_filters = parallel_results["extract_question_source_filters"]
predicted_search, predicted_flow = parallel_results["query_intent"]
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
source_filters = parallel_results[run_source_filters.result_id]
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
# Set flow as search so frontend doesn't ask the user if they want to run QA over more docs
if disable_generative_answer:
predicted_flow = QueryFlow.SEARCH
# Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
question.filters.source_type = source_filters
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
question=question,
user=user,
db_session=db_session,
@ -80,38 +91,23 @@ def answer_qa_query(
rerank_metrics_callback=rerank_metrics_callback,
)
if not ranked_chunks:
return QAResponse(
top_docs = chunks_to_search_docs(top_chunks)
partial_response = partial(
QAResponse,
top_documents=chunks_to_search_docs(top_chunks),
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
if disable_generative_answer or not top_docs:
return partial_response(
answer=None,
quotes=None,
top_ranked_docs=None,
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
# set flow as search so frontend doesn't ask the user if they want
# to run QA over more documents
predicted_flow=QueryFlow.SEARCH,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
try:
@ -119,41 +115,28 @@ def answer_qa_query(
timeout=answer_generation_timeout, real_time_flow=real_time_flow
)
except Exception as e:
return QAResponse(
return partial_response(
answer=None,
quotes=None,
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=str(e),
)
# remove chunks marked as not applicable for QA (e.g. Google Drive file
# types which can't be parsed). These chunks are useful to show in the
# search results, but not for QA.
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
)
error_msg = None
try:
d_answer, quotes = qa_model.answer_question(
query, usable_chunks, metrics_callback=llm_metrics_callback
query, llm_chunks, metrics_callback=llm_metrics_callback
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
@ -169,37 +152,17 @@ def answer_qa_query(
user_id=None if user is None else user.id,
)
validity = None
if not real_time_flow and enable_reflexion and d_answer is not None:
valid = False
validity = False
if d_answer.answer is not None:
valid = get_answer_validity(query, d_answer.answer)
validity = get_answer_validity(query, d_answer.answer)
return QAResponse(
answer=d_answer.answer if d_answer else None,
quotes=quotes.quotes if quotes else None,
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
eval_res_valid=True if valid else False,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
error_msg=error_msg,
)
return QAResponse(
return partial_response(
answer=d_answer.answer if d_answer else None,
quotes=quotes.quotes if quotes else None,
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
eval_res_valid=validity,
llm_chunks_indices=llm_chunks_indices,
error_msg=error_msg,
)
@ -220,36 +183,47 @@ def answer_qa_query_stream(
query = question.query
offset_count = question.offset if question.offset is not None else 0
functions_to_run: dict[Callable, tuple] = {
extract_question_time_filters: (question,),
extract_question_source_filters: (question, db_session),
query_intent: (query,),
}
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
run_source_filters = FunctionCall(
extract_question_source_filters, (question, db_session), {}
)
run_query_intent = FunctionCall(query_intent, (query,), {})
parallel_results = run_functions_in_parallel(functions_to_run)
parallel_results = run_functions_in_parallel(
[
run_time_filters,
run_source_filters,
run_query_intent,
]
)
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
source_filters = parallel_results["extract_question_source_filters"]
predicted_search, predicted_flow = parallel_results["query_intent"]
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
source_filters = parallel_results[run_source_filters.result_id]
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
# Modifies the question object but nothing upstream uses it
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
question.filters.source_type = source_filters
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
question=question,
user=user,
db_session=db_session,
document_index=get_default_document_index(),
)
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
top_docs = chunks_to_search_docs(top_chunks)
initial_response = RerankedRetrievalDocs(
llm_chunks_indices = get_chunks_for_qa(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
batch_offset=offset_count,
)
initial_response = QADocsResponse(
top_documents=top_docs,
unranked_top_documents=unranked_top_docs,
llm_chunks_indices=llm_chunks_indices,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH
@ -260,10 +234,9 @@ def answer_qa_query_stream(
favor_recent=favor_recent,
).dict()
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
yield get_json_line(initial_response)
if not ranked_chunks:
if not top_chunks:
logger.debug("No Documents Found")
return
@ -279,25 +252,13 @@ def answer_qa_query_stream(
yield get_json_line(error.dict())
return
# remove chunks marked as not applicable for QA (e.g. Google Drive file
# types which can't be parsed). These chunks are useful to show in the
# search results, but not for QA.
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
)
try:
for response_packet in qa_model.answer_question_stream(query, usable_chunks):
for response_packet in qa_model.answer_question_stream(query, llm_chunks):
if response_packet is None:
continue
if (
@ -321,4 +282,4 @@ def answer_qa_query_stream(
user_id=None if user is None else user.id,
)
yield get_json_line({"query_event_id": query_event_id})
yield get_json_line({QUERY_EVENT_ID: query_event_id})

View File

@ -11,6 +11,7 @@ import regex
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
@ -316,3 +317,57 @@ def get_usable_chunks(
offset_into_chunks += len(usable_chunks)
return usable_chunks
def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
batch_offset: int = 0,
) -> list[int]:
"""
Gives back indices of chunks to pass into the LLM for Q&A.
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
by the LLM in a separate flow (this can be turned off)
Note, the batch_offset calculation has to count the batches from the beginning each time as
there's no way to know which chunks were included in the prior batches without recounting atm,
this is somewhat slow as it requires tokenizing all the chunks again
"""
batch_index = 0
latest_batch_indices: list[int] = []
token_count = 0
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
for selection_target in [True, False]:
for ind, chunk in enumerate(chunks):
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
IGNORE_FOR_QA
):
continue
# We calculate it live in case the user uses a different LLM + tokenizer
chunk_token = check_number_of_tokens(chunk.content)
token_count += chunk_token
# Always use at least 1 chunk
if token_count <= token_limit or not latest_batch_indices:
latest_batch_indices.append(ind)
current_chunk_unused = False
else:
current_chunk_unused = True
if token_count >= token_limit:
if batch_index < batch_offset:
batch_index += 1
if current_chunk_unused:
latest_batch_indices = [ind]
token_count = chunk_token
else:
latest_batch_indices = []
token_count = 0
else:
return latest_batch_indices
return latest_batch_indices

View File

@ -33,11 +33,6 @@ class LangChainChatLLM(LLM, abc.ABC):
def llm(self) -> BaseChatModel:
raise NotImplementedError
def _log_model_config(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
@staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
@ -46,8 +41,12 @@ class LangChainChatLLM(LLM, abc.ABC):
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
def log_model_configs(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
@ -58,7 +57,6 @@ class LangChainChatLLM(LLM, abc.ABC):
return model_raw
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)

View File

@ -9,6 +9,10 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.llm.interfaces import LLM
from danswer.llm.utils import convert_lm_input_to_basic_string
from danswer.utils.logger import setup_logger
logger = setup_logger()
class CustomModelServer(LLM):
@ -65,6 +69,9 @@ class CustomModelServer(LLM):
response.raise_for_status()
return json.loads(response.content).get("generated_text", "")
def log_model_configs(self) -> None:
logger.debug(f"Custom model at: {self._endpoint}")
def invoke(self, prompt: LanguageModelInput) -> str:
return self._execute(prompt)

View File

@ -61,6 +61,11 @@ class DanswerGPT4All(LLM):
self.temperature = temperature
self.gpt4all_model = GPT4All(model_version)
def log_model_configs(self) -> None:
logger.debug(
f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}"
)
def invoke(self, prompt: LanguageModelInput) -> str:
prompt_basic = convert_lm_input_to_basic_string(prompt)
return self.gpt4all_model.generate(prompt_basic)

View File

@ -22,6 +22,10 @@ class LLM(abc.ABC):
def requires_api_key(self) -> bool:
return True
@abc.abstractmethod
def log_model_configs(self) -> None:
raise NotImplementedError
@abc.abstractmethod
def invoke(self, prompt: LanguageModelInput) -> str:
raise NotImplementedError

View File

@ -36,6 +36,7 @@ from danswer.configs.model_configs import SKIP_RERANKING
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.factory import get_default_qa_model
from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm
from danswer.server.cc_pair.api import router as cc_pair_router
from danswer.server.chat_backend import router as chat_router
from danswer.server.connector import router as connector_router
@ -197,7 +198,7 @@ def get_application() -> FastAPI:
warm_up_models()
# This is for the LLM, most LLMs will not need warming up
# It logs for itself
get_default_llm().log_model_configs()
get_default_qa_model().warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")

View File

@ -132,6 +132,29 @@ Note: The "file" source only applies to when the user refers to uploaded files i
""".strip()
USEFUL_PAT = "Yes useful"
NONUSEFUL_PAT = "Not useful"
CHUNK_FILTER_PROMPT = f"""
Determine if the reference section is USEFUL for answering the user query.
It is NOT enough for the section to be related to the query, \
it must contain information that is USEFUL for answering the query.
If the section contains ANY useful information, that is good enough, \
it does not need to fully answer the every part of the user query.
Reference Section:
```
{{chunk_text}}
```
User Query:
```
{{user_query}}
```
Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
""".strip()
# User the following for easy viewing of prompts
if __name__ == "__main__":
print(ANSWERABLE_PROMPT)

View File

@ -3,6 +3,7 @@ from enum import Enum
from pydantic import BaseModel
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
@ -57,6 +58,9 @@ class SearchQuery(BaseModel):
skip_rerank: bool = SKIP_RERANKING
# Only used if not skip_rerank
num_rerank: int | None = NUM_RERANKED_RESULTS
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER
# Only used if not skip_llm_chunk_filter
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
class RetrievalMetricsContainer(BaseModel):

View File

@ -8,7 +8,9 @@ from nltk.tokenize import word_tokenize # type:ignore
from sentence_transformers import SentenceTransformer # type: ignore
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.app_configs import HYBRID_ALPHA
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
@ -33,9 +35,12 @@ from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
@ -147,7 +152,12 @@ def semantic_reranking(
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
) -> list[InferenceChunk]:
) -> tuple[list[InferenceChunk], list[int]]:
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
of the chunks in their new sorted order.
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
@ -168,16 +178,20 @@ def semantic_reranking(
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
model_max - model_min
)
scored_results = list(zip(normalized_b_s_scores, raw_sim_scores, chunks))
orig_indices = [i for i in range(len(normalized_b_s_scores))]
scored_results = list(
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
)
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_raw_scores, ranked_chunks = zip(*scored_results)
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
*scored_results
)
logger.debug(
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
)
# Assign new chunk scores based on reranking
# TODO if pagination is added, the scores won't make sense with respect to the non-reranked hits
for ind, chunk in enumerate(ranked_chunks):
chunk.score = ranked_sim_scores[ind]
@ -198,7 +212,7 @@ def semantic_reranking(
)
)
return list(ranked_chunks)
return list(ranked_chunks), list(ranked_indices)
def apply_boost_legacy(
@ -257,6 +271,9 @@ def apply_boost_legacy(
def apply_boost(
chunks: list[InferenceChunk],
# Need the range of values to not be too spread out for applying boost
# therefore norm across only the top few results
norm_cutoff: int = NUM_RERANKED_RESULTS,
norm_min: float = SIM_SCORE_RANGE_LOW,
norm_max: float = SIM_SCORE_RANGE_HIGH,
) -> list[InferenceChunk]:
@ -266,13 +283,13 @@ def apply_boost(
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
recency_multiplier = [chunk.recency_bias for chunk in chunks]
norm_min = min(norm_min, min(scores))
norm_max = max(norm_max, max(scores))
norm_min = min(norm_min, min(scores[:norm_cutoff]))
norm_max = max(norm_max, max(scores[:norm_cutoff]))
# This should never be 0 unless user has done some weird/wrong settings
norm_range = norm_max - norm_min
boosted_scores = [
(score - norm_min) * boost * recency / norm_range
max(0, (score - norm_min) * boost * recency / norm_range)
for score, boost, recency in zip(scores, boosts, recency_multiplier)
]
@ -299,7 +316,14 @@ def search_chunks(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
) -> tuple[list[InferenceChunk], list[bool]]:
"""Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM.
For sake of speed, the system cannot rerank all retrieved chunks
Also pass the chunks through LLM to determine if they are relevant (binary for speed)
Only the first max_llm_filter_chunks
"""
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
top_links = [
c.source_links[0] if c.source_links is not None else "No Link"
@ -316,7 +340,7 @@ def search_chunks(
f"{query.search_type.value.capitalize()} search returned no results "
f"with filters: {query.filters}"
)
return None, None
return [], []
if retrieval_metrics_callback is not None:
chunk_metrics = [
@ -332,27 +356,62 @@ def search_chunks(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
)
# Keyword Search should never do reranking, no transformers involved in this flow
if query.search_type == SearchType.KEYWORD:
functions_to_run: list[FunctionCall] = []
# Keyword Search should not do reranking
if query.search_type == SearchType.KEYWORD or query.skip_rerank:
_log_top_chunk_links(query.search_type.value, top_chunks)
return top_chunks, None
run_rerank_id: str | None = None
else:
run_rerank = FunctionCall(
semantic_reranking,
(query.query, top_chunks[: query.num_rerank]),
{"rerank_metrics_callback": rerank_metrics_callback},
)
functions_to_run.append(run_rerank)
run_rerank_id = run_rerank.result_id
if query.skip_rerank:
# Need the range of values to not be too spread out for applying boost
# Therefore pass in smaller set of chunks to limit the range for norm-ing
boosted_chunks = apply_boost(top_chunks[: query.num_rerank])
_log_top_chunk_links(query.search_type.value, boosted_chunks)
return boosted_chunks, top_chunks[query.num_rerank :]
run_llm_filter_id = None
if not query.skip_llm_chunk_filter:
run_llm_filter = FunctionCall(
llm_batch_eval_chunks,
(
query.query,
[chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]],
),
{},
)
functions_to_run.append(run_llm_filter)
run_llm_filter_id = run_llm_filter.result_id
ranked_chunks = semantic_reranking(
query.query,
top_chunks[: query.num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
parallel_results = run_functions_in_parallel(functions_to_run)
ranked_results = parallel_results.get(str(run_rerank_id))
if ranked_results is None:
ranked_chunks = top_chunks
sorted_indices = [i for i in range(len(top_chunks))]
else:
ranked_chunks, orig_indices = ranked_results
sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks)))
lower_chunks = top_chunks[query.num_rerank :]
# Scores from rerank cannot be meaningfully combined with scores without rerank
for lower_chunk in lower_chunks:
lower_chunk.score = None
ranked_chunks.extend(lower_chunks)
llm_chunk_selection = parallel_results.get(str(run_llm_filter_id))
if llm_chunk_selection is None:
reranked_llm_chunk_selection = [True for _ in top_chunks]
else:
llm_chunk_selection.extend(
[False for _ in top_chunks[query.max_llm_filter_chunks :]]
)
reranked_llm_chunk_selection = [
llm_chunk_selection[ind] for ind in sorted_indices
]
_log_top_chunk_links(query.search_type.value, ranked_chunks)
return ranked_chunks, top_chunks[query.num_rerank :]
return ranked_chunks, reranked_llm_chunk_selection
def danswer_search(
@ -360,10 +419,11 @@ def danswer_search(
user: User | None,
db_session: Session,
document_index: DocumentIndex,
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None, int]:
) -> tuple[list[InferenceChunk], list[bool], int]:
query_event_id = create_query_event(
query=question.query,
search_type=question.search_type,
@ -384,17 +444,21 @@ def danswer_search(
query=question.query,
search_type=question.search_type,
filters=final_filters,
favor_recent=True if question.favor_recent is None else question.favor_recent,
# Still applies time decay but not magnified
favor_recent=question.favor_recent
if question.favor_recent is not None
else False,
skip_llm_chunk_filter=skip_llm_chunk_filter,
)
ranked_chunks, unranked_chunks = search_chunks(
top_chunks, llm_chunk_selection = search_chunks(
query=search_query,
document_index=document_index,
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
retrieved_ids = [doc.document_id for doc in ranked_chunks] if ranked_chunks else []
retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else []
update_query_event_retrieved_documents(
db_session=db_session,
@ -403,4 +467,4 @@ def danswer_search(
user_id=None if user is None else user.id,
)
return ranked_chunks, unranked_chunks, query_event_id
return top_chunks, llm_chunk_selection, query_event_id

View File

@ -0,0 +1,65 @@
from collections.abc import Callable
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.prompts.secondary_llm_flows import CHUNK_FILTER_PROMPT
from danswer.prompts.secondary_llm_flows import NONUSEFUL_PAT
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger()
def llm_eval_chunk(query: str, chunk_content: str) -> bool:
def _get_usefulness_messages() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": CHUNK_FILTER_PROMPT.format(
chunk_text=chunk_content, user_query=query
),
},
]
return messages
def _extract_usefulness(model_output: str) -> bool:
"""Default useful if the LLM doesn't match pattern exactly
This is because it's better to trust the (re)ranking if LLM fails"""
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
return False
return True
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
# When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds
model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt)
logger.debug(model_output)
return _extract_usefulness(model_output)
def llm_batch_eval_chunks(
query: str, chunk_contents: list[str], use_threads: bool = True
) -> list[bool]:
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents
]
logger.debug(
"Running LLM usefulness eval in parallel (following logging may be out of order)"
)
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
# In case of failure/timeout, don't throw out the chunk
return [True if item is None else item for item in parallel_results]
else:
return [
llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents
]

View File

@ -171,12 +171,58 @@ class SearchDoc(BaseModel):
return initial_dict
class QuestionRequest(BaseModel):
query: str
collection: str
filters: BaseFilters
offset: int | None
enable_auto_detect_filters: bool
favor_recent: bool | None = None
search_type: SearchType = SearchType.HYBRID
class QAFeedbackRequest(BaseModel):
query_id: int
feedback: QAFeedbackType
class SearchFeedbackRequest(BaseModel):
query_id: int
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool
class RetrievalDocs(BaseModel):
top_documents: list[SearchDoc]
class RerankedRetrievalDocs(RetrievalDocs):
unranked_top_documents: list[SearchDoc]
class SearchResponse(RetrievalDocs):
query_event_id: int
source_type: list[DocumentSource] | None
time_cutoff: datetime | None
favor_recent: bool
class QAResponse(SearchResponse):
answer: str | None # DanswerAnswer
quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
llm_chunks_indices: list[int] | None = None
error_msg: str | None = None
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs):
llm_chunks_indices: list[int]
predicted_flow: QueryFlow
predicted_search: SearchType
time_cutoff: datetime | None
@ -194,21 +240,6 @@ class CreateChatSessionID(BaseModel):
chat_session_id: int
class QuestionRequest(BaseModel):
query: str
collection: str
filters: BaseFilters
offset: int | None
enable_auto_detect_filters: bool
favor_recent: bool | None = None
search_type: SearchType = SearchType.HYBRID
class QAFeedbackRequest(BaseModel):
query_id: int
feedback: QAFeedbackType
class ChatFeedbackRequest(BaseModel):
chat_session_id: int
message_number: int
@ -217,14 +248,6 @@ class ChatFeedbackRequest(BaseModel):
feedback_text: str | None = None
class SearchFeedbackRequest(BaseModel):
query_id: int
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType
class CreateChatMessageRequest(BaseModel):
chat_session_id: int
message_number: int
@ -280,30 +303,6 @@ class ChatSessionDetailResponse(BaseModel):
messages: list[ChatMessageDetail]
class QueryValidationResponse(BaseModel):
reasoning: str
answerable: bool
class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None
lower_ranked_docs: list[SearchDoc] | None
query_event_id: int
source_type: list[DocumentSource] | None
time_cutoff: datetime | None
favor_recent: bool
class QAResponse(SearchResponse):
answer: str | None # DanswerAnswer
quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow
predicted_search: SearchType
eval_res_valid: bool | None = None
error_msg: str | None = None
class UserByEmail(BaseModel):
user_email: str

View File

@ -1,5 +1,3 @@
from collections.abc import Callable
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
@ -36,6 +34,7 @@ from danswer.server.models import SearchDoc
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
logger = setup_logger()
@ -131,10 +130,10 @@ def handle_search_request(
query = question.query
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
functions_to_run: dict[Callable, tuple] = {
extract_question_time_filters: (question,),
extract_question_source_filters: (question, db_session),
}
functions_to_run = [
FunctionCall(extract_question_time_filters, (question,), {}),
FunctionCall(extract_question_source_filters, (question, db_session), {}),
]
parallel_results = run_functions_in_parallel(functions_to_run)
@ -145,29 +144,18 @@ def handle_search_request(
question.favor_recent = favor_recent
question.filters.source_type = source_filters
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
top_chunks, _, query_event_id = danswer_search(
question=question,
user=user,
db_session=db_session,
document_index=get_default_document_index(),
skip_llm_chunk_filter=True,
)
if not ranked_chunks:
return SearchResponse(
top_ranked_docs=None,
lower_ranked_docs=None,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
)
top_docs = chunks_to_search_docs(ranked_chunks)
lower_top_docs = chunks_to_search_docs(unranked_chunks)
top_docs = chunks_to_search_docs(top_chunks)
return SearchResponse(
top_ranked_docs=top_docs,
lower_ranked_docs=lower_top_docs or None,
top_documents=top_docs,
query_event_id=query_event_id,
source_type=source_filters,
time_cutoff=time_cutoff,

View File

@ -37,16 +37,20 @@ def optional_telemetry(record_type: RecordType, data: dict) -> None:
try:
def telemetry_logic() -> None:
payload = {
"data": data,
"record": record_type,
"customer_uuid": get_or_generate_uuid(),
}
requests.post(
DANSWER_TELEMETRY_ENDPOINT,
headers={"Content-Type": "application/json"},
json=payload,
)
try:
payload = {
"data": data,
"record": record_type,
"customer_uuid": get_or_generate_uuid(),
}
requests.post(
DANSWER_TELEMETRY_ENDPOINT,
headers={"Content-Type": "application/json"},
json=payload,
)
except Exception:
# This way it silences all thread level logging as well
pass
# Run in separate thread to have minimal overhead in main flows
thread = threading.Thread(target=telemetry_logic, daemon=True)

View File

@ -1,3 +1,4 @@
import uuid
from collections.abc import Callable
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
@ -8,31 +9,82 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def run_functions_in_parallel(
functions_with_args: dict[Callable, tuple]
) -> dict[str, Any]:
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
allow_failures: bool = False,
) -> list[Any]:
"""
Executes multiple functions in parallel and returns a dictionary with the results.
Executes multiple functions in parallel and returns a list of the results for each function.
Args:
functions_with_args (dict): A dictionary mapping functions to a tuple of arguments.
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
allow_failures: if set to True, then the function result will just be None
Returns:
dict: A dictionary mapping function names to their results or error messages.
"""
results = {}
results = []
with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor:
future_to_function = {
executor.submit(func, *args): func.__name__
for func, args in functions_with_args.items()
future_to_index = {
executor.submit(func, *args): i
for i, (func, args) in enumerate(functions_with_args)
}
for future in as_completed(future_to_function):
function_name = future_to_function[future]
for future in as_completed(future_to_index):
index = future_to_index[future]
try:
results[function_name] = future.result()
results.append((index, future.result()))
except Exception as e:
logger.exception(f"Function {function_name} failed due to {e}")
raise
logger.exception(f"Function at index {index} failed due to {e}")
results.append((index, None))
if not allow_failures:
raise
results.sort(key=lambda x: x[0])
return [result for index, result in results]
class FunctionCall:
"""
Container for run_functions_in_parallel, fetch the results from the output of
run_functions_in_parallel via the FunctionCall.result_id.
"""
def __init__(self, func: Callable, args: tuple = (), kwargs: dict | None = None):
self.func = func
self.args = args
self.kwargs = kwargs if kwargs is not None else {}
self.result_id = str(uuid.uuid4())
def execute(self) -> Any:
return self.func(*self.args, **self.kwargs)
def run_functions_in_parallel(
function_calls: list[FunctionCall],
allow_failures: bool = False,
) -> dict[str, Any]:
"""
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
are the result_id of the FunctionCall and the values are the results of the call.
"""
results = {}
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
future_to_id = {
executor.submit(func_call.execute): func_call.result_id
for func_call in function_calls
}
for future in as_completed(future_to_id):
result_id = future_to_id[future]
try:
results[result_id] = future.result()
except Exception as e:
logger.exception(f"Function with ID {result_id} failed due to {e}")
results[result_id] = None
if not allow_failures:
raise
return results

View File

@ -32,6 +32,7 @@ services:
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
# Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
# Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)

View File

@ -116,6 +116,11 @@ export const DocumentDisplay = ({
}: DocumentDisplayProps) => {
const [isHovered, setIsHovered] = useState(false);
// Consider reintroducing null scored docs in the future
if (document.score === null) {
return null;
}
return (
<div
key={document.semantic_identifier}
@ -126,23 +131,25 @@ export const DocumentDisplay = ({
onMouseLeave={() => setIsHovered(false)}
>
<div className="flex relative">
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex">
<div
className={`
text-xs
text-gray-200
bg-gray-800
rounded
p-0.5
w-fit
my-auto
select-none
ml-auto
mr-2`}
>
{document.score.toFixed(2)}
{document.score !== null && (
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex">
<div
className={`
text-xs
text-gray-200
bg-gray-800
rounded
p-0.5
w-fit
my-auto
select-none
ml-auto
mr-2`}
>
{document.score.toFixed(2)}
</div>
</div>
</div>
)}
<a
className={
"rounded-lg flex font-bold " +