Files
danswer/backend/tests/regression/search_quality/run_search_eval.py
Rei Meguro 9dbe12cea8 Feat: Search Eval Testing Overhaul (provide ground truth, categorize query, etc.) (#4739)
* fix: autoflake & import order

* docs: readme

* fix: mypy

* feat: eval

* docs: readme

* fix: oops forgot to remove comment

* fix: typo

* fix: rename var

* updated default config

* fix: config issue

* oops

* fix: black

* fix: eval and config

* feat: non tool calling query mod
2025-05-21 19:25:10 +00:00

162 lines
6.5 KiB
Python

import csv
from collections import defaultdict
from pathlib import Path
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.context.search.models import RerankingDetails
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import SqlEngine
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
from onyx.document_index.factory import get_default_document_index
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from tests.regression.search_quality.util_config import load_config
from tests.regression.search_quality.util_data import export_test_queries
from tests.regression.search_quality.util_data import load_test_queries
from tests.regression.search_quality.util_eval import evaluate_one_query
from tests.regression.search_quality.util_eval import get_corresponding_document
from tests.regression.search_quality.util_eval import metric_names
from tests.regression.search_quality.util_retrieve import rerank_one_query
from tests.regression.search_quality.util_retrieve import search_one_query
logger = setup_logger(__name__)
def run_search_eval() -> None:
config = load_config()
test_queries = load_test_queries()
# export related
export_path = Path(config.export_folder)
export_test_queries(test_queries, export_path / "test_queries.json")
search_result_path = export_path / "search_results.csv"
eval_path = export_path / "eval_results.csv"
aggregate_eval_path = export_path / "aggregate_eval.csv"
aggregate_results: dict[str, list[list[float]]] = defaultdict(
lambda: [[] for _ in metric_names]
)
with get_session_with_current_tenant() as db_session:
multilingual_expansion = get_multilingual_expansion(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
rerank_settings = RerankingDetails.from_db_model(search_settings)
if config.skip_rerank:
logger.warning("Reranking is disabled, evaluation will not run")
elif rerank_settings.rerank_model_name is None:
raise ValueError(
"Reranking is enabled but no reranker is configured. "
"Please set the reranker in the admin panel search settings."
)
# run search and evaluate
logger.info(
"Running search and evaluation... "
f"Individual search and evaluation results will be saved to {search_result_path} and {eval_path}"
)
with (
search_result_path.open("w") as search_file,
eval_path.open("w") as eval_file,
):
search_csv_writer = csv.writer(search_file)
eval_csv_writer = csv.writer(eval_file)
search_csv_writer.writerow(
["source", "query", "rank", "score", "doc_id", "chunk_id"]
)
eval_csv_writer.writerow(["query", *metric_names])
for query in test_queries:
# search and write results
assert query.question_search is not None
search_chunks = search_one_query(
query.question_search,
multilingual_expansion,
document_index,
db_session,
config,
)
for rank, result in enumerate(search_chunks):
search_csv_writer.writerow(
[
"search",
query.question_search,
rank,
result.score,
result.document_id,
result.chunk_id,
]
)
rerank_chunks = []
if not config.skip_rerank:
# rerank and write results
rerank_chunks = rerank_one_query(
query.question, search_chunks, rerank_settings
)
for rank, result in enumerate(rerank_chunks):
search_csv_writer.writerow(
[
"rerank",
query.question,
rank,
result.score,
result.document_id,
result.chunk_id,
]
)
# evaluate and write results
truth_documents = [
doc
for truth in query.ground_truth
if (doc := get_corresponding_document(truth.doc_link, db_session))
]
metrics = evaluate_one_query(
search_chunks, rerank_chunks, truth_documents, config.eval_topk
)
metric_vals = [getattr(metrics, field) for field in metric_names]
eval_csv_writer.writerow([query.question, *metric_vals])
# add to aggregation
for category in ["all"] + query.categories:
for i, val in enumerate(metric_vals):
if val is not None:
aggregate_results[category][i].append(val)
# aggregate and write results
with aggregate_eval_path.open("w") as file:
aggregate_csv_writer = csv.writer(file)
aggregate_csv_writer.writerow(["category", *metric_names])
for category, agg_metrics in aggregate_results.items():
aggregate_csv_writer.writerow(
[
category,
*(
sum(metric) / len(metric) if metric else None
for metric in agg_metrics
),
]
)
if __name__ == "__main__":
if MULTI_TENANT:
raise ValueError("Multi-tenant is not supported currently")
SqlEngine.init_engine(
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
)
try:
run_search_eval()
except Exception as e:
logger.error(f"Error running search evaluation: {e}")
raise e
finally:
SqlEngine.reset_engine()