Regression Script for Search quality (#760)

This commit is contained in:
Yuhong Sun
2023-11-22 19:33:28 -08:00
committed by GitHub
parent bdfb894507
commit fda377a2fa
7 changed files with 216 additions and 19 deletions

View File

@@ -81,7 +81,7 @@ SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
##### #####
# DB Configs # DB Configs
##### #####
DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices DOCUMENT_INDEX_NAME = "danswer_index"
# Vespa is now the default document index store for both keyword and vector # Vespa is now the default document index store for both keyword and vector
DOCUMENT_INDEX_TYPE = os.environ.get( DOCUMENT_INDEX_TYPE = os.environ.get(
"DOCUMENT_INDEX_TYPE", DocumentIndexType.COMBINED.value "DOCUMENT_INDEX_TYPE", DocumentIndexType.COMBINED.value

View File

@@ -6,7 +6,6 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
@@ -204,11 +203,8 @@ def handle_message(
answer = _get_answer( answer = _get_answer(
QuestionRequest( QuestionRequest(
query=msg, query=msg,
collection=DOCUMENT_INDEX_NAME,
enable_auto_detect_filters=not disable_auto_detect_filters,
filters=filters, filters=filters,
favor_recent=None, enable_auto_detect_filters=not disable_auto_detect_filters,
offset=None,
) )
) )
except Exception as e: except Exception as e:

View File

@@ -64,7 +64,7 @@ class SearchQuery(BaseModel):
class RetrievalMetricsContainer(BaseModel): class RetrievalMetricsContainer(BaseModel):
keyword_search: bool # False for Vector Search search_type: SearchType
metrics: list[ChunkMetric] # This contains the scores for retrieval as well metrics: list[ChunkMetric] # This contains the scores for retrieval as well

View File

@@ -1,5 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Generator from collections.abc import Iterator
from copy import deepcopy from copy import deepcopy
from typing import cast from typing import cast
@@ -383,7 +383,9 @@ def retrieve_chunks(
for chunk in top_chunks for chunk in top_chunks
] ]
retrieval_metrics_callback( retrieval_metrics_callback(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics) RetrievalMetricsContainer(
search_type=query.search_type, metrics=chunk_metrics
)
) )
return top_chunks return top_chunks
@@ -468,7 +470,7 @@ def full_chunk_search_generator(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[bool], None, None]: ) -> Iterator[list[InferenceChunk] | list[bool]]:
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result.""" """Always yields twice. Once with the selected chunks and once with the LLM relevance filter result."""
chunks_yielded = False chunks_yielded = False
@@ -480,6 +482,11 @@ def full_chunk_search_generator(
retrieval_metrics_callback=retrieval_metrics_callback, retrieval_metrics_callback=retrieval_metrics_callback,
) )
if not retrieved_chunks:
yield cast(list[InferenceChunk], [])
yield cast(list[bool], [])
return
post_processing_tasks: list[FunctionCall] = [] post_processing_tasks: list[FunctionCall] = []
rerank_task_id = None rerank_task_id = None
@@ -549,7 +556,7 @@ def danswer_search_generator(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None, | None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]: ) -> Iterator[list[InferenceChunk] | list[bool] | int]:
"""The main entry point for search. This fetches the relevant documents from Vespa """The main entry point for search. This fetches the relevant documents from Vespa
based on the provided query (applying permissions / filters), does any specified based on the provided query (applying permissions / filters), does any specified
post-processing, and returns the results. It also create an entry in the query_event table post-processing, and returns the results. It also create an entry in the query_event table

View File

@@ -10,6 +10,7 @@ from pydantic import validator
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import AuthType from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
@@ -173,12 +174,12 @@ class SearchDoc(BaseModel):
class QuestionRequest(BaseModel): class QuestionRequest(BaseModel):
query: str query: str
collection: str
filters: BaseFilters filters: BaseFilters
offset: int | None collection: str = DOCUMENT_INDEX_NAME
enable_auto_detect_filters: bool
favor_recent: bool | None = None
search_type: SearchType = SearchType.HYBRID search_type: SearchType = SearchType.HYBRID
enable_auto_detect_filters: bool = True
favor_recent: bool | None = None
offset: int | None = None
class QAFeedbackRequest(BaseModel): class QAFeedbackRequest(BaseModel):

View File

@@ -8,7 +8,6 @@ from typing import TextIO
import yaml import yaml
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.access.access import get_acl_for_user
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.models import LLMMetricsContainer
@@ -80,14 +79,12 @@ def get_answer_for_question(
source_type=None, source_type=None,
document_set=None, document_set=None,
time_cutoff=None, time_cutoff=None,
access_control_list=list(get_acl_for_user(user=None)), access_control_list=None,
) )
question = QuestionRequest( question = QuestionRequest(
query=query, query=query,
collection="danswer_index",
filters=filters, filters=filters,
enable_auto_detect_filters=False, enable_auto_detect_filters=False,
offset=None,
) )
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
@@ -101,6 +98,7 @@ def get_answer_for_question(
answer_generation_timeout=100, answer_generation_timeout=100,
real_time_flow=False, real_time_flow=False,
enable_reflexion=False, enable_reflexion=False,
bypass_acl=True,
retrieval_metrics_callback=retrieval_metrics.record_metric, retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric,
llm_metrics_callback=llm_metrics.record_metric, llm_metrics_callback=llm_metrics.record_metric,

View File

@@ -0,0 +1,195 @@
import argparse
import builtins
import json
from contextlib import contextmanager
from typing import Any
from typing import TextIO
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.search_runner import danswer_search
from danswer.server.models import QuestionRequest
from danswer.utils.callbacks import MetricsHander
engine = get_sqlalchemy_engine()
@contextmanager
def redirect_print_to_file(file: TextIO) -> Any:
original_print = builtins.print
builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
try:
yield
finally:
builtins.print = original_print
def read_json(file_path: str) -> dict:
with open(file_path, "r") as file:
return json.load(file)
def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str:
words = s.split()
current_line: list[str] = []
result_lines: list[str] = []
current_length = 0
for word in words:
if len(word) > max_line_size:
if current_line:
result_lines.append(" ".join(current_line))
current_line = []
current_length = 0
result_lines.append(word)
continue
if current_length + len(word) + len(current_line) > max_line_size:
result_lines.append(" ".join(current_line))
current_line = []
current_length = 0
current_line.append(word)
current_length += len(word)
if current_line:
result_lines.append(" ".join(current_line))
return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines)
def get_search_results(
query: str, db_session: Session
) -> tuple[
list[InferenceChunk],
list[bool],
RetrievalMetricsContainer | None,
RerankMetricsContainer | None,
]:
filters = IndexFilters(
source_type=None,
document_set=None,
time_cutoff=None,
access_control_list=None,
)
question = QuestionRequest(
query=query,
filters=filters,
enable_auto_detect_filters=False,
)
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
rerank_metrics = MetricsHander[RerankMetricsContainer]()
chunks, llm_filter, query_id = danswer_search(
question=question,
user=None,
db_session=db_session,
document_index=get_default_document_index(),
bypass_acl=True,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
)
return (
chunks,
llm_filter,
retrieval_metrics.metrics,
rerank_metrics.metrics,
)
def _print_retrieval_metrics(
metrics_container: RetrievalMetricsContainer, show_all: bool
) -> None:
for ind, metric in enumerate(metrics_container.metrics):
if not show_all and ind >= 10:
break
if ind != 0:
print() # for spacing purposes
print(f"\tDocument: {metric.document_id}")
print(f"\tLink: {metric.first_link or 'NA'}")
section_start = metric.chunk_content_start.replace("\n", " ")
print(f"\tSection Start: {section_start}")
print(f"\tSimilarity Distance Metric: {metric.score}")
def _print_reranking_metrics(
metrics_container: RerankMetricsContainer, show_all: bool
) -> None:
# Printing the raw scores as they're more informational than post-norm/boosting
for ind, metric in enumerate(metrics_container.metrics):
if not show_all and ind >= 10:
break
if ind != 0:
print() # for spacing purposes
print(f"\tDocument: {metric.document_id}")
print(f"\tLink: {metric.first_link or 'NA'}")
section_start = metric.chunk_content_start.replace("\n", " ")
print(f"\tSection Start: {section_start}")
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
def main(questions_json: str, output_file: str) -> None:
questions_info = read_json(questions_json)
with open(output_file, "w") as outfile:
with redirect_print_to_file(outfile):
print("Running Document Retrieval Test\n")
with Session(engine, expire_on_commit=False) as db_session:
for question, targets in questions_info.items():
print(f"Question: {question}")
(
chunks,
llm_filters,
retrieval_metrics,
rerank_metrics,
) = get_search_results(query=question, db_session=db_session)
print("\nRetrieval Metrics:")
if retrieval_metrics is None:
print("No Retrieval Metrics Available")
else:
_print_retrieval_metrics(
retrieval_metrics, show_all=args.all_results
)
print("\nReranking Metrics:")
if rerank_metrics is None:
print("No Reranking Metrics Available")
else:
_print_reranking_metrics(
rerank_metrics, show_all=args.all_results
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"regression_questions_json",
type=str,
help="Path to the Questions JSON file.",
default="./tests/regression/search_quality/test_questions.json",
nargs="?",
)
parser.add_argument(
"--output_file",
type=str,
help="Path to the output results file.",
default="./tests/regression/search_quality/regression_results.txt",
)
args = parser.parse_args()
main(args.regression_questions_json, args.output_file)