mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-11 21:39:31 +02:00
Added support for multiple Eval Pipeline UIs (#1830)
This commit is contained in:
parent
dae4f6a0bd
commit
1b864a00e4
@ -16,59 +16,75 @@ This Python script automates the process of running search quality tests for a b
|
||||
|
||||
1. Ensure you have the required dependencies installed.
|
||||
2. Configure the `search_test_config.yaml` file based on the `search_test_config.yaml.template` file.
|
||||
3. Configure the `.env_eval` file with the correct environment variables.
|
||||
4. Navigate to the answer_quality folder:
|
||||
3. Configure the `.env_eval` file in `deployment/docker_compose` with the correct environment variables.
|
||||
4. Navigate to Danswer repo:
|
||||
```
|
||||
cd danswer/backend/tests/regression/answer_quality
|
||||
cd path/to/danswer
|
||||
```
|
||||
4. Run the script:
|
||||
5. Set Python Path variable:
|
||||
```
|
||||
python search_quality_test.py
|
||||
export PYTHONPATH=$PYTHONPATH:$PWD/backend
|
||||
```
|
||||
6. Navigate to the answer_quality folder:
|
||||
```
|
||||
cd backend/tests/regression/answer_quality
|
||||
```
|
||||
7. Run the script:
|
||||
```
|
||||
python run_eval_pipeline.py
|
||||
```
|
||||
|
||||
Note: All data will be saved even after the containers are shut down. There are instructions below to re-launching docker containers using this data.
|
||||
|
||||
If you decide to run multiple UIs at the same time, the ports will increment upwards from 3000 (E.g. http://localhost:3001).
|
||||
|
||||
To see which port the desired instance is on, look at the ports on the nginx container by running `docker ps` or using docker desktop.
|
||||
|
||||
Docker daemon must be running for this to work.
|
||||
|
||||
## Configuration
|
||||
|
||||
Edit `search_test_config.yaml` to set:
|
||||
|
||||
- output_folder
|
||||
This is the folder where the folders for each test will go
|
||||
These folders will contain the postgres/vespa data as well as the results for each test
|
||||
- This is the folder where the folders for each test will go
|
||||
- These folders will contain the postgres/vespa data as well as the results for each test
|
||||
- zipped_documents_file
|
||||
The path to the zip file containing the files you'd like to test against
|
||||
- The path to the zip file containing the files you'd like to test against
|
||||
- questions_file
|
||||
The path to the yaml containing the questions you'd like to test with
|
||||
- The path to the yaml containing the questions you'd like to test with
|
||||
- branch
|
||||
Set the branch to null if you want it to just use the code as is
|
||||
- Set the branch to null if you want it to just use the code as is
|
||||
- clean_up_docker_containers
|
||||
Set this to true to automatically delete all docker containers, networks and volumes after the test
|
||||
- Set this to true to automatically delete all docker containers, networks and volumes after the test
|
||||
- launch_web_ui
|
||||
Set this to true if you want to use the UI during/after the testing process
|
||||
- Set this to true if you want to use the UI during/after the testing process
|
||||
- use_cloud_gpu
|
||||
Set to true or false depending on if you want to use the remote gpu
|
||||
Only need to set this if use_cloud_gpu is true
|
||||
- Set to true or false depending on if you want to use the remote gpu
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
- model_server_ip
|
||||
This is the ip of the remote model server
|
||||
Only need to set this if use_cloud_gpu is true
|
||||
- This is the ip of the remote model server
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
- model_server_port
|
||||
This is the port of the remote model server
|
||||
Only need to set this if use_cloud_gpu is true
|
||||
- This is the port of the remote model server
|
||||
- Only need to set this if use_cloud_gpu is true
|
||||
- existing_test_suffix
|
||||
Use this if you would like to relaunch a previous test instance
|
||||
Input the suffix of the test you'd like to re-launch
|
||||
(E.g. to use the data from folder "test_1234_5678" put "_1234_5678")
|
||||
No new files will automatically be uploaded
|
||||
Leave empty to run a new test
|
||||
- Use this if you would like to relaunch a previous test instance
|
||||
- Input the suffix of the test you'd like to re-launch
|
||||
- (E.g. to use the data from folder "test_1234_5678" put "_1234_5678")
|
||||
- No new files will automatically be uploaded
|
||||
- Leave empty to run a new test
|
||||
- limit
|
||||
Max number of questions you'd like to ask against the dataset
|
||||
Set to null for no limit
|
||||
- Max number of questions you'd like to ask against the dataset
|
||||
- Set to null for no limit
|
||||
- llm
|
||||
Fill this out according to the normal LLM seeding
|
||||
- Fill this out according to the normal LLM seeding
|
||||
|
||||
|
||||
To restart the evaluation using a particular index, set the suffix and turn off clean_up_docker_containers.
|
||||
This also will skip running the evaluation questions, in this case, the relari.py script can be run manually.
|
||||
## Relaunching From Existing Data
|
||||
|
||||
To launch an existing set of containers that has already completed indexing, set the existing_test_suffix variable. This will launch the docker containers mounted on the volumes of the indicated suffix and will not automatically index any documents or run any QA.
|
||||
|
||||
Docker daemon must be running for this to work.
|
||||
|
||||
Each script is able to be individually run to upload additional docs or run additional tests
|
||||
Once these containers are launched you can run file_uploader.py or run_qa.py (assuming you have run the steps in the Usage section above).
|
||||
- file_uploader.py will upload and index additional zipped files located at the zipped_documents_file path.
|
||||
- run_qa.py will ask questions located at the questions_file path against the indexed documents.
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
from threading import Thread
|
||||
@ -108,6 +109,11 @@ def set_env_variables(
|
||||
print(f"Set {env_var_name} to: {env_var}")
|
||||
|
||||
|
||||
def _is_port_in_use(port: int) -> bool:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
return s.connect_ex(("localhost", port)) == 0
|
||||
|
||||
|
||||
def start_docker_compose(
|
||||
run_suffix: str, launch_web_ui: bool, use_cloud_gpu: bool
|
||||
) -> None:
|
||||
@ -117,12 +123,18 @@ def start_docker_compose(
|
||||
command = f"docker compose -f docker-compose.search-testing.yml -p danswer-stack{run_suffix} up -d"
|
||||
command += " --build"
|
||||
command += " --force-recreate"
|
||||
if not launch_web_ui:
|
||||
command += " --scale web_server=0"
|
||||
command += " --scale nginx=0"
|
||||
if use_cloud_gpu:
|
||||
command += " --scale indexing_model_server=0"
|
||||
command += " --scale inference_model_server=0"
|
||||
if launch_web_ui:
|
||||
web_ui_port = 3000
|
||||
while _is_port_in_use(web_ui_port):
|
||||
web_ui_port += 1
|
||||
print(f"UI will be launched at http://localhost:{web_ui_port}")
|
||||
os.environ["NGINX_PORT"] = str(web_ui_port)
|
||||
else:
|
||||
command += " --scale web_server=0"
|
||||
command += " --scale nginx=0"
|
||||
|
||||
print("Docker Command:\n", command)
|
||||
|
||||
|
@ -1,252 +0,0 @@
|
||||
import argparse
|
||||
import builtins
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TextIO
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import LLMMetricsContainer
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.utils.callbacks import MetricsHander
|
||||
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def redirect_print_to_file(file: TextIO) -> Any:
|
||||
original_print = builtins.print
|
||||
builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
builtins.print = original_print
|
||||
|
||||
|
||||
def load_yaml(filepath: str) -> dict:
|
||||
with open(filepath, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
return data
|
||||
|
||||
|
||||
def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str:
|
||||
words = s.split()
|
||||
|
||||
current_line: list[str] = []
|
||||
result_lines: list[str] = []
|
||||
current_length = 0
|
||||
for word in words:
|
||||
if len(word) > max_line_size:
|
||||
if current_line:
|
||||
result_lines.append(" ".join(current_line))
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
result_lines.append(word)
|
||||
continue
|
||||
|
||||
if current_length + len(word) + len(current_line) > max_line_size:
|
||||
result_lines.append(" ".join(current_line))
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
current_line.append(word)
|
||||
current_length += len(word)
|
||||
|
||||
if current_line:
|
||||
result_lines.append(" ".join(current_line))
|
||||
|
||||
return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines)
|
||||
|
||||
|
||||
def get_answer_for_question(
|
||||
query: str, db_session: Session
|
||||
) -> tuple[
|
||||
str | None,
|
||||
RetrievalMetricsContainer | None,
|
||||
RerankMetricsContainer | None,
|
||||
]:
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=None,
|
||||
time_cutoff=None,
|
||||
tags=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
new_message_request = DirectQARequest(
|
||||
messages=messages,
|
||||
prompt_id=0,
|
||||
persona_id=0,
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
),
|
||||
chain_of_thought=False,
|
||||
)
|
||||
|
||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=None,
|
||||
max_document_tokens=None,
|
||||
max_history_tokens=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=100,
|
||||
enable_reflexion=False,
|
||||
bypass_acl=True,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
)
|
||||
|
||||
return (
|
||||
answer.answer,
|
||||
retrieval_metrics.metrics,
|
||||
rerank_metrics.metrics,
|
||||
)
|
||||
|
||||
|
||||
def _print_retrieval_metrics(
|
||||
metrics_container: RetrievalMetricsContainer, show_all: bool
|
||||
) -> None:
|
||||
for ind, metric in enumerate(metrics_container.metrics):
|
||||
if not show_all and ind >= 10:
|
||||
break
|
||||
|
||||
if ind != 0:
|
||||
print() # for spacing purposes
|
||||
print(f"\tDocument: {metric.document_id}")
|
||||
print(f"\tLink: {metric.first_link or 'NA'}")
|
||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||
print(f"\tSection Start: {section_start}")
|
||||
print(f"\tSimilarity Distance Metric: {metric.score}")
|
||||
|
||||
|
||||
def _print_reranking_metrics(
|
||||
metrics_container: RerankMetricsContainer, show_all: bool
|
||||
) -> None:
|
||||
# Printing the raw scores as they're more informational than post-norm/boosting
|
||||
for ind, metric in enumerate(metrics_container.metrics):
|
||||
if not show_all and ind >= 10:
|
||||
break
|
||||
|
||||
if ind != 0:
|
||||
print() # for spacing purposes
|
||||
print(f"\tDocument: {metric.document_id}")
|
||||
print(f"\tLink: {metric.first_link or 'NA'}")
|
||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||
print(f"\tSection Start: {section_start}")
|
||||
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
|
||||
|
||||
|
||||
def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None:
|
||||
print(f"\tPrompt Tokens: {metrics_container.prompt_tokens}")
|
||||
print(f"\tResponse Tokens: {metrics_container.response_tokens}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"regression_yaml",
|
||||
type=str,
|
||||
help="Path to the Questions YAML file.",
|
||||
default="./tests/regression/answer_quality/sample_questions.yaml",
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--real-time", action="store_true", help="Set to use the real-time flow."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discard-metrics",
|
||||
action="store_true",
|
||||
help="Set to not include metrics on search, rerank, and token counts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all-results",
|
||||
action="store_true",
|
||||
help="Set to not include more than the 10 top sections for search and reranking metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
help="Path to the output results file.",
|
||||
default="./tests/regression/answer_quality/regression_results.txt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
questions_data = load_yaml(args.regression_yaml)
|
||||
|
||||
with open(args.output, "w") as outfile:
|
||||
with redirect_print_to_file(outfile):
|
||||
print("Running Question Answering Flow")
|
||||
print(
|
||||
"Note that running metrics requires tokenizing all "
|
||||
"prompts/returns and slightly slows down inference."
|
||||
)
|
||||
print(
|
||||
"Also note that the text embedding model (bi-encoder) currently used is trained for "
|
||||
"relative distances, not absolute distances. Therefore cosine similarity values may all be > 0.5 "
|
||||
"even for poor matches"
|
||||
)
|
||||
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
for sample in questions_data["questions"]:
|
||||
print(
|
||||
f"Running Test for Question {sample['id']}: {sample['question']}"
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
(
|
||||
answer,
|
||||
retrieval_metrics,
|
||||
rerank_metrics,
|
||||
) = get_answer_for_question(sample["question"], db_session)
|
||||
end_time = datetime.now()
|
||||
|
||||
print(f"====Duration: {end_time - start_time}====")
|
||||
print(f"Question {sample['id']}:")
|
||||
print(f'\t{sample["question"]}')
|
||||
print("\nApproximate Expected Answer:")
|
||||
print(f'\t{sample["expected_answer"]}')
|
||||
print("\nActual Answer:")
|
||||
print(
|
||||
word_wrap(answer)
|
||||
if answer
|
||||
else "\tFailed, either crashed or refused to answer."
|
||||
)
|
||||
if not args.discard_metrics:
|
||||
print("\nRetrieval Metrics:")
|
||||
if retrieval_metrics is None:
|
||||
print("No Retrieval Metrics Available")
|
||||
else:
|
||||
_print_retrieval_metrics(
|
||||
retrieval_metrics, show_all=args.all_results
|
||||
)
|
||||
|
||||
print("\nReranking Metrics:")
|
||||
if rerank_metrics is None:
|
||||
print("No Reranking Metrics Available")
|
||||
else:
|
||||
_print_reranking_metrics(
|
||||
rerank_metrics, show_all=args.all_results
|
||||
)
|
||||
|
||||
print("\n\n", flush=True)
|
@ -10,7 +10,7 @@ from tests.regression.answer_quality.cli_utils import set_env_variables
|
||||
from tests.regression.answer_quality.cli_utils import start_docker_compose
|
||||
from tests.regression.answer_quality.cli_utils import switch_to_branch
|
||||
from tests.regression.answer_quality.file_uploader import upload_test_files
|
||||
from tests.regression.answer_quality.relari import answer_relari_questions
|
||||
from tests.regression.answer_quality.run_qa import run_qa_test_and_save_results
|
||||
|
||||
|
||||
def load_config(config_filename: str) -> SimpleNamespace:
|
||||
@ -46,12 +46,12 @@ def main() -> None:
|
||||
if not config.existing_test_suffix:
|
||||
upload_test_files(config.zipped_documents_file, run_suffix)
|
||||
|
||||
answer_relari_questions(
|
||||
run_qa_test_and_save_results(
|
||||
config.questions_file, relari_output_folder_path, run_suffix, config.limit
|
||||
)
|
||||
|
||||
if config.clean_up_docker_containers:
|
||||
cleanup_docker(run_suffix)
|
||||
if config.clean_up_docker_containers:
|
||||
cleanup_docker(run_suffix)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
@ -10,7 +10,7 @@ from tests.regression.answer_quality.api_utils import get_answer_from_query
|
||||
from tests.regression.answer_quality.cli_utils import get_current_commit_sha
|
||||
|
||||
|
||||
def _get_and_write_relari_outputs(
|
||||
def _process_and_write_query_results(
|
||||
samples: list[dict], run_suffix: str, output_file_path: str
|
||||
) -> None:
|
||||
while not check_if_query_ready(run_suffix):
|
||||
@ -62,7 +62,7 @@ def _read_questions_jsonl(questions_file_path: str) -> list[dict]:
|
||||
return questions
|
||||
|
||||
|
||||
def answer_relari_questions(
|
||||
def run_qa_test_and_save_results(
|
||||
questions_file_path: str,
|
||||
results_folder_path: str,
|
||||
run_suffix: str,
|
||||
@ -91,7 +91,7 @@ def answer_relari_questions(
|
||||
|
||||
print("saving question results to:", output_file_path)
|
||||
_write_metadata_file(run_suffix, metadata_file_path)
|
||||
_get_and_write_relari_outputs(
|
||||
_process_and_write_query_results(
|
||||
samples=samples, run_suffix=run_suffix, output_file_path=output_file_path
|
||||
)
|
||||
|
||||
@ -110,7 +110,7 @@ def main() -> None:
|
||||
else:
|
||||
current_output_folder = os.path.join(current_output_folder, "no_defined_suffix")
|
||||
|
||||
answer_relari_questions(
|
||||
run_qa_test_and_save_results(
|
||||
config.questions_file,
|
||||
current_output_folder,
|
||||
config.existing_test_suffix,
|
@ -1,96 +0,0 @@
|
||||
# This YAML file contains regression questions for Danswer.
|
||||
# The sources mentioned are the same ones to power the DanswerBot for the community's use
|
||||
# The regression flow assumes the data from the sources listed are already indexed
|
||||
|
||||
metadata:
|
||||
version: v0.0.1
|
||||
date: 2023-09-10
|
||||
sources:
|
||||
- name: web
|
||||
detail: https://www.danswer.ai/
|
||||
- name: web
|
||||
detail: https://docs.danswer.dev/
|
||||
- name: github issues
|
||||
detail: danswer-ai/danswer
|
||||
- name: github pull-requests
|
||||
detail: danswer-ai/danswer
|
||||
- name: slack
|
||||
workspace: danswer.slack.com
|
||||
- name: file
|
||||
detail: Markdown files from Danswer repo
|
||||
|
||||
questions:
|
||||
- id: 1
|
||||
question: "What is Danswer?"
|
||||
expected_answer: "Danswer is an open source question-answering system."
|
||||
notes: "This comes directly from the docs, the actual answer should be more informative"
|
||||
|
||||
- id: 2
|
||||
question: "What is Danswer licensed under?"
|
||||
expected_answer: "Danswer is MIT licensed"
|
||||
notes: "This info can be found in many places"
|
||||
|
||||
- id: 3
|
||||
question: "What are the required variables to set to use GPT-4?"
|
||||
expected_answer: "Set the environment variables INTERNAL_MODEL_VERSION=openai-chat-completion and GEN_AI_MODEL_VERSION=gpt-4"
|
||||
notes: "Two env vars are must have, the third (the key) is optional"
|
||||
|
||||
- id: 4
|
||||
question: "Why might I want to use the deberta model for QnA?"
|
||||
expected_answer: "This kind of model can run on CPU and are less likely to produce hallucinations"
|
||||
notes: "https://docs.danswer.dev/gen_ai_configs/transformers, this is a pretty hard question"
|
||||
|
||||
- id: 5
|
||||
question: "What auth related tokens do I need for BookStack?"
|
||||
expected_answer: "You will need the API Token ID and the API Token Secret"
|
||||
notes: "https://docs.danswer.dev/connectors/bookstack"
|
||||
|
||||
- id: 6
|
||||
question: "ValueError: invalid literal for int() with base 10"
|
||||
expected_answer: "This was a bug that was fixed shortly after the issue was filed. Try updating the code."
|
||||
notes: "This question is in Github Issue #290"
|
||||
|
||||
- id: 7
|
||||
question: "Is there support for knowledge sets or document sets?"
|
||||
expected_answer: "This was requested and approved however it is not clear if the feature is implemented yet."
|
||||
notes: "This question is in Github Issue #338"
|
||||
|
||||
- id: 8
|
||||
question: "nginx returning 502"
|
||||
expected_answer: "Google OAuth must be configured for Danswer backend to work. A PR was created to fix it"
|
||||
notes: "This question is in Github Issue #260"
|
||||
|
||||
- id: 9
|
||||
question: "Why isn't GPT4All enabled by default"
|
||||
expected_answer: "There is no recent version of GPT4All that is compatible with M1 Mac."
|
||||
notes: "This question is in Github Issue #232 but also mentioned in several other places"
|
||||
|
||||
- id: 10
|
||||
question: "Why isn't GPT4All enabled by default"
|
||||
expected_answer: "There is no recent version of GPT4All that is compatible with M1 Mac."
|
||||
notes: "This question is in Github Issue #232 but also mentioned in several other places"
|
||||
|
||||
- id: 11
|
||||
question: "Why are the models warmed up on server start"
|
||||
expected_answer: "This ensures that the first indexing isn't really slow."
|
||||
notes: "This is in Github PR #333"
|
||||
|
||||
- id: 12
|
||||
question: "Why are the models warmed up on server start"
|
||||
expected_answer: "This ensures that the first indexing isn't really slow."
|
||||
notes: "This is in Github PR #333"
|
||||
|
||||
- id: 13
|
||||
question: "What text from the Alation Connector is used to generate the docs?"
|
||||
expected_answer: "Articles are used with the body contents. Schemas, Tables, and Columns use Description"
|
||||
notes: "This is in Github PR #161"
|
||||
|
||||
- id: 14
|
||||
question: "Does Danswer support PDFs in Google Drive?"
|
||||
expected_answer: "Yes"
|
||||
notes: "This question is in Slack, if the message expires due to using free slack version, the info may be gone as well"
|
||||
|
||||
- id: 15
|
||||
question: "I deleted a connector in Danswer but some deleted docs are still showing in search"
|
||||
expected_answer: "The issue was fixed via a code change, it should go away after pulling the latest code"
|
||||
notes: "This question is in Slack, if the message expires due to using free slack version, the info may be gone as well"
|
@ -1,251 +0,0 @@
|
||||
import argparse
|
||||
import builtins
|
||||
import json
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import TextIO
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.answering.prune_and_merge import reorder_sections
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.search.models import InferenceSection
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.utils.callbacks import MetricsHander
|
||||
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def redirect_print_to_file(file: TextIO) -> Any:
|
||||
original_print = builtins.print
|
||||
|
||||
def new_print(*args: Any, **kwargs: Any) -> Any:
|
||||
kwargs["file"] = file
|
||||
original_print(*args, **kwargs)
|
||||
|
||||
builtins.print = new_print
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
builtins.print = original_print
|
||||
|
||||
|
||||
def read_json(file_path: str) -> dict:
|
||||
with open(file_path, "r") as file:
|
||||
return json.load(file)
|
||||
|
||||
|
||||
def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str:
|
||||
words = s.split()
|
||||
|
||||
current_line: list[str] = []
|
||||
result_lines: list[str] = []
|
||||
current_length = 0
|
||||
for word in words:
|
||||
if len(word) > max_line_size:
|
||||
if current_line:
|
||||
result_lines.append(" ".join(current_line))
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
result_lines.append(word)
|
||||
continue
|
||||
|
||||
if current_length + len(word) + len(current_line) > max_line_size:
|
||||
result_lines.append(" ".join(current_line))
|
||||
current_line = []
|
||||
current_length = 0
|
||||
|
||||
current_line.append(word)
|
||||
current_length += len(word)
|
||||
|
||||
if current_line:
|
||||
result_lines.append(" ".join(current_line))
|
||||
|
||||
return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines)
|
||||
|
||||
|
||||
def get_search_results(
|
||||
query: str,
|
||||
) -> tuple[
|
||||
list[InferenceSection],
|
||||
RetrievalMetricsContainer | None,
|
||||
RerankMetricsContainer | None,
|
||||
]:
|
||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
llm, fast_llm = get_default_llms()
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
)
|
||||
|
||||
top_sections = search_pipeline.reranked_sections
|
||||
llm_section_selection = search_pipeline.section_relevance_list
|
||||
|
||||
return (
|
||||
reorder_sections(top_sections, llm_section_selection),
|
||||
retrieval_metrics.metrics,
|
||||
rerank_metrics.metrics,
|
||||
)
|
||||
|
||||
|
||||
def _print_retrieval_metrics(
|
||||
metrics_container: RetrievalMetricsContainer, show_all: bool = False
|
||||
) -> None:
|
||||
for ind, metric in enumerate(metrics_container.metrics):
|
||||
if not show_all and ind >= 10:
|
||||
break
|
||||
|
||||
if ind != 0:
|
||||
print() # for spacing purposes
|
||||
print(f"\tDocument: {metric.document_id}")
|
||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||
print(f"\tSection Start: {section_start}")
|
||||
print(f"\tSimilarity Distance Metric: {metric.score}")
|
||||
|
||||
|
||||
def _print_reranking_metrics(
|
||||
metrics_container: RerankMetricsContainer, show_all: bool = False
|
||||
) -> None:
|
||||
# Printing the raw scores as they're more informational than post-norm/boosting
|
||||
for ind, metric in enumerate(metrics_container.metrics):
|
||||
if not show_all and ind >= 10:
|
||||
break
|
||||
|
||||
if ind != 0:
|
||||
print() # for spacing purposes
|
||||
print(f"\tDocument: {metric.document_id}")
|
||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||
print(f"\tSection Start: {section_start}")
|
||||
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
|
||||
|
||||
|
||||
def calculate_score(
|
||||
log_prefix: str, chunk_ids: list[str], targets: list[str], max_chunks: int = 5
|
||||
) -> float:
|
||||
top_ids = chunk_ids[:max_chunks]
|
||||
matches = [top_id for top_id in top_ids if top_id in targets]
|
||||
print(f"{log_prefix} Hits: {len(matches)}/{len(targets)}", end="\t")
|
||||
return len(matches) / min(len(targets), max_chunks)
|
||||
|
||||
|
||||
def main(
|
||||
questions_json: str,
|
||||
output_file: str,
|
||||
show_details: bool,
|
||||
enable_llm: bool,
|
||||
stop_after: int,
|
||||
) -> None:
|
||||
questions_info = read_json(questions_json)
|
||||
|
||||
running_retrieval_score = 0.0
|
||||
running_rerank_score = 0.0
|
||||
running_llm_filter_score = 0.0
|
||||
|
||||
with open(output_file, "w") as outfile:
|
||||
with redirect_print_to_file(outfile):
|
||||
print("Running Document Retrieval Test\n")
|
||||
for ind, (question, targets) in enumerate(questions_info.items()):
|
||||
if ind >= stop_after:
|
||||
break
|
||||
|
||||
print(f"\n\nQuestion: {question}")
|
||||
|
||||
(
|
||||
top_sections,
|
||||
retrieval_metrics,
|
||||
rerank_metrics,
|
||||
) = get_search_results(query=question)
|
||||
|
||||
assert retrieval_metrics is not None and rerank_metrics is not None
|
||||
|
||||
retrieval_ids = [
|
||||
metric.document_id for metric in retrieval_metrics.metrics
|
||||
]
|
||||
retrieval_score = calculate_score("Retrieval", retrieval_ids, targets)
|
||||
running_retrieval_score += retrieval_score
|
||||
print(f"Average: {running_retrieval_score / (ind + 1)}")
|
||||
|
||||
rerank_ids = [metric.document_id for metric in rerank_metrics.metrics]
|
||||
rerank_score = calculate_score("Rerank", rerank_ids, targets)
|
||||
running_rerank_score += rerank_score
|
||||
print(f"Average: {running_rerank_score / (ind + 1)}")
|
||||
|
||||
llm_ids = [section.center_chunk.document_id for section in top_sections]
|
||||
llm_score = calculate_score("LLM Filter", llm_ids, targets)
|
||||
running_llm_filter_score += llm_score
|
||||
print(f"Average: {running_llm_filter_score / (ind + 1)}")
|
||||
|
||||
if show_details:
|
||||
print("\nRetrieval Metrics:")
|
||||
if retrieval_metrics is None:
|
||||
print("No Retrieval Metrics Available")
|
||||
else:
|
||||
_print_retrieval_metrics(retrieval_metrics)
|
||||
|
||||
print("\nReranking Metrics:")
|
||||
if rerank_metrics is None:
|
||||
print("No Reranking Metrics Available")
|
||||
else:
|
||||
_print_reranking_metrics(rerank_metrics)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"regression_questions_json",
|
||||
type=str,
|
||||
help="Path to the Questions JSON file.",
|
||||
default="./tests/regression/search_quality/test_questions.json",
|
||||
nargs="?",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_file",
|
||||
type=str,
|
||||
help="Path to the output results file.",
|
||||
default="./tests/regression/search_quality/regression_results.txt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--show_details",
|
||||
action="store_true",
|
||||
help="If set, show details of the retrieved chunks.",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_llm",
|
||||
action="store_true",
|
||||
help="If set, use LLM chunk filtering (this can get very expensive).",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_after",
|
||||
type=int,
|
||||
help="Stop processing after this many iterations.",
|
||||
default=100,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
main(
|
||||
args.regression_questions_json,
|
||||
args.output_file,
|
||||
args.show_details,
|
||||
args.enable_llm,
|
||||
args.stop_after,
|
||||
)
|
@ -179,9 +179,8 @@ services:
|
||||
- web_server
|
||||
environment:
|
||||
- DOMAIN=localhost
|
||||
ports:
|
||||
- "80:80"
|
||||
- "3000:80" # allow for localhost:3000 usage, since that is the norm
|
||||
ports:
|
||||
- "${NGINX_PORT:-3000}:80" # allow for localhost:3000 usage, since that is the norm
|
||||
volumes:
|
||||
- ../data/nginx:/etc/nginx/conf.d
|
||||
logging:
|
||||
|
Loading…
x
Reference in New Issue
Block a user