Feat: Search Eval Testing Overhaul (provide ground truth, categorize query, etc.) (#4739)

* fix: autoflake & import order

* docs: readme

* fix: mypy

* feat: eval

* docs: readme

* fix: oops forgot to remove comment

* fix: typo

* fix: rename var

* updated default config

* fix: config issue

* oops

* fix: black

* fix: eval and config

* feat: non tool calling query mod
This commit is contained in:
Rei Meguro
2025-05-21 12:25:10 -07:00
committed by GitHub
parent e78637d632
commit 9dbe12cea8
10 changed files with 614 additions and 505 deletions

View File

@@ -2,7 +2,9 @@
This Python script evaluates the search results for a list of queries.
Unlike the script in answer_quality, this script is much less customizable and runs using currently ingested documents, though it allows for quick testing of search parameters on a bunch of test queries that don't have well-defined answers.
This script will likely get refactored in the future as an API endpoint.
In the meanwhile, it is used to evaluate the search quality using locally ingested documents.
The key differentiating factor with `answer_quality` is that it can evaluate results without explicit "ground truth" using the reranker as a reference.
## Usage
@@ -25,32 +27,36 @@ This can be checked/modified by opening the admin panel, going to search setting
cd path/to/onyx/backend/tests/regression/search_quality
```
5. Copy `search_queries.json.template` to `search_queries.json` and add/remove test queries in it
5. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The possible fields are:
6. Run `generate_search_queries.py` to generate the modified queries for the search pipeline
- `question: str` the query
- `question_search: Optional[str]` modified query specifically for the search step
- `ground_truth: Optional[list[GroundTruth]]` a ranked list of expected search results with fields:
- `doc_source: str` document source (e.g., Web, Drive, Linear), currently unused
- `doc_link: str` link associated with document, used to find corresponding document in local index
- `categories: Optional[list[str]]` list of categories, used to aggregate evaluation results
```
python generate_search_queries.py
```
6. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters
7. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters
8. Run `run_search_eval.py` to evaluate the search results against the reranked results
7. Run `run_search_eval.py` to run the search and evaluate the search results
```
python run_search_eval.py
```
9. Repeat steps 7 and 8 to test and compare different search parameters
8. Optionally, save the generated `test_queries.json` in the export folder to reuse the generated `question_search`, and rerun the search evaluation with alternative search parameters.
## Metrics
- Jaccard Similarity: the ratio between the intersect and the union between the topk search and rerank results. Higher is better
- Average Rank Change: The average absolute rank difference of the topk reranked chunks vs the entire search chunks. Lower is better
- Average Missing Chunk Ratio: The number of chunks in the topk reranked chunks not in the topk search chunks, over topk. Lower is better
There are two main metrics currently implemented:
- ratio_topk: the ratio of documents in the comparison set that are in the topk search results (higher is better, 0-1)
- avg_rank_delta: the average rank difference between the comparison set and search results (lower is better, 0-inf)
Note that all of these metrics are affected by very narrow search results.
E.g., if topk is 20 but there is only 1 relevant document, the other 19 documents could be ordered arbitrarily, resulting in a lower score.
Ratio topk gives a general idea on whether the most relevant documents are appearing first in the search results. Decreasing `eval_topk` will make this metric stricter, requiring relevant documents to appear in a narrow window.
Avg rank delta is another metric which can give insight on the performance of documents not in the topk search results. If none of the comparison documents are in the topk, `ratio_topk` will only show a 0, whereas `avg_rank_delta` will show a higher value the worse the search results gets.
To address this limitation, there are score adjusted versions of the metrics.
The score adjusted version does not use a fixed topk, but computes the optimum topk based on the rerank scores.
This generally works in determining how many documents are relevant, although note that this approach isn't perfect.
Furthermore, there are two versions of the metrics: ground truth, and soft truth.
The ground truth includes documents explicitly listed as relevant in the test dataset. The ground truth metrics will only be computed if a ground truth set is provided for the question and exists in the index.
The soft truth is built on top of the ground truth (if provided), filling the remaining entries with results from the reranker. The soft truth metrics will only be computed if `skip_rerank` is false. Computing the soft truth metric can be extremely slow, especially for large `num_returned_hits`. However, it can provide a good basis when there are many relevant documents in no particular order, or for running quick tests without explicitly having to mention which documents are relevant.

View File

@@ -1,133 +0,0 @@
import json
from pathlib import Path
from langgraph.types import StreamWriter
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import SqlEngine
from onyx.db.persona import get_persona_by_id
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.interfaces import LLM
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
def _load_queries() -> list[str]:
current_dir = Path(__file__).parent
search_queries_path = current_dir / "search_queries.json"
if not search_queries_path.exists():
raise FileNotFoundError(
f"Search queries file not found at {search_queries_path}"
)
with search_queries_path.open("r") as file:
return json.load(file)
def _modify_one_query(
query: str,
llm: LLM,
prompt_config: PromptConfig,
tool_definition: dict,
writer: StreamWriter = lambda _: None,
) -> str:
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=query,
prompt_config=prompt_config,
files=[],
single_message_history=None,
),
system_message=default_build_system_message(prompt_config, llm.config),
message_history=[],
llm_config=llm.config,
raw_user_query=query,
raw_user_uploaded_files=[],
single_message_history=None,
)
prompt = prompt_builder.build()
stream = llm.stream(
prompt=prompt,
tools=[tool_definition],
tool_choice="required",
structured_response_format=None,
)
tool_message = process_llm_stream(
messages=stream,
should_stream_answer=False,
writer=writer,
)
return (
tool_message.tool_calls[0]["args"]["query"]
if tool_message.tool_calls
else query
)
class SearchToolOverride(SearchTool):
def __init__(self) -> None:
# do nothing, the tool_definition function doesn't require variables to be initialized
pass
def generate_search_queries() -> None:
if MULTI_TENANT:
raise ValueError("Multi-tenant is not supported currently")
SqlEngine.init_engine(
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
)
queries = _load_queries()
with get_session_with_current_tenant() as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
llm, _ = get_llms_for_persona(persona)
prompt_config = PromptConfig.from_model(persona.prompts[0])
tool_definition = SearchToolOverride().tool_definition()
tool_call_supported = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
if tool_call_supported:
logger.info(
"Tool calling is supported for the current model. Modifying queries."
)
modified_queries = [
_modify_one_query(
query=query,
llm=llm,
prompt_config=prompt_config,
tool_definition=tool_definition,
)
for query in queries
]
else:
logger.warning(
"Tool calling is not supported for the current model. "
"Using the original queries."
)
modified_queries = queries
with open("search_queries_modified.json", "w") as file:
json.dump(modified_queries, file, indent=4)
logger.info("Exported modified queries to search_queries_modified.json")
if __name__ == "__main__":
generate_search_queries()

View File

@@ -1,242 +1,149 @@
import csv
import json
from bisect import bisect_left
from datetime import datetime
from collections import defaultdict
from pathlib import Path
from typing import cast
import yaml
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import HYBRID_ALPHA
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import semantic_reranking
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.context.search.utils import remove_stop_words_and_punctuation
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import SqlEngine
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import DocumentIndex
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from tests.regression.search_quality.util_config import load_config
from tests.regression.search_quality.util_data import export_test_queries
from tests.regression.search_quality.util_data import load_test_queries
from tests.regression.search_quality.util_eval import evaluate_one_query
from tests.regression.search_quality.util_eval import get_corresponding_document
from tests.regression.search_quality.util_eval import metric_names
from tests.regression.search_quality.util_retrieve import rerank_one_query
from tests.regression.search_quality.util_retrieve import search_one_query
logger = setup_logger(__name__)
class SearchEvalParameters(BaseModel):
hybrid_alpha: float
hybrid_alpha_keyword: float
doc_time_decay: float
num_returned_hits: int
rank_profile: QueryExpansionType
offset: int
title_content_ratio: float
user_email: str | None
skip_rerank: bool
eval_topk: int
export_folder: str
def _load_search_parameters() -> SearchEvalParameters:
current_dir = Path(__file__).parent
config_path = current_dir / "search_eval_config.yaml"
if not config_path.exists():
raise FileNotFoundError(f"Search eval config file not found at {config_path}")
with config_path.open("r") as file:
config = yaml.safe_load(file)
export_folder = config.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S")
export_folder = datetime.now().strftime(export_folder)
export_path = Path(export_folder)
export_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created export folder: {export_path}")
search_parameters = SearchEvalParameters(
hybrid_alpha=config.get("HYBRID_ALPHA") or HYBRID_ALPHA,
hybrid_alpha_keyword=config.get("HYBRID_ALPHA_KEYWORD") or HYBRID_ALPHA_KEYWORD,
doc_time_decay=config.get("DOC_TIME_DECAY") or DOC_TIME_DECAY,
num_returned_hits=config.get("NUM_RETURNED_HITS") or NUM_RETURNED_HITS,
rank_profile=config.get("RANK_PROFILE") or QueryExpansionType.SEMANTIC,
offset=config.get("OFFSET") or 0,
title_content_ratio=config.get("TITLE_CONTENT_RATIO") or TITLE_CONTENT_RATIO,
user_email=config.get("USER_EMAIL"),
skip_rerank=config.get("SKIP_RERANK", False),
eval_topk=config.get("EVAL_TOPK", 20),
export_folder=export_folder,
)
logger.info(f"Using search parameters: {search_parameters}")
config_file = export_path / "search_eval_config.yaml"
with config_file.open("w") as file:
search_parameters_dict = search_parameters.model_dump(mode="python")
search_parameters_dict["rank_profile"] = search_parameters.rank_profile.value
yaml.dump(search_parameters_dict, file, sort_keys=False)
logger.info(f"Exported config to {config_file}")
return search_parameters
def _load_query_pairs() -> list[tuple[str, str]]:
current_dir = Path(__file__).parent
search_queries_path = current_dir / "search_queries.json"
if not search_queries_path.exists():
raise FileNotFoundError(
f"Search queries file not found at {search_queries_path}"
)
with search_queries_path.open("r") as file:
orig_queries = json.load(file)
alt_queries_path = current_dir / "search_queries_modified.json"
if not alt_queries_path.exists():
raise FileNotFoundError(
f"Modified search queries file not found at {alt_queries_path}. "
"Try running generate_search_queries.py."
)
with alt_queries_path.open("r") as file:
alt_queries = json.load(file)
if len(orig_queries) != len(alt_queries):
raise ValueError(
"Number of original and modified queries must be the same. "
"Try running generate_search_queries.py again."
)
return list(zip(orig_queries, alt_queries))
def _search_one_query(
alt_query: str,
multilingual_expansion: list[str],
document_index: DocumentIndex,
db_session: Session,
search_parameters: SearchEvalParameters,
) -> list[InferenceChunk]:
# the retrieval preprocessing is fairly stripped down so the query doesn't unexpectly change
query_embedding = get_query_embedding(alt_query, db_session)
all_query_terms = alt_query.split()
processed_keywords = (
remove_stop_words_and_punctuation(all_query_terms)
if not multilingual_expansion
else all_query_terms
)
is_keyword = query_analysis(alt_query)[0]
hybrid_alpha = (
search_parameters.hybrid_alpha_keyword
if is_keyword
else search_parameters.hybrid_alpha
)
access_control_list = ["PUBLIC"]
if search_parameters.user_email:
access_control_list.append(f"user_email:{search_parameters.user_email}")
filters = IndexFilters(
tags=[],
user_file_ids=[],
user_folder_ids=[],
access_control_list=access_control_list,
tenant_id=None,
)
results = document_index.hybrid_retrieval(
query=alt_query,
query_embedding=query_embedding,
final_keywords=processed_keywords,
filters=filters,
hybrid_alpha=hybrid_alpha,
time_decay_multiplier=search_parameters.doc_time_decay,
num_to_retrieve=search_parameters.num_returned_hits,
ranking_profile_type=search_parameters.rank_profile,
offset=search_parameters.offset,
title_content_ratio=search_parameters.title_content_ratio,
)
return [result.to_inference_chunk() for result in results]
def _rerank_one_query(
orig_query: str,
retrieved_chunks: list[InferenceChunk],
rerank_settings: RerankingDetails,
search_parameters: SearchEvalParameters,
) -> list[InferenceChunk]:
assert not search_parameters.skip_rerank, "Reranking is disabled"
return semantic_reranking(
query_str=orig_query,
rerank_settings=rerank_settings,
chunks=retrieved_chunks,
rerank_metrics_callback=None,
)[0]
def _evaluate_one_query(
search_results: list[InferenceChunk],
rerank_results: list[InferenceChunk],
search_parameters: SearchEvalParameters,
) -> list[float]:
search_topk = search_results[: search_parameters.eval_topk]
rerank_topk = rerank_results[: search_parameters.eval_topk]
# get the score adjusted topk (topk where the score is at least 50% of the top score)
# could be more than topk if top scores are similar, may or may not be a good thing
# can change by swapping rerank_results with rerank_topk in bisect
adj_topk = bisect_left(
rerank_results,
-0.5 * cast(float, rerank_results[0].score),
key=lambda x: -cast(float, x.score),
)
search_adj_topk = search_results[:adj_topk]
rerank_adj_topk = rerank_results[:adj_topk]
# compute metrics
search_ranks = {chunk.unique_id: rank for rank, chunk in enumerate(search_results)}
return [
*_compute_jaccard_and_missing_chunks_ratio(search_topk, rerank_topk),
_compute_average_rank_change(search_ranks, rerank_topk),
# score adjusted metrics
*_compute_jaccard_and_missing_chunks_ratio(search_adj_topk, rerank_adj_topk),
_compute_average_rank_change(search_ranks, rerank_adj_topk),
]
def _compute_jaccard_and_missing_chunks_ratio(
search_topk: list[InferenceChunk], rerank_topk: list[InferenceChunk]
) -> tuple[float, float]:
search_chunkids = {chunk.unique_id for chunk in search_topk}
rerank_chunkids = {chunk.unique_id for chunk in rerank_topk}
jaccard_similarity = len(search_chunkids & rerank_chunkids) / len(
search_chunkids | rerank_chunkids
)
missing_chunks_ratio = len(rerank_chunkids - search_chunkids) / len(rerank_chunkids)
return jaccard_similarity, missing_chunks_ratio
def _compute_average_rank_change(
search_ranks: dict[str, int], rerank_topk: list[InferenceChunk]
) -> float:
rank_changes = [
abs(search_ranks[chunk.unique_id] - rerank_rank)
for rerank_rank, chunk in enumerate(rerank_topk)
]
return sum(rank_changes) / len(rank_changes)
def run_search_eval() -> None:
config = load_config()
test_queries = load_test_queries()
# export related
export_path = Path(config.export_folder)
export_test_queries(test_queries, export_path / "test_queries.json")
search_result_path = export_path / "search_results.csv"
eval_path = export_path / "eval_results.csv"
aggregate_eval_path = export_path / "aggregate_eval.csv"
aggregate_results: dict[str, list[list[float]]] = defaultdict(
lambda: [[] for _ in metric_names]
)
with get_session_with_current_tenant() as db_session:
multilingual_expansion = get_multilingual_expansion(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
rerank_settings = RerankingDetails.from_db_model(search_settings)
if config.skip_rerank:
logger.warning("Reranking is disabled, evaluation will not run")
elif rerank_settings.rerank_model_name is None:
raise ValueError(
"Reranking is enabled but no reranker is configured. "
"Please set the reranker in the admin panel search settings."
)
# run search and evaluate
logger.info(
"Running search and evaluation... "
f"Individual search and evaluation results will be saved to {search_result_path} and {eval_path}"
)
with (
search_result_path.open("w") as search_file,
eval_path.open("w") as eval_file,
):
search_csv_writer = csv.writer(search_file)
eval_csv_writer = csv.writer(eval_file)
search_csv_writer.writerow(
["source", "query", "rank", "score", "doc_id", "chunk_id"]
)
eval_csv_writer.writerow(["query", *metric_names])
for query in test_queries:
# search and write results
assert query.question_search is not None
search_chunks = search_one_query(
query.question_search,
multilingual_expansion,
document_index,
db_session,
config,
)
for rank, result in enumerate(search_chunks):
search_csv_writer.writerow(
[
"search",
query.question_search,
rank,
result.score,
result.document_id,
result.chunk_id,
]
)
rerank_chunks = []
if not config.skip_rerank:
# rerank and write results
rerank_chunks = rerank_one_query(
query.question, search_chunks, rerank_settings
)
for rank, result in enumerate(rerank_chunks):
search_csv_writer.writerow(
[
"rerank",
query.question,
rank,
result.score,
result.document_id,
result.chunk_id,
]
)
# evaluate and write results
truth_documents = [
doc
for truth in query.ground_truth
if (doc := get_corresponding_document(truth.doc_link, db_session))
]
metrics = evaluate_one_query(
search_chunks, rerank_chunks, truth_documents, config.eval_topk
)
metric_vals = [getattr(metrics, field) for field in metric_names]
eval_csv_writer.writerow([query.question, *metric_vals])
# add to aggregation
for category in ["all"] + query.categories:
for i, val in enumerate(metric_vals):
if val is not None:
aggregate_results[category][i].append(val)
# aggregate and write results
with aggregate_eval_path.open("w") as file:
aggregate_csv_writer = csv.writer(file)
aggregate_csv_writer.writerow(["category", *metric_names])
for category, agg_metrics in aggregate_results.items():
aggregate_csv_writer.writerow(
[
category,
*(
sum(metric) / len(metric) if metric else None
for metric in agg_metrics
),
]
)
if __name__ == "__main__":
if MULTI_TENANT:
raise ValueError("Multi-tenant is not supported currently")
@@ -245,122 +152,10 @@ def run_search_eval() -> None:
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
)
query_pairs = _load_query_pairs()
search_parameters = _load_search_parameters()
with get_session_with_current_tenant() as db_session:
multilingual_expansion = get_multilingual_expansion(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
rerank_settings = RerankingDetails.from_db_model(search_settings)
if search_parameters.skip_rerank:
logger.warning("Reranking is disabled, evaluation will not run")
elif rerank_settings.rerank_model_name is None:
raise ValueError(
"Reranking is enabled but no reranker is configured. "
"Please set the reranker in the admin panel search settings."
)
export_path = Path(search_parameters.export_folder)
search_result_file = export_path / "search_results.csv"
eval_result_file = export_path / "eval_results.csv"
with (
search_result_file.open("w") as search_file,
eval_result_file.open("w") as eval_file,
):
search_csv_writer = csv.writer(search_file)
eval_csv_writer = csv.writer(eval_file)
search_csv_writer.writerow(
["source", "query", "rank", "score", "doc_id", "chunk_id"]
)
eval_csv_writer.writerow(
[
"query",
"jaccard_similarity",
"missing_chunks_ratio",
"average_rank_change",
"jaccard_similarity_adj",
"missing_chunks_ratio_adj",
"average_rank_change_adj",
]
)
sum_metrics = [0.0] * 6
for orig_query, alt_query in query_pairs:
search_results = _search_one_query(
alt_query,
multilingual_expansion,
document_index,
db_session,
search_parameters,
)
for rank, result in enumerate(search_results):
search_csv_writer.writerow(
[
"search",
alt_query,
rank,
result.score,
result.document_id,
result.chunk_id,
]
)
if not search_parameters.skip_rerank:
rerank_results = _rerank_one_query(
orig_query, search_results, rerank_settings, search_parameters
)
for rank, result in enumerate(rerank_results):
search_csv_writer.writerow(
[
"rerank",
orig_query,
rank,
result.score,
result.document_id,
result.chunk_id,
]
)
metrics = _evaluate_one_query(
search_results, rerank_results, search_parameters
)
eval_csv_writer.writerow([orig_query, *metrics])
sum_metrics = [
sum_metric + metric
for sum_metric, metric in zip(sum_metrics, metrics)
]
logger.info(
f"Exported individual results to {search_result_file} and {eval_result_file}"
)
if not search_parameters.skip_rerank:
average_metrics = [metric / len(query_pairs) for metric in sum_metrics]
logger.info(f"Jaccard similarity: {average_metrics[0]}")
logger.info(f"Average missing chunks ratio: {average_metrics[1]}")
logger.info(f"Average rank change: {average_metrics[2]}")
logger.info(f"Jaccard similarity (adjusted): {average_metrics[3]}")
logger.info(f"Average missing chunks ratio (adjusted): {average_metrics[4]}")
logger.info(f"Average rank change (adjusted): {average_metrics[5]}")
aggregate_file = export_path / "aggregate_results.csv"
with aggregate_file.open("w") as file:
aggregate_csv_writer = csv.writer(file)
aggregate_csv_writer.writerow(
[
"jaccard_similarity",
"missing_chunks_ratio",
"average_rank_change",
"jaccard_similarity_adj",
"missing_chunks_ratio_adj",
"average_rank_change_adj",
]
)
aggregate_csv_writer.writerow(average_metrics)
logger.info(f"Exported aggregate results to {aggregate_file}")
if __name__ == "__main__":
run_search_eval()
try:
run_search_eval()
except Exception as e:
logger.error(f"Error running search evaluation: {e}")
raise e
finally:
SqlEngine.reset_engine()

View File

@@ -1,16 +1,16 @@
# Search Parameters, null means use default
HYBRID_ALPHA: null
HYBRID_ALPHA_KEYWORD: null
DOC_TIME_DECAY: null
NUM_RETURNED_HITS: 200 # Setting to a higher value will improve evaluation quality but increase reranking time
RANK_PROFILE: null
OFFSET: null
TITLE_CONTENT_RATIO: null
USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files
# Search Parameters
HYBRID_ALPHA: 0.5
HYBRID_ALPHA_KEYWORD: 0.4
DOC_TIME_DECAY: 0.5
NUM_RETURNED_HITS: 50 # Setting to a higher value will improve evaluation quality but increase reranking time
RANK_PROFILE: 'semantic'
OFFSET: 0
TITLE_CONTENT_RATIO: 0.1
USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files
# Evaluation parameters
SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results
EVAL_TOPK: 20 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation
SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results
EVAL_TOPK: 5 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation
# Export file, will export a csv file with the results and a json file with the parameters
EXPORT_FOLDER: "eval-%Y-%m-%d-%H-%M-%S"

View File

@@ -1,4 +0,0 @@
[
"What is Onyx?",
"How is Onyx enterprise edition different?"
]

View File

@@ -0,0 +1,22 @@
[
{
"question": "What is Onyx?",
"ground_truth": [
{
"doc_source": "Web",
"doc_link": "https://docs.onyx.app/more/use_cases/overview"
},
{
"doc_source": "Web",
"doc_link": "https://docs.onyx.app/more/use_cases/ai_platform"
}
],
"categories": [
"keyword",
"broad"
]
},
{
"question": "What is the meaning of life?"
}
]

View File

@@ -0,0 +1,75 @@
from datetime import datetime
from pathlib import Path
import yaml
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import HYBRID_ALPHA
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.utils.logger import setup_logger
logger = setup_logger(__name__)
class SearchEvalConfig(BaseModel):
hybrid_alpha: float
hybrid_alpha_keyword: float
doc_time_decay: float
num_returned_hits: int
rank_profile: QueryExpansionType
offset: int
title_content_ratio: float
user_email: str | None
skip_rerank: bool
eval_topk: int
export_folder: str
def load_config() -> SearchEvalConfig:
"""Loads the search evaluation configs from the config file."""
# open the config file
current_dir = Path(__file__).parent
config_path = current_dir / "search_eval_config.yaml"
if not config_path.exists():
raise FileNotFoundError(f"Search eval config file not found at {config_path}")
with config_path.open("r") as file:
config_raw = yaml.safe_load(file)
# create the export folder
export_folder = config_raw.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S")
export_folder = datetime.now().strftime(export_folder)
export_path = Path(export_folder)
export_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created export folder: {export_path}")
# create the config
config = SearchEvalConfig(
hybrid_alpha=config_raw.get("HYBRID_ALPHA", HYBRID_ALPHA),
hybrid_alpha_keyword=config_raw.get(
"HYBRID_ALPHA_KEYWORD", HYBRID_ALPHA_KEYWORD
),
doc_time_decay=config_raw.get("DOC_TIME_DECAY", DOC_TIME_DECAY),
num_returned_hits=config_raw.get("NUM_RETURNED_HITS", NUM_RETURNED_HITS),
rank_profile=config_raw.get("RANK_PROFILE", QueryExpansionType.SEMANTIC),
offset=config_raw.get("OFFSET", 0),
title_content_ratio=config_raw.get("TITLE_CONTENT_RATIO", TITLE_CONTENT_RATIO),
user_email=config_raw.get("USER_EMAIL"),
skip_rerank=config_raw.get("SKIP_RERANK", False),
eval_topk=config_raw.get("EVAL_TOPK", 5),
export_folder=export_folder,
)
logger.info(f"Using search parameters: {config}")
# export the config
config_file = export_path / "search_eval_config.yaml"
with config_file.open("w") as file:
config_dict = config.model_dump(mode="python")
config_dict["rank_profile"] = config.rank_profile.value
yaml.dump(config_dict, file, sort_keys=False)
logger.info(f"Exported config to {config_file}")
return config

View File

@@ -0,0 +1,166 @@
import json
from pathlib import Path
from typing import cast
from typing import Optional
from langgraph.types import StreamWriter
from pydantic import BaseModel
from pydantic import ValidationError
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.persona import get_persona_by_id
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.interfaces import LLM
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GroundTruth(BaseModel):
doc_source: str
doc_link: str
class TestQuery(BaseModel):
question: str
question_search: Optional[str] = None
ground_truth: list[GroundTruth] = []
categories: list[str] = []
def load_test_queries() -> list[TestQuery]:
"""
Loads the test queries from the test_queries.json file.
If `question_search` is missing, it will use the tool-calling LLM to generate it.
"""
# open test queries file
current_dir = Path(__file__).parent
test_queries_path = current_dir / "test_queries.json"
logger.info(f"Loading test queries from {test_queries_path}")
if not test_queries_path.exists():
raise FileNotFoundError(f"Test queries file not found at {test_queries_path}")
with test_queries_path.open("r") as f:
test_queries_raw: list[dict] = json.load(f)
# setup llm for question_search generation
with get_session_with_current_tenant() as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
llm, _ = get_llms_for_persona(persona)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchToolOverride()
tool_call_supported = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
# validate keys and generate question_search if missing
test_queries: list[TestQuery] = []
for query_raw in test_queries_raw:
try:
test_query = TestQuery(**query_raw)
except ValidationError as e:
logger.error(f"Incorrectly formatted query: {e}")
continue
if test_query.question_search is None:
test_query.question_search = _modify_one_query(
query=test_query.question,
llm=llm,
prompt_config=prompt_config,
tool=search_tool,
tool_call_supported=tool_call_supported,
)
test_queries.append(test_query)
return test_queries
def export_test_queries(test_queries: list[TestQuery], export_path: Path) -> None:
"""Exports the test queries to a JSON file."""
logger.info(f"Exporting test queries to {export_path}")
with export_path.open("w") as f:
json.dump(
[query.model_dump() for query in test_queries],
f,
indent=4,
)
class SearchToolOverride(SearchTool):
def __init__(self) -> None:
# do nothing, only class variables are required for the functions we call
pass
warned = False
def _modify_one_query(
query: str,
llm: LLM,
prompt_config: PromptConfig,
tool: SearchTool,
tool_call_supported: bool,
writer: StreamWriter = lambda _: None,
) -> str:
global warned
if not warned:
logger.warning(
"Generating question_search. If you do not save the question_search, "
"it will be generated again on the next run, potentially altering the search results."
)
warned = True
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=query,
prompt_config=prompt_config,
files=[],
single_message_history=None,
),
system_message=default_build_system_message(prompt_config, llm.config),
message_history=[],
llm_config=llm.config,
raw_user_query=query,
raw_user_uploaded_files=[],
single_message_history=None,
)
if tool_call_supported:
prompt = prompt_builder.build()
tool_definition = tool.tool_definition()
stream = llm.stream(
prompt=prompt,
tools=[tool_definition],
tool_choice="required",
structured_response_format=None,
)
tool_message = process_llm_stream(
messages=stream,
should_stream_answer=False,
writer=writer,
)
return (
tool_message.tool_calls[0]["args"]["query"]
if tool_message.tool_calls
else query
)
history = prompt_builder.get_message_history()
return cast(
dict[str, str],
tool.get_args_for_non_tool_calling_llm(
query=query,
history=history,
llm=llm,
force_run=True,
),
)["query"]

View File

@@ -0,0 +1,94 @@
from typing import Optional
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.context.search.models import InferenceChunk
from onyx.db.models import Document
from onyx.utils.logger import setup_logger
from tests.regression.search_quality.util_retrieve import group_by_documents
logger = setup_logger(__name__)
class Metrics(BaseModel):
# computed if ground truth is provided
ground_truth_ratio_topk: Optional[float] = None
ground_truth_avg_rank_delta: Optional[float] = None
# computed if reranked results are provided
soft_truth_ratio_topk: Optional[float] = None
soft_truth_avg_rank_delta: Optional[float] = None
metric_names = list(Metrics.model_fields.keys())
def get_corresponding_document(
doc_link: str, db_session: Session
) -> Optional[Document]:
"""Get the corresponding document from the database."""
doc_filter = db_session.query(Document).filter(Document.link == doc_link)
count = doc_filter.count()
if count == 0:
logger.warning(f"Could not find document with link {doc_link}, ignoring")
return None
if count > 1:
logger.warning(f"Found multiple documents with link {doc_link}, using first")
return doc_filter.first()
def evaluate_one_query(
search_chunks: list[InferenceChunk],
rerank_chunks: list[InferenceChunk],
true_documents: list[Document],
topk: int,
) -> Metrics:
"""Computes metrics for the search results, relative to the ground truth and reranked results."""
metrics_dict: dict[str, float] = {}
search_documents = group_by_documents(search_chunks)
search_ranks = {docid: rank for rank, docid in enumerate(search_documents)}
search_ranks_topk = {
docid: rank for rank, docid in enumerate(search_documents[:topk])
}
true_ranks = {doc.id: rank for rank, doc in enumerate(true_documents)}
if true_documents:
metrics_dict["ground_truth_ratio_topk"] = _compute_ratio(
search_ranks_topk, true_ranks
)
metrics_dict["ground_truth_avg_rank_delta"] = _compute_avg_rank_delta(
search_ranks, true_ranks
)
if rerank_chunks:
# build soft truth out of ground truth + reranked results, up to topk
soft_ranks = true_ranks
for docid in group_by_documents(rerank_chunks):
if len(soft_ranks) >= topk:
break
if docid not in soft_ranks:
soft_ranks[docid] = len(soft_ranks)
metrics_dict["soft_truth_ratio_topk"] = _compute_ratio(
search_ranks_topk, soft_ranks
)
metrics_dict["soft_truth_avg_rank_delta"] = _compute_avg_rank_delta(
search_ranks, soft_ranks
)
return Metrics(**metrics_dict)
def _compute_ratio(search_ranks: dict[str, int], true_ranks: dict[str, int]) -> float:
return len(set(search_ranks) & set(true_ranks)) / len(true_ranks)
def _compute_avg_rank_delta(
search_ranks: dict[str, int], true_ranks: dict[str, int]
) -> float:
out = len(search_ranks)
return sum(
abs(search_ranks.get(docid, out) - rank) for docid, rank in true_ranks.items()
) / len(true_ranks)

View File

@@ -0,0 +1,88 @@
from sqlalchemy.orm import Session
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import semantic_reranking
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.context.search.utils import remove_stop_words_and_punctuation
from onyx.document_index.interfaces import DocumentIndex
from onyx.utils.logger import setup_logger
from tests.regression.search_quality.util_config import SearchEvalConfig
logger = setup_logger(__name__)
def search_one_query(
question_search: str,
multilingual_expansion: list[str],
document_index: DocumentIndex,
db_session: Session,
config: SearchEvalConfig,
) -> list[InferenceChunk]:
"""Uses the search pipeline to retrieve relevant chunks for the given query."""
# the retrieval preprocessing is fairly stripped down so the query doesn't unexpectedly change
query_embedding = get_query_embedding(question_search, db_session)
all_query_terms = question_search.split()
processed_keywords = (
remove_stop_words_and_punctuation(all_query_terms)
if not multilingual_expansion
else all_query_terms
)
is_keyword = query_analysis(question_search)[0]
hybrid_alpha = config.hybrid_alpha_keyword if is_keyword else config.hybrid_alpha
access_control_list = ["PUBLIC"]
if config.user_email:
access_control_list.append(f"user_email:{config.user_email}")
filters = IndexFilters(
tags=[],
user_file_ids=[],
user_folder_ids=[],
access_control_list=access_control_list,
tenant_id=None,
)
results = document_index.hybrid_retrieval(
query=question_search,
query_embedding=query_embedding,
final_keywords=processed_keywords,
filters=filters,
hybrid_alpha=hybrid_alpha,
time_decay_multiplier=config.doc_time_decay,
num_to_retrieve=config.num_returned_hits,
ranking_profile_type=config.rank_profile,
offset=config.offset,
title_content_ratio=config.title_content_ratio,
)
return [result.to_inference_chunk() for result in results]
def rerank_one_query(
question: str,
retrieved_chunks: list[InferenceChunk],
rerank_settings: RerankingDetails,
) -> list[InferenceChunk]:
"""Uses the reranker to rerank the retrieved chunks for the given query."""
rerank_settings.num_rerank = len(retrieved_chunks)
return semantic_reranking(
query_str=question,
rerank_settings=rerank_settings,
chunks=retrieved_chunks,
rerank_metrics_callback=None,
)[0]
def group_by_documents(chunks: list[InferenceChunk]) -> list[str]:
"""Groups a sorted list of chunks into a sorted list of document ids."""
seen_docids: set[str] = set()
retrieved_docids: list[str] = []
for chunk in chunks:
if chunk.document_id not in seen_docids:
seen_docids.add(chunk.document_id)
retrieved_docids.append(chunk.document_id)
return retrieved_docids