mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 05:43:33 +02:00
* fix: autoflake & import order * docs: readme * fix: mypy * feat: eval * docs: readme * fix: oops forgot to remove comment * fix: typo * fix: rename var * updated default config * fix: config issue * oops * fix: black * fix: eval and config * feat: non tool calling query mod
162 lines
6.5 KiB
Python
162 lines
6.5 KiB
Python
import csv
|
|
from collections import defaultdict
|
|
from pathlib import Path
|
|
|
|
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
|
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
|
from onyx.context.search.models import RerankingDetails
|
|
from onyx.db.engine import get_session_with_current_tenant
|
|
from onyx.db.engine import SqlEngine
|
|
from onyx.db.search_settings import get_current_search_settings
|
|
from onyx.db.search_settings import get_multilingual_expansion
|
|
from onyx.document_index.factory import get_default_document_index
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import MULTI_TENANT
|
|
from tests.regression.search_quality.util_config import load_config
|
|
from tests.regression.search_quality.util_data import export_test_queries
|
|
from tests.regression.search_quality.util_data import load_test_queries
|
|
from tests.regression.search_quality.util_eval import evaluate_one_query
|
|
from tests.regression.search_quality.util_eval import get_corresponding_document
|
|
from tests.regression.search_quality.util_eval import metric_names
|
|
from tests.regression.search_quality.util_retrieve import rerank_one_query
|
|
from tests.regression.search_quality.util_retrieve import search_one_query
|
|
|
|
logger = setup_logger(__name__)
|
|
|
|
|
|
def run_search_eval() -> None:
|
|
config = load_config()
|
|
test_queries = load_test_queries()
|
|
|
|
# export related
|
|
export_path = Path(config.export_folder)
|
|
export_test_queries(test_queries, export_path / "test_queries.json")
|
|
search_result_path = export_path / "search_results.csv"
|
|
eval_path = export_path / "eval_results.csv"
|
|
aggregate_eval_path = export_path / "aggregate_eval.csv"
|
|
aggregate_results: dict[str, list[list[float]]] = defaultdict(
|
|
lambda: [[] for _ in metric_names]
|
|
)
|
|
|
|
with get_session_with_current_tenant() as db_session:
|
|
multilingual_expansion = get_multilingual_expansion(db_session)
|
|
search_settings = get_current_search_settings(db_session)
|
|
document_index = get_default_document_index(search_settings, None)
|
|
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
|
|
|
if config.skip_rerank:
|
|
logger.warning("Reranking is disabled, evaluation will not run")
|
|
elif rerank_settings.rerank_model_name is None:
|
|
raise ValueError(
|
|
"Reranking is enabled but no reranker is configured. "
|
|
"Please set the reranker in the admin panel search settings."
|
|
)
|
|
|
|
# run search and evaluate
|
|
logger.info(
|
|
"Running search and evaluation... "
|
|
f"Individual search and evaluation results will be saved to {search_result_path} and {eval_path}"
|
|
)
|
|
with (
|
|
search_result_path.open("w") as search_file,
|
|
eval_path.open("w") as eval_file,
|
|
):
|
|
search_csv_writer = csv.writer(search_file)
|
|
eval_csv_writer = csv.writer(eval_file)
|
|
search_csv_writer.writerow(
|
|
["source", "query", "rank", "score", "doc_id", "chunk_id"]
|
|
)
|
|
eval_csv_writer.writerow(["query", *metric_names])
|
|
|
|
for query in test_queries:
|
|
# search and write results
|
|
assert query.question_search is not None
|
|
search_chunks = search_one_query(
|
|
query.question_search,
|
|
multilingual_expansion,
|
|
document_index,
|
|
db_session,
|
|
config,
|
|
)
|
|
for rank, result in enumerate(search_chunks):
|
|
search_csv_writer.writerow(
|
|
[
|
|
"search",
|
|
query.question_search,
|
|
rank,
|
|
result.score,
|
|
result.document_id,
|
|
result.chunk_id,
|
|
]
|
|
)
|
|
|
|
rerank_chunks = []
|
|
if not config.skip_rerank:
|
|
# rerank and write results
|
|
rerank_chunks = rerank_one_query(
|
|
query.question, search_chunks, rerank_settings
|
|
)
|
|
for rank, result in enumerate(rerank_chunks):
|
|
search_csv_writer.writerow(
|
|
[
|
|
"rerank",
|
|
query.question,
|
|
rank,
|
|
result.score,
|
|
result.document_id,
|
|
result.chunk_id,
|
|
]
|
|
)
|
|
|
|
# evaluate and write results
|
|
truth_documents = [
|
|
doc
|
|
for truth in query.ground_truth
|
|
if (doc := get_corresponding_document(truth.doc_link, db_session))
|
|
]
|
|
metrics = evaluate_one_query(
|
|
search_chunks, rerank_chunks, truth_documents, config.eval_topk
|
|
)
|
|
metric_vals = [getattr(metrics, field) for field in metric_names]
|
|
eval_csv_writer.writerow([query.question, *metric_vals])
|
|
|
|
# add to aggregation
|
|
for category in ["all"] + query.categories:
|
|
for i, val in enumerate(metric_vals):
|
|
if val is not None:
|
|
aggregate_results[category][i].append(val)
|
|
|
|
# aggregate and write results
|
|
with aggregate_eval_path.open("w") as file:
|
|
aggregate_csv_writer = csv.writer(file)
|
|
aggregate_csv_writer.writerow(["category", *metric_names])
|
|
|
|
for category, agg_metrics in aggregate_results.items():
|
|
aggregate_csv_writer.writerow(
|
|
[
|
|
category,
|
|
*(
|
|
sum(metric) / len(metric) if metric else None
|
|
for metric in agg_metrics
|
|
),
|
|
]
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if MULTI_TENANT:
|
|
raise ValueError("Multi-tenant is not supported currently")
|
|
|
|
SqlEngine.init_engine(
|
|
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
|
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
|
)
|
|
|
|
try:
|
|
run_search_eval()
|
|
except Exception as e:
|
|
logger.error(f"Error running search evaluation: {e}")
|
|
raise e
|
|
finally:
|
|
SqlEngine.reset_engine()
|