Option to turn off LLM for eval script (#769)

This commit is contained in:
Yuhong Sun
2023-11-26 15:31:03 -08:00
committed by GitHub
parent 65d38ac8c3
commit 2665bff78e

View File

@@ -25,7 +25,13 @@ engine = get_sqlalchemy_engine()
@contextmanager
def redirect_print_to_file(file: TextIO) -> Any:
original_print = builtins.print
builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
def new_print(*args, **kwargs):
kwargs["file"] = file
original_print(*args, **kwargs)
builtins.print = new_print
try:
yield
finally:
@@ -68,7 +74,7 @@ def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str
def get_search_results(
query: str, db_session: Session
query: str, enable_llm: bool, db_session: Session
) -> tuple[
list[InferenceChunk],
RetrievalMetricsContainer | None,
@@ -95,6 +101,7 @@ def get_search_results(
db_session=db_session,
document_index=get_default_document_index(),
bypass_acl=True,
skip_llm_chunk_filter=not enable_llm,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
)
@@ -155,7 +162,11 @@ def calculate_score(
def main(
questions_json: str, output_file: str, show_details: bool, stop_after: int
questions_json: str,
output_file: str,
show_details: bool,
enable_llm: bool,
stop_after: int,
) -> None:
questions_info = read_json(questions_json)
@@ -178,7 +189,9 @@ def main(
top_chunks,
retrieval_metrics,
rerank_metrics,
) = get_search_results(query=question, db_session=db_session)
) = get_search_results(
query=question, enable_llm=enable_llm, db_session=db_session
)
assert retrieval_metrics is not None and rerank_metrics is not None
@@ -198,10 +211,11 @@ def main(
running_rerank_score += rerank_score
print(f"Average: {running_rerank_score / (ind + 1)}")
llm_ids = [chunk.document_id for chunk in top_chunks]
llm_score = calculate_score("LLM Filter", llm_ids, targets)
running_llm_filter_score += llm_score
print(f"Average: {running_llm_filter_score / (ind + 1)}")
if enable_llm:
llm_ids = [chunk.document_id for chunk in top_chunks]
llm_score = calculate_score("LLM Filter", llm_ids, targets)
running_llm_filter_score += llm_score
print(f"Average: {running_llm_filter_score / (ind + 1)}")
if show_details:
print("\nRetrieval Metrics:")
@@ -236,13 +250,19 @@ if __name__ == "__main__":
"--show_details",
action="store_true",
help="If set, show details of the retrieved chunks.",
default=True,
default=False,
)
parser.add_argument(
"--enable_llm",
action="store_true",
help="If set, use LLM chunk filtering (this can get very expensive).",
default=False,
)
parser.add_argument(
"--stop_after",
type=int,
help="Stop processing after this many iterations.",
default=10,
default=100,
)
args = parser.parse_args()
@@ -250,5 +270,6 @@ if __name__ == "__main__":
args.regression_questions_json,
args.output_file,
args.show_details,
args.enable_llm,
args.stop_after,
)