diff --git a/README.md b/README.md index 867769887..9f4a01eff 100644 --- a/README.md +++ b/README.md @@ -46,28 +46,32 @@ We also have built-in support for deployment on Kubernetes. Files for that can b ## 💃 Features * Direct QA powered by Generative AI models with answers backed by quotes and source links. -* Intelligent Document Retrieval (Semantic Search/Reranking) using the latest LLMs. -* An AI Helper backed by a custom Deep Learning model to interpret user intent. +* Intelligent Document Retrieval (Hybrid Search + Reranking) using the latest NLP models. +* Automatic time/source filter extraction from natural language + custom model to identify user intent. * User authentication and document level access management. -* Support for an LLM of your choice (GPT-4, Llama2, Orca, etc.) -* Management Dashboard to manage connectors and set up features such as live update fetching. +* Support for LLMs of your choice (GPT-4, Llama2, Orca, etc.) +* Management Dashboards to manage connectors and set up features such as live update fetching. * One line Docker Compose (or Kubernetes) deployment to host Danswer anywhere. ## 🔌 Connectors -Danswer currently syncs documents (every 10 minutes) from: +Efficiently pulls the latest changes from: * Slack * GitHub * Google Drive * Confluence * Jira * Notion + * Gong * Slab * Linear * Productboard * Guru * Zulip * Bookstack + * Document360 + * Request Tracker + * Hubspot * Local Files * Websites * With more to come... @@ -75,7 +79,9 @@ Danswer currently syncs documents (every 10 minutes) from: ## 🚧 Roadmap * Chat/Conversation support. * Organizational understanding. -* Ability to locate and suggest experts. +* Code Search +* Structured Query Languages (SQL, Excel formulas, etc.) +* Ability to locate and suggest experts from your team. ## 💡 Contributing Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details. diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index 695e6ad6f..f3b4d7cd6 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -140,18 +140,15 @@ def danswer_chat_retrieval( ) # Good Debug/Breakpoint - ranked_chunks, unranked_chunks = search_chunks( + top_chunks, _ = search_chunks( query=search_query, document_index=get_default_document_index() ) - if not ranked_chunks: + if not top_chunks: return [] - if unranked_chunks: - ranked_chunks.extend(unranked_chunks) - filtered_ranked_chunks = [ - chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) + chunk for chunk in top_chunks if not chunk.metadata.get(IGNORE_FOR_QA) ] # get all chunks that fit into the token limit diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index dc7faeb1f..16b200d1d 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -178,8 +178,12 @@ MINI_CHUNK_SIZE = 150 NUM_RETURNED_HITS = 50 NUM_RERANKED_RESULTS = 15 # We feed in document chunks until we reach this token limit. -# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks -# may be smaller which could result in passing in more total chunks +# Default is ~5 full chunks (max chunk size is 2000 chars), although some chunks may be +# significantly smaller which could result in passing in more total chunks. +# There is also a slight bit of overhead, not accounted for here such as separator patterns +# between the docs, metadata for the docs, etc. +# Finally, this is combined with the rest of the QA prompt, so don't set this too close to the +# model token limit NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL = int( os.environ.get("NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL") or (512 * 5) ) @@ -198,12 +202,14 @@ FAVOR_RECENT_DECAY_MULTIPLIER = 2 DISABLE_LLM_FILTER_EXTRACTION = ( os.environ.get("DISABLE_LLM_FILTER_EXTRACTION", "").lower() == "true" ) -# 1 edit per 2 characters, currently unused due to fuzzy match being too slow +DISABLE_LLM_CHUNK_FILTER = ( + os.environ.get("DISABLE_LLM_CHUNK_FILTER", "").lower() == "true" +) +# 1 edit per 20 characters, currently unused due to fuzzy match being too slow QUOTE_ALLOWED_ERROR_PERCENT = 0.05 QA_TIMEOUT = int(os.environ.get("QA_TIMEOUT") or "60") # 60 seconds # Include additional document/chunk metadata in prompt to GenerativeAI INCLUDE_METADATA = False -HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false" # Keyword Search Drop Stopwords # If user has changed the default model, would most likely be to use a multilingual # model, the stopwords are NLTK english stopwords so then we would want to not drop the keywords diff --git a/backend/danswer/configs/chat_configs.py b/backend/danswer/configs/chat_configs.py index 553566f9e..1ca9fc38d 100644 --- a/backend/danswer/configs/chat_configs.py +++ b/backend/danswer/configs/chat_configs.py @@ -1,3 +1,4 @@ import os FORCE_TOOL_PROMPT = os.environ.get("FORCE_TOOL_PROMPT", "").lower() == "true" +HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "True").lower() != "false" diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 23362a576..122449a10 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -35,6 +35,8 @@ SCORE = "score" ID_SEPARATOR = ":;:" DEFAULT_BOOST = 0 SESSION_KEY = "session" +QUERY_EVENT_ID = "query_event_id" +LLM_CHUNKS = "llm_chunks" class DocumentSource(str, Enum): diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index a8c0c6f06..b19ea8d9f 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -232,7 +232,7 @@ def handle_message( logger.debug(answer.answer) return True - if not answer.top_ranked_docs: + if not answer.top_documents: logger.error(f"Unable to answer question: '{msg}' - no documents found") # Optionally, respond in thread with the error message, Used primarily # for debugging purposes @@ -265,8 +265,17 @@ def handle_message( favor_recent=answer.favor_recent, ) + # Get the chunks fed to the LLM only, then fill with other docs + top_docs = answer.top_documents + llm_doc_inds = answer.llm_chunks_indices or [] + llm_docs = [top_docs[i] for i in llm_doc_inds] + remaining_docs = [ + doc for idx, doc in enumerate(top_docs) if idx not in llm_doc_inds + ] + priority_ordered_docs = llm_docs + remaining_docs document_blocks = build_documents_blocks( - documents=answer.top_ranked_docs, query_event_id=answer.query_event_id + documents=priority_ordered_docs, + query_event_id=answer.query_event_id, ) try: diff --git a/backend/danswer/db/chat.py b/backend/danswer/db/chat.py index 3b0502c00..f39d75adf 100644 --- a/backend/danswer/db/chat.py +++ b/backend/danswer/db/chat.py @@ -9,7 +9,7 @@ from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session -from danswer.configs.app_configs import HARD_DELETE_CHATS +from danswer.configs.chat_configs import HARD_DELETE_CHATS from danswer.configs.constants import MessageType from danswer.db.models import ChatMessage from danswer.db.models import ChatSession diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index bb04cb659..b8c02d0e1 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -1,19 +1,19 @@ from collections.abc import Callable from collections.abc import Iterator +from functools import partial from sqlalchemy.orm import Session from danswer.configs.app_configs import DISABLE_GENERATIVE_AI -from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.constants import QUERY_EVENT_ID from danswer.db.feedback import update_query_event_llm_answer from danswer.db.models import User from danswer.direct_qa.factory import get_default_qa_model from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.models import LLMMetricsContainer -from danswer.direct_qa.qa_utils import get_usable_chunks +from danswer.direct_qa.qa_utils import get_chunks_for_qa from danswer.document_index.factory import get_default_document_index from danswer.search.danswer_helper import query_intent from danswer.search.models import QueryFlow @@ -24,11 +24,12 @@ from danswer.search.search_runner import danswer_search from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.source_filter import extract_question_source_filters from danswer.secondary_llm_flows.time_filter import extract_question_time_filters +from danswer.server.models import QADocsResponse from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest -from danswer.server.models import RerankedRetrievalDocs from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.timing import log_function_time from danswer.utils.timing import log_generator_function_time @@ -54,24 +55,34 @@ def answer_qa_query( offset_count = question.offset if question.offset is not None else 0 logger.info(f"Received QA query: {query}") - functions_to_run: dict[Callable, tuple] = { - extract_question_time_filters: (question,), - extract_question_source_filters: (question, db_session), - query_intent: (query,), - } + run_time_filters = FunctionCall(extract_question_time_filters, (question,), {}) + run_source_filters = FunctionCall( + extract_question_source_filters, (question, db_session), {} + ) + run_query_intent = FunctionCall(query_intent, (query,), {}) - parallel_results = run_functions_in_parallel(functions_to_run) + parallel_results = run_functions_in_parallel( + [ + run_time_filters, + run_source_filters, + run_query_intent, + ] + ) - time_cutoff, favor_recent = parallel_results["extract_question_time_filters"] - source_filters = parallel_results["extract_question_source_filters"] - predicted_search, predicted_flow = parallel_results["query_intent"] + time_cutoff, favor_recent = parallel_results[run_time_filters.result_id] + source_filters = parallel_results[run_source_filters.result_id] + predicted_search, predicted_flow = parallel_results[run_query_intent.result_id] + + # Set flow as search so frontend doesn't ask the user if they want to run QA over more docs + if disable_generative_answer: + predicted_flow = QueryFlow.SEARCH # Modifies the question object but nothing upstream uses it question.filters.time_cutoff = time_cutoff question.favor_recent = favor_recent question.filters.source_type = source_filters - ranked_chunks, unranked_chunks, query_event_id = danswer_search( + top_chunks, llm_chunk_selection, query_event_id = danswer_search( question=question, user=user, db_session=db_session, @@ -80,38 +91,23 @@ def answer_qa_query( rerank_metrics_callback=rerank_metrics_callback, ) - if not ranked_chunks: - return QAResponse( + top_docs = chunks_to_search_docs(top_chunks) + + partial_response = partial( + QAResponse, + top_documents=chunks_to_search_docs(top_chunks), + predicted_flow=predicted_flow, + predicted_search=predicted_search, + query_event_id=query_event_id, + source_type=source_filters, + time_cutoff=time_cutoff, + favor_recent=favor_recent, + ) + + if disable_generative_answer or not top_docs: + return partial_response( answer=None, quotes=None, - top_ranked_docs=None, - lower_ranked_docs=None, - predicted_flow=predicted_flow, - predicted_search=predicted_search, - query_event_id=query_event_id, - source_type=source_filters, - time_cutoff=time_cutoff, - favor_recent=favor_recent, - ) - - top_docs = chunks_to_search_docs(ranked_chunks) - unranked_top_docs = chunks_to_search_docs(unranked_chunks) - - if disable_generative_answer: - logger.debug("Skipping QA because generative AI is disabled") - return QAResponse( - answer=None, - quotes=None, - top_ranked_docs=top_docs, - lower_ranked_docs=unranked_top_docs, - # set flow as search so frontend doesn't ask the user if they want - # to run QA over more documents - predicted_flow=QueryFlow.SEARCH, - predicted_search=predicted_search, - query_event_id=query_event_id, - source_type=source_filters, - time_cutoff=time_cutoff, - favor_recent=favor_recent, ) try: @@ -119,41 +115,28 @@ def answer_qa_query( timeout=answer_generation_timeout, real_time_flow=real_time_flow ) except Exception as e: - return QAResponse( + return partial_response( answer=None, quotes=None, - top_ranked_docs=top_docs, - lower_ranked_docs=unranked_top_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search, - query_event_id=query_event_id, - source_type=source_filters, - time_cutoff=time_cutoff, - favor_recent=favor_recent, error_msg=str(e), ) - # remove chunks marked as not applicable for QA (e.g. Google Drive file - # types which can't be parsed). These chunks are useful to show in the - # search results, but not for QA. - filtered_ranked_chunks = [ - chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) - ] - - # get all chunks that fit into the token limit - usable_chunks = get_usable_chunks( - chunks=filtered_ranked_chunks, - token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL, - offset=offset_count, + llm_chunks_indices = get_chunks_for_qa( + chunks=top_chunks, + llm_chunk_selection=llm_chunk_selection, + batch_offset=offset_count, ) + + llm_chunks = [top_chunks[i] for i in llm_chunks_indices] + logger.debug( - f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}" + f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}" ) error_msg = None try: d_answer, quotes = qa_model.answer_question( - query, usable_chunks, metrics_callback=llm_metrics_callback + query, llm_chunks, metrics_callback=llm_metrics_callback ) except Exception as e: # exception is logged in the answer_question method, no need to re-log @@ -169,37 +152,17 @@ def answer_qa_query( user_id=None if user is None else user.id, ) + validity = None if not real_time_flow and enable_reflexion and d_answer is not None: - valid = False + validity = False if d_answer.answer is not None: - valid = get_answer_validity(query, d_answer.answer) + validity = get_answer_validity(query, d_answer.answer) - return QAResponse( - answer=d_answer.answer if d_answer else None, - quotes=quotes.quotes if quotes else None, - top_ranked_docs=top_docs, - lower_ranked_docs=unranked_top_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search, - eval_res_valid=True if valid else False, - query_event_id=query_event_id, - source_type=source_filters, - time_cutoff=time_cutoff, - favor_recent=favor_recent, - error_msg=error_msg, - ) - - return QAResponse( + return partial_response( answer=d_answer.answer if d_answer else None, quotes=quotes.quotes if quotes else None, - top_ranked_docs=top_docs, - lower_ranked_docs=unranked_top_docs, - predicted_flow=predicted_flow, - predicted_search=predicted_search, - query_event_id=query_event_id, - source_type=source_filters, - time_cutoff=time_cutoff, - favor_recent=favor_recent, + eval_res_valid=validity, + llm_chunks_indices=llm_chunks_indices, error_msg=error_msg, ) @@ -220,36 +183,47 @@ def answer_qa_query_stream( query = question.query offset_count = question.offset if question.offset is not None else 0 - functions_to_run: dict[Callable, tuple] = { - extract_question_time_filters: (question,), - extract_question_source_filters: (question, db_session), - query_intent: (query,), - } + run_time_filters = FunctionCall(extract_question_time_filters, (question,), {}) + run_source_filters = FunctionCall( + extract_question_source_filters, (question, db_session), {} + ) + run_query_intent = FunctionCall(query_intent, (query,), {}) - parallel_results = run_functions_in_parallel(functions_to_run) + parallel_results = run_functions_in_parallel( + [ + run_time_filters, + run_source_filters, + run_query_intent, + ] + ) - time_cutoff, favor_recent = parallel_results["extract_question_time_filters"] - source_filters = parallel_results["extract_question_source_filters"] - predicted_search, predicted_flow = parallel_results["query_intent"] + time_cutoff, favor_recent = parallel_results[run_time_filters.result_id] + source_filters = parallel_results[run_source_filters.result_id] + predicted_search, predicted_flow = parallel_results[run_query_intent.result_id] # Modifies the question object but nothing upstream uses it question.filters.time_cutoff = time_cutoff question.favor_recent = favor_recent question.filters.source_type = source_filters - ranked_chunks, unranked_chunks, query_event_id = danswer_search( + top_chunks, llm_chunk_selection, query_event_id = danswer_search( question=question, user=user, db_session=db_session, document_index=get_default_document_index(), ) - top_docs = chunks_to_search_docs(ranked_chunks) - unranked_top_docs = chunks_to_search_docs(unranked_chunks) + top_docs = chunks_to_search_docs(top_chunks) - initial_response = RerankedRetrievalDocs( + llm_chunks_indices = get_chunks_for_qa( + chunks=top_chunks, + llm_chunk_selection=llm_chunk_selection, + batch_offset=offset_count, + ) + + initial_response = QADocsResponse( top_documents=top_docs, - unranked_top_documents=unranked_top_docs, + llm_chunks_indices=llm_chunks_indices, # if generative AI is disabled, set flow as search so frontend # doesn't ask the user if they want to run QA over more documents predicted_flow=QueryFlow.SEARCH @@ -260,10 +234,9 @@ def answer_qa_query_stream( favor_recent=favor_recent, ).dict() - logger.debug(f"Sending Initial Retrival Results: {initial_response}") yield get_json_line(initial_response) - if not ranked_chunks: + if not top_chunks: logger.debug("No Documents Found") return @@ -279,25 +252,13 @@ def answer_qa_query_stream( yield get_json_line(error.dict()) return - # remove chunks marked as not applicable for QA (e.g. Google Drive file - # types which can't be parsed). These chunks are useful to show in the - # search results, but not for QA. - filtered_ranked_chunks = [ - chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA) - ] - - # get all chunks that fit into the token limit - usable_chunks = get_usable_chunks( - chunks=filtered_ranked_chunks, - token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL, - offset=offset_count, - ) + llm_chunks = [top_chunks[i] for i in llm_chunks_indices] logger.debug( - f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}" + f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in llm_chunks]}" ) try: - for response_packet in qa_model.answer_question_stream(query, usable_chunks): + for response_packet in qa_model.answer_question_stream(query, llm_chunks): if response_packet is None: continue if ( @@ -321,4 +282,4 @@ def answer_qa_query_stream( user_id=None if user is None else user.id, ) - yield get_json_line({"query_event_id": query_event_id}) + yield get_json_line({QUERY_EVENT_ID: query_event_id}) diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index 7f45fceee..0a41f0f26 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -11,6 +11,7 @@ import regex from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT +from danswer.configs.constants import IGNORE_FOR_QA from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuote @@ -316,3 +317,57 @@ def get_usable_chunks( offset_into_chunks += len(usable_chunks) return usable_chunks + + +def get_chunks_for_qa( + chunks: list[InferenceChunk], + llm_chunk_selection: list[bool], + token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL, + batch_offset: int = 0, +) -> list[int]: + """ + Gives back indices of chunks to pass into the LLM for Q&A. + + Only selects chunks viable for Q&A, within the token limit, and prioritize those selected + by the LLM in a separate flow (this can be turned off) + + Note, the batch_offset calculation has to count the batches from the beginning each time as + there's no way to know which chunks were included in the prior batches without recounting atm, + this is somewhat slow as it requires tokenizing all the chunks again + """ + batch_index = 0 + latest_batch_indices: list[int] = [] + token_count = 0 + + # First iterate the LLM selected chunks, then iterate the rest if tokens remaining + for selection_target in [True, False]: + for ind, chunk in enumerate(chunks): + if llm_chunk_selection[ind] is not selection_target or chunk.metadata.get( + IGNORE_FOR_QA + ): + continue + + # We calculate it live in case the user uses a different LLM + tokenizer + chunk_token = check_number_of_tokens(chunk.content) + token_count += chunk_token + + # Always use at least 1 chunk + if token_count <= token_limit or not latest_batch_indices: + latest_batch_indices.append(ind) + current_chunk_unused = False + else: + current_chunk_unused = True + + if token_count >= token_limit: + if batch_index < batch_offset: + batch_index += 1 + if current_chunk_unused: + latest_batch_indices = [ind] + token_count = chunk_token + else: + latest_batch_indices = [] + token_count = 0 + else: + return latest_batch_indices + + return latest_batch_indices diff --git a/backend/danswer/llm/chat_llm.py b/backend/danswer/llm/chat_llm.py index ef25473e5..23011b0ab 100644 --- a/backend/danswer/llm/chat_llm.py +++ b/backend/danswer/llm/chat_llm.py @@ -33,11 +33,6 @@ class LangChainChatLLM(LLM, abc.ABC): def llm(self) -> BaseChatModel: raise NotImplementedError - def _log_model_config(self) -> None: - logger.debug( - f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}" - ) - @staticmethod def _log_prompt(prompt: LanguageModelInput) -> None: if isinstance(prompt, list): @@ -46,8 +41,12 @@ class LangChainChatLLM(LLM, abc.ABC): if isinstance(prompt, str): logger.debug(f"Prompt:\n{prompt}") + def log_model_configs(self) -> None: + logger.debug( + f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}" + ) + def invoke(self, prompt: LanguageModelInput) -> str: - self._log_model_config() if LOG_ALL_MODEL_INTERACTIONS: self._log_prompt(prompt) @@ -58,7 +57,6 @@ class LangChainChatLLM(LLM, abc.ABC): return model_raw def stream(self, prompt: LanguageModelInput) -> Iterator[str]: - self._log_model_config() if LOG_ALL_MODEL_INTERACTIONS: self._log_prompt(prompt) diff --git a/backend/danswer/llm/custom_llm.py b/backend/danswer/llm/custom_llm.py index 5de9b0d97..4c11a29a4 100644 --- a/backend/danswer/llm/custom_llm.py +++ b/backend/danswer/llm/custom_llm.py @@ -9,6 +9,10 @@ from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.llm.interfaces import LLM from danswer.llm.utils import convert_lm_input_to_basic_string +from danswer.utils.logger import setup_logger + + +logger = setup_logger() class CustomModelServer(LLM): @@ -65,6 +69,9 @@ class CustomModelServer(LLM): response.raise_for_status() return json.loads(response.content).get("generated_text", "") + def log_model_configs(self) -> None: + logger.debug(f"Custom model at: {self._endpoint}") + def invoke(self, prompt: LanguageModelInput) -> str: return self._execute(prompt) diff --git a/backend/danswer/llm/gpt_4_all.py b/backend/danswer/llm/gpt_4_all.py index 316c4e7aa..d2307eb78 100644 --- a/backend/danswer/llm/gpt_4_all.py +++ b/backend/danswer/llm/gpt_4_all.py @@ -61,6 +61,11 @@ class DanswerGPT4All(LLM): self.temperature = temperature self.gpt4all_model = GPT4All(model_version) + def log_model_configs(self) -> None: + logger.debug( + f"GPT4All Model: {self.gpt4all_model}, Temperature: {self.temperature}" + ) + def invoke(self, prompt: LanguageModelInput) -> str: prompt_basic = convert_lm_input_to_basic_string(prompt) return self.gpt4all_model.generate(prompt_basic) diff --git a/backend/danswer/llm/interfaces.py b/backend/danswer/llm/interfaces.py index 9e56b0934..41fe428bb 100644 --- a/backend/danswer/llm/interfaces.py +++ b/backend/danswer/llm/interfaces.py @@ -22,6 +22,10 @@ class LLM(abc.ABC): def requires_api_key(self) -> bool: return True + @abc.abstractmethod + def log_model_configs(self) -> None: + raise NotImplementedError + @abc.abstractmethod def invoke(self, prompt: LanguageModelInput) -> str: raise NotImplementedError diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 7e37bb097..4990fc0f6 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -36,6 +36,7 @@ from danswer.configs.model_configs import SKIP_RERANKING from danswer.db.credentials import create_initial_public_credential from danswer.direct_qa.factory import get_default_qa_model from danswer.document_index.factory import get_default_document_index +from danswer.llm.factory import get_default_llm from danswer.server.cc_pair.api import router as cc_pair_router from danswer.server.chat_backend import router as chat_router from danswer.server.connector import router as connector_router @@ -197,7 +198,7 @@ def get_application() -> FastAPI: warm_up_models() # This is for the LLM, most LLMs will not need warming up - # It logs for itself + get_default_llm().log_model_configs() get_default_qa_model().warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") diff --git a/backend/danswer/prompts/secondary_llm_flows.py b/backend/danswer/prompts/secondary_llm_flows.py index 5bce628b1..484c468a6 100644 --- a/backend/danswer/prompts/secondary_llm_flows.py +++ b/backend/danswer/prompts/secondary_llm_flows.py @@ -132,6 +132,29 @@ Note: The "file" source only applies to when the user refers to uploaded files i """.strip() +USEFUL_PAT = "Yes useful" +NONUSEFUL_PAT = "Not useful" +CHUNK_FILTER_PROMPT = f""" +Determine if the reference section is USEFUL for answering the user query. +It is NOT enough for the section to be related to the query, \ +it must contain information that is USEFUL for answering the query. +If the section contains ANY useful information, that is good enough, \ +it does not need to fully answer the every part of the user query. + +Reference Section: +``` +{{chunk_text}} +``` + +User Query: +``` +{{user_query}} +``` + +Respond with EXACTLY AND ONLY: "{USEFUL_PAT}" or "{NONUSEFUL_PAT}" +""".strip() + + # User the following for easy viewing of prompts if __name__ == "__main__": print(ANSWERABLE_PROMPT) diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 97c4db6f6..ef0d4c2cb 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -3,6 +3,7 @@ from enum import Enum from pydantic import BaseModel +from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource @@ -57,6 +58,9 @@ class SearchQuery(BaseModel): skip_rerank: bool = SKIP_RERANKING # Only used if not skip_rerank num_rerank: int | None = NUM_RERANKED_RESULTS + skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER + # Only used if not skip_llm_chunk_filter + max_llm_filter_chunks: int = NUM_RERANKED_RESULTS class RetrievalMetricsContainer(BaseModel): diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 2ba000471..e5c78cce7 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -8,7 +8,9 @@ from nltk.tokenize import word_tokenize # type:ignore from sentence_transformers import SentenceTransformer # type: ignore from sqlalchemy.orm import Session +from danswer.configs.app_configs import DISABLE_LLM_CHUNK_FILTER from danswer.configs.app_configs import HYBRID_ALPHA +from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN @@ -33,9 +35,12 @@ from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_nlp_models import CrossEncoderEnsembleModel from danswer.search.search_nlp_models import EmbeddingModel +from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks from danswer.server.models import QuestionRequest from danswer.server.models import SearchDoc from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import FunctionCall +from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.timing import log_function_time @@ -147,7 +152,12 @@ def semantic_reranking( rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, model_min: int = CROSS_ENCODER_RANGE_MIN, model_max: int = CROSS_ENCODER_RANGE_MAX, -) -> list[InferenceChunk]: +) -> tuple[list[InferenceChunk], list[int]]: + """Reranks chunks based on cross-encoder models. Additionally provides the original indices + of the chunks in their new sorted order. + + Note: this updates the chunks in place, it updates the chunk scores which came from retrieval + """ cross_encoders = CrossEncoderEnsembleModel() passages = [chunk.content for chunk in chunks] sim_scores_floats = cross_encoders.predict(query=query, passages=passages) @@ -168,16 +178,20 @@ def semantic_reranking( normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / ( model_max - model_min ) - scored_results = list(zip(normalized_b_s_scores, raw_sim_scores, chunks)) + orig_indices = [i for i in range(len(normalized_b_s_scores))] + scored_results = list( + zip(normalized_b_s_scores, raw_sim_scores, chunks, orig_indices) + ) scored_results.sort(key=lambda x: x[0], reverse=True) - ranked_sim_scores, ranked_raw_scores, ranked_chunks = zip(*scored_results) + ranked_sim_scores, ranked_raw_scores, ranked_chunks, ranked_indices = zip( + *scored_results + ) logger.debug( f"Reranked (Boosted + Time Weighted) similarity scores: {ranked_sim_scores}" ) # Assign new chunk scores based on reranking - # TODO if pagination is added, the scores won't make sense with respect to the non-reranked hits for ind, chunk in enumerate(ranked_chunks): chunk.score = ranked_sim_scores[ind] @@ -198,7 +212,7 @@ def semantic_reranking( ) ) - return list(ranked_chunks) + return list(ranked_chunks), list(ranked_indices) def apply_boost_legacy( @@ -257,6 +271,9 @@ def apply_boost_legacy( def apply_boost( chunks: list[InferenceChunk], + # Need the range of values to not be too spread out for applying boost + # therefore norm across only the top few results + norm_cutoff: int = NUM_RERANKED_RESULTS, norm_min: float = SIM_SCORE_RANGE_LOW, norm_max: float = SIM_SCORE_RANGE_HIGH, ) -> list[InferenceChunk]: @@ -266,13 +283,13 @@ def apply_boost( boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks] recency_multiplier = [chunk.recency_bias for chunk in chunks] - norm_min = min(norm_min, min(scores)) - norm_max = max(norm_max, max(scores)) + norm_min = min(norm_min, min(scores[:norm_cutoff])) + norm_max = max(norm_max, max(scores[:norm_cutoff])) # This should never be 0 unless user has done some weird/wrong settings norm_range = norm_max - norm_min boosted_scores = [ - (score - norm_min) * boost * recency / norm_range + max(0, (score - norm_min) * boost * recency / norm_range) for score, boost, recency in zip(scores, boosts, recency_multiplier) ] @@ -299,7 +316,14 @@ def search_chunks( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]: +) -> tuple[list[InferenceChunk], list[bool]]: + """Returns a list of the best chunks from search/reranking and if the chunks are relevant via LLM. + For sake of speed, the system cannot rerank all retrieved chunks + Also pass the chunks through LLM to determine if they are relevant (binary for speed) + + Only the first max_llm_filter_chunks + """ + def _log_top_chunk_links(search_flow: str, chunks: list[InferenceChunk]) -> None: top_links = [ c.source_links[0] if c.source_links is not None else "No Link" @@ -316,7 +340,7 @@ def search_chunks( f"{query.search_type.value.capitalize()} search returned no results " f"with filters: {query.filters}" ) - return None, None + return [], [] if retrieval_metrics_callback is not None: chunk_metrics = [ @@ -332,27 +356,62 @@ def search_chunks( RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics) ) - # Keyword Search should never do reranking, no transformers involved in this flow - if query.search_type == SearchType.KEYWORD: + functions_to_run: list[FunctionCall] = [] + + # Keyword Search should not do reranking + if query.search_type == SearchType.KEYWORD or query.skip_rerank: _log_top_chunk_links(query.search_type.value, top_chunks) - return top_chunks, None + run_rerank_id: str | None = None + else: + run_rerank = FunctionCall( + semantic_reranking, + (query.query, top_chunks[: query.num_rerank]), + {"rerank_metrics_callback": rerank_metrics_callback}, + ) + functions_to_run.append(run_rerank) + run_rerank_id = run_rerank.result_id - if query.skip_rerank: - # Need the range of values to not be too spread out for applying boost - # Therefore pass in smaller set of chunks to limit the range for norm-ing - boosted_chunks = apply_boost(top_chunks[: query.num_rerank]) - _log_top_chunk_links(query.search_type.value, boosted_chunks) - return boosted_chunks, top_chunks[query.num_rerank :] + run_llm_filter_id = None + if not query.skip_llm_chunk_filter: + run_llm_filter = FunctionCall( + llm_batch_eval_chunks, + ( + query.query, + [chunk.content for chunk in top_chunks[: query.max_llm_filter_chunks]], + ), + {}, + ) + functions_to_run.append(run_llm_filter) + run_llm_filter_id = run_llm_filter.result_id - ranked_chunks = semantic_reranking( - query.query, - top_chunks[: query.num_rerank], - rerank_metrics_callback=rerank_metrics_callback, - ) + parallel_results = run_functions_in_parallel(functions_to_run) + ranked_results = parallel_results.get(str(run_rerank_id)) + if ranked_results is None: + ranked_chunks = top_chunks + sorted_indices = [i for i in range(len(top_chunks))] + else: + ranked_chunks, orig_indices = ranked_results + sorted_indices = orig_indices + list(range(len(orig_indices), len(top_chunks))) + lower_chunks = top_chunks[query.num_rerank :] + # Scores from rerank cannot be meaningfully combined with scores without rerank + for lower_chunk in lower_chunks: + lower_chunk.score = None + ranked_chunks.extend(lower_chunks) + + llm_chunk_selection = parallel_results.get(str(run_llm_filter_id)) + if llm_chunk_selection is None: + reranked_llm_chunk_selection = [True for _ in top_chunks] + else: + llm_chunk_selection.extend( + [False for _ in top_chunks[query.max_llm_filter_chunks :]] + ) + reranked_llm_chunk_selection = [ + llm_chunk_selection[ind] for ind in sorted_indices + ] _log_top_chunk_links(query.search_type.value, ranked_chunks) - return ranked_chunks, top_chunks[query.num_rerank :] + return ranked_chunks, reranked_llm_chunk_selection def danswer_search( @@ -360,10 +419,11 @@ def danswer_search( user: User | None, db_session: Session, document_index: DocumentIndex, + skip_llm_chunk_filter: bool = DISABLE_LLM_CHUNK_FILTER, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None, int]: +) -> tuple[list[InferenceChunk], list[bool], int]: query_event_id = create_query_event( query=question.query, search_type=question.search_type, @@ -384,17 +444,21 @@ def danswer_search( query=question.query, search_type=question.search_type, filters=final_filters, - favor_recent=True if question.favor_recent is None else question.favor_recent, + # Still applies time decay but not magnified + favor_recent=question.favor_recent + if question.favor_recent is not None + else False, + skip_llm_chunk_filter=skip_llm_chunk_filter, ) - ranked_chunks, unranked_chunks = search_chunks( + top_chunks, llm_chunk_selection = search_chunks( query=search_query, document_index=document_index, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) - retrieved_ids = [doc.document_id for doc in ranked_chunks] if ranked_chunks else [] + retrieved_ids = [doc.document_id for doc in top_chunks] if top_chunks else [] update_query_event_retrieved_documents( db_session=db_session, @@ -403,4 +467,4 @@ def danswer_search( user_id=None if user is None else user.id, ) - return ranked_chunks, unranked_chunks, query_event_id + return top_chunks, llm_chunk_selection, query_event_id diff --git a/backend/danswer/secondary_llm_flows/chunk_usefulness.py b/backend/danswer/secondary_llm_flows/chunk_usefulness.py new file mode 100644 index 000000000..057aa7f63 --- /dev/null +++ b/backend/danswer/secondary_llm_flows/chunk_usefulness.py @@ -0,0 +1,65 @@ +from collections.abc import Callable + +from danswer.llm.factory import get_default_llm +from danswer.llm.utils import dict_based_prompt_to_langchain_prompt +from danswer.prompts.secondary_llm_flows import CHUNK_FILTER_PROMPT +from danswer.prompts.secondary_llm_flows import NONUSEFUL_PAT +from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel + +logger = setup_logger() + + +def llm_eval_chunk(query: str, chunk_content: str) -> bool: + def _get_usefulness_messages() -> list[dict[str, str]]: + messages = [ + { + "role": "user", + "content": CHUNK_FILTER_PROMPT.format( + chunk_text=chunk_content, user_query=query + ), + }, + ] + + return messages + + def _extract_usefulness(model_output: str) -> bool: + """Default useful if the LLM doesn't match pattern exactly + This is because it's better to trust the (re)ranking if LLM fails""" + if model_output.strip().strip('"').lower() == NONUSEFUL_PAT.lower(): + return False + return True + + messages = _get_usefulness_messages() + filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) + # When running in a batch, it takes as long as the longest thread + # And when running a large batch, one may fail and take the whole timeout + # instead cap it to 5 seconds + model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt) + logger.debug(model_output) + + return _extract_usefulness(model_output) + + +def llm_batch_eval_chunks( + query: str, chunk_contents: list[str], use_threads: bool = True +) -> list[bool]: + if use_threads: + functions_with_args: list[tuple[Callable, tuple]] = [ + (llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents + ] + + logger.debug( + "Running LLM usefulness eval in parallel (following logging may be out of order)" + ) + parallel_results = run_functions_tuples_in_parallel( + functions_with_args, allow_failures=True + ) + + # In case of failure/timeout, don't throw out the chunk + return [True if item is None else item for item in parallel_results] + + else: + return [ + llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents + ] diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index d9b81c250..30f3e41af 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -171,12 +171,58 @@ class SearchDoc(BaseModel): return initial_dict +class QuestionRequest(BaseModel): + query: str + collection: str + filters: BaseFilters + offset: int | None + enable_auto_detect_filters: bool + favor_recent: bool | None = None + search_type: SearchType = SearchType.HYBRID + + +class QAFeedbackRequest(BaseModel): + query_id: int + feedback: QAFeedbackType + + +class SearchFeedbackRequest(BaseModel): + query_id: int + document_id: str + document_rank: int + click: bool + search_feedback: SearchFeedbackType + + +class QueryValidationResponse(BaseModel): + reasoning: str + answerable: bool + + class RetrievalDocs(BaseModel): top_documents: list[SearchDoc] -class RerankedRetrievalDocs(RetrievalDocs): - unranked_top_documents: list[SearchDoc] +class SearchResponse(RetrievalDocs): + query_event_id: int + source_type: list[DocumentSource] | None + time_cutoff: datetime | None + favor_recent: bool + + +class QAResponse(SearchResponse): + answer: str | None # DanswerAnswer + quotes: list[DanswerQuote] | None + predicted_flow: QueryFlow + predicted_search: SearchType + eval_res_valid: bool | None = None + llm_chunks_indices: list[int] | None = None + error_msg: str | None = None + + +# First chunk of info for streaming QA +class QADocsResponse(RetrievalDocs): + llm_chunks_indices: list[int] predicted_flow: QueryFlow predicted_search: SearchType time_cutoff: datetime | None @@ -194,21 +240,6 @@ class CreateChatSessionID(BaseModel): chat_session_id: int -class QuestionRequest(BaseModel): - query: str - collection: str - filters: BaseFilters - offset: int | None - enable_auto_detect_filters: bool - favor_recent: bool | None = None - search_type: SearchType = SearchType.HYBRID - - -class QAFeedbackRequest(BaseModel): - query_id: int - feedback: QAFeedbackType - - class ChatFeedbackRequest(BaseModel): chat_session_id: int message_number: int @@ -217,14 +248,6 @@ class ChatFeedbackRequest(BaseModel): feedback_text: str | None = None -class SearchFeedbackRequest(BaseModel): - query_id: int - document_id: str - document_rank: int - click: bool - search_feedback: SearchFeedbackType - - class CreateChatMessageRequest(BaseModel): chat_session_id: int message_number: int @@ -280,30 +303,6 @@ class ChatSessionDetailResponse(BaseModel): messages: list[ChatMessageDetail] -class QueryValidationResponse(BaseModel): - reasoning: str - answerable: bool - - -class SearchResponse(BaseModel): - # For semantic search, top docs are reranked, the remaining are as ordered from retrieval - top_ranked_docs: list[SearchDoc] | None - lower_ranked_docs: list[SearchDoc] | None - query_event_id: int - source_type: list[DocumentSource] | None - time_cutoff: datetime | None - favor_recent: bool - - -class QAResponse(SearchResponse): - answer: str | None # DanswerAnswer - quotes: list[DanswerQuote] | None - predicted_flow: QueryFlow - predicted_search: SearchType - eval_res_valid: bool | None = None - error_msg: str | None = None - - class UserByEmail(BaseModel): user_email: str diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 030aae881..2f33bd274 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,5 +1,3 @@ -from collections.abc import Callable - from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException @@ -36,6 +34,7 @@ from danswer.server.models import SearchDoc from danswer.server.models import SearchFeedbackRequest from danswer.server.models import SearchResponse from danswer.utils.logger import setup_logger +from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel logger = setup_logger() @@ -131,10 +130,10 @@ def handle_search_request( query = question.query logger.info(f"Received {question.search_type.value} " f"search query: {query}") - functions_to_run: dict[Callable, tuple] = { - extract_question_time_filters: (question,), - extract_question_source_filters: (question, db_session), - } + functions_to_run = [ + FunctionCall(extract_question_time_filters, (question,), {}), + FunctionCall(extract_question_source_filters, (question, db_session), {}), + ] parallel_results = run_functions_in_parallel(functions_to_run) @@ -145,29 +144,18 @@ def handle_search_request( question.favor_recent = favor_recent question.filters.source_type = source_filters - ranked_chunks, unranked_chunks, query_event_id = danswer_search( + top_chunks, _, query_event_id = danswer_search( question=question, user=user, db_session=db_session, document_index=get_default_document_index(), + skip_llm_chunk_filter=True, ) - if not ranked_chunks: - return SearchResponse( - top_ranked_docs=None, - lower_ranked_docs=None, - query_event_id=query_event_id, - source_type=source_filters, - time_cutoff=time_cutoff, - favor_recent=favor_recent, - ) - - top_docs = chunks_to_search_docs(ranked_chunks) - lower_top_docs = chunks_to_search_docs(unranked_chunks) + top_docs = chunks_to_search_docs(top_chunks) return SearchResponse( - top_ranked_docs=top_docs, - lower_ranked_docs=lower_top_docs or None, + top_documents=top_docs, query_event_id=query_event_id, source_type=source_filters, time_cutoff=time_cutoff, diff --git a/backend/danswer/utils/telemetry.py b/backend/danswer/utils/telemetry.py index 8311ce282..39790face 100644 --- a/backend/danswer/utils/telemetry.py +++ b/backend/danswer/utils/telemetry.py @@ -37,16 +37,20 @@ def optional_telemetry(record_type: RecordType, data: dict) -> None: try: def telemetry_logic() -> None: - payload = { - "data": data, - "record": record_type, - "customer_uuid": get_or_generate_uuid(), - } - requests.post( - DANSWER_TELEMETRY_ENDPOINT, - headers={"Content-Type": "application/json"}, - json=payload, - ) + try: + payload = { + "data": data, + "record": record_type, + "customer_uuid": get_or_generate_uuid(), + } + requests.post( + DANSWER_TELEMETRY_ENDPOINT, + headers={"Content-Type": "application/json"}, + json=payload, + ) + except Exception: + # This way it silences all thread level logging as well + pass # Run in separate thread to have minimal overhead in main flows thread = threading.Thread(target=telemetry_logic, daemon=True) diff --git a/backend/danswer/utils/threadpool_concurrency.py b/backend/danswer/utils/threadpool_concurrency.py index 692714801..ac868e43a 100644 --- a/backend/danswer/utils/threadpool_concurrency.py +++ b/backend/danswer/utils/threadpool_concurrency.py @@ -1,3 +1,4 @@ +import uuid from collections.abc import Callable from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor @@ -8,31 +9,82 @@ from danswer.utils.logger import setup_logger logger = setup_logger() -def run_functions_in_parallel( - functions_with_args: dict[Callable, tuple] -) -> dict[str, Any]: +def run_functions_tuples_in_parallel( + functions_with_args: list[tuple[Callable, tuple]], + allow_failures: bool = False, +) -> list[Any]: """ - Executes multiple functions in parallel and returns a dictionary with the results. + Executes multiple functions in parallel and returns a list of the results for each function. Args: - functions_with_args (dict): A dictionary mapping functions to a tuple of arguments. + functions_with_args: List of tuples each containing the function callable and a tuple of arguments. + allow_failures: if set to True, then the function result will just be None Returns: dict: A dictionary mapping function names to their results or error messages. """ - results = {} + results = [] with ThreadPoolExecutor(max_workers=len(functions_with_args)) as executor: - future_to_function = { - executor.submit(func, *args): func.__name__ - for func, args in functions_with_args.items() + future_to_index = { + executor.submit(func, *args): i + for i, (func, args) in enumerate(functions_with_args) } - for future in as_completed(future_to_function): - function_name = future_to_function[future] + for future in as_completed(future_to_index): + index = future_to_index[future] try: - results[function_name] = future.result() + results.append((index, future.result())) except Exception as e: - logger.exception(f"Function {function_name} failed due to {e}") - raise + logger.exception(f"Function at index {index} failed due to {e}") + results.append((index, None)) + + if not allow_failures: + raise + + results.sort(key=lambda x: x[0]) + return [result for index, result in results] + + +class FunctionCall: + """ + Container for run_functions_in_parallel, fetch the results from the output of + run_functions_in_parallel via the FunctionCall.result_id. + """ + + def __init__(self, func: Callable, args: tuple = (), kwargs: dict | None = None): + self.func = func + self.args = args + self.kwargs = kwargs if kwargs is not None else {} + self.result_id = str(uuid.uuid4()) + + def execute(self) -> Any: + return self.func(*self.args, **self.kwargs) + + +def run_functions_in_parallel( + function_calls: list[FunctionCall], + allow_failures: bool = False, +) -> dict[str, Any]: + """ + Executes a list of FunctionCalls in parallel and stores the results in a dictionary where the keys + are the result_id of the FunctionCall and the values are the results of the call. + """ + results = {} + with ThreadPoolExecutor(max_workers=len(function_calls)) as executor: + future_to_id = { + executor.submit(func_call.execute): func_call.result_id + for func_call in function_calls + } + + for future in as_completed(future_to_id): + result_id = future_to_id[future] + try: + results[result_id] = future.result() + except Exception as e: + logger.exception(f"Function with ID {result_id} failed due to {e}") + results[result_id] = None + + if not allow_failures: + raise return results diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 28c780483..7dd763563 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -32,6 +32,7 @@ services: - DISABLE_GENERATIVE_AI=${DISABLE_GENERATIVE_AI:-} - NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-} - DISABLE_LLM_FILTER_EXTRACTION=${DISABLE_LLM_FILTER_EXTRACTION:-} + - DISABLE_LLM_CHUNK_FILTER=${DISABLE_LLM_CHUNK_FILTER:-} # Recency Bias for search results, decay at 1 / (1 + DOC_TIME_DECAY * x years) - DOC_TIME_DECAY=${DOC_TIME_DECAY:-} # Hybrid Search Alpha (0 for entirely keyword, 1 for entirely vector) diff --git a/web/src/components/search/DocumentDisplay.tsx b/web/src/components/search/DocumentDisplay.tsx index c78cf8e83..ccf14f0c5 100644 --- a/web/src/components/search/DocumentDisplay.tsx +++ b/web/src/components/search/DocumentDisplay.tsx @@ -116,6 +116,11 @@ export const DocumentDisplay = ({ }: DocumentDisplayProps) => { const [isHovered, setIsHovered] = useState(false); + // Consider reintroducing null scored docs in the future + if (document.score === null) { + return null; + } + return (