mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
Option to turn off LLM for eval script (#769)
This commit is contained in:
@@ -25,7 +25,13 @@ engine = get_sqlalchemy_engine()
|
||||
@contextmanager
|
||||
def redirect_print_to_file(file: TextIO) -> Any:
|
||||
original_print = builtins.print
|
||||
builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
|
||||
|
||||
def new_print(*args, **kwargs):
|
||||
kwargs["file"] = file
|
||||
original_print(*args, **kwargs)
|
||||
|
||||
builtins.print = new_print
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
@@ -68,7 +74,7 @@ def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str
|
||||
|
||||
|
||||
def get_search_results(
|
||||
query: str, db_session: Session
|
||||
query: str, enable_llm: bool, db_session: Session
|
||||
) -> tuple[
|
||||
list[InferenceChunk],
|
||||
RetrievalMetricsContainer | None,
|
||||
@@ -95,6 +101,7 @@ def get_search_results(
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
bypass_acl=True,
|
||||
skip_llm_chunk_filter=not enable_llm,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
)
|
||||
@@ -155,7 +162,11 @@ def calculate_score(
|
||||
|
||||
|
||||
def main(
|
||||
questions_json: str, output_file: str, show_details: bool, stop_after: int
|
||||
questions_json: str,
|
||||
output_file: str,
|
||||
show_details: bool,
|
||||
enable_llm: bool,
|
||||
stop_after: int,
|
||||
) -> None:
|
||||
questions_info = read_json(questions_json)
|
||||
|
||||
@@ -178,7 +189,9 @@ def main(
|
||||
top_chunks,
|
||||
retrieval_metrics,
|
||||
rerank_metrics,
|
||||
) = get_search_results(query=question, db_session=db_session)
|
||||
) = get_search_results(
|
||||
query=question, enable_llm=enable_llm, db_session=db_session
|
||||
)
|
||||
|
||||
assert retrieval_metrics is not None and rerank_metrics is not None
|
||||
|
||||
@@ -198,10 +211,11 @@ def main(
|
||||
running_rerank_score += rerank_score
|
||||
print(f"Average: {running_rerank_score / (ind + 1)}")
|
||||
|
||||
llm_ids = [chunk.document_id for chunk in top_chunks]
|
||||
llm_score = calculate_score("LLM Filter", llm_ids, targets)
|
||||
running_llm_filter_score += llm_score
|
||||
print(f"Average: {running_llm_filter_score / (ind + 1)}")
|
||||
if enable_llm:
|
||||
llm_ids = [chunk.document_id for chunk in top_chunks]
|
||||
llm_score = calculate_score("LLM Filter", llm_ids, targets)
|
||||
running_llm_filter_score += llm_score
|
||||
print(f"Average: {running_llm_filter_score / (ind + 1)}")
|
||||
|
||||
if show_details:
|
||||
print("\nRetrieval Metrics:")
|
||||
@@ -236,13 +250,19 @@ if __name__ == "__main__":
|
||||
"--show_details",
|
||||
action="store_true",
|
||||
help="If set, show details of the retrieved chunks.",
|
||||
default=True,
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_llm",
|
||||
action="store_true",
|
||||
help="If set, use LLM chunk filtering (this can get very expensive).",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--stop_after",
|
||||
type=int,
|
||||
help="Stop processing after this many iterations.",
|
||||
default=10,
|
||||
default=100,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -250,5 +270,6 @@ if __name__ == "__main__":
|
||||
args.regression_questions_json,
|
||||
args.output_file,
|
||||
args.show_details,
|
||||
args.enable_llm,
|
||||
args.stop_after,
|
||||
)
|
||||
|
Reference in New Issue
Block a user