import argparse import builtins from contextlib import contextmanager from datetime import datetime from typing import Any from typing import TextIO import yaml from sqlalchemy.orm import Session from danswer.access.access import get_acl_for_user from danswer.db.engine import get_sqlalchemy_engine from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.models import LLMMetricsContainer from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer from danswer.server.models import IndexFilters from danswer.server.models import QuestionRequest from danswer.utils.callbacks import MetricsHander engine = get_sqlalchemy_engine() @contextmanager def redirect_print_to_file(file: TextIO) -> Any: original_print = builtins.print builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs) try: yield finally: builtins.print = original_print def load_yaml(filepath: str) -> dict: with open(filepath, "r") as file: data = yaml.safe_load(file) return data def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str: words = s.split() current_line: list[str] = [] result_lines: list[str] = [] current_length = 0 for word in words: if len(word) > max_line_size: if current_line: result_lines.append(" ".join(current_line)) current_line = [] current_length = 0 result_lines.append(word) continue if current_length + len(word) + len(current_line) > max_line_size: result_lines.append(" ".join(current_line)) current_line = [] current_length = 0 current_line.append(word) current_length += len(word) if current_line: result_lines.append(" ".join(current_line)) return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines) def get_answer_for_question( query: str, db_session: Session ) -> tuple[ str | None, RetrievalMetricsContainer | None, RerankMetricsContainer | None, LLMMetricsContainer | None, ]: filters = IndexFilters( source_type=None, document_set=None, time_cutoff=None, access_control_list=list(get_acl_for_user(user=None)), ) question = QuestionRequest( query=query, collection="danswer_index", use_keyword=False, filters=filters, enable_auto_detect_filters=False, offset=None, ) retrieval_metrics = MetricsHander[RetrievalMetricsContainer]() rerank_metrics = MetricsHander[RerankMetricsContainer]() llm_metrics = MetricsHander[LLMMetricsContainer]() answer = answer_qa_query( question=question, user=None, db_session=db_session, answer_generation_timeout=100, real_time_flow=False, enable_reflexion=False, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric, llm_metrics_callback=llm_metrics.record_metric, ) return ( answer.answer, retrieval_metrics.metrics, rerank_metrics.metrics, llm_metrics.metrics, ) def _print_retrieval_metrics( metrics_container: RetrievalMetricsContainer, show_all: bool ) -> None: for ind, metric in enumerate(metrics_container.metrics): if not show_all and ind >= 10: break if ind != 0: print() # for spacing purposes print(f"\tDocument: {metric.document_id}") print(f"\tLink: {metric.first_link or 'NA'}") section_start = metric.chunk_content_start.replace("\n", " ") print(f"\tSection Start: {section_start}") print(f"\tSimilarity Distance Metric: {metric.score}") def _print_reranking_metrics( metrics_container: RerankMetricsContainer, show_all: bool ) -> None: # Printing the raw scores as they're more informational than post-norm/boosting for ind, metric in enumerate(metrics_container.metrics): if not show_all and ind >= 10: break if ind != 0: print() # for spacing purposes print(f"\tDocument: {metric.document_id}") print(f"\tLink: {metric.first_link or 'NA'}") section_start = metric.chunk_content_start.replace("\n", " ") print(f"\tSection Start: {section_start}") print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}") def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None: print(f"\tPrompt Tokens: {metrics_container.prompt_tokens}") print(f"\tResponse Tokens: {metrics_container.response_tokens}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "regression_yaml", type=str, help="Path to the Questions YAML file.", default="./tests/regression/answer_quality/sample_questions.yaml", nargs="?", ) parser.add_argument( "--real-time", action="store_true", help="Set to use the real-time flow." ) parser.add_argument( "--discard-metrics", action="store_true", help="Set to not include metrics on search, rerank, and token counts.", ) parser.add_argument( "--all-results", action="store_true", help="Set to not include more than the 10 top sections for search and reranking metrics.", ) parser.add_argument( "--output", type=str, help="Path to the output results file.", default="./tests/regression/answer_quality/regression_results.txt", ) args = parser.parse_args() questions_data = load_yaml(args.regression_yaml) with open(args.output, "w") as outfile: with redirect_print_to_file(outfile): print("Running Question Answering Flow") print( "Note that running metrics requires tokenizing all " "prompts/returns and slightly slows down inference." ) print( "Also note that the text embedding model (bi-encoder) currently used is trained for " "relative distances, not absolute distances. Therefore cosine similarity values may all be > 0.5 " "even for poor matches" ) with Session(engine, expire_on_commit=False) as db_session: for sample in questions_data["questions"]: print( f"Running Test for Question {sample['id']}: {sample['question']}" ) start_time = datetime.now() ( answer, retrieval_metrics, rerank_metrics, llm_metrics, ) = get_answer_for_question(sample["question"], db_session) end_time = datetime.now() print(f"====Duration: {end_time - start_time}====") print(f"Question {sample['id']}:") print(f'\t{sample["question"]}') print("\nApproximate Expected Answer:") print(f'\t{sample["expected_answer"]}') print("\nActual Answer:") print( word_wrap(answer) if answer else "\tFailed, either crashed or refused to answer." ) if not args.discard_metrics: print("\nLLM Tokens Usage:") if llm_metrics is None: print("No LLM Metrics Available") else: _print_llm_metrics(llm_metrics) print("\nRetrieval Metrics:") if retrieval_metrics is None: print("No Retrieval Metrics Available") else: _print_retrieval_metrics( retrieval_metrics, show_all=args.all_results ) print("\nReranking Metrics:") if rerank_metrics is None: print("No Reranking Metrics Available") else: _print_reranking_metrics( rerank_metrics, show_all=args.all_results ) print("\n\n", flush=True)