mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 20:24:32 +02:00
Search Regression Test and Save/Load State updates (#761)
This commit is contained in:
@@ -248,7 +248,9 @@ def answer_qa_query_stream(
|
|||||||
batch_offset=offset_count,
|
batch_offset=offset_count,
|
||||||
)
|
)
|
||||||
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
llm_relevance_filtering_response = LLMRelevanceFilterResponse(
|
||||||
relevant_chunk_indices=llm_chunks_indices
|
relevant_chunk_indices=[
|
||||||
|
index for index, value in enumerate(llm_chunk_selection) if value
|
||||||
|
]
|
||||||
).dict()
|
).dict()
|
||||||
yield get_json_line(llm_relevance_filtering_response)
|
yield get_json_line(llm_relevance_filtering_response)
|
||||||
|
|
||||||
|
@@ -322,7 +322,7 @@ def get_usable_chunks(
|
|||||||
def get_chunks_for_qa(
|
def get_chunks_for_qa(
|
||||||
chunks: list[InferenceChunk],
|
chunks: list[InferenceChunk],
|
||||||
llm_chunk_selection: list[bool],
|
llm_chunk_selection: list[bool],
|
||||||
token_limit: int = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
token_limit: int | None = NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||||
batch_offset: int = 0,
|
batch_offset: int = 0,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
@@ -353,13 +353,17 @@ def get_chunks_for_qa(
|
|||||||
token_count += chunk_token + 50
|
token_count += chunk_token + 50
|
||||||
|
|
||||||
# Always use at least 1 chunk
|
# Always use at least 1 chunk
|
||||||
if token_count <= token_limit or not latest_batch_indices:
|
if (
|
||||||
|
token_limit is None
|
||||||
|
or token_count <= token_limit
|
||||||
|
or not latest_batch_indices
|
||||||
|
):
|
||||||
latest_batch_indices.append(ind)
|
latest_batch_indices.append(ind)
|
||||||
current_chunk_unused = False
|
current_chunk_unused = False
|
||||||
else:
|
else:
|
||||||
current_chunk_unused = True
|
current_chunk_unused = True
|
||||||
|
|
||||||
if token_count >= token_limit:
|
if token_limit is not None and token_count >= token_limit:
|
||||||
if batch_index < batch_offset:
|
if batch_index < batch_offset:
|
||||||
batch_index += 1
|
batch_index += 1
|
||||||
if current_chunk_unused:
|
if current_chunk_unused:
|
||||||
|
@@ -540,7 +540,7 @@ def full_chunk_search_generator(
|
|||||||
if llm_filter_task_id
|
if llm_filter_task_id
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
if llm_chunk_selection:
|
if llm_chunk_selection is not None:
|
||||||
yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks]
|
yield [chunk.unique_id in llm_chunk_selection for chunk in retrieved_chunks]
|
||||||
else:
|
else:
|
||||||
yield [True for _ in reranked_chunks or retrieved_chunks]
|
yield [True for _ in reranked_chunks or retrieved_chunks]
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
# This file is purely for development use, not included in any builds
|
# This file is purely for development use, not included in any builds
|
||||||
|
# Remember to first to send over the schema information (run API Server)
|
||||||
import argparse
|
import argparse
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
@@ -19,25 +20,40 @@ from danswer.utils.logger import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def save_postgres(filename: str) -> None:
|
def save_postgres(filename: str, container_name: str) -> None:
|
||||||
logger.info("Attempting to take Postgres snapshot")
|
logger.info("Attempting to take Postgres snapshot")
|
||||||
cmd = f"pg_dump -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -F t {POSTGRES_DB} > {filename}"
|
cmd = f"docker exec {container_name} pg_dump -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -F t {POSTGRES_DB}"
|
||||||
|
with open(filename, "w") as file:
|
||||||
subprocess.run(
|
subprocess.run(
|
||||||
cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True
|
cmd,
|
||||||
|
shell=True,
|
||||||
|
check=True,
|
||||||
|
stdout=file,
|
||||||
|
text=True,
|
||||||
|
input=f"{POSTGRES_PASSWORD}\n",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_postgres(filename: str) -> None:
|
def load_postgres(filename: str, container_name: str) -> None:
|
||||||
logger.info("Attempting to load Postgres snapshot")
|
logger.info("Attempting to load Postgres snapshot")
|
||||||
try:
|
try:
|
||||||
alembic_cfg = Config("alembic.ini")
|
alembic_cfg = Config("alembic.ini")
|
||||||
command.upgrade(alembic_cfg, "head")
|
command.upgrade(alembic_cfg, "head")
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.info("Alembic upgrade failed, maybe already has run")
|
logger.error(f"Alembic upgrade failed: {e}")
|
||||||
cmd = f"pg_restore --clean -U {POSTGRES_USER} -h {POSTGRES_HOST} -p {POSTGRES_PORT} -W -d {POSTGRES_DB} -1 {filename}"
|
|
||||||
subprocess.run(
|
host_file_path = os.path.abspath(filename)
|
||||||
cmd, shell=True, check=True, input=f"{POSTGRES_PASSWORD}\n", text=True
|
|
||||||
|
copy_cmd = f"docker cp {host_file_path} {container_name}:/tmp/"
|
||||||
|
subprocess.run(copy_cmd, shell=True, check=True)
|
||||||
|
|
||||||
|
container_file_path = f"/tmp/{os.path.basename(filename)}"
|
||||||
|
|
||||||
|
restore_cmd = (
|
||||||
|
f"docker exec {container_name} pg_restore --clean -U {POSTGRES_USER} "
|
||||||
|
f"-h localhost -p {POSTGRES_PORT} -d {POSTGRES_DB} -1 -F t {container_file_path}"
|
||||||
)
|
)
|
||||||
|
subprocess.run(restore_cmd, shell=True, check=True)
|
||||||
|
|
||||||
|
|
||||||
def save_vespa(filename: str) -> None:
|
def save_vespa(filename: str) -> None:
|
||||||
@@ -85,6 +101,12 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--load", action="store_true", help="Load Danswer state from save directory."
|
"--load", action="store_true", help="Load Danswer state from save directory."
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--postgres_container_name",
|
||||||
|
type=str,
|
||||||
|
default="danswer-stack-relational_db-1",
|
||||||
|
help="Name of the postgres container to dump",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--checkpoint_dir",
|
"--checkpoint_dir",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -94,6 +116,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
checkpoint_dir = args.checkpoint_dir
|
checkpoint_dir = args.checkpoint_dir
|
||||||
|
postgres_container = args.postgres_container_name
|
||||||
|
|
||||||
if not os.path.exists(checkpoint_dir):
|
if not os.path.exists(checkpoint_dir):
|
||||||
os.makedirs(checkpoint_dir)
|
os.makedirs(checkpoint_dir)
|
||||||
@@ -102,8 +125,12 @@ if __name__ == "__main__":
|
|||||||
raise ValueError("Must specify --save or --load")
|
raise ValueError("Must specify --save or --load")
|
||||||
|
|
||||||
if args.load:
|
if args.load:
|
||||||
load_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar"))
|
load_postgres(
|
||||||
|
os.path.join(checkpoint_dir, "postgres_snapshot.tar"), postgres_container
|
||||||
|
)
|
||||||
load_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
|
load_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
|
||||||
else:
|
else:
|
||||||
save_postgres(os.path.join(checkpoint_dir, "postgres_snapshot.tar"))
|
save_postgres(
|
||||||
|
os.path.join(checkpoint_dir, "postgres_snapshot.tar"), postgres_container
|
||||||
|
)
|
||||||
save_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
|
save_vespa(os.path.join(checkpoint_dir, "vespa_snapshot.jsonl"))
|
||||||
|
@@ -8,6 +8,7 @@ from typing import TextIO
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
|
from danswer.direct_qa.qa_utils import get_chunks_for_qa
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
@@ -70,7 +71,6 @@ def get_search_results(
|
|||||||
query: str, db_session: Session
|
query: str, db_session: Session
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
list[InferenceChunk],
|
list[InferenceChunk],
|
||||||
list[bool],
|
|
||||||
RetrievalMetricsContainer | None,
|
RetrievalMetricsContainer | None,
|
||||||
RerankMetricsContainer | None,
|
RerankMetricsContainer | None,
|
||||||
]:
|
]:
|
||||||
@@ -89,7 +89,7 @@ def get_search_results(
|
|||||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||||
|
|
||||||
chunks, llm_filter, query_id = danswer_search(
|
top_chunks, llm_chunk_selection, query_id = danswer_search(
|
||||||
question=question,
|
question=question,
|
||||||
user=None,
|
user=None,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
@@ -99,16 +99,23 @@ def get_search_results(
|
|||||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
llm_chunks_indices = get_chunks_for_qa(
|
||||||
|
chunks=top_chunks,
|
||||||
|
llm_chunk_selection=llm_chunk_selection,
|
||||||
|
token_limit=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||||
|
|
||||||
return (
|
return (
|
||||||
chunks,
|
llm_chunks,
|
||||||
llm_filter,
|
|
||||||
retrieval_metrics.metrics,
|
retrieval_metrics.metrics,
|
||||||
rerank_metrics.metrics,
|
rerank_metrics.metrics,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _print_retrieval_metrics(
|
def _print_retrieval_metrics(
|
||||||
metrics_container: RetrievalMetricsContainer, show_all: bool
|
metrics_container: RetrievalMetricsContainer, show_all: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
for ind, metric in enumerate(metrics_container.metrics):
|
for ind, metric in enumerate(metrics_container.metrics):
|
||||||
if not show_all and ind >= 10:
|
if not show_all and ind >= 10:
|
||||||
@@ -117,14 +124,13 @@ def _print_retrieval_metrics(
|
|||||||
if ind != 0:
|
if ind != 0:
|
||||||
print() # for spacing purposes
|
print() # for spacing purposes
|
||||||
print(f"\tDocument: {metric.document_id}")
|
print(f"\tDocument: {metric.document_id}")
|
||||||
print(f"\tLink: {metric.first_link or 'NA'}")
|
|
||||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||||
print(f"\tSection Start: {section_start}")
|
print(f"\tSection Start: {section_start}")
|
||||||
print(f"\tSimilarity Distance Metric: {metric.score}")
|
print(f"\tSimilarity Distance Metric: {metric.score}")
|
||||||
|
|
||||||
|
|
||||||
def _print_reranking_metrics(
|
def _print_reranking_metrics(
|
||||||
metrics_container: RerankMetricsContainer, show_all: bool
|
metrics_container: RerankMetricsContainer, show_all: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
# Printing the raw scores as they're more informational than post-norm/boosting
|
# Printing the raw scores as they're more informational than post-norm/boosting
|
||||||
for ind, metric in enumerate(metrics_container.metrics):
|
for ind, metric in enumerate(metrics_container.metrics):
|
||||||
@@ -134,45 +140,81 @@ def _print_reranking_metrics(
|
|||||||
if ind != 0:
|
if ind != 0:
|
||||||
print() # for spacing purposes
|
print() # for spacing purposes
|
||||||
print(f"\tDocument: {metric.document_id}")
|
print(f"\tDocument: {metric.document_id}")
|
||||||
print(f"\tLink: {metric.first_link or 'NA'}")
|
|
||||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||||
print(f"\tSection Start: {section_start}")
|
print(f"\tSection Start: {section_start}")
|
||||||
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
|
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
|
||||||
|
|
||||||
|
|
||||||
def main(questions_json: str, output_file: str) -> None:
|
def calculate_score(
|
||||||
|
log_prefix: str, chunk_ids: list[str], targets: list[str], max_chunks: int = 5
|
||||||
|
) -> float:
|
||||||
|
top_ids = chunk_ids[:max_chunks]
|
||||||
|
matches = [top_id for top_id in top_ids if top_id in targets]
|
||||||
|
print(f"{log_prefix} Hits: {len(matches)}/{len(targets)}", end="\t")
|
||||||
|
return len(matches) / min(len(targets), max_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
questions_json: str, output_file: str, show_details: bool, stop_after: int
|
||||||
|
) -> None:
|
||||||
questions_info = read_json(questions_json)
|
questions_info = read_json(questions_json)
|
||||||
|
|
||||||
|
running_retrieval_score = 0.0
|
||||||
|
running_rerank_score = 0.0
|
||||||
|
running_llm_filter_score = 0.0
|
||||||
|
|
||||||
with open(output_file, "w") as outfile:
|
with open(output_file, "w") as outfile:
|
||||||
with redirect_print_to_file(outfile):
|
with redirect_print_to_file(outfile):
|
||||||
print("Running Document Retrieval Test\n")
|
print("Running Document Retrieval Test\n")
|
||||||
|
|
||||||
with Session(engine, expire_on_commit=False) as db_session:
|
with Session(engine, expire_on_commit=False) as db_session:
|
||||||
for question, targets in questions_info.items():
|
for ind, (question, targets) in enumerate(questions_info.items()):
|
||||||
print(f"Question: {question}")
|
if ind >= stop_after:
|
||||||
|
break
|
||||||
|
|
||||||
|
print(f"\n\nQuestion: {question}")
|
||||||
|
|
||||||
(
|
(
|
||||||
chunks,
|
top_chunks,
|
||||||
llm_filters,
|
|
||||||
retrieval_metrics,
|
retrieval_metrics,
|
||||||
rerank_metrics,
|
rerank_metrics,
|
||||||
) = get_search_results(query=question, db_session=db_session)
|
) = get_search_results(query=question, db_session=db_session)
|
||||||
|
|
||||||
|
assert retrieval_metrics is not None and rerank_metrics is not None
|
||||||
|
|
||||||
|
retrieval_ids = [
|
||||||
|
metric.document_id for metric in retrieval_metrics.metrics
|
||||||
|
]
|
||||||
|
retrieval_score = calculate_score(
|
||||||
|
"Retrieval", retrieval_ids, targets
|
||||||
|
)
|
||||||
|
running_retrieval_score += retrieval_score
|
||||||
|
print(f"Average: {running_retrieval_score / (ind + 1)}")
|
||||||
|
|
||||||
|
rerank_ids = [
|
||||||
|
metric.document_id for metric in rerank_metrics.metrics
|
||||||
|
]
|
||||||
|
rerank_score = calculate_score("Rerank", rerank_ids, targets)
|
||||||
|
running_rerank_score += rerank_score
|
||||||
|
print(f"Average: {running_rerank_score / (ind + 1)}")
|
||||||
|
|
||||||
|
llm_ids = [chunk.document_id for chunk in top_chunks]
|
||||||
|
llm_score = calculate_score("LLM Filter", llm_ids, targets)
|
||||||
|
running_llm_filter_score += llm_score
|
||||||
|
print(f"Average: {running_llm_filter_score / (ind + 1)}")
|
||||||
|
|
||||||
|
if show_details:
|
||||||
print("\nRetrieval Metrics:")
|
print("\nRetrieval Metrics:")
|
||||||
if retrieval_metrics is None:
|
if retrieval_metrics is None:
|
||||||
print("No Retrieval Metrics Available")
|
print("No Retrieval Metrics Available")
|
||||||
else:
|
else:
|
||||||
_print_retrieval_metrics(
|
_print_retrieval_metrics(retrieval_metrics)
|
||||||
retrieval_metrics, show_all=args.all_results
|
|
||||||
)
|
|
||||||
|
|
||||||
print("\nReranking Metrics:")
|
print("\nReranking Metrics:")
|
||||||
if rerank_metrics is None:
|
if rerank_metrics is None:
|
||||||
print("No Reranking Metrics Available")
|
print("No Reranking Metrics Available")
|
||||||
else:
|
else:
|
||||||
_print_reranking_metrics(
|
_print_reranking_metrics(rerank_metrics)
|
||||||
rerank_metrics, show_all=args.all_results
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -190,6 +232,23 @@ if __name__ == "__main__":
|
|||||||
help="Path to the output results file.",
|
help="Path to the output results file.",
|
||||||
default="./tests/regression/search_quality/regression_results.txt",
|
default="./tests/regression/search_quality/regression_results.txt",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--show_details",
|
||||||
|
action="store_true",
|
||||||
|
help="If set, show details of the retrieved chunks.",
|
||||||
|
default=True,
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--stop_after",
|
||||||
|
type=int,
|
||||||
|
help="Stop processing after this many iterations.",
|
||||||
|
default=10,
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
main(args.regression_questions_json, args.output_file)
|
main(
|
||||||
|
args.regression_questions_json,
|
||||||
|
args.output_file,
|
||||||
|
args.show_details,
|
||||||
|
args.stop_after,
|
||||||
|
)
|
||||||
|
Reference in New Issue
Block a user