mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-21 13:30:59 +02:00
LLM Chunk Filtering (#735)
This commit is contained in:
parent
d5916e420c
commit
fa0d19cc8c
18
README.md
18
README.md
@ -46,28 +46,32 @@ We also have built-in support for deployment on Kubernetes. Files for that can b
|
|||||||
|
|
||||||
## 💃 Features
|
## 💃 Features
|
||||||
* Direct QA powered by Generative AI models with answers backed by quotes and source links.
|
* Direct QA powered by Generative AI models with answers backed by quotes and source links.
|
||||||
* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs.
|
* Intelligent Document Retrieval (Hybrid Search + Reranking) using the latest NLP models.
|
||||||
* An AI Helper backed by a custom Deep Learning model to interpret user intent.
|
* Automatic time/source filter extraction from natural language + custom model to identify user intent.
|
||||||
* User authentication and document level access management.
|
* User authentication and document level access management.
|
||||||
* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.)
|
* Support for LLMs of your choice (GPT-4, Llama2, Orca, etc.)
|
||||||
* Management Dashboard to manage connectors and set up features such as live update fetching.
|
* Management Dashboards to manage connectors and set up features such as live update fetching.
|
||||||
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
|
* One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere.
|
||||||
|
|
||||||
## 🔌 Connectors
|
## 🔌 Connectors
|
||||||
|
|
||||||
Danswer currently syncs documents (every 10 minutes) from:
|
Efficiently pulls the latest changes from:
|
||||||
* Slack
|
* Slack
|
||||||
* GitHub
|
* GitHub
|
||||||
* Google Drive
|
* Google Drive
|
||||||
* Confluence
|
* Confluence
|
||||||
* Jira
|
* Jira
|
||||||
* Notion
|
* Notion
|
||||||
|
* Gong
|
||||||
* Slab
|
* Slab
|
||||||
* Linear
|
* Linear
|
||||||
* Productboard
|
* Productboard
|
||||||
* Guru
|
* Guru
|
||||||
* Zulip
|
* Zulip
|
||||||
* Bookstack
|
* Bookstack
|
||||||
|
* Document360
|
||||||
|
* Request Tracker
|
||||||
|
* Hubspot
|
||||||
* Local Files
|
* Local Files
|
||||||
* Websites
|
* Websites
|
||||||
* With more to come...
|
* With more to come...
|
||||||
@ -75,7 +79,9 @@ Danswer currently syncs documents (every 10 minutes) from:
|
|||||||
## 🚧 Roadmap
|
## 🚧 Roadmap
|
||||||
* Chat/Conversation support.
|
* Chat/Conversation support.
|
||||||
* Organizational understanding.
|
* Organizational understanding.
|
||||||
* Ability to locate and suggest experts.
|
* Code Search
|
||||||
|
* Structured Query Languages (SQL, Excel formulas, etc.)
|
||||||
|
* Ability to locate and suggest experts from your team.
|
||||||
|
|
||||||
## 💡 Contributing
|
## 💡 Contributing
|
||||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||||
|
@ -140,18 +140,15 @@ def danswer_chat_retrieval(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Good Debug/Breakpoint
|
# Good Debug/Breakpoint
|
||||||
ranked_chunks, unranked_chunks = search_chunks(
|
top_chunks, _ = search_chunks(
|
||||||
query=search_query, document_index=get_default_document_index()
|
query=search_query, document_index=get_default_document_index()
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ranked_chunks:
|
if not top_chunks:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
if unranked_chunks:
|
|
||||||
ranked_chunks.extend(unranked_chunks)
|
|
||||||
|
|
||||||
filtered_ranked_chunks = [
|
filtered_ranked_chunks = [
|
||||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
chunk for chunk in top_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||||
]
|
]
|
||||||
|
|
||||||
# get all chunks that fit into the token limit
|
# get all chunks that fit into the token limit
|
||||||
|
@ -178,8 +178,12 @@ MINI_CHUNK_SIZE = 150
|
|||||||
NUM_RETURNED_HITS = 50
|
NUM_RETURNED_HITS = 50
|
||||||
NUM_RERANKED_RESULTS = 15
|
NUM_RERANKED_RESULTS = 15
|
||||||
# We feed in document chunks until we reach this token limit.
|
# We feed in document chunks until we reach this token limit.
|
||||||
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks
|
# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be
|
||||||
# may be smaller which could result in passing in more total chunks
|
# significantly smaller which could result in passing in more total chunks.
|
||||||
|
# There is also a slight bit of overhead, not accounted for here such as separator patterns
|
||||||
|
# between the docs, metadata for the docs, etc.
|
||||||
|
# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the
|
||||||
|
# model token limit
|
||||||
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int(
|
||||||
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5)
|
||||||
)
|
)
|
||||||
@ -198,12 +202,14 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2
|
|||||||
DISABLE_LLM_FILTER_EXTRACTION = (
|
DISABLE_LLM_FILTER_EXTRACTION = (
|
||||||
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true"
|
||||||
)
|
)
|
||||||
# 1 edit per 2 characters, currently unused due to fuzzy match being too slow
|
DISABLE_LLM_CHUNK_FILTER = (
|
||||||
|
os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true"
|
||||||
|
)
|
||||||
|
# 1 edit per 20 characters, currently unused due to fuzzy match being too slow
|
||||||
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
QUOTE_ALLOWED_ERROR_PERCENT = 0.05
|
||||||
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds
|
||||||
# Include additional document/chunk metadata in prompt to GenerativeAI
|
# Include additional document/chunk metadata in prompt to GenerativeAI
|
||||||
INCLUDE_METADATA = False
|
INCLUDE_METADATA = False
|
||||||
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
|
||||||
# Keyword Search Drop Stopwords
|
# Keyword Search Drop Stopwords
|
||||||
# If user has changed the default model, would most likely be to use a multilingual
|
# If user has changed the default model, would most likely be to use a multilingual
|
||||||
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
# model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true"
|
FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true"
|
||||||
|
HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false"
|
||||||
|
@ -35,6 +35,8 @@ SCORE = "score"
|
|||||||
ID_SEPARATOR = ":;:"
|
ID_SEPARATOR = ":;:"
|
||||||
DEFAULT_BOOST = 0
|
DEFAULT_BOOST = 0
|
||||||
SESSION_KEY = "session"
|
SESSION_KEY = "session"
|
||||||
|
QUERY_EVENT_ID = "query_event_id"
|
||||||
|
LLM_CHUNKS = "llm_chunks"
|
||||||
|
|
||||||
|
|
||||||
class DocumentSource(str, Enum):
|
class DocumentSource(str, Enum):
|
||||||
|
@ -232,7 +232,7 @@ def handle_message(
|
|||||||
logger.debug(answer.answer)
|
logger.debug(answer.answer)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not answer.top_ranked_docs:
|
if not answer.top_documents:
|
||||||
logger.error(f"Unable to answer question: '{msg}' - no documents found")
|
logger.error(f"Unable to answer question: '{msg}' - no documents found")
|
||||||
# Optionally, respond in thread with the error message, Used primarily
|
# Optionally, respond in thread with the error message, Used primarily
|
||||||
# for debugging purposes
|
# for debugging purposes
|
||||||
@ -265,8 +265,17 @@ def handle_message(
|
|||||||
favor_recent=answer.favor_recent,
|
favor_recent=answer.favor_recent,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get the chunks fed to the LLM only, then fill with other docs
|
||||||
|
top_docs = answer.top_documents
|
||||||
|
llm_doc_inds = answer.llm_chunks_indices or []
|
||||||
|
llm_docs = [top_docs[i] for i in llm_doc_inds]
|
||||||
|
remaining_docs = [
|
||||||
|
doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds
|
||||||
|
]
|
||||||
|
priority_ordered_docs = llm_docs + remaining_docs
|
||||||
document_blocks = build_documents_blocks(
|
document_blocks = build_documents_blocks(
|
||||||
documents=answer.top_ranked_docs, query_event_id=answer.query_event_id
|
documents=priority_ordered_docs,
|
||||||
|
query_event_id=answer.query_event_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -9,7 +9,7 @@ from sqlalchemy.exc import NoResultFound
|
|||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.configs.app_configs import HARD_DELETE_CHATS
|
from danswer.configs.chat_configs import HARD_DELETE_CHATS
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
from danswer.db.models import ChatSession
|
from danswer.db.models import ChatSession
|
||||||
|
@ -1,19 +1,19 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import QUERY_EVENT_ID
|
||||||
from danswer.db.feedback import update_query_event_llm_answer
|
from danswer.db.feedback import update_query_event_llm_answer
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.factory import get_default_qa_model
|
from danswer.direct_qa.factory import get_default_qa_model
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.interfaces import StreamingError
|
from danswer.direct_qa.interfaces import StreamingError
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
from danswer.direct_qa.models import LLMMetricsContainer
|
||||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
from danswer.direct_qa.qa_utils import get_chunks_for_qa
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
@ -24,11 +24,12 @@ from danswer.search.search_runner import danswer_search
|
|||||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||||
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
from danswer.secondary_llm_flows.source_filter import extract_question_source_filters
|
||||||
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
from danswer.secondary_llm_flows.time_filter import extract_question_time_filters
|
||||||
|
from danswer.server.models import QADocsResponse
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import RerankedRetrievalDocs
|
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
from danswer.utils.timing import log_generator_function_time
|
from danswer.utils.timing import log_generator_function_time
|
||||||
@ -54,24 +55,34 @@ def answer_qa_query(
|
|||||||
offset_count = question.offset if question.offset is not None else 0
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
logger.info(f"Received QA query: {query}")
|
logger.info(f"Received QA query: {query}")
|
||||||
|
|
||||||
functions_to_run: dict[Callable, tuple] = {
|
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
|
||||||
extract_question_time_filters: (question,),
|
run_source_filters = FunctionCall(
|
||||||
extract_question_source_filters: (question, db_session),
|
extract_question_source_filters, (question, db_session), {}
|
||||||
query_intent: (query,),
|
)
|
||||||
}
|
run_query_intent = FunctionCall(query_intent, (query,), {})
|
||||||
|
|
||||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
parallel_results = run_functions_in_parallel(
|
||||||
|
[
|
||||||
|
run_time_filters,
|
||||||
|
run_source_filters,
|
||||||
|
run_query_intent,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
|
||||||
source_filters = parallel_results["extract_question_source_filters"]
|
source_filters = parallel_results[run_source_filters.result_id]
|
||||||
predicted_search, predicted_flow = parallel_results["query_intent"]
|
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
|
||||||
|
|
||||||
|
# Set flow as search so frontend doesn't ask the user if they want to run QA over more docs
|
||||||
|
if disable_generative_answer:
|
||||||
|
predicted_flow = QueryFlow.SEARCH
|
||||||
|
|
||||||
# Modifies the question object but nothing upstream uses it
|
# Modifies the question object but nothing upstream uses it
|
||||||
question.filters.time_cutoff = time_cutoff
|
question.filters.time_cutoff = time_cutoff
|
||||||
question.favor_recent = favor_recent
|
question.favor_recent = favor_recent
|
||||||
question.filters.source_type = source_filters
|
question.filters.source_type = source_filters
|
||||||
|
|
||||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
|
||||||
question=question,
|
question=question,
|
||||||
user=user,
|
user=user,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@ -80,38 +91,23 @@ def answer_qa_query(
|
|||||||
rerank_metrics_callback=rerank_metrics_callback,
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ranked_chunks:
|
top_docs = chunks_to_search_docs(top_chunks)
|
||||||
return QAResponse(
|
|
||||||
|
partial_response = partial(
|
||||||
|
QAResponse,
|
||||||
|
top_documents=chunks_to_search_docs(top_chunks),
|
||||||
|
predicted_flow=predicted_flow,
|
||||||
|
predicted_search=predicted_search,
|
||||||
|
query_event_id=query_event_id,
|
||||||
|
source_type=source_filters,
|
||||||
|
time_cutoff=time_cutoff,
|
||||||
|
favor_recent=favor_recent,
|
||||||
|
)
|
||||||
|
|
||||||
|
if disable_generative_answer or not top_docs:
|
||||||
|
return partial_response(
|
||||||
answer=None,
|
answer=None,
|
||||||
quotes=None,
|
quotes=None,
|
||||||
top_ranked_docs=None,
|
|
||||||
lower_ranked_docs=None,
|
|
||||||
predicted_flow=predicted_flow,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
source_type=source_filters,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
|
||||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
|
||||||
|
|
||||||
if disable_generative_answer:
|
|
||||||
logger.debug("Skipping QA because generative AI is disabled")
|
|
||||||
return QAResponse(
|
|
||||||
answer=None,
|
|
||||||
quotes=None,
|
|
||||||
top_ranked_docs=top_docs,
|
|
||||||
lower_ranked_docs=unranked_top_docs,
|
|
||||||
# set flow as search so frontend doesn't ask the user if they want
|
|
||||||
# to run QA over more documents
|
|
||||||
predicted_flow=QueryFlow.SEARCH,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
source_type=source_filters,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -119,41 +115,28 @@ def answer_qa_query(
|
|||||||
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return QAResponse(
|
return partial_response(
|
||||||
answer=None,
|
answer=None,
|
||||||
quotes=None,
|
quotes=None,
|
||||||
top_ranked_docs=top_docs,
|
|
||||||
lower_ranked_docs=unranked_top_docs,
|
|
||||||
predicted_flow=predicted_flow,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
source_type=source_filters,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
error_msg=str(e),
|
error_msg=str(e),
|
||||||
)
|
)
|
||||||
|
|
||||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
llm_chunks_indices = get_chunks_for_qa(
|
||||||
# types which can't be parsed). These chunks are useful to show in the
|
chunks=top_chunks,
|
||||||
# search results, but not for QA.
|
llm_chunk_selection=llm_chunk_selection,
|
||||||
filtered_ranked_chunks = [
|
batch_offset=offset_count,
|
||||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
|
||||||
]
|
|
||||||
|
|
||||||
# get all chunks that fit into the token limit
|
|
||||||
usable_chunks = get_usable_chunks(
|
|
||||||
chunks=filtered_ranked_chunks,
|
|
||||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
|
||||||
offset=offset_count,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
error_msg = None
|
error_msg = None
|
||||||
try:
|
try:
|
||||||
d_answer, quotes = qa_model.answer_question(
|
d_answer, quotes = qa_model.answer_question(
|
||||||
query, usable_chunks, metrics_callback=llm_metrics_callback
|
query, llm_chunks, metrics_callback=llm_metrics_callback
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
@ -169,37 +152,17 @@ def answer_qa_query(
|
|||||||
user_id=None if user is None else user.id,
|
user_id=None if user is None else user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
validity = None
|
||||||
if not real_time_flow and enable_reflexion and d_answer is not None:
|
if not real_time_flow and enable_reflexion and d_answer is not None:
|
||||||
valid = False
|
validity = False
|
||||||
if d_answer.answer is not None:
|
if d_answer.answer is not None:
|
||||||
valid = get_answer_validity(query, d_answer.answer)
|
validity = get_answer_validity(query, d_answer.answer)
|
||||||
|
|
||||||
return QAResponse(
|
return partial_response(
|
||||||
answer=d_answer.answer if d_answer else None,
|
|
||||||
quotes=quotes.quotes if quotes else None,
|
|
||||||
top_ranked_docs=top_docs,
|
|
||||||
lower_ranked_docs=unranked_top_docs,
|
|
||||||
predicted_flow=predicted_flow,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
eval_res_valid=True if valid else False,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
source_type=source_filters,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
error_msg=error_msg,
|
|
||||||
)
|
|
||||||
|
|
||||||
return QAResponse(
|
|
||||||
answer=d_answer.answer if d_answer else None,
|
answer=d_answer.answer if d_answer else None,
|
||||||
quotes=quotes.quotes if quotes else None,
|
quotes=quotes.quotes if quotes else None,
|
||||||
top_ranked_docs=top_docs,
|
eval_res_valid=validity,
|
||||||
lower_ranked_docs=unranked_top_docs,
|
llm_chunks_indices=llm_chunks_indices,
|
||||||
predicted_flow=predicted_flow,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
source_type=source_filters,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
error_msg=error_msg,
|
error_msg=error_msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -220,36 +183,47 @@ def answer_qa_query_stream(
|
|||||||
query = question.query
|
query = question.query
|
||||||
offset_count = question.offset if question.offset is not None else 0
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
|
|
||||||
functions_to_run: dict[Callable, tuple] = {
|
run_time_filters = FunctionCall(extract_question_time_filters, (question,), {})
|
||||||
extract_question_time_filters: (question,),
|
run_source_filters = FunctionCall(
|
||||||
extract_question_source_filters: (question, db_session),
|
extract_question_source_filters, (question, db_session), {}
|
||||||
query_intent: (query,),
|
)
|
||||||
}
|
run_query_intent = FunctionCall(query_intent, (query,), {})
|
||||||
|
|
||||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
parallel_results = run_functions_in_parallel(
|
||||||
|
[
|
||||||
|
run_time_filters,
|
||||||
|
run_source_filters,
|
||||||
|
run_query_intent,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
time_cutoff, favor_recent = parallel_results["extract_question_time_filters"]
|
time_cutoff, favor_recent = parallel_results[run_time_filters.result_id]
|
||||||
source_filters = parallel_results["extract_question_source_filters"]
|
source_filters = parallel_results[run_source_filters.result_id]
|
||||||
predicted_search, predicted_flow = parallel_results["query_intent"]
|
predicted_search, predicted_flow = parallel_results[run_query_intent.result_id]
|
||||||
|
|
||||||
# Modifies the question object but nothing upstream uses it
|
# Modifies the question object but nothing upstream uses it
|
||||||
question.filters.time_cutoff = time_cutoff
|
question.filters.time_cutoff = time_cutoff
|
||||||
question.favor_recent = favor_recent
|
question.favor_recent = favor_recent
|
||||||
question.filters.source_type = source_filters
|
question.filters.source_type = source_filters
|
||||||
|
|
||||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
top_chunks, llm_chunk_selection, query_event_id = danswer_search(
|
||||||
question=question,
|
question=question,
|
||||||
user=user,
|
user=user,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
document_index=get_default_document_index(),
|
document_index=get_default_document_index(),
|
||||||
)
|
)
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
top_docs = chunks_to_search_docs(top_chunks)
|
||||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
|
||||||
|
|
||||||
initial_response = RerankedRetrievalDocs(
|
llm_chunks_indices = get_chunks_for_qa(
|
||||||
|
chunks=top_chunks,
|
||||||
|
llm_chunk_selection=llm_chunk_selection,
|
||||||
|
batch_offset=offset_count,
|
||||||
|
)
|
||||||
|
|
||||||
|
initial_response = QADocsResponse(
|
||||||
top_documents=top_docs,
|
top_documents=top_docs,
|
||||||
unranked_top_documents=unranked_top_docs,
|
llm_chunks_indices=llm_chunks_indices,
|
||||||
# if generative AI is disabled, set flow as search so frontend
|
# if generative AI is disabled, set flow as search so frontend
|
||||||
# doesn't ask the user if they want to run QA over more documents
|
# doesn't ask the user if they want to run QA over more documents
|
||||||
predicted_flow=QueryFlow.SEARCH
|
predicted_flow=QueryFlow.SEARCH
|
||||||
@ -260,10 +234,9 @@ def answer_qa_query_stream(
|
|||||||
favor_recent=favor_recent,
|
favor_recent=favor_recent,
|
||||||
).dict()
|
).dict()
|
||||||
|
|
||||||
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
|
|
||||||
yield get_json_line(initial_response)
|
yield get_json_line(initial_response)
|
||||||
|
|
||||||
if not ranked_chunks:
|
if not top_chunks:
|
||||||
logger.debug("No Documents Found")
|
logger.debug("No Documents Found")
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -279,25 +252,13 @@ def answer_qa_query_stream(
|
|||||||
yield get_json_line(error.dict())
|
yield get_json_line(error.dict())
|
||||||
return
|
return
|
||||||
|
|
||||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||||
# types which can't be parsed). These chunks are useful to show in the
|
|
||||||
# search results, but not for QA.
|
|
||||||
filtered_ranked_chunks = [
|
|
||||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
|
||||||
]
|
|
||||||
|
|
||||||
# get all chunks that fit into the token limit
|
|
||||||
usable_chunks = get_usable_chunks(
|
|
||||||
chunks=filtered_ranked_chunks,
|
|
||||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
|
||||||
offset=offset_count,
|
|
||||||
)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}"
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for response_packet in qa_model.answer_question_stream(query, usable_chunks):
|
for response_packet in qa_model.answer_question_stream(query, llm_chunks):
|
||||||
if response_packet is None:
|
if response_packet is None:
|
||||||
continue
|
continue
|
||||||
if (
|
if (
|
||||||
@ -321,4 +282,4 @@ def answer_qa_query_stream(
|
|||||||
user_id=None if user is None else user.id,
|
user_id=None if user is None else user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
yield get_json_line({"query_event_id": query_event_id})
|
yield get_json_line({QUERY_EVENT_ID: query_event_id})
|
||||||
|
@ -11,6 +11,7 @@ import regex
|
|||||||
|
|
||||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||||
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.interfaces import DanswerQuote
|
from danswer.direct_qa.interfaces import DanswerQuote
|
||||||
@ -316,3 +317,57 @@ def get_usable_chunks(
|
|||||||
offset_into_chunks += len(usable_chunks)
|
offset_into_chunks += len(usable_chunks)
|
||||||
|
|
||||||
return usable_chunks
|
return usable_chunks
|
||||||
|
|
||||||
|
|
||||||
|
def get_chunks_for_qa(
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
llm_chunk_selection: list[bool],
|
||||||
|
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||||
|
batch_offset: int = 0,
|
||||||
|
) -> list[int]:
|
||||||
|
"""
|
||||||
|
Gives back indices of chunks to pass into the LLM for Q&A.
|
||||||
|
|
||||||
|
Only selects chunks viable for Q&A, within the token limit, and prioritize those selected
|
||||||
|
by the LLM in a separate flow (this can be turned off)
|
||||||
|
|
||||||
|
Note, the batch_offset calculation has to count the batches from the beginning each time as
|
||||||
|
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||||
|
this is somewhat slow as it requires tokenizing all the chunks again
|
||||||
|
"""
|
||||||
|
batch_index = 0
|
||||||
|
latest_batch_indices: list[int] = []
|
||||||
|
token_count = 0
|
||||||
|
|
||||||
|
# First iterate the LLM selected chunks, then iterate the rest if tokens remaining
|
||||||
|
for selection_target in [True, False]:
|
||||||
|
for ind, chunk in enumerate(chunks):
|
||||||
|
if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get(
|
||||||
|
IGNORE_FOR_QA
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||||
|
chunk_token = check_number_of_tokens(chunk.content)
|
||||||
|
token_count += chunk_token
|
||||||
|
|
||||||
|
# Always use at least 1 chunk
|
||||||
|
if token_count <= token_limit or not latest_batch_indices:
|
||||||
|
latest_batch_indices.append(ind)
|
||||||
|
current_chunk_unused = False
|
||||||
|
else:
|
||||||
|
current_chunk_unused = True
|
||||||
|
|
||||||
|
if token_count >= token_limit:
|
||||||
|
if batch_index < batch_offset:
|
||||||
|
batch_index += 1
|
||||||
|
if current_chunk_unused:
|
||||||
|
latest_batch_indices = [ind]
|
||||||
|
token_count = chunk_token
|
||||||
|
else:
|
||||||
|
latest_batch_indices = []
|
||||||
|
token_count = 0
|
||||||
|
else:
|
||||||
|
return latest_batch_indices
|
||||||
|
|
||||||
|
return latest_batch_indices
|
||||||
|
@ -33,11 +33,6 @@ class LangChainChatLLM(LLM, abc.ABC):
|
|||||||
def llm(self) -> BaseChatModel:
|
def llm(self) -> BaseChatModel:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _log_model_config(self) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
def _log_prompt(prompt: LanguageModelInput) -> None:
|
||||||
if isinstance(prompt, list):
|
if isinstance(prompt, list):
|
||||||
@ -46,8 +41,12 @@ class LangChainChatLLM(LLM, abc.ABC):
|
|||||||
if isinstance(prompt, str):
|
if isinstance(prompt, str):
|
||||||
logger.debug(f"Prompt:\n{prompt}")
|
logger.debug(f"Prompt:\n{prompt}")
|
||||||
|
|
||||||
|
def log_model_configs(self) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
self._log_model_config()
|
|
||||||
if LOG_ALL_MODEL_INTERACTIONS:
|
if LOG_ALL_MODEL_INTERACTIONS:
|
||||||
self._log_prompt(prompt)
|
self._log_prompt(prompt)
|
||||||
|
|
||||||
@ -58,7 +57,6 @@ class LangChainChatLLM(LLM, abc.ABC):
|
|||||||
return model_raw
|
return model_raw
|
||||||
|
|
||||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||||
self._log_model_config()
|
|
||||||
if LOG_ALL_MODEL_INTERACTIONS:
|
if LOG_ALL_MODEL_INTERACTIONS:
|
||||||
self._log_prompt(prompt)
|
self._log_prompt(prompt)
|
||||||
|
|
||||||
|
@ -9,6 +9,10 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
|||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
class CustomModelServer(LLM):
|
class CustomModelServer(LLM):
|
||||||
@ -65,6 +69,9 @@ class CustomModelServer(LLM):
|
|||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return json.loads(response.content).get("generated_text", "")
|
return json.loads(response.content).get("generated_text", "")
|
||||||
|
|
||||||
|
def log_model_configs(self) -> None:
|
||||||
|
logger.debug(f"Custom model at: {self._endpoint}")
|
||||||
|
|
||||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
return self._execute(prompt)
|
return self._execute(prompt)
|
||||||
|
|
||||||
|
@ -61,6 +61,11 @@ class DanswerGPT4All(LLM):
|
|||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.gpt4all_model = GPT4All(model_version)
|
self.gpt4all_model = GPT4All(model_version)
|
||||||
|
|
||||||
|
def log_model_configs(self) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}"
|
||||||
|
)
|
||||||
|
|
||||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
prompt_basic = convert_lm_input_to_basic_string(prompt)
|
prompt_basic = convert_lm_input_to_basic_string(prompt)
|
||||||
return self.gpt4all_model.generate(prompt_basic)
|
return self.gpt4all_model.generate(prompt_basic)
|
||||||
|
@ -22,6 +22,10 @@ class LLM(abc.ABC):
|
|||||||
def requires_api_key(self) -> bool:
|
def requires_api_key(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def log_model_configs(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
@ -36,6 +36,7 @@ from danswer.configs.model_configs import SKIP_RERANKING
|
|||||||
from danswer.db.credentials import create_initial_public_credential
|
from danswer.db.credentials import create_initial_public_credential
|
||||||
from danswer.direct_qa.factory import get_default_qa_model
|
from danswer.direct_qa.factory import get_default_qa_model
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.server.cc_pair.api import router as cc_pair_router
|
from danswer.server.cc_pair.api import router as cc_pair_router
|
||||||
from danswer.server.chat_backend import router as chat_router
|
from danswer.server.chat_backend import router as chat_router
|
||||||
from danswer.server.connector import router as connector_router
|
from danswer.server.connector import router as connector_router
|
||||||
@ -197,7 +198,7 @@ def get_application() -> FastAPI:
|
|||||||
warm_up_models()
|
warm_up_models()
|
||||||
|
|
||||||
# This is for the LLM, most LLMs will not need warming up
|
# This is for the LLM, most LLMs will not need warming up
|
||||||
# It logs for itself
|
get_default_llm().log_model_configs()
|
||||||
get_default_qa_model().warm_up_model()
|
get_default_qa_model().warm_up_model()
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
|
@ -132,6 +132,29 @@ Note: The "file" source only applies to when the user refers to uploaded files i
|
|||||||
""".strip()
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
|
USEFUL_PAT = "Yes useful"
|
||||||
|
NONUSEFUL_PAT = "Not useful"
|
||||||
|
CHUNK_FILTER_PROMPT = f"""
|
||||||
|
Determine if the reference section is USEFUL for answering the user query.
|
||||||
|
It is NOT enough for the section to be related to the query, \
|
||||||
|
it must contain information that is USEFUL for answering the query.
|
||||||
|
If the section contains ANY useful information, that is good enough, \
|
||||||
|
it does not need to fully answer the every part of the user query.
|
||||||
|
|
||||||
|
Reference Section:
|
||||||
|
```
|
||||||
|
{{chunk_text}}
|
||||||
|
```
|
||||||
|
|
||||||
|
User Query:
|
||||||
|
```
|
||||||
|
{{user_query}}
|
||||||
|
```
|
||||||
|
|
||||||
|
Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}"
|
||||||
|
""".strip()
|
||||||
|
|
||||||
|
|
||||||
# User the following for easy viewing of prompts
|
# User the following for easy viewing of prompts
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print(ANSWERABLE_PROMPT)
|
print(ANSWERABLE_PROMPT)
|
||||||
|
@ -3,6 +3,7 @@ from enum import Enum
|
|||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
@ -57,6 +58,9 @@ class SearchQuery(BaseModel):
|
|||||||
skip_rerank: bool = SKIP_RERANKING
|
skip_rerank: bool = SKIP_RERANKING
|
||||||
# Only used if not skip_rerank
|
# Only used if not skip_rerank
|
||||||
num_rerank: int | None = NUM_RERANKED_RESULTS
|
num_rerank: int | None = NUM_RERANKED_RESULTS
|
||||||
|
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER
|
||||||
|
# Only used if not skip_llm_chunk_filter
|
||||||
|
max_llm_filter_chunks: int = NUM_RERANKED_RESULTS
|
||||||
|
|
||||||
|
|
||||||
class RetrievalMetricsContainer(BaseModel):
|
class RetrievalMetricsContainer(BaseModel):
|
||||||
|
@ -8,7 +8,9 @@ from nltk.tokenize import word_tokenize # type:ignore
|
|||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER
|
||||||
from danswer.configs.app_configs import HYBRID_ALPHA
|
from danswer.configs.app_configs import HYBRID_ALPHA
|
||||||
|
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||||
@ -33,9 +35,12 @@ from danswer.search.models import SearchQuery
|
|||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||||
from danswer.search.search_nlp_models import EmbeddingModel
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
|
from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import SearchDoc
|
from danswer.server.models import SearchDoc
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||||
|
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
|
|
||||||
|
|
||||||
@ -147,7 +152,12 @@ def semantic_reranking(
|
|||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||||
) -> list[InferenceChunk]:
|
) -> tuple[list[InferenceChunk], list[int]]:
|
||||||
|
"""Reranks chunks based on cross-encoder models. Additionally provides the original indices
|
||||||
|
of the chunks in their new sorted order.
|
||||||
|
|
||||||
|
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||||
|
"""
|
||||||
cross_encoders = CrossEncoderEnsembleModel()
|
cross_encoders = CrossEncoderEnsembleModel()
|
||||||
passages = [chunk.content for chunk in chunks]
|
passages = [chunk.content for chunk in chunks]
|
||||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||||
@ -168,16 +178,20 @@ def semantic_reranking(
|
|||||||
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
||||||
model_max - model_min
|
model_max - model_min
|
||||||
)
|
)
|
||||||
scored_results = list(zip(normalized_b_s_scores, raw_sim_scores, chunks))
|
orig_indices = [i for i in range(len(normalized_b_s_scores))]
|
||||||
|
scored_results = list(
|
||||||
|
zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices)
|
||||||
|
)
|
||||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||||
ranked_sim_scores, ranked_raw_scores, ranked_chunks = zip(*scored_results)
|
ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip(
|
||||||
|
*scored_results
|
||||||
|
)
|
||||||
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
|
f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assign new chunk scores based on reranking
|
# Assign new chunk scores based on reranking
|
||||||
# TODO if pagination is added, the scores won't make sense with respect to the non-reranked hits
|
|
||||||
for ind, chunk in enumerate(ranked_chunks):
|
for ind, chunk in enumerate(ranked_chunks):
|
||||||
chunk.score = ranked_sim_scores[ind]
|
chunk.score = ranked_sim_scores[ind]
|
||||||
|
|
||||||
@ -198,7 +212,7 @@ def semantic_reranking(
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
return list(ranked_chunks)
|
return list(ranked_chunks), list(ranked_indices)
|
||||||
|
|
||||||
|
|
||||||
def apply_boost_legacy(
|
def apply_boost_legacy(
|
||||||
@ -257,6 +271,9 @@ def apply_boost_legacy(
|
|||||||
|
|
||||||
def apply_boost(
|
def apply_boost(
|
||||||
chunks: list[InferenceChunk],
|
chunks: list[InferenceChunk],
|
||||||
|
# Need the range of values to not be too spread out for applying boost
|
||||||
|
# therefore norm across only the top few results
|
||||||
|
norm_cutoff: int = NUM_RERANKED_RESULTS,
|
||||||
norm_min: float = SIM_SCORE_RANGE_LOW,
|
norm_min: float = SIM_SCORE_RANGE_LOW,
|
||||||
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
@ -266,13 +283,13 @@ def apply_boost(
|
|||||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||||
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
recency_multiplier = [chunk.recency_bias for chunk in chunks]
|
||||||
|
|
||||||
norm_min = min(norm_min, min(scores))
|
norm_min = min(norm_min, min(scores[:norm_cutoff]))
|
||||||
norm_max = max(norm_max, max(scores))
|
norm_max = max(norm_max, max(scores[:norm_cutoff]))
|
||||||
# This should never be 0 unless user has done some weird/wrong settings
|
# This should never be 0 unless user has done some weird/wrong settings
|
||||||
norm_range = norm_max - norm_min
|
norm_range = norm_max - norm_min
|
||||||
|
|
||||||
boosted_scores = [
|
boosted_scores = [
|
||||||
(score - norm_min) * boost * recency / norm_range
|
max(0, (score - norm_min) * boost * recency / norm_range)
|
||||||
for score, boost, recency in zip(scores, boosts, recency_multiplier)
|
for score, boost, recency in zip(scores, boosts, recency_multiplier)
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -299,7 +316,14 @@ def search_chunks(
|
|||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
| None = None,
|
| None = None,
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
) -> tuple[list[InferenceChunk], list[bool]]:
|
||||||
|
"""Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM.
|
||||||
|
For sake of speed, the system cannot rerank all retrieved chunks
|
||||||
|
Also pass the chunks through LLM to determine if they are relevant (binary for speed)
|
||||||
|
|
||||||
|
Only the first max_llm_filter_chunks
|
||||||
|
"""
|
||||||
|
|
||||||
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None:
|
||||||
top_links = [
|
top_links = [
|
||||||
c.source_links[0] if c.source_links is not None else "No Link"
|
c.source_links[0] if c.source_links is not None else "No Link"
|
||||||
@ -316,7 +340,7 @@ def search_chunks(
|
|||||||
f"{query.search_type.value.capitalize()} search returned no results "
|
f"{query.search_type.value.capitalize()} search returned no results "
|
||||||
f"with filters: {query.filters}"
|
f"with filters: {query.filters}"
|
||||||
)
|
)
|
||||||
return None, None
|
return [], []
|
||||||
|
|
||||||
if retrieval_metrics_callback is not None:
|
if retrieval_metrics_callback is not None:
|
||||||
chunk_metrics = [
|
chunk_metrics = [
|
||||||
@ -332,27 +356,62 @@ def search_chunks(
|
|||||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword Search should never do reranking, no transformers involved in this flow
|
functions_to_run: list[FunctionCall] = []
|
||||||
if query.search_type == SearchType.KEYWORD:
|
|
||||||
|
# Keyword Search should not do reranking
|
||||||
|
if query.search_type == SearchType.KEYWORD or query.skip_rerank:
|
||||||
_log_top_chunk_links(query.search_type.value, top_chunks)
|
_log_top_chunk_links(query.search_type.value, top_chunks)
|
||||||
return top_chunks, None
|
run_rerank_id: str | None = None
|
||||||
|
else:
|
||||||
|
run_rerank = FunctionCall(
|
||||||
|
semantic_reranking,
|
||||||
|
(query.query, top_chunks[: query.num_rerank]),
|
||||||
|
{"rerank_metrics_callback": rerank_metrics_callback},
|
||||||
|
)
|
||||||
|
functions_to_run.append(run_rerank)
|
||||||
|
run_rerank_id = run_rerank.result_id
|
||||||
|
|
||||||
if query.skip_rerank:
|
run_llm_filter_id = None
|
||||||
# Need the range of values to not be too spread out for applying boost
|
if not query.skip_llm_chunk_filter:
|
||||||
# Therefore pass in smaller set of chunks to limit the range for norm-ing
|
run_llm_filter = FunctionCall(
|
||||||
boosted_chunks = apply_boost(top_chunks[: query.num_rerank])
|
llm_batch_eval_chunks,
|
||||||
_log_top_chunk_links(query.search_type.value, boosted_chunks)
|
(
|
||||||
return boosted_chunks, top_chunks[query.num_rerank :]
|
query.query,
|
||||||
|
[chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]],
|
||||||
|
),
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
functions_to_run.append(run_llm_filter)
|
||||||
|
run_llm_filter_id = run_llm_filter.result_id
|
||||||
|
|
||||||
ranked_chunks = semantic_reranking(
|
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||||
query.query,
|
|
||||||
top_chunks[: query.num_rerank],
|
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
ranked_results = parallel_results.get(str(run_rerank_id))
|
||||||
|
if ranked_results is None:
|
||||||
|
ranked_chunks = top_chunks
|
||||||
|
sorted_indices = [i for i in range(len(top_chunks))]
|
||||||
|
else:
|
||||||
|
ranked_chunks, orig_indices = ranked_results
|
||||||
|
sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks)))
|
||||||
|
lower_chunks = top_chunks[query.num_rerank :]
|
||||||
|
# Scores from rerank cannot be meaningfully combined with scores without rerank
|
||||||
|
for lower_chunk in lower_chunks:
|
||||||
|
lower_chunk.score = None
|
||||||
|
ranked_chunks.extend(lower_chunks)
|
||||||
|
|
||||||
|
llm_chunk_selection = parallel_results.get(str(run_llm_filter_id))
|
||||||
|
if llm_chunk_selection is None:
|
||||||
|
reranked_llm_chunk_selection = [True for _ in top_chunks]
|
||||||
|
else:
|
||||||
|
llm_chunk_selection.extend(
|
||||||
|
[False for _ in top_chunks[query.max_llm_filter_chunks :]]
|
||||||
|
)
|
||||||
|
reranked_llm_chunk_selection = [
|
||||||
|
llm_chunk_selection[ind] for ind in sorted_indices
|
||||||
|
]
|
||||||
_log_top_chunk_links(query.search_type.value, ranked_chunks)
|
_log_top_chunk_links(query.search_type.value, ranked_chunks)
|
||||||
|
|
||||||
return ranked_chunks, top_chunks[query.num_rerank :]
|
return ranked_chunks, reranked_llm_chunk_selection
|
||||||
|
|
||||||
|
|
||||||
def danswer_search(
|
def danswer_search(
|
||||||
@ -360,10 +419,11 @@ def danswer_search(
|
|||||||
user: User | None,
|
user: User | None,
|
||||||
db_session: Session,
|
db_session: Session,
|
||||||
document_index: DocumentIndex,
|
document_index: DocumentIndex,
|
||||||
|
skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER,
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
| None = None,
|
| None = None,
|
||||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None, int]:
|
) -> tuple[list[InferenceChunk], list[bool], int]:
|
||||||
query_event_id = create_query_event(
|
query_event_id = create_query_event(
|
||||||
query=question.query,
|
query=question.query,
|
||||||
search_type=question.search_type,
|
search_type=question.search_type,
|
||||||
@ -384,17 +444,21 @@ def danswer_search(
|
|||||||
query=question.query,
|
query=question.query,
|
||||||
search_type=question.search_type,
|
search_type=question.search_type,
|
||||||
filters=final_filters,
|
filters=final_filters,
|
||||||
favor_recent=True if question.favor_recent is None else question.favor_recent,
|
# Still applies time decay but not magnified
|
||||||
|
favor_recent=question.favor_recent
|
||||||
|
if question.favor_recent is not None
|
||||||
|
else False,
|
||||||
|
skip_llm_chunk_filter=skip_llm_chunk_filter,
|
||||||
)
|
)
|
||||||
|
|
||||||
ranked_chunks, unranked_chunks = search_chunks(
|
top_chunks, llm_chunk_selection = search_chunks(
|
||||||
query=search_query,
|
query=search_query,
|
||||||
document_index=document_index,
|
document_index=document_index,
|
||||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||||
rerank_metrics_callback=rerank_metrics_callback,
|
rerank_metrics_callback=rerank_metrics_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
retrieved_ids = [doc.document_id for doc in ranked_chunks] if ranked_chunks else []
|
retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else []
|
||||||
|
|
||||||
update_query_event_retrieved_documents(
|
update_query_event_retrieved_documents(
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@ -403,4 +467,4 @@ def danswer_search(
|
|||||||
user_id=None if user is None else user.id,
|
user_id=None if user is None else user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return ranked_chunks, unranked_chunks, query_event_id
|
return top_chunks, llm_chunk_selection, query_event_id
|
||||||
|
65
backend/danswer/secondary_llm_flows/chunk_usefulness.py
Normal file
65
backend/danswer/secondary_llm_flows/chunk_usefulness.py
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from danswer.llm.factory import get_default_llm
|
||||||
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
|
from danswer.prompts.secondary_llm_flows import CHUNK_FILTER_PROMPT
|
||||||
|
from danswer.prompts.secondary_llm_flows import NONUSEFUL_PAT
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
||||||
|
def _get_usefulness_messages() -> list[dict[str, str]]:
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": CHUNK_FILTER_PROMPT.format(
|
||||||
|
chunk_text=chunk_content, user_query=query
|
||||||
|
),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def _extract_usefulness(model_output: str) -> bool:
|
||||||
|
"""Default useful if the LLM doesn't match pattern exactly
|
||||||
|
This is because it's better to trust the (re)ranking if LLM fails"""
|
||||||
|
if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower():
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
messages = _get_usefulness_messages()
|
||||||
|
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||||
|
# When running in a batch, it takes as long as the longest thread
|
||||||
|
# And when running a large batch, one may fail and take the whole timeout
|
||||||
|
# instead cap it to 5 seconds
|
||||||
|
model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt)
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
return _extract_usefulness(model_output)
|
||||||
|
|
||||||
|
|
||||||
|
def llm_batch_eval_chunks(
|
||||||
|
query: str, chunk_contents: list[str], use_threads: bool = True
|
||||||
|
) -> list[bool]:
|
||||||
|
if use_threads:
|
||||||
|
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||||
|
(llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents
|
||||||
|
]
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Running LLM usefulness eval in parallel (following logging may be out of order)"
|
||||||
|
)
|
||||||
|
parallel_results = run_functions_tuples_in_parallel(
|
||||||
|
functions_with_args, allow_failures=True
|
||||||
|
)
|
||||||
|
|
||||||
|
# In case of failure/timeout, don't throw out the chunk
|
||||||
|
return [True if item is None else item for item in parallel_results]
|
||||||
|
|
||||||
|
else:
|
||||||
|
return [
|
||||||
|
llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents
|
||||||
|
]
|
@ -171,12 +171,58 @@ class SearchDoc(BaseModel):
|
|||||||
return initial_dict
|
return initial_dict
|
||||||
|
|
||||||
|
|
||||||
|
class QuestionRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
collection: str
|
||||||
|
filters: BaseFilters
|
||||||
|
offset: int | None
|
||||||
|
enable_auto_detect_filters: bool
|
||||||
|
favor_recent: bool | None = None
|
||||||
|
search_type: SearchType = SearchType.HYBRID
|
||||||
|
|
||||||
|
|
||||||
|
class QAFeedbackRequest(BaseModel):
|
||||||
|
query_id: int
|
||||||
|
feedback: QAFeedbackType
|
||||||
|
|
||||||
|
|
||||||
|
class SearchFeedbackRequest(BaseModel):
|
||||||
|
query_id: int
|
||||||
|
document_id: str
|
||||||
|
document_rank: int
|
||||||
|
click: bool
|
||||||
|
search_feedback: SearchFeedbackType
|
||||||
|
|
||||||
|
|
||||||
|
class QueryValidationResponse(BaseModel):
|
||||||
|
reasoning: str
|
||||||
|
answerable: bool
|
||||||
|
|
||||||
|
|
||||||
class RetrievalDocs(BaseModel):
|
class RetrievalDocs(BaseModel):
|
||||||
top_documents: list[SearchDoc]
|
top_documents: list[SearchDoc]
|
||||||
|
|
||||||
|
|
||||||
class RerankedRetrievalDocs(RetrievalDocs):
|
class SearchResponse(RetrievalDocs):
|
||||||
unranked_top_documents: list[SearchDoc]
|
query_event_id: int
|
||||||
|
source_type: list[DocumentSource] | None
|
||||||
|
time_cutoff: datetime | None
|
||||||
|
favor_recent: bool
|
||||||
|
|
||||||
|
|
||||||
|
class QAResponse(SearchResponse):
|
||||||
|
answer: str | None # DanswerAnswer
|
||||||
|
quotes: list[DanswerQuote] | None
|
||||||
|
predicted_flow: QueryFlow
|
||||||
|
predicted_search: SearchType
|
||||||
|
eval_res_valid: bool | None = None
|
||||||
|
llm_chunks_indices: list[int] | None = None
|
||||||
|
error_msg: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# First chunk of info for streaming QA
|
||||||
|
class QADocsResponse(RetrievalDocs):
|
||||||
|
llm_chunks_indices: list[int]
|
||||||
predicted_flow: QueryFlow
|
predicted_flow: QueryFlow
|
||||||
predicted_search: SearchType
|
predicted_search: SearchType
|
||||||
time_cutoff: datetime | None
|
time_cutoff: datetime | None
|
||||||
@ -194,21 +240,6 @@ class CreateChatSessionID(BaseModel):
|
|||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
|
|
||||||
|
|
||||||
class QuestionRequest(BaseModel):
|
|
||||||
query: str
|
|
||||||
collection: str
|
|
||||||
filters: BaseFilters
|
|
||||||
offset: int | None
|
|
||||||
enable_auto_detect_filters: bool
|
|
||||||
favor_recent: bool | None = None
|
|
||||||
search_type: SearchType = SearchType.HYBRID
|
|
||||||
|
|
||||||
|
|
||||||
class QAFeedbackRequest(BaseModel):
|
|
||||||
query_id: int
|
|
||||||
feedback: QAFeedbackType
|
|
||||||
|
|
||||||
|
|
||||||
class ChatFeedbackRequest(BaseModel):
|
class ChatFeedbackRequest(BaseModel):
|
||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
message_number: int
|
message_number: int
|
||||||
@ -217,14 +248,6 @@ class ChatFeedbackRequest(BaseModel):
|
|||||||
feedback_text: str | None = None
|
feedback_text: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class SearchFeedbackRequest(BaseModel):
|
|
||||||
query_id: int
|
|
||||||
document_id: str
|
|
||||||
document_rank: int
|
|
||||||
click: bool
|
|
||||||
search_feedback: SearchFeedbackType
|
|
||||||
|
|
||||||
|
|
||||||
class CreateChatMessageRequest(BaseModel):
|
class CreateChatMessageRequest(BaseModel):
|
||||||
chat_session_id: int
|
chat_session_id: int
|
||||||
message_number: int
|
message_number: int
|
||||||
@ -280,30 +303,6 @@ class ChatSessionDetailResponse(BaseModel):
|
|||||||
messages: list[ChatMessageDetail]
|
messages: list[ChatMessageDetail]
|
||||||
|
|
||||||
|
|
||||||
class QueryValidationResponse(BaseModel):
|
|
||||||
reasoning: str
|
|
||||||
answerable: bool
|
|
||||||
|
|
||||||
|
|
||||||
class SearchResponse(BaseModel):
|
|
||||||
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
|
||||||
top_ranked_docs: list[SearchDoc] | None
|
|
||||||
lower_ranked_docs: list[SearchDoc] | None
|
|
||||||
query_event_id: int
|
|
||||||
source_type: list[DocumentSource] | None
|
|
||||||
time_cutoff: datetime | None
|
|
||||||
favor_recent: bool
|
|
||||||
|
|
||||||
|
|
||||||
class QAResponse(SearchResponse):
|
|
||||||
answer: str | None # DanswerAnswer
|
|
||||||
quotes: list[DanswerQuote] | None
|
|
||||||
predicted_flow: QueryFlow
|
|
||||||
predicted_search: SearchType
|
|
||||||
eval_res_valid: bool | None = None
|
|
||||||
error_msg: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserByEmail(BaseModel):
|
class UserByEmail(BaseModel):
|
||||||
user_email: str
|
user_email: str
|
||||||
|
|
||||||
|
@ -1,5 +1,3 @@
|
|||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@ -36,6 +34,7 @@ from danswer.server.models import SearchDoc
|
|||||||
from danswer.server.models import SearchFeedbackRequest
|
from danswer.server.models import SearchFeedbackRequest
|
||||||
from danswer.server.models import SearchResponse
|
from danswer.server.models import SearchResponse
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -131,10 +130,10 @@ def handle_search_request(
|
|||||||
query = question.query
|
query = question.query
|
||||||
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
|
logger.info(f"Received {question.search_type.value} " f"search query: {query}")
|
||||||
|
|
||||||
functions_to_run: dict[Callable, tuple] = {
|
functions_to_run = [
|
||||||
extract_question_time_filters: (question,),
|
FunctionCall(extract_question_time_filters, (question,), {}),
|
||||||
extract_question_source_filters: (question, db_session),
|
FunctionCall(extract_question_source_filters, (question, db_session), {}),
|
||||||
}
|
]
|
||||||
|
|
||||||
parallel_results = run_functions_in_parallel(functions_to_run)
|
parallel_results = run_functions_in_parallel(functions_to_run)
|
||||||
|
|
||||||
@ -145,29 +144,18 @@ def handle_search_request(
|
|||||||
question.favor_recent = favor_recent
|
question.favor_recent = favor_recent
|
||||||
question.filters.source_type = source_filters
|
question.filters.source_type = source_filters
|
||||||
|
|
||||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
top_chunks, _, query_event_id = danswer_search(
|
||||||
question=question,
|
question=question,
|
||||||
user=user,
|
user=user,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
document_index=get_default_document_index(),
|
document_index=get_default_document_index(),
|
||||||
|
skip_llm_chunk_filter=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ranked_chunks:
|
top_docs = chunks_to_search_docs(top_chunks)
|
||||||
return SearchResponse(
|
|
||||||
top_ranked_docs=None,
|
|
||||||
lower_ranked_docs=None,
|
|
||||||
query_event_id=query_event_id,
|
|
||||||
source_type=source_filters,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
|
||||||
lower_top_docs = chunks_to_search_docs(unranked_chunks)
|
|
||||||
|
|
||||||
return SearchResponse(
|
return SearchResponse(
|
||||||
top_ranked_docs=top_docs,
|
top_documents=top_docs,
|
||||||
lower_ranked_docs=lower_top_docs or None,
|
|
||||||
query_event_id=query_event_id,
|
query_event_id=query_event_id,
|
||||||
source_type=source_filters,
|
source_type=source_filters,
|
||||||
time_cutoff=time_cutoff,
|
time_cutoff=time_cutoff,
|
||||||
|
@ -37,16 +37,20 @@ def optional_telemetry(record_type: RecordType, data: dict) -> None:
|
|||||||
try:
|
try:
|
||||||
|
|
||||||
def telemetry_logic() -> None:
|
def telemetry_logic() -> None:
|
||||||
payload = {
|
try:
|
||||||
"data": data,
|
payload = {
|
||||||
"record": record_type,
|
"data": data,
|
||||||
"customer_uuid": get_or_generate_uuid(),
|
"record": record_type,
|
||||||
}
|
"customer_uuid": get_or_generate_uuid(),
|
||||||
requests.post(
|
}
|
||||||
DANSWER_TELEMETRY_ENDPOINT,
|
requests.post(
|
||||||
headers={"Content-Type": "application/json"},
|
DANSWER_TELEMETRY_ENDPOINT,
|
||||||
json=payload,
|
headers={"Content-Type": "application/json"},
|
||||||
)
|
json=payload,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# This way it silences all thread level logging as well
|
||||||
|
pass
|
||||||
|
|
||||||
# Run in separate thread to have minimal overhead in main flows
|
# Run in separate thread to have minimal overhead in main flows
|
||||||
thread = threading.Thread(target=telemetry_logic, daemon=True)
|
thread = threading.Thread(target=telemetry_logic, daemon=True)
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
import uuid
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from concurrent.futures import as_completed
|
from concurrent.futures import as_completed
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
@ -8,31 +9,82 @@ from danswer.utils.logger import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def run_functions_in_parallel(
|
def run_functions_tuples_in_parallel(
|
||||||
functions_with_args: dict[Callable, tuple]
|
functions_with_args: list[tuple[Callable, tuple]],
|
||||||
) -> dict[str, Any]:
|
allow_failures: bool = False,
|
||||||
|
) -> list[Any]:
|
||||||
"""
|
"""
|
||||||
Executes multiple functions in parallel and returns a dictionary with the results.
|
Executes multiple functions in parallel and returns a list of the results for each function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
functions_with_args (dict): A dictionary mapping functions to a tuple of arguments.
|
functions_with_args: List of tuples each containing the function callable and a tuple of arguments.
|
||||||
|
allow_failures: if set to True, then the function result will just be None
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict: A dictionary mapping function names to their results or error messages.
|
dict: A dictionary mapping function names to their results or error messages.
|
||||||
"""
|
"""
|
||||||
results = {}
|
results = []
|
||||||
with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor:
|
with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor:
|
||||||
future_to_function = {
|
future_to_index = {
|
||||||
executor.submit(func, *args): func.__name__
|
executor.submit(func, *args): i
|
||||||
for func, args in functions_with_args.items()
|
for i, (func, args) in enumerate(functions_with_args)
|
||||||
}
|
}
|
||||||
|
|
||||||
for future in as_completed(future_to_function):
|
for future in as_completed(future_to_index):
|
||||||
function_name = future_to_function[future]
|
index = future_to_index[future]
|
||||||
try:
|
try:
|
||||||
results[function_name] = future.result()
|
results.append((index, future.result()))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception(f"Function {function_name} failed due to {e}")
|
logger.exception(f"Function at index {index} failed due to {e}")
|
||||||
raise
|
results.append((index, None))
|
||||||
|
|
||||||
|
if not allow_failures:
|
||||||
|
raise
|
||||||
|
|
||||||
|
results.sort(key=lambda x: x[0])
|
||||||
|
return [result for index, result in results]
|
||||||
|
|
||||||
|
|
||||||
|
class FunctionCall:
|
||||||
|
"""
|
||||||
|
Container for run_functions_in_parallel, fetch the results from the output of
|
||||||
|
run_functions_in_parallel via the FunctionCall.result_id.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, func: Callable, args: tuple = (), kwargs: dict | None = None):
|
||||||
|
self.func = func
|
||||||
|
self.args = args
|
||||||
|
self.kwargs = kwargs if kwargs is not None else {}
|
||||||
|
self.result_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
def execute(self) -> Any:
|
||||||
|
return self.func(*self.args, **self.kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def run_functions_in_parallel(
|
||||||
|
function_calls: list[FunctionCall],
|
||||||
|
allow_failures: bool = False,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys
|
||||||
|
are the result_id of the FunctionCall and the values are the results of the call.
|
||||||
|
"""
|
||||||
|
results = {}
|
||||||
|
with ThreadPoolExecutor(max_workers=len(function_calls)) as executor:
|
||||||
|
future_to_id = {
|
||||||
|
executor.submit(func_call.execute): func_call.result_id
|
||||||
|
for func_call in function_calls
|
||||||
|
}
|
||||||
|
|
||||||
|
for future in as_completed(future_to_id):
|
||||||
|
result_id = future_to_id[future]
|
||||||
|
try:
|
||||||
|
results[result_id] = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Function with ID {result_id} failed due to {e}")
|
||||||
|
results[result_id] = None
|
||||||
|
|
||||||
|
if not allow_failures:
|
||||||
|
raise
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
@ -32,6 +32,7 @@ services:
|
|||||||
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
- DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-}
|
||||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||||
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
- DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-}
|
||||||
|
- DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-}
|
||||||
# Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
# Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years)
|
||||||
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
- DOC_TIME_DECAY=${DOC_TIME_DECAY:-}
|
||||||
# Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
# Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector)
|
||||||
|
@ -116,6 +116,11 @@ export const DocumentDisplay = ({
|
|||||||
}: DocumentDisplayProps) => {
|
}: DocumentDisplayProps) => {
|
||||||
const [isHovered, setIsHovered] = useState(false);
|
const [isHovered, setIsHovered] = useState(false);
|
||||||
|
|
||||||
|
// Consider reintroducing null scored docs in the future
|
||||||
|
if (document.score === null) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
key={document.semantic_identifier}
|
key={document.semantic_identifier}
|
||||||
@ -126,23 +131,25 @@ export const DocumentDisplay = ({
|
|||||||
onMouseLeave={() => setIsHovered(false)}
|
onMouseLeave={() => setIsHovered(false)}
|
||||||
>
|
>
|
||||||
<div className="flex relative">
|
<div className="flex relative">
|
||||||
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex">
|
{document.score !== null && (
|
||||||
<div
|
<div className="absolute -left-10 top-2/4 -translate-y-2/4 w-10 flex">
|
||||||
className={`
|
<div
|
||||||
text-xs
|
className={`
|
||||||
text-gray-200
|
text-xs
|
||||||
bg-gray-800
|
text-gray-200
|
||||||
rounded
|
bg-gray-800
|
||||||
p-0.5
|
rounded
|
||||||
w-fit
|
p-0.5
|
||||||
my-auto
|
w-fit
|
||||||
select-none
|
my-auto
|
||||||
ml-auto
|
select-none
|
||||||
mr-2`}
|
ml-auto
|
||||||
>
|
mr-2`}
|
||||||
{document.score.toFixed(2)}
|
>
|
||||||
|
{document.score.toFixed(2)}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
)}
|
||||||
<a
|
<a
|
||||||
className={
|
className={
|
||||||
"rounded-lg flex font-bold " +
|
"rounded-lg flex font-bold " +
|
||||||
|
Loading…
x
Reference in New Issue
Block a user