Regression Script for Search quality (#760)

This commit is contained in:
Yuhong Sun 2023-11-22 19:33:28 -08:00 committed by GitHub
parent bdfb894507
commit fda377a2fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 216 additions and 19 deletions

View File

@ -81,7 +81,7 @@ SMTP_PASS = os.environ.get("SMTP_PASS", "your-gmail-password")
#####
# DB Configs
#####
DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices
DOCUMENT_INDEX_NAME = "danswer_index"
# Vespa is now the default document index store for both keyword and vector
DOCUMENT_INDEX_TYPE = os.environ.get(
"DOCUMENT_INDEX_TYPE", DocumentIndexType.COMBINED.value

View File

@ -6,7 +6,6 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
@ -204,11 +203,8 @@ def handle_message(
answer = _get_answer(
QuestionRequest(
query=msg,
collection=DOCUMENT_INDEX_NAME,
enable_auto_detect_filters=not disable_auto_detect_filters,
filters=filters,
favor_recent=None,
offset=None,
enable_auto_detect_filters=not disable_auto_detect_filters,
)
)
except Exception as e:

View File

@ -64,7 +64,7 @@ class SearchQuery(BaseModel):
class RetrievalMetricsContainer(BaseModel):
keyword_search: bool # False for Vector Search
search_type: SearchType
metrics: list[ChunkMetric] # This contains the scores for retrieval as well

View File

@ -1,5 +1,5 @@
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from copy import deepcopy
from typing import cast
@ -383,7 +383,9 @@ def retrieve_chunks(
for chunk in top_chunks
]
retrieval_metrics_callback(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
RetrievalMetricsContainer(
search_type=query.search_type, metrics=chunk_metrics
)
)
return top_chunks
@ -468,7 +470,7 @@ def full_chunk_search_generator(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[bool], None, None]:
) -> Iterator[list[InferenceChunk] | list[bool]]:
"""Always yields twice. Once with the selected chunks and once with the LLM relevance filter result."""
chunks_yielded = False
@ -480,6 +482,11 @@ def full_chunk_search_generator(
retrieval_metrics_callback=retrieval_metrics_callback,
)
if not retrieved_chunks:
yield cast(list[InferenceChunk], [])
yield cast(list[bool], [])
return
post_processing_tasks: list[FunctionCall] = []
rerank_task_id = None
@ -549,7 +556,7 @@ def danswer_search_generator(
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[bool] | int, None, None]:
) -> Iterator[list[InferenceChunk] | list[bool] | int]:
"""The main entry point for search. This fetches the relevant documents from Vespa
based on the provided query (applying permissions / filters), does any specified
post-processing, and returns the results. It also create an entry in the query_event table

View File

@ -10,6 +10,7 @@ from pydantic import validator
from pydantic.generics import GenericModel
from danswer.auth.schemas import UserRole
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentSource
@ -173,12 +174,12 @@ class SearchDoc(BaseModel):
class QuestionRequest(BaseModel):
query: str
collection: str
filters: BaseFilters
offset: int | None
enable_auto_detect_filters: bool
favor_recent: bool | None = None
collection: str = DOCUMENT_INDEX_NAME
search_type: SearchType = SearchType.HYBRID
enable_auto_detect_filters: bool = True
favor_recent: bool | None = None
offset: int | None = None
class QAFeedbackRequest(BaseModel):

View File

@ -8,7 +8,6 @@ from typing import TextIO
import yaml
from sqlalchemy.orm import Session
from danswer.access.access import get_acl_for_user
from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.models import LLMMetricsContainer
@ -80,14 +79,12 @@ def get_answer_for_question(
source_type=None,
document_set=None,
time_cutoff=None,
access_control_list=list(get_acl_for_user(user=None)),
access_control_list=None,
)
question = QuestionRequest(
query=query,
collection="danswer_index",
filters=filters,
enable_auto_detect_filters=False,
offset=None,
)
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
@ -101,6 +98,7 @@ def get_answer_for_question(
answer_generation_timeout=100,
real_time_flow=False,
enable_reflexion=False,
bypass_acl=True,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
llm_metrics_callback=llm_metrics.record_metric,

View File

@ -0,0 +1,195 @@
import argparse
import builtins
import json
from contextlib import contextmanager
from typing import Any
from typing import TextIO
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import InferenceChunk
from danswer.search.models import IndexFilters
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.search_runner import danswer_search
from danswer.server.models import QuestionRequest
from danswer.utils.callbacks import MetricsHander
engine = get_sqlalchemy_engine()
@contextmanager
def redirect_print_to_file(file: TextIO) -> Any:
original_print = builtins.print
builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
try:
yield
finally:
builtins.print = original_print
def read_json(file_path: str) -> dict:
with open(file_path, "r") as file:
return json.load(file)
def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str:
words = s.split()
current_line: list[str] = []
result_lines: list[str] = []
current_length = 0
for word in words:
if len(word) > max_line_size:
if current_line:
result_lines.append(" ".join(current_line))
current_line = []
current_length = 0
result_lines.append(word)
continue
if current_length + len(word) + len(current_line) > max_line_size:
result_lines.append(" ".join(current_line))
current_line = []
current_length = 0
current_line.append(word)
current_length += len(word)
if current_line:
result_lines.append(" ".join(current_line))
return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines)
def get_search_results(
query: str, db_session: Session
) -> tuple[
list[InferenceChunk],
list[bool],
RetrievalMetricsContainer | None,
RerankMetricsContainer | None,
]:
filters = IndexFilters(
source_type=None,
document_set=None,
time_cutoff=None,
access_control_list=None,
)
question = QuestionRequest(
query=query,
filters=filters,
enable_auto_detect_filters=False,
)
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
rerank_metrics = MetricsHander[RerankMetricsContainer]()
chunks, llm_filter, query_id = danswer_search(
question=question,
user=None,
db_session=db_session,
document_index=get_default_document_index(),
bypass_acl=True,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
)
return (
chunks,
llm_filter,
retrieval_metrics.metrics,
rerank_metrics.metrics,
)
def _print_retrieval_metrics(
metrics_container: RetrievalMetricsContainer, show_all: bool
) -> None:
for ind, metric in enumerate(metrics_container.metrics):
if not show_all and ind >= 10:
break
if ind != 0:
print() # for spacing purposes
print(f"\tDocument: {metric.document_id}")
print(f"\tLink: {metric.first_link or 'NA'}")
section_start = metric.chunk_content_start.replace("\n", " ")
print(f"\tSection Start: {section_start}")
print(f"\tSimilarity Distance Metric: {metric.score}")
def _print_reranking_metrics(
metrics_container: RerankMetricsContainer, show_all: bool
) -> None:
# Printing the raw scores as they're more informational than post-norm/boosting
for ind, metric in enumerate(metrics_container.metrics):
if not show_all and ind >= 10:
break
if ind != 0:
print() # for spacing purposes
print(f"\tDocument: {metric.document_id}")
print(f"\tLink: {metric.first_link or 'NA'}")
section_start = metric.chunk_content_start.replace("\n", " ")
print(f"\tSection Start: {section_start}")
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
def main(questions_json: str, output_file: str) -> None:
questions_info = read_json(questions_json)
with open(output_file, "w") as outfile:
with redirect_print_to_file(outfile):
print("Running Document Retrieval Test\n")
with Session(engine, expire_on_commit=False) as db_session:
for question, targets in questions_info.items():
print(f"Question: {question}")
(
chunks,
llm_filters,
retrieval_metrics,
rerank_metrics,
) = get_search_results(query=question, db_session=db_session)
print("\nRetrieval Metrics:")
if retrieval_metrics is None:
print("No Retrieval Metrics Available")
else:
_print_retrieval_metrics(
retrieval_metrics, show_all=args.all_results
)
print("\nReranking Metrics:")
if rerank_metrics is None:
print("No Reranking Metrics Available")
else:
_print_reranking_metrics(
rerank_metrics, show_all=args.all_results
)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"regression_questions_json",
type=str,
help="Path to the Questions JSON file.",
default="./tests/regression/search_quality/test_questions.json",
nargs="?",
)
parser.add_argument(
"--output_file",
type=str,
help="Path to the output results file.",
default="./tests/regression/search_quality/regression_results.txt",
)
args = parser.parse_args()
main(args.regression_questions_json, args.output_file)