From fda377a2fae486187f03f352777e11abc667b5f3 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Wed, 22 Nov 2023 19:33:28 -0800 Subject: [PATCH] Regression Script for Search quality (#760) --- backend/danswer/configs/app_configs.py | 2 +- .../slack/handlers/handle_message.py | 6 +- backend/danswer/search/models.py | 2 +- backend/danswer/search/search_runner.py | 15 +- backend/danswer/server/models.py | 9 +- .../answer_quality/eval_direct_qa.py | 6 +- .../regression/search_quality/eval_search.py | 195 ++++++++++++++++++ 7 files changed, 216 insertions(+), 19 deletions(-) create mode 100644 backend/tests/regression/search_quality/eval_search.py diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index b55ae7886..90604fd28 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -81,7 +81,7 @@ SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password") ##### # DB Configs ##### -DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices +DOCUMENT_INDEX_NAME = "danswer_index" # Vespa is now the default document index store for both keyword and vector DOCUMENT_INDEX_TYPE = os.environ.get( "DOCUMENT_INDEX_TYPE", DocumentIndexType.COMBINED.value diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index cf9fb348d..85d71d2c8 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -6,7 +6,6 @@ from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from sqlalchemy.orm import Session -from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS @@ -204,11 +203,8 @@ def handle_message( answer = _get_answer( QuestionRequest( query=msg, - collection=DOCUMENT_INDEX_NAME, - enable_auto_detect_filters=not disable_auto_detect_filters, filters=filters, - favor_recent=None, - offset=None, + enable_auto_detect_filters=not disable_auto_detect_filters, ) ) except Exception as e: diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 06258ea43..330ab2290 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -64,7 +64,7 @@ class SearchQuery(BaseModel): class RetrievalMetricsContainer(BaseModel): - keyword_search: bool # False for Vector Search + search_type: SearchType metrics: list[ChunkMetric] # This contains the scores for retrieval as well diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 08fc06e00..f02a794b0 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from collections.abc import Generator +from collections.abc import Iterator from copy import deepcopy from typing import cast @@ -383,7 +383,9 @@ def retrieve_chunks( for chunk in top_chunks ] retrieval_metrics_callback( - RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics) + RetrievalMetricsContainer( + search_type=query.search_type, metrics=chunk_metrics + ) ) return top_chunks @@ -468,7 +470,7 @@ def full_chunk_search_generator( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> Generator[list[InferenceChunk] | list[bool], None, None]: +) -> Iterator[list[InferenceChunk] | list[bool]]: """Always yields twice. Once with the selected chunks and once with the LLM relevance filter result.""" chunks_yielded = False @@ -480,6 +482,11 @@ def full_chunk_search_generator( retrieval_metrics_callback=retrieval_metrics_callback, ) + if not retrieved_chunks: + yield cast(list[InferenceChunk], []) + yield cast(list[bool], []) + return + post_processing_tasks: list[FunctionCall] = [] rerank_task_id = None @@ -549,7 +556,7 @@ def danswer_search_generator( retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, -) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]: +) -> Iterator[list[InferenceChunk] | list[bool] | int]: """The main entry point for search. This fetches the relevant documents from Vespa based on the provided query (applying permissions / filters), does any specified post-processing, and returns the results. It also create an entry in the query_event table diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 343633f2c..d901c40e1 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -10,6 +10,7 @@ from pydantic import validator from pydantic.generics import GenericModel from danswer.auth.schemas import UserRole +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.constants import AuthType from danswer.configs.constants import DocumentSource @@ -173,12 +174,12 @@ class SearchDoc(BaseModel): class QuestionRequest(BaseModel): query: str - collection: str filters: BaseFilters - offset: int | None - enable_auto_detect_filters: bool - favor_recent: bool | None = None + collection: str = DOCUMENT_INDEX_NAME search_type: SearchType = SearchType.HYBRID + enable_auto_detect_filters: bool = True + favor_recent: bool | None = None + offset: int | None = None class QAFeedbackRequest(BaseModel): diff --git a/backend/tests/regression/answer_quality/eval_direct_qa.py b/backend/tests/regression/answer_quality/eval_direct_qa.py index 462e87a72..081e0ae74 100644 --- a/backend/tests/regression/answer_quality/eval_direct_qa.py +++ b/backend/tests/regression/answer_quality/eval_direct_qa.py @@ -8,7 +8,6 @@ from typing import TextIO import yaml from sqlalchemy.orm import Session -from danswer.access.access import get_acl_for_user from danswer.db.engine import get_sqlalchemy_engine from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.models import LLMMetricsContainer @@ -80,14 +79,12 @@ def get_answer_for_question( source_type=None, document_set=None, time_cutoff=None, - access_control_list=list(get_acl_for_user(user=None)), + access_control_list=None, ) question = QuestionRequest( query=query, - collection="danswer_index", filters=filters, enable_auto_detect_filters=False, - offset=None, ) retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() @@ -101,6 +98,7 @@ def get_answer_for_question( answer_generation_timeout=100, real_time_flow=False, enable_reflexion=False, + bypass_acl=True, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric, llm_metrics_callback=llm_metrics.record_metric, diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py new file mode 100644 index 000000000..2de4f02f5 --- /dev/null +++ b/backend/tests/regression/search_quality/eval_search.py @@ -0,0 +1,195 @@ +import argparse +import builtins +import json +from contextlib import contextmanager +from typing import Any +from typing import TextIO + +from sqlalchemy.orm import Session + +from danswer.db.engine import get_sqlalchemy_engine +from danswer.document_index.factory import get_default_document_index +from danswer.indexing.models import InferenceChunk +from danswer.search.models import IndexFilters +from danswer.search.models import RerankMetricsContainer +from danswer.search.models import RetrievalMetricsContainer +from danswer.search.search_runner import danswer_search +from danswer.server.models import QuestionRequest +from danswer.utils.callbacks import MetricsHander + + +engine = get_sqlalchemy_engine() + + +@contextmanager +def redirect_print_to_file(file: TextIO) -> Any: + original_print = builtins.print + builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs) + try: + yield + finally: + builtins.print = original_print + + +def read_json(file_path: str) -> dict: + with open(file_path, "r") as file: + return json.load(file) + + +def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str: + words = s.split() + + current_line: list[str] = [] + result_lines: list[str] = [] + current_length = 0 + for word in words: + if len(word) > max_line_size: + if current_line: + result_lines.append(" ".join(current_line)) + current_line = [] + current_length = 0 + + result_lines.append(word) + continue + + if current_length + len(word) + len(current_line) > max_line_size: + result_lines.append(" ".join(current_line)) + current_line = [] + current_length = 0 + + current_line.append(word) + current_length += len(word) + + if current_line: + result_lines.append(" ".join(current_line)) + + return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines) + + +def get_search_results( + query: str, db_session: Session +) -> tuple[ + list[InferenceChunk], + list[bool], + RetrievalMetricsContainer | None, + RerankMetricsContainer | None, +]: + filters = IndexFilters( + source_type=None, + document_set=None, + time_cutoff=None, + access_control_list=None, + ) + question = QuestionRequest( + query=query, + filters=filters, + enable_auto_detect_filters=False, + ) + + retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() + rerank_metrics = MetricsHander[RerankMetricsContainer]() + + chunks, llm_filter, query_id = danswer_search( + question=question, + user=None, + db_session=db_session, + document_index=get_default_document_index(), + bypass_acl=True, + retrieval_metrics_callback=retrieval_metrics.record_metric, + rerank_metrics_callback=rerank_metrics.record_metric, + ) + + return ( + chunks, + llm_filter, + retrieval_metrics.metrics, + rerank_metrics.metrics, + ) + + +def _print_retrieval_metrics( + metrics_container: RetrievalMetricsContainer, show_all: bool +) -> None: + for ind, metric in enumerate(metrics_container.metrics): + if not show_all and ind >= 10: + break + + if ind != 0: + print() # for spacing purposes + print(f"\tDocument: {metric.document_id}") + print(f"\tLink: {metric.first_link or 'NA'}") + section_start = metric.chunk_content_start.replace("\n", " ") + print(f"\tSection Start: {section_start}") + print(f"\tSimilarity Distance Metric: {metric.score}") + + +def _print_reranking_metrics( + metrics_container: RerankMetricsContainer, show_all: bool +) -> None: + # Printing the raw scores as they're more informational than post-norm/boosting + for ind, metric in enumerate(metrics_container.metrics): + if not show_all and ind >= 10: + break + + if ind != 0: + print() # for spacing purposes + print(f"\tDocument: {metric.document_id}") + print(f"\tLink: {metric.first_link or 'NA'}") + section_start = metric.chunk_content_start.replace("\n", " ") + print(f"\tSection Start: {section_start}") + print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}") + + +def main(questions_json: str, output_file: str) -> None: + questions_info = read_json(questions_json) + + with open(output_file, "w") as outfile: + with redirect_print_to_file(outfile): + print("Running Document Retrieval Test\n") + + with Session(engine, expire_on_commit=False) as db_session: + for question, targets in questions_info.items(): + print(f"Question: {question}") + + ( + chunks, + llm_filters, + retrieval_metrics, + rerank_metrics, + ) = get_search_results(query=question, db_session=db_session) + + print("\nRetrieval Metrics:") + if retrieval_metrics is None: + print("No Retrieval Metrics Available") + else: + _print_retrieval_metrics( + retrieval_metrics, show_all=args.all_results + ) + + print("\nReranking Metrics:") + if rerank_metrics is None: + print("No Reranking Metrics Available") + else: + _print_reranking_metrics( + rerank_metrics, show_all=args.all_results + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "regression_questions_json", + type=str, + help="Path to the Questions JSON file.", + default="./tests/regression/search_quality/test_questions.json", + nargs="?", + ) + parser.add_argument( + "--output_file", + type=str, + help="Path to the output results file.", + default="./tests/regression/search_quality/regression_results.txt", + ) + args = parser.parse_args() + + main(args.regression_questions_json, args.output_file)