From 30d9ce131011aa5f91228967479ce82823babc20 Mon Sep 17 00:00:00 2001 From: Rei Meguro <36625832+Orbital-Web@users.noreply.github.com> Date: Thu, 15 May 2025 16:44:33 -0700 Subject: [PATCH] feat: search quality eval (#4720) * fix: import order * test examples * fix: import * wip: reranker based eval * fix: import order * feat: adjuted score * fix: mypy * fix: suggestions * sorry cvs, you must go Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * fix: mypy * fix: suggestions --------- Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- backend/onyx/context/search/utils.py | 6 +- .../tests/regression/search_quality/README.md | 56 +++ .../search_quality/generate_search_queries.py | 124 ++++++ .../search_quality/run_search_eval.py | 353 ++++++++++++++++++ .../search_eval_config.yaml.template | 16 + .../search_queries.json.template | 4 + 6 files changed, 558 insertions(+), 1 deletion(-) create mode 100644 backend/tests/regression/search_quality/README.md create mode 100644 backend/tests/regression/search_quality/generate_search_queries.py create mode 100644 backend/tests/regression/search_quality/run_search_eval.py create mode 100644 backend/tests/regression/search_quality/search_eval_config.yaml.template create mode 100644 backend/tests/regression/search_quality/search_queries.json.template diff --git a/backend/onyx/context/search/utils.py b/backend/onyx/context/search/utils.py index 22e6b0f8df6..b0fc7c7e2b7 100644 --- a/backend/onyx/context/search/utils.py +++ b/backend/onyx/context/search/utils.py @@ -12,6 +12,9 @@ from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SavedSearchDocWithContent from onyx.context.search.models import SearchDoc from onyx.db.models import SearchDoc as DBSearchDoc +from onyx.utils.logger import setup_logger + +logger = setup_logger() T = TypeVar( @@ -154,5 +157,6 @@ def remove_stop_words_and_punctuation(keywords: list[str]) -> list[str]: if (word.casefold() not in stop_words and word not in string.punctuation) ] return text_trimmed or word_tokens - except Exception: + except Exception as e: + logger.warning(f"Error removing stop words and punctuation: {e}") return keywords diff --git a/backend/tests/regression/search_quality/README.md b/backend/tests/regression/search_quality/README.md new file mode 100644 index 00000000000..edf44f745f7 --- /dev/null +++ b/backend/tests/regression/search_quality/README.md @@ -0,0 +1,56 @@ +# Search Quality Test Script + +This Python script evaluates the search results for a list of queries. + +Unlike the script in answer_quality, this script is much less customizable and runs using currently ingested documents, though it allows for quick testing of search parameters on a bunch of test queries that don't have well-defined answers. + +## Usage + +1. Ensure you have the required dependencies installed and onyx running. + +2. Ensure a reranker model is configured in the search settings. +This can be checked/modified by opening the admin panel, going to search settings, and ensuring a reranking model is set. + +3. Set up the PYTHONPATH permanently: + Add the following line to your shell configuration file (e.g., `~/.bashrc`, `~/.zshrc`, or `~/.bash_profile`): + ``` + export PYTHONPATH=$PYTHONPATH:/path/to/onyx/backend + ``` + Replace `/path/to/onyx` with the actual path to your Onyx repository. + After adding this line, restart your terminal or run `source ~/.bashrc` (or the appropriate config file) to apply the changes. + +4. Navigate to Onyx repo, search_quality folder: + +``` +cd path/to/onyx/backend/tests/regression/search_quality +``` + +5. Copy `search_queries.json.template` to `search_queries.json` and add/remove test queries in it + +6. Run `generate_search_queries.py` to generate the modified queries for the search pipeline + +``` +python generate_search_queries.py +``` + +7. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters +8. Run `run_search_eval.py` to evaluate the search results against the reranked results + +``` +python run_search_eval.py +``` + +9. Repeat steps 7 and 8 to test and compare different search parameters + +## Metrics +- Jaccard Similarity: the ratio between the intersect and the union between the topk search and rerank results. Higher is better +- Average Rank Change: The average absolute rank difference of the topk reranked chunks vs the entire search chunks. Lower is better +- Average Missing Chunk Ratio: The number of chunks in the topk reranked chunks not in the topk search chunks, over topk. Lower is better + +Note that all of these metrics are affected by very narrow search results. +E.g., if topk is 20 but there is only 1 relevant document, the other 19 documents could be ordered arbitrarily, resulting in a lower score. + + +To address this limitation, there are score adjusted versions of the metrics. +The score adjusted version does not use a fixed topk, but computes the optimum topk based on the rerank scores. +This generally works in determining how many documents are relevant, although note that this approach isn't perfect. \ No newline at end of file diff --git a/backend/tests/regression/search_quality/generate_search_queries.py b/backend/tests/regression/search_quality/generate_search_queries.py new file mode 100644 index 00000000000..ac6fb0cfdd9 --- /dev/null +++ b/backend/tests/regression/search_quality/generate_search_queries.py @@ -0,0 +1,124 @@ +import json +from pathlib import Path + +from langgraph.types import StreamWriter + +from onyx.agents.agent_search.basic.utils import process_llm_stream +from onyx.chat.models import PromptConfig +from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message +from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message +from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW +from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE +from onyx.configs.constants import DEFAULT_PERSONA_ID +from onyx.db.engine import get_session_with_current_tenant +from onyx.db.engine import SqlEngine +from onyx.db.persona import get_persona_by_id +from onyx.llm.factory import get_llms_for_persona +from onyx.llm.interfaces import LLM +from onyx.tools.tool_implementations.search.search_tool import SearchTool +from onyx.tools.utils import explicit_tool_calling_supported +from onyx.utils.logger import setup_logger + +logger = setup_logger() + + +def _load_queries() -> list[str]: + current_dir = Path(__file__).parent + with open(current_dir / "search_queries.json", "r") as file: + return json.load(file) + + +def _modify_one_query( + query: str, + llm: LLM, + prompt_config: PromptConfig, + tool_definition: dict, + writer: StreamWriter = lambda _: None, +) -> str: + prompt_builder = AnswerPromptBuilder( + user_message=default_build_user_message( + user_query=query, + prompt_config=prompt_config, + files=[], + single_message_history=None, + ), + system_message=default_build_system_message(prompt_config, llm.config), + message_history=[], + llm_config=llm.config, + raw_user_query=query, + raw_user_uploaded_files=[], + single_message_history=None, + ) + prompt = prompt_builder.build() + + stream = llm.stream( + prompt=prompt, + tools=[tool_definition], + tool_choice="required", + structured_response_format=None, + ) + tool_message = process_llm_stream( + messages=stream, + should_stream_answer=False, + writer=writer, + ) + return ( + tool_message.tool_calls[0]["args"]["query"] + if tool_message.tool_calls + else query + ) + + +class SearchToolOverride(SearchTool): + def __init__(self) -> None: + # do nothing, the tool_definition function doesn't require variables to be initialized + pass + + +def generate_search_queries() -> None: + SqlEngine.init_engine( + pool_size=POSTGRES_API_SERVER_POOL_SIZE, + max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW, + ) + + queries = _load_queries() + + with get_session_with_current_tenant() as db_session: + persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session) + llm, _ = get_llms_for_persona(persona) + prompt_config = PromptConfig.from_model(persona.prompts[0]) + tool_definition = SearchToolOverride().tool_definition() + + tool_call_supported = explicit_tool_calling_supported( + llm.config.model_provider, llm.config.model_name + ) + + if tool_call_supported: + logger.info( + "Tool calling is supported for the current model. Modifying queries." + ) + modified_queries = [ + _modify_one_query( + query=query, + llm=llm, + prompt_config=prompt_config, + tool_definition=tool_definition, + ) + for query in queries + ] + else: + logger.warning( + "Tool calling is not supported for the current model. " + "Using the original queries." + ) + modified_queries = queries + + with open("search_queries_modified.json", "w") as file: + json.dump(modified_queries, file, indent=4) + + logger.info("Exported modified queries to search_queries_modified.json") + + +if __name__ == "__main__": + generate_search_queries() diff --git a/backend/tests/regression/search_quality/run_search_eval.py b/backend/tests/regression/search_quality/run_search_eval.py new file mode 100644 index 00000000000..578c2e84caa --- /dev/null +++ b/backend/tests/regression/search_quality/run_search_eval.py @@ -0,0 +1,353 @@ +import csv +import json +import os +from bisect import bisect_left +from datetime import datetime +from pathlib import Path +from typing import cast + +import yaml +from pydantic import BaseModel +from sqlalchemy.orm import Session + +from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType +from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW +from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE +from onyx.configs.chat_configs import DOC_TIME_DECAY +from onyx.configs.chat_configs import HYBRID_ALPHA +from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD +from onyx.configs.chat_configs import NUM_RETURNED_HITS +from onyx.configs.chat_configs import TITLE_CONTENT_RATIO +from onyx.context.search.models import IndexFilters +from onyx.context.search.models import InferenceChunk +from onyx.context.search.models import RerankingDetails +from onyx.context.search.postprocessing.postprocessing import semantic_reranking +from onyx.context.search.preprocessing.preprocessing import query_analysis +from onyx.context.search.retrieval.search_runner import get_query_embedding +from onyx.context.search.utils import remove_stop_words_and_punctuation +from onyx.db.engine import get_session_with_current_tenant +from onyx.db.engine import SqlEngine +from onyx.db.search_settings import get_current_search_settings +from onyx.db.search_settings import get_multilingual_expansion +from onyx.document_index.factory import get_default_document_index +from onyx.document_index.interfaces import DocumentIndex +from onyx.utils.logger import setup_logger + +logger = setup_logger(__name__) + + +class SearchEvalParameters(BaseModel): + hybrid_alpha: float + hybrid_alpha_keyword: float + doc_time_decay: float + num_returned_hits: int + rank_profile: QueryExpansionType + offset: int + title_content_ratio: float + user_email: str | None + skip_rerank: bool + eval_topk: int + export_folder: str + + +def _load_search_parameters() -> SearchEvalParameters: + current_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join(current_dir, "search_eval_config.yaml") + with open(config_path, "r") as file: + config = yaml.safe_load(file) + + export_folder = config.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S") + export_folder = datetime.now().strftime(export_folder) + + export_path = Path(export_folder) + export_path.mkdir(parents=True, exist_ok=True) + logger.info(f"Created export folder: {export_path}") + + search_parameters = SearchEvalParameters( + hybrid_alpha=config.get("HYBRID_ALPHA") or HYBRID_ALPHA, + hybrid_alpha_keyword=config.get("HYBRID_ALPHA_KEYWORD") or HYBRID_ALPHA_KEYWORD, + doc_time_decay=config.get("DOC_TIME_DECAY") or DOC_TIME_DECAY, + num_returned_hits=config.get("NUM_RETURNED_HITS") or NUM_RETURNED_HITS, + rank_profile=config.get("RANK_PROFILE") or QueryExpansionType.SEMANTIC, + offset=config.get("OFFSET") or 0, + title_content_ratio=config.get("TITLE_CONTENT_RATIO") or TITLE_CONTENT_RATIO, + user_email=config.get("USER_EMAIL"), + skip_rerank=config.get("SKIP_RERANK", False), + eval_topk=config.get("EVAL_TOPK", 20), + export_folder=export_folder, + ) + logger.info(f"Using search parameters: {search_parameters}") + + config_file = export_path / "search_eval_config.yaml" + with config_file.open("w") as file: + search_parameters_dict = search_parameters.model_dump(mode="python") + search_parameters_dict["rank_profile"] = search_parameters.rank_profile.value + yaml.dump(search_parameters_dict, file, sort_keys=False) + logger.info(f"Exported config to {config_file}") + + return search_parameters + + +def _load_query_pairs() -> list[tuple[str, str]]: + current_dir = Path(__file__).parent + + with open(current_dir / "search_queries.json", "r") as file: + orig_queries = json.load(file) + + with open(current_dir / "search_queries_modified.json", "r") as file: + alt_queries = json.load(file) + + return list(zip(orig_queries, alt_queries)) + + +def _search_one_query( + alt_query: str, + multilingual_expansion: list[str], + document_index: DocumentIndex, + db_session: Session, + search_parameters: SearchEvalParameters, +) -> list[InferenceChunk]: + # the retrieval preprocessing is fairly stripped down so the query doesn't unexpectly change + query_embedding = get_query_embedding(alt_query, db_session) + + all_query_terms = alt_query.split() + processed_keywords = ( + remove_stop_words_and_punctuation(all_query_terms) + if not multilingual_expansion + else all_query_terms + ) + + is_keyword = query_analysis(alt_query)[0] + hybrid_alpha = ( + search_parameters.hybrid_alpha_keyword + if is_keyword + else search_parameters.hybrid_alpha + ) + + access_control_list = ["PUBLIC"] + if search_parameters.user_email: + access_control_list.append(f"user_email:{search_parameters.user_email}") + filters = IndexFilters( + tags=[], + user_file_ids=[], + user_folder_ids=[], + access_control_list=access_control_list, + tenant_id=None, + ) + + results = document_index.hybrid_retrieval( + query=alt_query, + query_embedding=query_embedding, + final_keywords=processed_keywords, + filters=filters, + hybrid_alpha=hybrid_alpha, + time_decay_multiplier=search_parameters.doc_time_decay, + num_to_retrieve=search_parameters.num_returned_hits, + ranking_profile_type=search_parameters.rank_profile, + offset=search_parameters.offset, + title_content_ratio=search_parameters.title_content_ratio, + ) + + return [result.to_inference_chunk() for result in results] + + +def _rerank_one_query( + orig_query: str, + retrieved_chunks: list[InferenceChunk], + rerank_settings: RerankingDetails, + search_parameters: SearchEvalParameters, +) -> list[InferenceChunk]: + assert not search_parameters.skip_rerank, "Reranking is disabled" + return semantic_reranking( + query_str=orig_query, + rerank_settings=rerank_settings, + chunks=retrieved_chunks, + rerank_metrics_callback=None, + )[0] + + +def _evaluate_one_query( + search_results: list[InferenceChunk], + rerank_results: list[InferenceChunk], + search_parameters: SearchEvalParameters, +) -> list[float]: + search_topk = search_results[: search_parameters.eval_topk] + rerank_topk = rerank_results[: search_parameters.eval_topk] + + # get the score adjusted topk (topk where the score is at least 50% of the top score) + # could be more than topk if top scores are similar, may or may not be a good thing + # can change by swapping rerank_results with rerank_topk in bisect + adj_topk = bisect_left( + rerank_results, + -0.5 * cast(float, rerank_results[0].score), + key=lambda x: -cast(float, x.score), + ) + search_adj_topk = search_results[:adj_topk] + rerank_adj_topk = rerank_results[:adj_topk] + + # compute metrics + search_ranks = {chunk.unique_id: rank for rank, chunk in enumerate(search_results)} + return [ + _compute_jaccard_similarity(search_topk, rerank_topk), + _compute_average_rank_change(search_ranks, rerank_topk), + _compute_average_missing_chunk_ratio(search_topk, rerank_topk), + # score adjusted metrics + _compute_jaccard_similarity(search_adj_topk, rerank_adj_topk), + _compute_average_rank_change(search_ranks, rerank_adj_topk), + _compute_average_missing_chunk_ratio(search_adj_topk, rerank_adj_topk), + ] + + +def _compute_jaccard_similarity( + search_topk: list[InferenceChunk], rerank_topk: list[InferenceChunk] +) -> float: + search_chunkids = {chunk.unique_id for chunk in search_topk} + rerank_chunkids = {chunk.unique_id for chunk in rerank_topk} + return len(search_chunkids.intersection(rerank_chunkids)) / len( + search_chunkids.union(rerank_chunkids) + ) + + +def _compute_average_rank_change( + search_ranks: dict[str, int], rerank_topk: list[InferenceChunk] +) -> float: + rank_changes = [ + abs(search_ranks[chunk.unique_id] - rerank_rank) + for rerank_rank, chunk in enumerate(rerank_topk) + ] + return sum(rank_changes) / len(rank_changes) + + +def _compute_average_missing_chunk_ratio( + search_topk: list[InferenceChunk], rerank_topk: list[InferenceChunk] +) -> float: + search_chunkids = {chunk.unique_id for chunk in search_topk} + rerank_chunkids = {chunk.unique_id for chunk in rerank_topk} + return len(rerank_chunkids.difference(search_chunkids)) / len(rerank_chunkids) + + +def run_search_eval() -> None: + SqlEngine.init_engine( + pool_size=POSTGRES_API_SERVER_POOL_SIZE, + max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW, + ) + + search_parameters = _load_search_parameters() + query_pairs = _load_query_pairs() + + with get_session_with_current_tenant() as db_session: + multilingual_expansion = get_multilingual_expansion(db_session) + search_settings = get_current_search_settings(db_session) + document_index = get_default_document_index(search_settings, None) + rerank_settings = RerankingDetails.from_db_model(search_settings) + + if search_parameters.skip_rerank: + logger.warning("Reranking is disabled, evaluation will not run") + elif rerank_settings.rerank_model_name is None: + raise ValueError( + "Reranking is enabled but no reranker is configured. " + "Please set the reranker in the admin panel search settings." + ) + + export_path = Path(search_parameters.export_folder) + search_result_file = export_path / "search_results.csv" + eval_result_file = export_path / "eval_results.csv" + with ( + search_result_file.open("w") as search_file, + eval_result_file.open("w") as eval_file, + ): + search_csv_writer = csv.writer(search_file) + eval_csv_writer = csv.writer(eval_file) + search_csv_writer.writerow( + ["source", "query", "rank", "score", "doc_id", "chunk_id"] + ) + eval_csv_writer.writerow( + [ + "query", + "jaccard_similarity", + "average_rank_change", + "missing_chunks_ratio", + "jaccard_similarity_adj", + "average_rank_change_adj", + "missing_chunks_ratio_adj", + ] + ) + + sum_metrics = [0.0] * 6 + for orig_query, alt_query in query_pairs: + search_results = _search_one_query( + alt_query, + multilingual_expansion, + document_index, + db_session, + search_parameters, + ) + for rank, result in enumerate(search_results): + search_csv_writer.writerow( + [ + "search", + alt_query, + rank, + result.score, + result.document_id, + result.chunk_id, + ] + ) + + if not search_parameters.skip_rerank: + rerank_results = _rerank_one_query( + orig_query, search_results, rerank_settings, search_parameters + ) + for rank, result in enumerate(rerank_results): + search_csv_writer.writerow( + [ + "rerank", + orig_query, + rank, + result.score, + result.document_id, + result.chunk_id, + ] + ) + + metrics = _evaluate_one_query( + search_results, rerank_results, search_parameters + ) + eval_csv_writer.writerow([orig_query, *metrics]) + sum_metrics = [ + sum_metric + metric + for sum_metric, metric in zip(sum_metrics, metrics) + ] + + logger.info( + f"Exported individual results to {search_result_file} and {eval_result_file}" + ) + + if not search_parameters.skip_rerank: + average_metrics = [metric / len(query_pairs) for metric in sum_metrics] + logger.info(f"Jaccard similarity: {average_metrics[0]}") + logger.info(f"Average rank change: {average_metrics[1]}") + logger.info(f"Average missing chunks ratio: {average_metrics[2]}") + logger.info(f"Jaccard similarity (adjusted): {average_metrics[3]}") + logger.info(f"Average rank change (adjusted): {average_metrics[4]}") + logger.info(f"Average missing chunks ratio (adjusted): {average_metrics[5]}") + + aggregate_file = export_path / "aggregate_results.csv" + with aggregate_file.open("w") as file: + aggregate_csv_writer = csv.writer(file) + aggregate_csv_writer.writerow( + [ + "jaccard_similarity", + "average_rank_change", + "missing_chunks_ratio", + "jaccard_similarity_adj", + "average_rank_change_adj", + "missing_chunks_ratio_adj", + ] + ) + aggregate_csv_writer.writerow(average_metrics) + logger.info(f"Exported aggregate results to {aggregate_file}") + + +if __name__ == "__main__": + run_search_eval() diff --git a/backend/tests/regression/search_quality/search_eval_config.yaml.template b/backend/tests/regression/search_quality/search_eval_config.yaml.template new file mode 100644 index 00000000000..14447ec5d6b --- /dev/null +++ b/backend/tests/regression/search_quality/search_eval_config.yaml.template @@ -0,0 +1,16 @@ +# Search Parameters, null means use default +HYBRID_ALPHA: null +HYBRID_ALPHA_KEYWORD: null +DOC_TIME_DECAY: null +NUM_RETURNED_HITS: 200 # Setting to a higher value will improve evaluation quality but increase reranking time +RANK_PROFILE: null +OFFSET: null +TITLE_CONTENT_RATIO: null +USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files + +# Evaluation parameters +SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results +EVAL_TOPK: 20 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation + +# Export file, will export a csv file with the results and a json file with the parameters +EXPORT_FOLDER: "eval-%Y-%m-%d-%H-%M-%S" diff --git a/backend/tests/regression/search_quality/search_queries.json.template b/backend/tests/regression/search_quality/search_queries.json.template new file mode 100644 index 00000000000..bb31ef52c07 --- /dev/null +++ b/backend/tests/regression/search_quality/search_queries.json.template @@ -0,0 +1,4 @@ +[ + "What is Onyx?", + "How is Onyx enterprise edition different?" +] \ No newline at end of file