mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 04:18:32 +02:00
LLM Chunk Filtering (#735)
This commit is contained in:
parent
d5916e420c
commit
fa0d19cc8c
18
README.md
18
README.md
@ -46,28 +46,32 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
||||
|
||||
## 💃 Features
|
||||
* Direct QA powered by Generative AI models with answers backed by quotes and source links.
|
||||
* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs.
|
||||
* An AI Helper backed by a custom Deep Learning model to interpret user intent.
|
||||
* Intelligent Document Retrieval (Hybrid Search + Reranking) using the latest NLP models.
|
||||
* Automatic time/source filter extraction from natural language + custom model to identify user intent.
|
||||
* User authentication and document level access management.
|
||||
* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.)
|
||||
* Management Dashboard to manage connectors and set up features such as live update fetching.
|
||||
* Support for LLMs of your choice (GPT-4, Llama2, Orca, etc.)
|
||||
* Management Dashboards to manage connectors and set up features such as live update fetching.
|
||||
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
|
||||
|
||||
## 🔌 Connectors
|
||||
|
||||
Danswer currently syncs documents (every 10 minutes) from:
|
||||
Efficiently pulls the latest changes from:
|
||||
* Slack
|
||||
* GitHub
|
||||
* Google Drive
|
||||
* Confluence
|
||||
* Jira
|
||||
* Notion
|
||||
* Gong
|
||||
* Slab
|
||||
* Linear
|
||||
* Productboard
|
||||
* Guru
|
||||
* Zulip
|
||||
* Bookstack
|
||||
* Document360
|
||||
* Request Tracker
|
||||
* Hubspot
|
||||
* Local Files
|
||||
* Websites
|
||||
* With more to come...
|
||||
@ -75,7 +79,9 @@ Danswer currently syncs documents (every 10 minutes) from:
|
||||
## 🚧 Roadmap
|
||||
* Chat/Conversation support.
|
||||
* Organizational understanding.
|
||||
* Ability to locate and suggest experts.
|
||||
* Code Search
|
||||
* Structured Query Languages (SQL, Excel formulas, etc.)
|
||||
* Ability to locate and suggest experts from your team.
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
@ -140,18 +140,15 @@ def danswer_chat_retrieval(
|
||||
)
|
||||
|
||||
# Good Debug/Breakpoint
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
top_chunks, _ = search_chunks(
|
||||
query=search_query, document_index=get_default_document_index()
|
||||
)
|
||||
|
||||
if not ranked_chunks:
|
||||
if not top_chunks:
|
||||
return []
|
||||
|
||||
if unranked_chunks:
|
||||
ranked_chunks.extend(unranked_chunks)
|
||||
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
chunk for chunk in top_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
|
@ -178,8 +178,12 @@ MINI_CHUNK_SIZE = 150
|
||||
NUM_RETURNED_HITS = 50
|
||||
NUM_RERANKED_RESULTS = 15
|
||||
# We feed in document chunks until we reach this token limit.
|
||||
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
|
||||
# may be smaller which could result in passing in more total chunks
|
||||
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
|
||||
# significantly smaller which could result in passing in more total chunks.
|
||||
# There is also a slight bit of overhead, not accounted for here such as separator patterns
|
||||
# between the docs, metadata for the docs, etc.
|
||||
# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
|
||||
# model token limit
|
||||
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
||||
)
|
||||
@ -198,12 +202,14 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2
|
||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
||||
DISABLE_LLM_CHUNK_FILTER = (
|
||||
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
|
||||
)
|
||||
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||
INCLUDE_METADATA = False
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||
# Keyword Search Drop Stopwords
|
||||
# If user has changed the default model, would most likely be to use a multilingual
|
||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
|
||||
FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true"
|
||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||
|
@ -35,6 +35,8 @@ SCORE = "score"
|
||||
ID_SEPARATOR = ":;:"
|
||||
DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
QUERY_EVENT_ID = "query_event_id"
|
||||
LLM_CHUNKS = "llm_chunks"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
|
@ -232,7 +232,7 @@ def handle_message(
|
||||
logger.debug(answer.answer)
|
||||
return True
|
||||
|
||||
if not answer.top_ranked_docs:
|
||||
if not answer.top_documents:
|
||||
logger.error(f"Unable to answer question: '{msg}' - no documents found")
|
||||
# Optionally, respond in thread with the error message, Used primarily
|
||||
# for debugging purposes
|
||||
@ -265,8 +265,17 @@ def handle_message(
|
||||
favor_recent=answer.favor_recent,
|
||||
)
|
||||
|
||||
# Get the chunks fed to the LLM only, then fill with other docs
|
||||
top_docs = answer.top_documents
|
||||
llm_doc_inds = answer.llm_chunks_indices or []
|
||||
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||
remaining_docs = [
|
||||
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||
]
|
||||
priority_ordered_docs = llm_docs + remaining_docs
|
||||
document_blocks = build_documents_blocks(
|
||||
documents=answer.top_ranked_docs, query_event_id=answer.query_event_id
|
||||
documents=priority_ordered_docs,
|
||||
query_event_id=answer.query_event_id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
@ -9,7 +9,7 @@ from sqlalchemy.exc import NoResultFound
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import ChatSession
|
||||
|
@ -1,19 +1,19 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from functools import partial
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.constants import QUERY_EVENT_ID
|
||||
from danswer.db.feedback import update_query_event_llm_answer
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.factory import get_default_qa_model
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.direct_qa.qa_utils import get_chunks_for_qa
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.models import QueryFlow
|
||||
@ -24,11 +24,12 @@ from danswer.search.search_runner import danswer_search
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||
from danswer.server.models import QADocsResponse
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import RerankedRetrievalDocs
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
@ -54,24 +55,34 @@ def answer_qa_query(
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
logger.info(f"Received QA query: {query}")
|
||||
|
||||
functions_to_run: dict[Callable, tuple] = {
|
||||
extract_question_time_filters: (question,),
|
||||
extract_question_source_filters: (question, db_session),
|
||||
query_intent: (query,),
|
||||
}
|
||||
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
|
||||
run_source_filters = FunctionCall(
|
||||
extract_question_source_filters, (question, db_session), {}
|
||||
)
|
||||
run_query_intent = FunctionCall(query_intent, (query,), {})
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
parallel_results = run_functions_in_parallel(
|
||||
[
|
||||
run_time_filters,
|
||||
run_source_filters,
|
||||
run_query_intent,
|
||||
]
|
||||
)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
||||
source_filters = parallel_results["extract_question_source_filters"]
|
||||
predicted_search, predicted_flow = parallel_results["query_intent"]
|
||||
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
|
||||
source_filters = parallel_results[run_source_filters.result_id]
|
||||
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
|
||||
|
||||
# Set flow as search so frontend doesn't ask the user if they want to run QA over more docs
|
||||
if disable_generative_answer:
|
||||
predicted_flow = QueryFlow.SEARCH
|
||||
|
||||
# Modifies the question object but nothing upstream uses it
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
@ -80,38 +91,23 @@ def answer_qa_query(
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
if not ranked_chunks:
|
||||
return QAResponse(
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
partial_response = partial(
|
||||
QAResponse,
|
||||
top_documents=chunks_to_search_docs(top_chunks),
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
if disable_generative_answer or not top_docs:
|
||||
return partial_response(
|
||||
answer=None,
|
||||
quotes=None,
|
||||
top_ranked_docs=None,
|
||||
lower_ranked_docs=None,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
|
||||
if disable_generative_answer:
|
||||
logger.debug("Skipping QA because generative AI is disabled")
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
quotes=None,
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=unranked_top_docs,
|
||||
# set flow as search so frontend doesn't ask the user if they want
|
||||
# to run QA over more documents
|
||||
predicted_flow=QueryFlow.SEARCH,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -119,41 +115,28 @@ def answer_qa_query(
|
||||
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
||||
)
|
||||
except Exception as e:
|
||||
return QAResponse(
|
||||
return partial_response(
|
||||
answer=None,
|
||||
quotes=None,
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=unranked_top_docs,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
error_msg=str(e),
|
||||
)
|
||||
|
||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
||||
# types which can't be parsed). These chunks are useful to show in the
|
||||
# search results, but not for QA.
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
offset=offset_count,
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
batch_offset=offset_count,
|
||||
)
|
||||
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
|
||||
logger.debug(
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
|
||||
)
|
||||
|
||||
error_msg = None
|
||||
try:
|
||||
d_answer, quotes = qa_model.answer_question(
|
||||
query, usable_chunks, metrics_callback=llm_metrics_callback
|
||||
query, llm_chunks, metrics_callback=llm_metrics_callback
|
||||
)
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
@ -169,37 +152,17 @@ def answer_qa_query(
|
||||
user_id=None if user is None else user.id,
|
||||
)
|
||||
|
||||
validity = None
|
||||
if not real_time_flow and enable_reflexion and d_answer is not None:
|
||||
valid = False
|
||||
validity = False
|
||||
if d_answer.answer is not None:
|
||||
valid = get_answer_validity(query, d_answer.answer)
|
||||
validity = get_answer_validity(query, d_answer.answer)
|
||||
|
||||
return QAResponse(
|
||||
answer=d_answer.answer if d_answer else None,
|
||||
quotes=quotes.quotes if quotes else None,
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=unranked_top_docs,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
eval_res_valid=True if valid else False,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
return QAResponse(
|
||||
return partial_response(
|
||||
answer=d_answer.answer if d_answer else None,
|
||||
quotes=quotes.quotes if quotes else None,
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=unranked_top_docs,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
eval_res_valid=validity,
|
||||
llm_chunks_indices=llm_chunks_indices,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
@ -220,36 +183,47 @@ def answer_qa_query_stream(
|
||||
query = question.query
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
|
||||
functions_to_run: dict[Callable, tuple] = {
|
||||
extract_question_time_filters: (question,),
|
||||
extract_question_source_filters: (question, db_session),
|
||||
query_intent: (query,),
|
||||
}
|
||||
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
|
||||
run_source_filters = FunctionCall(
|
||||
extract_question_source_filters, (question, db_session), {}
|
||||
)
|
||||
run_query_intent = FunctionCall(query_intent, (query,), {})
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
parallel_results = run_functions_in_parallel(
|
||||
[
|
||||
run_time_filters,
|
||||
run_source_filters,
|
||||
run_query_intent,
|
||||
]
|
||||
)
|
||||
|
||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
||||
source_filters = parallel_results["extract_question_source_filters"]
|
||||
predicted_search, predicted_flow = parallel_results["query_intent"]
|
||||
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
|
||||
source_filters = parallel_results[run_source_filters.result_id]
|
||||
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
|
||||
|
||||
# Modifies the question object but nothing upstream uses it
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
initial_response = RerankedRetrievalDocs(
|
||||
llm_chunks_indices = get_chunks_for_qa(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
batch_offset=offset_count,
|
||||
)
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
top_documents=top_docs,
|
||||
unranked_top_documents=unranked_top_docs,
|
||||
llm_chunks_indices=llm_chunks_indices,
|
||||
# if generative AI is disabled, set flow as search so frontend
|
||||
# doesn't ask the user if they want to run QA over more documents
|
||||
predicted_flow=QueryFlow.SEARCH
|
||||
@ -260,10 +234,9 @@ def answer_qa_query_stream(
|
||||
favor_recent=favor_recent,
|
||||
).dict()
|
||||
|
||||
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
if not ranked_chunks:
|
||||
if not top_chunks:
|
||||
logger.debug("No Documents Found")
|
||||
return
|
||||
|
||||
@ -279,25 +252,13 @@ def answer_qa_query_stream(
|
||||
yield get_json_line(error.dict())
|
||||
return
|
||||
|
||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
||||
# types which can't be parsed). These chunks are useful to show in the
|
||||
# search results, but not for QA.
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
offset=offset_count,
|
||||
)
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
logger.debug(
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
|
||||
)
|
||||
|
||||
try:
|
||||
for response_packet in qa_model.answer_question_stream(query, usable_chunks):
|
||||
for response_packet in qa_model.answer_question_stream(query, llm_chunks):
|
||||
if response_packet is None:
|
||||
continue
|
||||
if (
|
||||
@ -321,4 +282,4 @@ def answer_qa_query_stream(
|
||||
user_id=None if user is None else user.id,
|
||||
)
|
||||
|
||||
yield get_json_line({"query_event_id": query_event_id})
|
||||
yield get_json_line({QUERY_EVENT_ID: query_event_id})
|
||||
|
@ -11,6 +11,7 @@ import regex
|
||||
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
@ -316,3 +317,57 @@ def get_usable_chunks(
|
||||
offset_into_chunks += len(usable_chunks)
|
||||
|
||||
return usable_chunks
|
||||
|
||||
|
||||
def get_chunks_for_qa(
|
||||
chunks: list[InferenceChunk],
|
||||
llm_chunk_selection: list[bool],
|
||||
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
batch_offset: int = 0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Gives back indices of chunks to pass into the LLM for Q&A.
|
||||
|
||||
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
|
||||
by the LLM in a separate flow (this can be turned off)
|
||||
|
||||
Note, the batch_offset calculation has to count the batches from the beginning each time as
|
||||
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||
this is somewhat slow as it requires tokenizing all the chunks again
|
||||
"""
|
||||
batch_index = 0
|
||||
latest_batch_indices: list[int] = []
|
||||
token_count = 0
|
||||
|
||||
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
|
||||
for selection_target in [True, False]:
|
||||
for ind, chunk in enumerate(chunks):
|
||||
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
|
||||
IGNORE_FOR_QA
|
||||
):
|
||||
continue
|
||||
|
||||
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||
chunk_token = check_number_of_tokens(chunk.content)
|
||||
token_count += chunk_token
|
||||
|
||||
# Always use at least 1 chunk
|
||||
if token_count <= token_limit or not latest_batch_indices:
|
||||
latest_batch_indices.append(ind)
|
||||
current_chunk_unused = False
|
||||
else:
|
||||
current_chunk_unused = True
|
||||
|
||||
if token_count >= token_limit:
|
||||
if batch_index < batch_offset:
|
||||
batch_index += 1
|
||||
if current_chunk_unused:
|
||||
latest_batch_indices = [ind]
|
||||
token_count = chunk_token
|
||||
else:
|
||||
latest_batch_indices = []
|
||||
token_count = 0
|
||||
else:
|
||||
return latest_batch_indices
|
||||
|
||||
return latest_batch_indices
|
||||
|
@ -33,11 +33,6 @@ class LangChainChatLLM(LLM, abc.ABC):
|
||||
def llm(self) -> BaseChatModel:
|
||||
raise NotImplementedError
|
||||
|
||||
def _log_model_config(self) -> None:
|
||||
logger.debug(
|
||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
@ -46,8 +41,12 @@ class LangChainChatLLM(LLM, abc.ABC):
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(
|
||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||
)
|
||||
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
self._log_model_config()
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
self._log_prompt(prompt)
|
||||
|
||||
@ -58,7 +57,6 @@ class LangChainChatLLM(LLM, abc.ABC):
|
||||
return model_raw
|
||||
|
||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||
self._log_model_config()
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
self._log_prompt(prompt)
|
||||
|
||||
|
@ -9,6 +9,10 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CustomModelServer(LLM):
|
||||
@ -65,6 +69,9 @@ class CustomModelServer(LLM):
|
||||
response.raise_for_status()
|
||||
return json.loads(response.content).get("generated_text", "")
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Custom model at: {self._endpoint}")
|
||||
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
return self._execute(prompt)
|
||||
|
||||
|
@ -61,6 +61,11 @@ class DanswerGPT4All(LLM):
|
||||
self.temperature = temperature
|
||||
self.gpt4all_model = GPT4All(model_version)
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(
|
||||
f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}"
|
||||
)
|
||||
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
prompt_basic = convert_lm_input_to_basic_string(prompt)
|
||||
return self.gpt4all_model.generate(prompt_basic)
|
||||
|
@ -22,6 +22,10 @@ class LLM(abc.ABC):
|
||||
def requires_api_key(self) -> bool:
|
||||
return True
|
||||
|
||||
@abc.abstractmethod
|
||||
def log_model_configs(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
raise NotImplementedError
|
||||
|
@ -36,6 +36,7 @@ from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.factory import get_default_qa_model
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.server.cc_pair.api import router as cc_pair_router
|
||||
from danswer.server.chat_backend import router as chat_router
|
||||
from danswer.server.connector import router as connector_router
|
||||
@ -197,7 +198,7 @@ def get_application() -> FastAPI:
|
||||
warm_up_models()
|
||||
|
||||
# This is for the LLM, most LLMs will not need warming up
|
||||
# It logs for itself
|
||||
get_default_llm().log_model_configs()
|
||||
get_default_qa_model().warm_up_model()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
|
@ -132,6 +132,29 @@ Note: The "file" source only applies to when the user refers to uploaded files i
|
||||
""".strip()
|
||||
|
||||
|
||||
USEFUL_PAT = "Yes useful"
|
||||
NONUSEFUL_PAT = "Not useful"
|
||||
CHUNK_FILTER_PROMPT = f"""
|
||||
Determine if the reference section is USEFUL for answering the user query.
|
||||
It is NOT enough for the section to be related to the query, \
|
||||
it must contain information that is USEFUL for answering the query.
|
||||
If the section contains ANY useful information, that is good enough, \
|
||||
it does not need to fully answer the every part of the user query.
|
||||
|
||||
Reference Section:
|
||||
```
|
||||
{{chunk_text}}
|
||||
```
|
||||
|
||||
User Query:
|
||||
```
|
||||
{{user_query}}
|
||||
```
|
||||
|
||||
Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
|
||||
""".strip()
|
||||
|
||||
|
||||
# User the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(ANSWERABLE_PROMPT)
|
||||
|
@ -3,6 +3,7 @@ from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@ -57,6 +58,9 @@ class SearchQuery(BaseModel):
|
||||
skip_rerank: bool = SKIP_RERANKING
|
||||
# Only used if not skip_rerank
|
||||
num_rerank: int | None = NUM_RERANKED_RESULTS
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER
|
||||
# Only used if not skip_llm_chunk_filter
|
||||
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
|
||||
|
||||
|
||||
class RetrievalMetricsContainer(BaseModel):
|
||||
|
@ -8,7 +8,9 @@ from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.app_configs import HYBRID_ALPHA
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
@ -33,9 +35,12 @@ from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
@ -147,7 +152,12 @@ def semantic_reranking(
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||
) -> list[InferenceChunk]:
|
||||
) -> tuple[list[InferenceChunk], list[int]]:
|
||||
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
|
||||
of the chunks in their new sorted order.
|
||||
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
cross_encoders = CrossEncoderEnsembleModel()
|
||||
passages = [chunk.content for chunk in chunks]
|
||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||
@ -168,16 +178,20 @@ def semantic_reranking(
|
||||
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
||||
model_max - model_min
|
||||
)
|
||||
scored_results = list(zip(normalized_b_s_scores, raw_sim_scores, chunks))
|
||||
orig_indices = [i for i in range(len(normalized_b_s_scores))]
|
||||
scored_results = list(
|
||||
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
|
||||
)
|
||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||
ranked_sim_scores, ranked_raw_scores, ranked_chunks = zip(*scored_results)
|
||||
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
|
||||
*scored_results
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
|
||||
)
|
||||
|
||||
# Assign new chunk scores based on reranking
|
||||
# TODO if pagination is added, the scores won't make sense with respect to the non-reranked hits
|
||||
for ind, chunk in enumerate(ranked_chunks):
|
||||
chunk.score = ranked_sim_scores[ind]
|
||||
|
||||
@ -198,7 +212,7 @@ def semantic_reranking(
|
||||
)
|
||||
)
|
||||
|
||||
return list(ranked_chunks)
|
||||
return list(ranked_chunks), list(ranked_indices)
|
||||
|
||||
|
||||
def apply_boost_legacy(
|
||||
@ -257,6 +271,9 @@ def apply_boost_legacy(
|
||||
|
||||
def apply_boost(
|
||||
chunks: list[InferenceChunk],
|
||||
# Need the range of values to not be too spread out for applying boost
|
||||
# therefore norm across only the top few results
|
||||
norm_cutoff: int = NUM_RERANKED_RESULTS,
|
||||
norm_min: float = SIM_SCORE_RANGE_LOW,
|
||||
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
||||
) -> list[InferenceChunk]:
|
||||
@ -266,13 +283,13 @@ def apply_boost(
|
||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
||||
|
||||
norm_min = min(norm_min, min(scores))
|
||||
norm_max = max(norm_max, max(scores))
|
||||
norm_min = min(norm_min, min(scores[:norm_cutoff]))
|
||||
norm_max = max(norm_max, max(scores[:norm_cutoff]))
|
||||
# This should never be 0 unless user has done some weird/wrong settings
|
||||
norm_range = norm_max - norm_min
|
||||
|
||||
boosted_scores = [
|
||||
(score - norm_min) * boost * recency / norm_range
|
||||
max(0, (score - norm_min) * boost * recency / norm_range)
|
||||
for score, boost, recency in zip(scores, boosts, recency_multiplier)
|
||||
]
|
||||
|
||||
@ -299,7 +316,14 @@ def search_chunks(
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
||||
) -> tuple[list[InferenceChunk], list[bool]]:
|
||||
"""Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM.
|
||||
For sake of speed, the system cannot rerank all retrieved chunks
|
||||
Also pass the chunks through LLM to determine if they are relevant (binary for speed)
|
||||
|
||||
Only the first max_llm_filter_chunks
|
||||
"""
|
||||
|
||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||
top_links = [
|
||||
c.source_links[0] if c.source_links is not None else "No Link"
|
||||
@ -316,7 +340,7 @@ def search_chunks(
|
||||
f"{query.search_type.value.capitalize()} search returned no results "
|
||||
f"with filters: {query.filters}"
|
||||
)
|
||||
return None, None
|
||||
return [], []
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
@ -332,27 +356,62 @@ def search_chunks(
|
||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||
)
|
||||
|
||||
# Keyword Search should never do reranking, no transformers involved in this flow
|
||||
if query.search_type == SearchType.KEYWORD:
|
||||
functions_to_run: list[FunctionCall] = []
|
||||
|
||||
# Keyword Search should not do reranking
|
||||
if query.search_type == SearchType.KEYWORD or query.skip_rerank:
|
||||
_log_top_chunk_links(query.search_type.value, top_chunks)
|
||||
return top_chunks, None
|
||||
run_rerank_id: str | None = None
|
||||
else:
|
||||
run_rerank = FunctionCall(
|
||||
semantic_reranking,
|
||||
(query.query, top_chunks[: query.num_rerank]),
|
||||
{"rerank_metrics_callback": rerank_metrics_callback},
|
||||
)
|
||||
functions_to_run.append(run_rerank)
|
||||
run_rerank_id = run_rerank.result_id
|
||||
|
||||
if query.skip_rerank:
|
||||
# Need the range of values to not be too spread out for applying boost
|
||||
# Therefore pass in smaller set of chunks to limit the range for norm-ing
|
||||
boosted_chunks = apply_boost(top_chunks[: query.num_rerank])
|
||||
_log_top_chunk_links(query.search_type.value, boosted_chunks)
|
||||
return boosted_chunks, top_chunks[query.num_rerank :]
|
||||
run_llm_filter_id = None
|
||||
if not query.skip_llm_chunk_filter:
|
||||
run_llm_filter = FunctionCall(
|
||||
llm_batch_eval_chunks,
|
||||
(
|
||||
query.query,
|
||||
[chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]],
|
||||
),
|
||||
{},
|
||||
)
|
||||
functions_to_run.append(run_llm_filter)
|
||||
run_llm_filter_id = run_llm_filter.result_id
|
||||
|
||||
ranked_chunks = semantic_reranking(
|
||||
query.query,
|
||||
top_chunks[: query.num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
ranked_results = parallel_results.get(str(run_rerank_id))
|
||||
if ranked_results is None:
|
||||
ranked_chunks = top_chunks
|
||||
sorted_indices = [i for i in range(len(top_chunks))]
|
||||
else:
|
||||
ranked_chunks, orig_indices = ranked_results
|
||||
sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks)))
|
||||
lower_chunks = top_chunks[query.num_rerank :]
|
||||
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||
for lower_chunk in lower_chunks:
|
||||
lower_chunk.score = None
|
||||
ranked_chunks.extend(lower_chunks)
|
||||
|
||||
llm_chunk_selection = parallel_results.get(str(run_llm_filter_id))
|
||||
if llm_chunk_selection is None:
|
||||
reranked_llm_chunk_selection = [True for _ in top_chunks]
|
||||
else:
|
||||
llm_chunk_selection.extend(
|
||||
[False for _ in top_chunks[query.max_llm_filter_chunks :]]
|
||||
)
|
||||
reranked_llm_chunk_selection = [
|
||||
llm_chunk_selection[ind] for ind in sorted_indices
|
||||
]
|
||||
_log_top_chunk_links(query.search_type.value, ranked_chunks)
|
||||
|
||||
return ranked_chunks, top_chunks[query.num_rerank :]
|
||||
return ranked_chunks, reranked_llm_chunk_selection
|
||||
|
||||
|
||||
def danswer_search(
|
||||
@ -360,10 +419,11 @@ def danswer_search(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None, int]:
|
||||
) -> tuple[list[InferenceChunk], list[bool], int]:
|
||||
query_event_id = create_query_event(
|
||||
query=question.query,
|
||||
search_type=question.search_type,
|
||||
@ -384,17 +444,21 @@ def danswer_search(
|
||||
query=question.query,
|
||||
search_type=question.search_type,
|
||||
filters=final_filters,
|
||||
favor_recent=True if question.favor_recent is None else question.favor_recent,
|
||||
# Still applies time decay but not magnified
|
||||
favor_recent=question.favor_recent
|
||||
if question.favor_recent is not None
|
||||
else False,
|
||||
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||
)
|
||||
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
top_chunks, llm_chunk_selection = search_chunks(
|
||||
query=search_query,
|
||||
document_index=document_index,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
retrieved_ids = [doc.document_id for doc in ranked_chunks] if ranked_chunks else []
|
||||
retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else []
|
||||
|
||||
update_query_event_retrieved_documents(
|
||||
db_session=db_session,
|
||||
@ -403,4 +467,4 @@ def danswer_search(
|
||||
user_id=None if user is None else user.id,
|
||||
)
|
||||
|
||||
return ranked_chunks, unranked_chunks, query_event_id
|
||||
return top_chunks, llm_chunk_selection, query_event_id
|
||||
|
65
backend/danswer/secondary_llm_flows/chunk_usefulness.py
Normal file
65
backend/danswer/secondary_llm_flows/chunk_usefulness.py
Normal file
@ -0,0 +1,65 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.prompts.secondary_llm_flows import CHUNK_FILTER_PROMPT
|
||||
from danswer.prompts.secondary_llm_flows import NONUSEFUL_PAT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
||||
def _get_usefulness_messages() -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": CHUNK_FILTER_PROMPT.format(
|
||||
chunk_text=chunk_content, user_query=query
|
||||
),
|
||||
},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
def _extract_usefulness(model_output: str) -> bool:
|
||||
"""Default useful if the LLM doesn't match pattern exactly
|
||||
This is because it's better to trust the (re)ranking if LLM fails"""
|
||||
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
|
||||
return False
|
||||
return True
|
||||
|
||||
messages = _get_usefulness_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
# When running in a batch, it takes as long as the longest thread
|
||||
# And when running a large batch, one may fail and take the whole timeout
|
||||
# instead cap it to 5 seconds
|
||||
model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_usefulness(model_output)
|
||||
|
||||
|
||||
def llm_batch_eval_chunks(
|
||||
query: str, chunk_contents: list[str], use_threads: bool = True
|
||||
) -> list[bool]:
|
||||
if use_threads:
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
"Running LLM usefulness eval in parallel (following logging may be out of order)"
|
||||
)
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
|
||||
# In case of failure/timeout, don't throw out the chunk
|
||||
return [True if item is None else item for item in parallel_results]
|
||||
|
||||
else:
|
||||
return [
|
||||
llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents
|
||||
]
|
@ -171,12 +171,58 @@ class SearchDoc(BaseModel):
|
||||
return initial_dict
|
||||
|
||||
|
||||
class QuestionRequest(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
filters: BaseFilters
|
||||
offset: int | None
|
||||
enable_auto_detect_filters: bool
|
||||
favor_recent: bool | None = None
|
||||
search_type: SearchType = SearchType.HYBRID
|
||||
|
||||
|
||||
class QAFeedbackRequest(BaseModel):
|
||||
query_id: int
|
||||
feedback: QAFeedbackType
|
||||
|
||||
|
||||
class SearchFeedbackRequest(BaseModel):
|
||||
query_id: int
|
||||
document_id: str
|
||||
document_rank: int
|
||||
click: bool
|
||||
search_feedback: SearchFeedbackType
|
||||
|
||||
|
||||
class QueryValidationResponse(BaseModel):
|
||||
reasoning: str
|
||||
answerable: bool
|
||||
|
||||
|
||||
class RetrievalDocs(BaseModel):
|
||||
top_documents: list[SearchDoc]
|
||||
|
||||
|
||||
class RerankedRetrievalDocs(RetrievalDocs):
|
||||
unranked_top_documents: list[SearchDoc]
|
||||
class SearchResponse(RetrievalDocs):
|
||||
query_event_id: int
|
||||
source_type: list[DocumentSource] | None
|
||||
time_cutoff: datetime | None
|
||||
favor_recent: bool
|
||||
|
||||
|
||||
class QAResponse(SearchResponse):
|
||||
answer: str | None # DanswerAnswer
|
||||
quotes: list[DanswerQuote] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs):
|
||||
llm_chunks_indices: list[int]
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
time_cutoff: datetime | None
|
||||
@ -194,21 +240,6 @@ class CreateChatSessionID(BaseModel):
|
||||
chat_session_id: int
|
||||
|
||||
|
||||
class QuestionRequest(BaseModel):
|
||||
query: str
|
||||
collection: str
|
||||
filters: BaseFilters
|
||||
offset: int | None
|
||||
enable_auto_detect_filters: bool
|
||||
favor_recent: bool | None = None
|
||||
search_type: SearchType = SearchType.HYBRID
|
||||
|
||||
|
||||
class QAFeedbackRequest(BaseModel):
|
||||
query_id: int
|
||||
feedback: QAFeedbackType
|
||||
|
||||
|
||||
class ChatFeedbackRequest(BaseModel):
|
||||
chat_session_id: int
|
||||
message_number: int
|
||||
@ -217,14 +248,6 @@ class ChatFeedbackRequest(BaseModel):
|
||||
feedback_text: str | None = None
|
||||
|
||||
|
||||
class SearchFeedbackRequest(BaseModel):
|
||||
query_id: int
|
||||
document_id: str
|
||||
document_rank: int
|
||||
click: bool
|
||||
search_feedback: SearchFeedbackType
|
||||
|
||||
|
||||
class CreateChatMessageRequest(BaseModel):
|
||||
chat_session_id: int
|
||||
message_number: int
|
||||
@ -280,30 +303,6 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
|
||||
class QueryValidationResponse(BaseModel):
|
||||
reasoning: str
|
||||
answerable: bool
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
||||
top_ranked_docs: list[SearchDoc] | None
|
||||
lower_ranked_docs: list[SearchDoc] | None
|
||||
query_event_id: int
|
||||
source_type: list[DocumentSource] | None
|
||||
time_cutoff: datetime | None
|
||||
favor_recent: bool
|
||||
|
||||
|
||||
class QAResponse(SearchResponse):
|
||||
answer: str | None # DanswerAnswer
|
||||
quotes: list[DanswerQuote] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
error_msg: str | None = None
|
||||
|
||||
|
||||
class UserByEmail(BaseModel):
|
||||
user_email: str
|
||||
|
||||
|
@ -1,5 +1,3 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@ -36,6 +34,7 @@ from danswer.server.models import SearchDoc
|
||||
from danswer.server.models import SearchFeedbackRequest
|
||||
from danswer.server.models import SearchResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
@ -131,10 +130,10 @@ def handle_search_request(
|
||||
query = question.query
|
||||
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
|
||||
|
||||
functions_to_run: dict[Callable, tuple] = {
|
||||
extract_question_time_filters: (question,),
|
||||
extract_question_source_filters: (question, db_session),
|
||||
}
|
||||
functions_to_run = [
|
||||
FunctionCall(extract_question_time_filters, (question,), {}),
|
||||
FunctionCall(extract_question_source_filters, (question, db_session), {}),
|
||||
]
|
||||
|
||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||
|
||||
@ -145,29 +144,18 @@ def handle_search_request(
|
||||
question.favor_recent = favor_recent
|
||||
question.filters.source_type = source_filters
|
||||
|
||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||
top_chunks, _, query_event_id = danswer_search(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
skip_llm_chunk_filter=True,
|
||||
)
|
||||
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(
|
||||
top_ranked_docs=None,
|
||||
lower_ranked_docs=None,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
lower_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
top_docs = chunks_to_search_docs(top_chunks)
|
||||
|
||||
return SearchResponse(
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=lower_top_docs or None,
|
||||
top_documents=top_docs,
|
||||
query_event_id=query_event_id,
|
||||
source_type=source_filters,
|
||||
time_cutoff=time_cutoff,
|
||||
|
@ -37,16 +37,20 @@ def optional_telemetry(record_type: RecordType, data: dict) -> None:
|
||||
try:
|
||||
|
||||
def telemetry_logic() -> None:
|
||||
payload = {
|
||||
"data": data,
|
||||
"record": record_type,
|
||||
"customer_uuid": get_or_generate_uuid(),
|
||||
}
|
||||
requests.post(
|
||||
DANSWER_TELEMETRY_ENDPOINT,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
try:
|
||||
payload = {
|
||||
"data": data,
|
||||
"record": record_type,
|
||||
"customer_uuid": get_or_generate_uuid(),
|
||||
}
|
||||
requests.post(
|
||||
DANSWER_TELEMETRY_ENDPOINT,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=payload,
|
||||
)
|
||||
except Exception:
|
||||
# This way it silences all thread level logging as well
|
||||
pass
|
||||
|
||||
# Run in separate thread to have minimal overhead in main flows
|
||||
thread = threading.Thread(target=telemetry_logic, daemon=True)
|
||||
|
@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
@ -8,31 +9,82 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def run_functions_in_parallel(
|
||||
functions_with_args: dict[Callable, tuple]
|
||||
) -> dict[str, Any]:
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
allow_failures: bool = False,
|
||||
) -> list[Any]:
|
||||
"""
|
||||
Executes multiple functions in parallel and returns a dictionary with the results.
|
||||
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||
|
||||
Args:
|
||||
functions_with_args (dict): A dictionary mapping functions to a tuple of arguments.
|
||||
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||
allow_failures: if set to True, then the function result will just be None
|
||||
|
||||
Returns:
|
||||
dict: A dictionary mapping function names to their results or error messages.
|
||||
"""
|
||||
results = {}
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor:
|
||||
future_to_function = {
|
||||
executor.submit(func, *args): func.__name__
|
||||
for func, args in functions_with_args.items()
|
||||
future_to_index = {
|
||||
executor.submit(func, *args): i
|
||||
for i, (func, args) in enumerate(functions_with_args)
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_function):
|
||||
function_name = future_to_function[future]
|
||||
for future in as_completed(future_to_index):
|
||||
index = future_to_index[future]
|
||||
try:
|
||||
results[function_name] = future.result()
|
||||
results.append((index, future.result()))
|
||||
except Exception as e:
|
||||
logger.exception(f"Function {function_name} failed due to {e}")
|
||||
raise
|
||||
logger.exception(f"Function at index {index} failed due to {e}")
|
||||
results.append((index, None))
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
|
||||
results.sort(key=lambda x: x[0])
|
||||
return [result for index, result in results]
|
||||
|
||||
|
||||
class FunctionCall:
|
||||
"""
|
||||
Container for run_functions_in_parallel, fetch the results from the output of
|
||||
run_functions_in_parallel via the FunctionCall.result_id.
|
||||
"""
|
||||
|
||||
def __init__(self, func: Callable, args: tuple = (), kwargs: dict | None = None):
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs if kwargs is not None else {}
|
||||
self.result_id = str(uuid.uuid4())
|
||||
|
||||
def execute(self) -> Any:
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
def run_functions_in_parallel(
|
||||
function_calls: list[FunctionCall],
|
||||
allow_failures: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
|
||||
are the result_id of the FunctionCall and the values are the results of the call.
|
||||
"""
|
||||
results = {}
|
||||
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
||||
future_to_id = {
|
||||
executor.submit(func_call.execute): func_call.result_id
|
||||
for func_call in function_calls
|
||||
}
|
||||
|
||||
for future in as_completed(future_to_id):
|
||||
result_id = future_to_id[future]
|
||||
try:
|
||||
results[result_id] = future.result()
|
||||
except Exception as e:
|
||||
logger.exception(f"Function with ID {result_id} failed due to {e}")
|
||||
results[result_id] = None
|
||||
|
||||
if not allow_failures:
|
||||
raise
|
||||
|
||||
return results
|
||||
|
@ -32,6 +32,7 @@ services:
|
||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
||||
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
||||
# Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
||||
# Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||
|
@ -116,6 +116,11 @@ export const DocumentDisplay = ({
|
||||
}: DocumentDisplayProps) => {
|
||||
const [isHovered, setIsHovered] = useState(false);
|
||||
|
||||
// Consider reintroducing null scored docs in the future
|
||||
if (document.score === null) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
key={document.semantic_identifier}
|
||||
@ -126,23 +131,25 @@ export const DocumentDisplay = ({
|
||||
onMouseLeave={() => setIsHovered(false)}
|
||||
>
|
||||
<div className="flex relative">
|
||||
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex">
|
||||
<div
|
||||
className={`
|
||||
text-xs
|
||||
text-gray-200
|
||||
bg-gray-800
|
||||
rounded
|
||||
p-0.5
|
||||
w-fit
|
||||
my-auto
|
||||
select-none
|
||||
ml-auto
|
||||
mr-2`}
|
||||
>
|
||||
{document.score.toFixed(2)}
|
||||
{document.score !== null && (
|
||||
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex">
|
||||
<div
|
||||
className={`
|
||||
text-xs
|
||||
text-gray-200
|
||||
bg-gray-800
|
||||
rounded
|
||||
p-0.5
|
||||
w-fit
|
||||
my-auto
|
||||
select-none
|
||||
ml-auto
|
||||
mr-2`}
|
||||
>
|
||||
{document.score.toFixed(2)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
<a
|
||||
className={
|
||||
"rounded-lg flex font-bold " +
|
||||
|
Loading…
x
Reference in New Issue
Block a user