mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Feat: Search Eval Testing Overhaul (provide ground truth, categorize query, etc.) (#4739)
* fix: autoflake & import order * docs: readme * fix: mypy * feat: eval * docs: readme * fix: oops forgot to remove comment * fix: typo * fix: rename var * updated default config * fix: config issue * oops * fix: black * fix: eval and config * feat: non tool calling query mod
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
|
||||
This Python script evaluates the search results for a list of queries.
|
||||
|
||||
Unlike the script in answer_quality, this script is much less customizable and runs using currently ingested documents, though it allows for quick testing of search parameters on a bunch of test queries that don't have well-defined answers.
|
||||
This script will likely get refactored in the future as an API endpoint.
|
||||
In the meanwhile, it is used to evaluate the search quality using locally ingested documents.
|
||||
The key differentiating factor with `answer_quality` is that it can evaluate results without explicit "ground truth" using the reranker as a reference.
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -25,32 +27,36 @@ This can be checked/modified by opening the admin panel, going to search setting
|
||||
cd path/to/onyx/backend/tests/regression/search_quality
|
||||
```
|
||||
|
||||
5. Copy `search_queries.json.template` to `search_queries.json` and add/remove test queries in it
|
||||
5. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The possible fields are:
|
||||
|
||||
6. Run `generate_search_queries.py` to generate the modified queries for the search pipeline
|
||||
- `question: str` the query
|
||||
- `question_search: Optional[str]` modified query specifically for the search step
|
||||
- `ground_truth: Optional[list[GroundTruth]]` a ranked list of expected search results with fields:
|
||||
- `doc_source: str` document source (e.g., Web, Drive, Linear), currently unused
|
||||
- `doc_link: str` link associated with document, used to find corresponding document in local index
|
||||
- `categories: Optional[list[str]]` list of categories, used to aggregate evaluation results
|
||||
|
||||
```
|
||||
python generate_search_queries.py
|
||||
```
|
||||
6. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters
|
||||
|
||||
7. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters
|
||||
8. Run `run_search_eval.py` to evaluate the search results against the reranked results
|
||||
7. Run `run_search_eval.py` to run the search and evaluate the search results
|
||||
|
||||
```
|
||||
python run_search_eval.py
|
||||
```
|
||||
|
||||
9. Repeat steps 7 and 8 to test and compare different search parameters
|
||||
8. Optionally, save the generated `test_queries.json` in the export folder to reuse the generated `question_search`, and rerun the search evaluation with alternative search parameters.
|
||||
|
||||
## Metrics
|
||||
- Jaccard Similarity: the ratio between the intersect and the union between the topk search and rerank results. Higher is better
|
||||
- Average Rank Change: The average absolute rank difference of the topk reranked chunks vs the entire search chunks. Lower is better
|
||||
- Average Missing Chunk Ratio: The number of chunks in the topk reranked chunks not in the topk search chunks, over topk. Lower is better
|
||||
There are two main metrics currently implemented:
|
||||
- ratio_topk: the ratio of documents in the comparison set that are in the topk search results (higher is better, 0-1)
|
||||
- avg_rank_delta: the average rank difference between the comparison set and search results (lower is better, 0-inf)
|
||||
|
||||
Note that all of these metrics are affected by very narrow search results.
|
||||
E.g., if topk is 20 but there is only 1 relevant document, the other 19 documents could be ordered arbitrarily, resulting in a lower score.
|
||||
Ratio topk gives a general idea on whether the most relevant documents are appearing first in the search results. Decreasing `eval_topk` will make this metric stricter, requiring relevant documents to appear in a narrow window.
|
||||
|
||||
Avg rank delta is another metric which can give insight on the performance of documents not in the topk search results. If none of the comparison documents are in the topk, `ratio_topk` will only show a 0, whereas `avg_rank_delta` will show a higher value the worse the search results gets.
|
||||
|
||||
To address this limitation, there are score adjusted versions of the metrics.
|
||||
The score adjusted version does not use a fixed topk, but computes the optimum topk based on the rerank scores.
|
||||
This generally works in determining how many documents are relevant, although note that this approach isn't perfect.
|
||||
Furthermore, there are two versions of the metrics: ground truth, and soft truth.
|
||||
|
||||
The ground truth includes documents explicitly listed as relevant in the test dataset. The ground truth metrics will only be computed if a ground truth set is provided for the question and exists in the index.
|
||||
|
||||
The soft truth is built on top of the ground truth (if provided), filling the remaining entries with results from the reranker. The soft truth metrics will only be computed if `skip_rerank` is false. Computing the soft truth metric can be extremely slow, especially for large `num_returned_hits`. However, it can provide a good basis when there are many relevant documents in no particular order, or for running quick tests without explicitly having to mention which documents are relevant.
|
@@ -1,133 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _load_queries() -> list[str]:
|
||||
current_dir = Path(__file__).parent
|
||||
search_queries_path = current_dir / "search_queries.json"
|
||||
if not search_queries_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Search queries file not found at {search_queries_path}"
|
||||
)
|
||||
with search_queries_path.open("r") as file:
|
||||
return json.load(file)
|
||||
|
||||
|
||||
def _modify_one_query(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
tool_definition: dict,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> str:
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=query,
|
||||
prompt_config=prompt_config,
|
||||
files=[],
|
||||
single_message_history=None,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
message_history=[],
|
||||
llm_config=llm.config,
|
||||
raw_user_query=query,
|
||||
raw_user_uploaded_files=[],
|
||||
single_message_history=None,
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
stream = llm.stream(
|
||||
prompt=prompt,
|
||||
tools=[tool_definition],
|
||||
tool_choice="required",
|
||||
structured_response_format=None,
|
||||
)
|
||||
tool_message = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=False,
|
||||
writer=writer,
|
||||
)
|
||||
return (
|
||||
tool_message.tool_calls[0]["args"]["query"]
|
||||
if tool_message.tool_calls
|
||||
else query
|
||||
)
|
||||
|
||||
|
||||
class SearchToolOverride(SearchTool):
|
||||
def __init__(self) -> None:
|
||||
# do nothing, the tool_definition function doesn't require variables to be initialized
|
||||
pass
|
||||
|
||||
|
||||
def generate_search_queries() -> None:
|
||||
if MULTI_TENANT:
|
||||
raise ValueError("Multi-tenant is not supported currently")
|
||||
|
||||
SqlEngine.init_engine(
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
)
|
||||
|
||||
queries = _load_queries()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
tool_definition = SearchToolOverride().tool_definition()
|
||||
|
||||
tool_call_supported = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
|
||||
if tool_call_supported:
|
||||
logger.info(
|
||||
"Tool calling is supported for the current model. Modifying queries."
|
||||
)
|
||||
modified_queries = [
|
||||
_modify_one_query(
|
||||
query=query,
|
||||
llm=llm,
|
||||
prompt_config=prompt_config,
|
||||
tool_definition=tool_definition,
|
||||
)
|
||||
for query in queries
|
||||
]
|
||||
else:
|
||||
logger.warning(
|
||||
"Tool calling is not supported for the current model. "
|
||||
"Using the original queries."
|
||||
)
|
||||
modified_queries = queries
|
||||
|
||||
with open("search_queries_modified.json", "w") as file:
|
||||
json.dump(modified_queries, file, indent=4)
|
||||
|
||||
logger.info("Exported modified queries to search_queries_modified.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
generate_search_queries()
|
@@ -1,242 +1,149 @@
|
||||
import csv
|
||||
import json
|
||||
from bisect import bisect_left
|
||||
from datetime import datetime
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.postprocessing.postprocessing import semantic_reranking
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.context.search.utils import remove_stop_words_and_punctuation
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from tests.regression.search_quality.util_config import load_config
|
||||
from tests.regression.search_quality.util_data import export_test_queries
|
||||
from tests.regression.search_quality.util_data import load_test_queries
|
||||
from tests.regression.search_quality.util_eval import evaluate_one_query
|
||||
from tests.regression.search_quality.util_eval import get_corresponding_document
|
||||
from tests.regression.search_quality.util_eval import metric_names
|
||||
from tests.regression.search_quality.util_retrieve import rerank_one_query
|
||||
from tests.regression.search_quality.util_retrieve import search_one_query
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class SearchEvalParameters(BaseModel):
|
||||
hybrid_alpha: float
|
||||
hybrid_alpha_keyword: float
|
||||
doc_time_decay: float
|
||||
num_returned_hits: int
|
||||
rank_profile: QueryExpansionType
|
||||
offset: int
|
||||
title_content_ratio: float
|
||||
user_email: str | None
|
||||
skip_rerank: bool
|
||||
eval_topk: int
|
||||
export_folder: str
|
||||
|
||||
|
||||
def _load_search_parameters() -> SearchEvalParameters:
|
||||
current_dir = Path(__file__).parent
|
||||
config_path = current_dir / "search_eval_config.yaml"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Search eval config file not found at {config_path}")
|
||||
with config_path.open("r") as file:
|
||||
config = yaml.safe_load(file)
|
||||
|
||||
export_folder = config.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S")
|
||||
export_folder = datetime.now().strftime(export_folder)
|
||||
|
||||
export_path = Path(export_folder)
|
||||
export_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created export folder: {export_path}")
|
||||
|
||||
search_parameters = SearchEvalParameters(
|
||||
hybrid_alpha=config.get("HYBRID_ALPHA") or HYBRID_ALPHA,
|
||||
hybrid_alpha_keyword=config.get("HYBRID_ALPHA_KEYWORD") or HYBRID_ALPHA_KEYWORD,
|
||||
doc_time_decay=config.get("DOC_TIME_DECAY") or DOC_TIME_DECAY,
|
||||
num_returned_hits=config.get("NUM_RETURNED_HITS") or NUM_RETURNED_HITS,
|
||||
rank_profile=config.get("RANK_PROFILE") or QueryExpansionType.SEMANTIC,
|
||||
offset=config.get("OFFSET") or 0,
|
||||
title_content_ratio=config.get("TITLE_CONTENT_RATIO") or TITLE_CONTENT_RATIO,
|
||||
user_email=config.get("USER_EMAIL"),
|
||||
skip_rerank=config.get("SKIP_RERANK", False),
|
||||
eval_topk=config.get("EVAL_TOPK", 20),
|
||||
export_folder=export_folder,
|
||||
)
|
||||
logger.info(f"Using search parameters: {search_parameters}")
|
||||
|
||||
config_file = export_path / "search_eval_config.yaml"
|
||||
with config_file.open("w") as file:
|
||||
search_parameters_dict = search_parameters.model_dump(mode="python")
|
||||
search_parameters_dict["rank_profile"] = search_parameters.rank_profile.value
|
||||
yaml.dump(search_parameters_dict, file, sort_keys=False)
|
||||
logger.info(f"Exported config to {config_file}")
|
||||
|
||||
return search_parameters
|
||||
|
||||
|
||||
def _load_query_pairs() -> list[tuple[str, str]]:
|
||||
current_dir = Path(__file__).parent
|
||||
search_queries_path = current_dir / "search_queries.json"
|
||||
if not search_queries_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Search queries file not found at {search_queries_path}"
|
||||
)
|
||||
with search_queries_path.open("r") as file:
|
||||
orig_queries = json.load(file)
|
||||
|
||||
alt_queries_path = current_dir / "search_queries_modified.json"
|
||||
if not alt_queries_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Modified search queries file not found at {alt_queries_path}. "
|
||||
"Try running generate_search_queries.py."
|
||||
)
|
||||
with alt_queries_path.open("r") as file:
|
||||
alt_queries = json.load(file)
|
||||
|
||||
if len(orig_queries) != len(alt_queries):
|
||||
raise ValueError(
|
||||
"Number of original and modified queries must be the same. "
|
||||
"Try running generate_search_queries.py again."
|
||||
)
|
||||
|
||||
return list(zip(orig_queries, alt_queries))
|
||||
|
||||
|
||||
def _search_one_query(
|
||||
alt_query: str,
|
||||
multilingual_expansion: list[str],
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
search_parameters: SearchEvalParameters,
|
||||
) -> list[InferenceChunk]:
|
||||
# the retrieval preprocessing is fairly stripped down so the query doesn't unexpectly change
|
||||
query_embedding = get_query_embedding(alt_query, db_session)
|
||||
|
||||
all_query_terms = alt_query.split()
|
||||
processed_keywords = (
|
||||
remove_stop_words_and_punctuation(all_query_terms)
|
||||
if not multilingual_expansion
|
||||
else all_query_terms
|
||||
)
|
||||
|
||||
is_keyword = query_analysis(alt_query)[0]
|
||||
hybrid_alpha = (
|
||||
search_parameters.hybrid_alpha_keyword
|
||||
if is_keyword
|
||||
else search_parameters.hybrid_alpha
|
||||
)
|
||||
|
||||
access_control_list = ["PUBLIC"]
|
||||
if search_parameters.user_email:
|
||||
access_control_list.append(f"user_email:{search_parameters.user_email}")
|
||||
filters = IndexFilters(
|
||||
tags=[],
|
||||
user_file_ids=[],
|
||||
user_folder_ids=[],
|
||||
access_control_list=access_control_list,
|
||||
tenant_id=None,
|
||||
)
|
||||
|
||||
results = document_index.hybrid_retrieval(
|
||||
query=alt_query,
|
||||
query_embedding=query_embedding,
|
||||
final_keywords=processed_keywords,
|
||||
filters=filters,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
time_decay_multiplier=search_parameters.doc_time_decay,
|
||||
num_to_retrieve=search_parameters.num_returned_hits,
|
||||
ranking_profile_type=search_parameters.rank_profile,
|
||||
offset=search_parameters.offset,
|
||||
title_content_ratio=search_parameters.title_content_ratio,
|
||||
)
|
||||
|
||||
return [result.to_inference_chunk() for result in results]
|
||||
|
||||
|
||||
def _rerank_one_query(
|
||||
orig_query: str,
|
||||
retrieved_chunks: list[InferenceChunk],
|
||||
rerank_settings: RerankingDetails,
|
||||
search_parameters: SearchEvalParameters,
|
||||
) -> list[InferenceChunk]:
|
||||
assert not search_parameters.skip_rerank, "Reranking is disabled"
|
||||
return semantic_reranking(
|
||||
query_str=orig_query,
|
||||
rerank_settings=rerank_settings,
|
||||
chunks=retrieved_chunks,
|
||||
rerank_metrics_callback=None,
|
||||
)[0]
|
||||
|
||||
|
||||
def _evaluate_one_query(
|
||||
search_results: list[InferenceChunk],
|
||||
rerank_results: list[InferenceChunk],
|
||||
search_parameters: SearchEvalParameters,
|
||||
) -> list[float]:
|
||||
search_topk = search_results[: search_parameters.eval_topk]
|
||||
rerank_topk = rerank_results[: search_parameters.eval_topk]
|
||||
|
||||
# get the score adjusted topk (topk where the score is at least 50% of the top score)
|
||||
# could be more than topk if top scores are similar, may or may not be a good thing
|
||||
# can change by swapping rerank_results with rerank_topk in bisect
|
||||
adj_topk = bisect_left(
|
||||
rerank_results,
|
||||
-0.5 * cast(float, rerank_results[0].score),
|
||||
key=lambda x: -cast(float, x.score),
|
||||
)
|
||||
search_adj_topk = search_results[:adj_topk]
|
||||
rerank_adj_topk = rerank_results[:adj_topk]
|
||||
|
||||
# compute metrics
|
||||
search_ranks = {chunk.unique_id: rank for rank, chunk in enumerate(search_results)}
|
||||
return [
|
||||
*_compute_jaccard_and_missing_chunks_ratio(search_topk, rerank_topk),
|
||||
_compute_average_rank_change(search_ranks, rerank_topk),
|
||||
# score adjusted metrics
|
||||
*_compute_jaccard_and_missing_chunks_ratio(search_adj_topk, rerank_adj_topk),
|
||||
_compute_average_rank_change(search_ranks, rerank_adj_topk),
|
||||
]
|
||||
|
||||
|
||||
def _compute_jaccard_and_missing_chunks_ratio(
|
||||
search_topk: list[InferenceChunk], rerank_topk: list[InferenceChunk]
|
||||
) -> tuple[float, float]:
|
||||
search_chunkids = {chunk.unique_id for chunk in search_topk}
|
||||
rerank_chunkids = {chunk.unique_id for chunk in rerank_topk}
|
||||
jaccard_similarity = len(search_chunkids & rerank_chunkids) / len(
|
||||
search_chunkids | rerank_chunkids
|
||||
)
|
||||
missing_chunks_ratio = len(rerank_chunkids - search_chunkids) / len(rerank_chunkids)
|
||||
return jaccard_similarity, missing_chunks_ratio
|
||||
|
||||
|
||||
def _compute_average_rank_change(
|
||||
search_ranks: dict[str, int], rerank_topk: list[InferenceChunk]
|
||||
) -> float:
|
||||
rank_changes = [
|
||||
abs(search_ranks[chunk.unique_id] - rerank_rank)
|
||||
for rerank_rank, chunk in enumerate(rerank_topk)
|
||||
]
|
||||
return sum(rank_changes) / len(rank_changes)
|
||||
|
||||
|
||||
def run_search_eval() -> None:
|
||||
config = load_config()
|
||||
test_queries = load_test_queries()
|
||||
|
||||
# export related
|
||||
export_path = Path(config.export_folder)
|
||||
export_test_queries(test_queries, export_path / "test_queries.json")
|
||||
search_result_path = export_path / "search_results.csv"
|
||||
eval_path = export_path / "eval_results.csv"
|
||||
aggregate_eval_path = export_path / "aggregate_eval.csv"
|
||||
aggregate_results: dict[str, list[list[float]]] = defaultdict(
|
||||
lambda: [[] for _ in metric_names]
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
if config.skip_rerank:
|
||||
logger.warning("Reranking is disabled, evaluation will not run")
|
||||
elif rerank_settings.rerank_model_name is None:
|
||||
raise ValueError(
|
||||
"Reranking is enabled but no reranker is configured. "
|
||||
"Please set the reranker in the admin panel search settings."
|
||||
)
|
||||
|
||||
# run search and evaluate
|
||||
logger.info(
|
||||
"Running search and evaluation... "
|
||||
f"Individual search and evaluation results will be saved to {search_result_path} and {eval_path}"
|
||||
)
|
||||
with (
|
||||
search_result_path.open("w") as search_file,
|
||||
eval_path.open("w") as eval_file,
|
||||
):
|
||||
search_csv_writer = csv.writer(search_file)
|
||||
eval_csv_writer = csv.writer(eval_file)
|
||||
search_csv_writer.writerow(
|
||||
["source", "query", "rank", "score", "doc_id", "chunk_id"]
|
||||
)
|
||||
eval_csv_writer.writerow(["query", *metric_names])
|
||||
|
||||
for query in test_queries:
|
||||
# search and write results
|
||||
assert query.question_search is not None
|
||||
search_chunks = search_one_query(
|
||||
query.question_search,
|
||||
multilingual_expansion,
|
||||
document_index,
|
||||
db_session,
|
||||
config,
|
||||
)
|
||||
for rank, result in enumerate(search_chunks):
|
||||
search_csv_writer.writerow(
|
||||
[
|
||||
"search",
|
||||
query.question_search,
|
||||
rank,
|
||||
result.score,
|
||||
result.document_id,
|
||||
result.chunk_id,
|
||||
]
|
||||
)
|
||||
|
||||
rerank_chunks = []
|
||||
if not config.skip_rerank:
|
||||
# rerank and write results
|
||||
rerank_chunks = rerank_one_query(
|
||||
query.question, search_chunks, rerank_settings
|
||||
)
|
||||
for rank, result in enumerate(rerank_chunks):
|
||||
search_csv_writer.writerow(
|
||||
[
|
||||
"rerank",
|
||||
query.question,
|
||||
rank,
|
||||
result.score,
|
||||
result.document_id,
|
||||
result.chunk_id,
|
||||
]
|
||||
)
|
||||
|
||||
# evaluate and write results
|
||||
truth_documents = [
|
||||
doc
|
||||
for truth in query.ground_truth
|
||||
if (doc := get_corresponding_document(truth.doc_link, db_session))
|
||||
]
|
||||
metrics = evaluate_one_query(
|
||||
search_chunks, rerank_chunks, truth_documents, config.eval_topk
|
||||
)
|
||||
metric_vals = [getattr(metrics, field) for field in metric_names]
|
||||
eval_csv_writer.writerow([query.question, *metric_vals])
|
||||
|
||||
# add to aggregation
|
||||
for category in ["all"] + query.categories:
|
||||
for i, val in enumerate(metric_vals):
|
||||
if val is not None:
|
||||
aggregate_results[category][i].append(val)
|
||||
|
||||
# aggregate and write results
|
||||
with aggregate_eval_path.open("w") as file:
|
||||
aggregate_csv_writer = csv.writer(file)
|
||||
aggregate_csv_writer.writerow(["category", *metric_names])
|
||||
|
||||
for category, agg_metrics in aggregate_results.items():
|
||||
aggregate_csv_writer.writerow(
|
||||
[
|
||||
category,
|
||||
*(
|
||||
sum(metric) / len(metric) if metric else None
|
||||
for metric in agg_metrics
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if MULTI_TENANT:
|
||||
raise ValueError("Multi-tenant is not supported currently")
|
||||
|
||||
@@ -245,122 +152,10 @@ def run_search_eval() -> None:
|
||||
max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW,
|
||||
)
|
||||
|
||||
query_pairs = _load_query_pairs()
|
||||
search_parameters = _load_search_parameters()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
if search_parameters.skip_rerank:
|
||||
logger.warning("Reranking is disabled, evaluation will not run")
|
||||
elif rerank_settings.rerank_model_name is None:
|
||||
raise ValueError(
|
||||
"Reranking is enabled but no reranker is configured. "
|
||||
"Please set the reranker in the admin panel search settings."
|
||||
)
|
||||
|
||||
export_path = Path(search_parameters.export_folder)
|
||||
search_result_file = export_path / "search_results.csv"
|
||||
eval_result_file = export_path / "eval_results.csv"
|
||||
with (
|
||||
search_result_file.open("w") as search_file,
|
||||
eval_result_file.open("w") as eval_file,
|
||||
):
|
||||
search_csv_writer = csv.writer(search_file)
|
||||
eval_csv_writer = csv.writer(eval_file)
|
||||
search_csv_writer.writerow(
|
||||
["source", "query", "rank", "score", "doc_id", "chunk_id"]
|
||||
)
|
||||
eval_csv_writer.writerow(
|
||||
[
|
||||
"query",
|
||||
"jaccard_similarity",
|
||||
"missing_chunks_ratio",
|
||||
"average_rank_change",
|
||||
"jaccard_similarity_adj",
|
||||
"missing_chunks_ratio_adj",
|
||||
"average_rank_change_adj",
|
||||
]
|
||||
)
|
||||
|
||||
sum_metrics = [0.0] * 6
|
||||
for orig_query, alt_query in query_pairs:
|
||||
search_results = _search_one_query(
|
||||
alt_query,
|
||||
multilingual_expansion,
|
||||
document_index,
|
||||
db_session,
|
||||
search_parameters,
|
||||
)
|
||||
for rank, result in enumerate(search_results):
|
||||
search_csv_writer.writerow(
|
||||
[
|
||||
"search",
|
||||
alt_query,
|
||||
rank,
|
||||
result.score,
|
||||
result.document_id,
|
||||
result.chunk_id,
|
||||
]
|
||||
)
|
||||
|
||||
if not search_parameters.skip_rerank:
|
||||
rerank_results = _rerank_one_query(
|
||||
orig_query, search_results, rerank_settings, search_parameters
|
||||
)
|
||||
for rank, result in enumerate(rerank_results):
|
||||
search_csv_writer.writerow(
|
||||
[
|
||||
"rerank",
|
||||
orig_query,
|
||||
rank,
|
||||
result.score,
|
||||
result.document_id,
|
||||
result.chunk_id,
|
||||
]
|
||||
)
|
||||
|
||||
metrics = _evaluate_one_query(
|
||||
search_results, rerank_results, search_parameters
|
||||
)
|
||||
eval_csv_writer.writerow([orig_query, *metrics])
|
||||
sum_metrics = [
|
||||
sum_metric + metric
|
||||
for sum_metric, metric in zip(sum_metrics, metrics)
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Exported individual results to {search_result_file} and {eval_result_file}"
|
||||
)
|
||||
|
||||
if not search_parameters.skip_rerank:
|
||||
average_metrics = [metric / len(query_pairs) for metric in sum_metrics]
|
||||
logger.info(f"Jaccard similarity: {average_metrics[0]}")
|
||||
logger.info(f"Average missing chunks ratio: {average_metrics[1]}")
|
||||
logger.info(f"Average rank change: {average_metrics[2]}")
|
||||
logger.info(f"Jaccard similarity (adjusted): {average_metrics[3]}")
|
||||
logger.info(f"Average missing chunks ratio (adjusted): {average_metrics[4]}")
|
||||
logger.info(f"Average rank change (adjusted): {average_metrics[5]}")
|
||||
|
||||
aggregate_file = export_path / "aggregate_results.csv"
|
||||
with aggregate_file.open("w") as file:
|
||||
aggregate_csv_writer = csv.writer(file)
|
||||
aggregate_csv_writer.writerow(
|
||||
[
|
||||
"jaccard_similarity",
|
||||
"missing_chunks_ratio",
|
||||
"average_rank_change",
|
||||
"jaccard_similarity_adj",
|
||||
"missing_chunks_ratio_adj",
|
||||
"average_rank_change_adj",
|
||||
]
|
||||
)
|
||||
aggregate_csv_writer.writerow(average_metrics)
|
||||
logger.info(f"Exported aggregate results to {aggregate_file}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_search_eval()
|
||||
try:
|
||||
run_search_eval()
|
||||
except Exception as e:
|
||||
logger.error(f"Error running search evaluation: {e}")
|
||||
raise e
|
||||
finally:
|
||||
SqlEngine.reset_engine()
|
||||
|
@@ -1,16 +1,16 @@
|
||||
# Search Parameters, null means use default
|
||||
HYBRID_ALPHA: null
|
||||
HYBRID_ALPHA_KEYWORD: null
|
||||
DOC_TIME_DECAY: null
|
||||
NUM_RETURNED_HITS: 200 # Setting to a higher value will improve evaluation quality but increase reranking time
|
||||
RANK_PROFILE: null
|
||||
OFFSET: null
|
||||
TITLE_CONTENT_RATIO: null
|
||||
USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files
|
||||
# Search Parameters
|
||||
HYBRID_ALPHA: 0.5
|
||||
HYBRID_ALPHA_KEYWORD: 0.4
|
||||
DOC_TIME_DECAY: 0.5
|
||||
NUM_RETURNED_HITS: 50 # Setting to a higher value will improve evaluation quality but increase reranking time
|
||||
RANK_PROFILE: 'semantic'
|
||||
OFFSET: 0
|
||||
TITLE_CONTENT_RATIO: 0.1
|
||||
USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files
|
||||
|
||||
# Evaluation parameters
|
||||
SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results
|
||||
EVAL_TOPK: 20 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation
|
||||
SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results
|
||||
EVAL_TOPK: 5 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation
|
||||
|
||||
# Export file, will export a csv file with the results and a json file with the parameters
|
||||
EXPORT_FOLDER: "eval-%Y-%m-%d-%H-%M-%S"
|
||||
|
@@ -1,4 +0,0 @@
|
||||
[
|
||||
"What is Onyx?",
|
||||
"How is Onyx enterprise edition different?"
|
||||
]
|
@@ -0,0 +1,22 @@
|
||||
[
|
||||
{
|
||||
"question": "What is Onyx?",
|
||||
"ground_truth": [
|
||||
{
|
||||
"doc_source": "Web",
|
||||
"doc_link": "https://docs.onyx.app/more/use_cases/overview"
|
||||
},
|
||||
{
|
||||
"doc_source": "Web",
|
||||
"doc_link": "https://docs.onyx.app/more/use_cases/ai_platform"
|
||||
}
|
||||
],
|
||||
"categories": [
|
||||
"keyword",
|
||||
"broad"
|
||||
]
|
||||
},
|
||||
{
|
||||
"question": "What is the meaning of life?"
|
||||
}
|
||||
]
|
75
backend/tests/regression/search_quality/util_config.py
Normal file
75
backend/tests/regression/search_quality/util_config.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class SearchEvalConfig(BaseModel):
|
||||
hybrid_alpha: float
|
||||
hybrid_alpha_keyword: float
|
||||
doc_time_decay: float
|
||||
num_returned_hits: int
|
||||
rank_profile: QueryExpansionType
|
||||
offset: int
|
||||
title_content_ratio: float
|
||||
user_email: str | None
|
||||
skip_rerank: bool
|
||||
eval_topk: int
|
||||
export_folder: str
|
||||
|
||||
|
||||
def load_config() -> SearchEvalConfig:
|
||||
"""Loads the search evaluation configs from the config file."""
|
||||
# open the config file
|
||||
current_dir = Path(__file__).parent
|
||||
config_path = current_dir / "search_eval_config.yaml"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Search eval config file not found at {config_path}")
|
||||
with config_path.open("r") as file:
|
||||
config_raw = yaml.safe_load(file)
|
||||
|
||||
# create the export folder
|
||||
export_folder = config_raw.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S")
|
||||
export_folder = datetime.now().strftime(export_folder)
|
||||
export_path = Path(export_folder)
|
||||
export_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created export folder: {export_path}")
|
||||
|
||||
# create the config
|
||||
config = SearchEvalConfig(
|
||||
hybrid_alpha=config_raw.get("HYBRID_ALPHA", HYBRID_ALPHA),
|
||||
hybrid_alpha_keyword=config_raw.get(
|
||||
"HYBRID_ALPHA_KEYWORD", HYBRID_ALPHA_KEYWORD
|
||||
),
|
||||
doc_time_decay=config_raw.get("DOC_TIME_DECAY", DOC_TIME_DECAY),
|
||||
num_returned_hits=config_raw.get("NUM_RETURNED_HITS", NUM_RETURNED_HITS),
|
||||
rank_profile=config_raw.get("RANK_PROFILE", QueryExpansionType.SEMANTIC),
|
||||
offset=config_raw.get("OFFSET", 0),
|
||||
title_content_ratio=config_raw.get("TITLE_CONTENT_RATIO", TITLE_CONTENT_RATIO),
|
||||
user_email=config_raw.get("USER_EMAIL"),
|
||||
skip_rerank=config_raw.get("SKIP_RERANK", False),
|
||||
eval_topk=config_raw.get("EVAL_TOPK", 5),
|
||||
export_folder=export_folder,
|
||||
)
|
||||
logger.info(f"Using search parameters: {config}")
|
||||
|
||||
# export the config
|
||||
config_file = export_path / "search_eval_config.yaml"
|
||||
with config_file.open("w") as file:
|
||||
config_dict = config.model_dump(mode="python")
|
||||
config_dict["rank_profile"] = config.rank_profile.value
|
||||
yaml.dump(config_dict, file, sort_keys=False)
|
||||
logger.info(f"Exported config to {config_file}")
|
||||
|
||||
return config
|
166
backend/tests/regression/search_quality/util_data.py
Normal file
166
backend/tests/regression/search_quality/util_data.py
Normal file
@@ -0,0 +1,166 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GroundTruth(BaseModel):
|
||||
doc_source: str
|
||||
doc_link: str
|
||||
|
||||
|
||||
class TestQuery(BaseModel):
|
||||
question: str
|
||||
question_search: Optional[str] = None
|
||||
ground_truth: list[GroundTruth] = []
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
def load_test_queries() -> list[TestQuery]:
|
||||
"""
|
||||
Loads the test queries from the test_queries.json file.
|
||||
If `question_search` is missing, it will use the tool-calling LLM to generate it.
|
||||
"""
|
||||
# open test queries file
|
||||
current_dir = Path(__file__).parent
|
||||
test_queries_path = current_dir / "test_queries.json"
|
||||
logger.info(f"Loading test queries from {test_queries_path}")
|
||||
if not test_queries_path.exists():
|
||||
raise FileNotFoundError(f"Test queries file not found at {test_queries_path}")
|
||||
with test_queries_path.open("r") as f:
|
||||
test_queries_raw: list[dict] = json.load(f)
|
||||
|
||||
# setup llm for question_search generation
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
search_tool = SearchToolOverride()
|
||||
|
||||
tool_call_supported = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
|
||||
# validate keys and generate question_search if missing
|
||||
test_queries: list[TestQuery] = []
|
||||
for query_raw in test_queries_raw:
|
||||
try:
|
||||
test_query = TestQuery(**query_raw)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Incorrectly formatted query: {e}")
|
||||
continue
|
||||
|
||||
if test_query.question_search is None:
|
||||
test_query.question_search = _modify_one_query(
|
||||
query=test_query.question,
|
||||
llm=llm,
|
||||
prompt_config=prompt_config,
|
||||
tool=search_tool,
|
||||
tool_call_supported=tool_call_supported,
|
||||
)
|
||||
test_queries.append(test_query)
|
||||
|
||||
return test_queries
|
||||
|
||||
|
||||
def export_test_queries(test_queries: list[TestQuery], export_path: Path) -> None:
|
||||
"""Exports the test queries to a JSON file."""
|
||||
logger.info(f"Exporting test queries to {export_path}")
|
||||
with export_path.open("w") as f:
|
||||
json.dump(
|
||||
[query.model_dump() for query in test_queries],
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
|
||||
class SearchToolOverride(SearchTool):
|
||||
def __init__(self) -> None:
|
||||
# do nothing, only class variables are required for the functions we call
|
||||
pass
|
||||
|
||||
|
||||
warned = False
|
||||
|
||||
|
||||
def _modify_one_query(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
tool: SearchTool,
|
||||
tool_call_supported: bool,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> str:
|
||||
global warned
|
||||
if not warned:
|
||||
logger.warning(
|
||||
"Generating question_search. If you do not save the question_search, "
|
||||
"it will be generated again on the next run, potentially altering the search results."
|
||||
)
|
||||
warned = True
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=query,
|
||||
prompt_config=prompt_config,
|
||||
files=[],
|
||||
single_message_history=None,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
message_history=[],
|
||||
llm_config=llm.config,
|
||||
raw_user_query=query,
|
||||
raw_user_uploaded_files=[],
|
||||
single_message_history=None,
|
||||
)
|
||||
|
||||
if tool_call_supported:
|
||||
prompt = prompt_builder.build()
|
||||
tool_definition = tool.tool_definition()
|
||||
stream = llm.stream(
|
||||
prompt=prompt,
|
||||
tools=[tool_definition],
|
||||
tool_choice="required",
|
||||
structured_response_format=None,
|
||||
)
|
||||
tool_message = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=False,
|
||||
writer=writer,
|
||||
)
|
||||
return (
|
||||
tool_message.tool_calls[0]["args"]["query"]
|
||||
if tool_message.tool_calls
|
||||
else query
|
||||
)
|
||||
|
||||
history = prompt_builder.get_message_history()
|
||||
return cast(
|
||||
dict[str, str],
|
||||
tool.get_args_for_non_tool_calling_llm(
|
||||
query=query,
|
||||
history=history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
),
|
||||
)["query"]
|
94
backend/tests/regression/search_quality/util_eval.py
Normal file
94
backend/tests/regression/search_quality/util_eval.py
Normal file
@@ -0,0 +1,94 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.regression.search_quality.util_retrieve import group_by_documents
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class Metrics(BaseModel):
|
||||
# computed if ground truth is provided
|
||||
ground_truth_ratio_topk: Optional[float] = None
|
||||
ground_truth_avg_rank_delta: Optional[float] = None
|
||||
|
||||
# computed if reranked results are provided
|
||||
soft_truth_ratio_topk: Optional[float] = None
|
||||
soft_truth_avg_rank_delta: Optional[float] = None
|
||||
|
||||
|
||||
metric_names = list(Metrics.model_fields.keys())
|
||||
|
||||
|
||||
def get_corresponding_document(
|
||||
doc_link: str, db_session: Session
|
||||
) -> Optional[Document]:
|
||||
"""Get the corresponding document from the database."""
|
||||
doc_filter = db_session.query(Document).filter(Document.link == doc_link)
|
||||
count = doc_filter.count()
|
||||
if count == 0:
|
||||
logger.warning(f"Could not find document with link {doc_link}, ignoring")
|
||||
return None
|
||||
if count > 1:
|
||||
logger.warning(f"Found multiple documents with link {doc_link}, using first")
|
||||
return doc_filter.first()
|
||||
|
||||
|
||||
def evaluate_one_query(
|
||||
search_chunks: list[InferenceChunk],
|
||||
rerank_chunks: list[InferenceChunk],
|
||||
true_documents: list[Document],
|
||||
topk: int,
|
||||
) -> Metrics:
|
||||
"""Computes metrics for the search results, relative to the ground truth and reranked results."""
|
||||
metrics_dict: dict[str, float] = {}
|
||||
|
||||
search_documents = group_by_documents(search_chunks)
|
||||
search_ranks = {docid: rank for rank, docid in enumerate(search_documents)}
|
||||
search_ranks_topk = {
|
||||
docid: rank for rank, docid in enumerate(search_documents[:topk])
|
||||
}
|
||||
true_ranks = {doc.id: rank for rank, doc in enumerate(true_documents)}
|
||||
|
||||
if true_documents:
|
||||
metrics_dict["ground_truth_ratio_topk"] = _compute_ratio(
|
||||
search_ranks_topk, true_ranks
|
||||
)
|
||||
metrics_dict["ground_truth_avg_rank_delta"] = _compute_avg_rank_delta(
|
||||
search_ranks, true_ranks
|
||||
)
|
||||
|
||||
if rerank_chunks:
|
||||
# build soft truth out of ground truth + reranked results, up to topk
|
||||
soft_ranks = true_ranks
|
||||
for docid in group_by_documents(rerank_chunks):
|
||||
if len(soft_ranks) >= topk:
|
||||
break
|
||||
if docid not in soft_ranks:
|
||||
soft_ranks[docid] = len(soft_ranks)
|
||||
|
||||
metrics_dict["soft_truth_ratio_topk"] = _compute_ratio(
|
||||
search_ranks_topk, soft_ranks
|
||||
)
|
||||
metrics_dict["soft_truth_avg_rank_delta"] = _compute_avg_rank_delta(
|
||||
search_ranks, soft_ranks
|
||||
)
|
||||
|
||||
return Metrics(**metrics_dict)
|
||||
|
||||
|
||||
def _compute_ratio(search_ranks: dict[str, int], true_ranks: dict[str, int]) -> float:
|
||||
return len(set(search_ranks) & set(true_ranks)) / len(true_ranks)
|
||||
|
||||
|
||||
def _compute_avg_rank_delta(
|
||||
search_ranks: dict[str, int], true_ranks: dict[str, int]
|
||||
) -> float:
|
||||
out = len(search_ranks)
|
||||
return sum(
|
||||
abs(search_ranks.get(docid, out) - rank) for docid, rank in true_ranks.items()
|
||||
) / len(true_ranks)
|
88
backend/tests/regression/search_quality/util_retrieve.py
Normal file
88
backend/tests/regression/search_quality/util_retrieve.py
Normal file
@@ -0,0 +1,88 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.postprocessing.postprocessing import semantic_reranking
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.context.search.utils import remove_stop_words_and_punctuation
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.regression.search_quality.util_config import SearchEvalConfig
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def search_one_query(
|
||||
question_search: str,
|
||||
multilingual_expansion: list[str],
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
config: SearchEvalConfig,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Uses the search pipeline to retrieve relevant chunks for the given query."""
|
||||
# the retrieval preprocessing is fairly stripped down so the query doesn't unexpectedly change
|
||||
query_embedding = get_query_embedding(question_search, db_session)
|
||||
|
||||
all_query_terms = question_search.split()
|
||||
processed_keywords = (
|
||||
remove_stop_words_and_punctuation(all_query_terms)
|
||||
if not multilingual_expansion
|
||||
else all_query_terms
|
||||
)
|
||||
|
||||
is_keyword = query_analysis(question_search)[0]
|
||||
hybrid_alpha = config.hybrid_alpha_keyword if is_keyword else config.hybrid_alpha
|
||||
|
||||
access_control_list = ["PUBLIC"]
|
||||
if config.user_email:
|
||||
access_control_list.append(f"user_email:{config.user_email}")
|
||||
filters = IndexFilters(
|
||||
tags=[],
|
||||
user_file_ids=[],
|
||||
user_folder_ids=[],
|
||||
access_control_list=access_control_list,
|
||||
tenant_id=None,
|
||||
)
|
||||
|
||||
results = document_index.hybrid_retrieval(
|
||||
query=question_search,
|
||||
query_embedding=query_embedding,
|
||||
final_keywords=processed_keywords,
|
||||
filters=filters,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
time_decay_multiplier=config.doc_time_decay,
|
||||
num_to_retrieve=config.num_returned_hits,
|
||||
ranking_profile_type=config.rank_profile,
|
||||
offset=config.offset,
|
||||
title_content_ratio=config.title_content_ratio,
|
||||
)
|
||||
|
||||
return [result.to_inference_chunk() for result in results]
|
||||
|
||||
|
||||
def rerank_one_query(
|
||||
question: str,
|
||||
retrieved_chunks: list[InferenceChunk],
|
||||
rerank_settings: RerankingDetails,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Uses the reranker to rerank the retrieved chunks for the given query."""
|
||||
rerank_settings.num_rerank = len(retrieved_chunks)
|
||||
return semantic_reranking(
|
||||
query_str=question,
|
||||
rerank_settings=rerank_settings,
|
||||
chunks=retrieved_chunks,
|
||||
rerank_metrics_callback=None,
|
||||
)[0]
|
||||
|
||||
|
||||
def group_by_documents(chunks: list[InferenceChunk]) -> list[str]:
|
||||
"""Groups a sorted list of chunks into a sorted list of document ids."""
|
||||
seen_docids: set[str] = set()
|
||||
retrieved_docids: list[str] = []
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in seen_docids:
|
||||
seen_docids.add(chunk.document_id)
|
||||
retrieved_docids.append(chunk.document_id)
|
||||
return retrieved_docids
|
Reference in New Issue
Block a user