Add Metrics to Regression Test (#470)

This commit is contained in:
Yuhong Sun 2023-09-20 20:42:02 -07:00 committed by GitHub
parent 4912beb283
commit b416c85f0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 370 additions and 50 deletions

View File

@ -1,3 +1,5 @@
from collections.abc import Callable
from sqlalchemy.orm import Session
from danswer.chunking.models import InferenceChunk
@ -12,10 +14,13 @@ from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
@ -37,6 +42,10 @@ def answer_qa_query(
answer_generation_timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
enable_reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> QAResponse:
query = question.query
filters = question.filters
@ -59,12 +68,21 @@ def answer_qa_query(
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, get_default_document_index()
query,
user_id,
filters,
get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, get_default_document_index()
query,
user_id,
filters,
get_default_document_index(),
retrieval_metrics_callback=retrieval_metrics_callback,
rerank_metrics_callback=rerank_metrics_callback,
)
if not ranked_chunks:
return QAResponse(
@ -126,11 +144,13 @@ def answer_qa_query(
error_msg = None
try:
d_answer, quotes = qa_model.answer_question(query, usable_chunks)
d_answer, quotes = qa_model.answer_question(
query, usable_chunks, metrics_callback=llm_metrics_callback
)
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
d_answer, quotes = None, None
error_msg = f"Error occurred in call to LLM - {e}"
error_msg = f"Error occurred in call to LLM - {e}" # Used in the QAResponse
if not real_time_flow and enable_reflexion and d_answer is not None:
valid = False

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from typing import Any
from danswer.chunking.models import InferenceChunk
@ -6,6 +7,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
@ -83,7 +85,10 @@ class GPT4AllCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
@ -100,8 +105,7 @@ class GPT4AllCompletionQA(QAModel):
logger.debug(model_output)
answer, quotes = process_answer(model_output, context_docs)
return answer, quotes
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
@ -148,7 +152,10 @@ class GPT4AllChatCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
@ -171,8 +178,7 @@ class GPT4AllChatCompletionQA(QAModel):
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]

View File

@ -1,3 +1,4 @@
from collections.abc import Callable
from typing import Any
from huggingface_hub import InferenceClient # type:ignore
@ -9,6 +10,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import FreeformProcessor
from danswer.direct_qa.qa_prompts import JsonChatProcessor
@ -49,7 +51,10 @@ class HuggingFaceCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
@ -62,8 +67,7 @@ class HuggingFaceCompletionQA(QAModel):
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
@ -163,7 +167,10 @@ class HuggingFaceChatCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
model_output = self._get_hf_model_output(query, context_docs)

View File

@ -1,8 +1,10 @@
import abc
from collections.abc import Callable
from collections.abc import Generator
from dataclasses import dataclass
from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.models import LLMMetricsContainer
@dataclass
@ -44,6 +46,7 @@ class DanswerQuotes:
quotes: list[DanswerQuote]
# Final int is for number of output tokens
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
AnswerQuestionStreamReturn = Generator[
DanswerAnswerPiece | DanswerQuotes | None, None, None
@ -66,6 +69,7 @@ class QAModel:
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
raise NotImplementedError

View File

@ -1,4 +1,5 @@
import re
from collections.abc import Callable
from transformers import pipeline # type:ignore
from transformers import QuestionAnsweringPipeline # type:ignore
@ -12,6 +13,7 @@ from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@ -104,7 +106,10 @@ class TransformerQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
danswer_quotes: list[DanswerQuote] = []
d_answers: list[str] = []

View File

@ -0,0 +1,6 @@
from pydantic import BaseModel
class LLMMetricsContainer(BaseModel):
prompt_tokens: int
response_tokens: int

View File

@ -24,6 +24,7 @@ from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
@ -142,7 +143,10 @@ class OpenAICompletionQA(OpenAIQAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
@ -168,8 +172,7 @@ class OpenAICompletionQA(OpenAIQAModel):
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
return process_answer(model_output, context_docs)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]

View File

@ -1,6 +1,7 @@
import abc
import json
import re
from collections.abc import Callable
from collections.abc import Iterator
from copy import copy
@ -21,13 +22,16 @@ from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.llm.llm import LLM
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import str_prompt_to_langchain_prompt
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_up_code_blocks
@ -245,11 +249,32 @@ class QABlock(QAModel):
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
) -> AnswerQuestionReturn:
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
model_out = self._llm.invoke(prompt)
if metrics_callback is not None:
prompt_tokens = sum(
[
check_number_of_tokens(
text=p.content, encode_fn=get_default_llm_tokenizer()
)
for p in prompt
]
)
response_tokens = check_number_of_tokens(
text=model_out, encode_fn=get_default_llm_tokenizer()
)
metrics_callback(
LLMMetricsContainer(
prompt_tokens=prompt_tokens, response_tokens=response_tokens
)
)
return self._qa_handler.process_llm_output(model_out, trimmed_context_docs)
def answer_question_stream(

View File

@ -1,5 +1,6 @@
import abc
import json
from collections.abc import Callable
from collections.abc import Generator
import requests
@ -15,6 +16,7 @@ from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
@ -234,7 +236,10 @@ class RequestCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
self,
query: str,
context_docs: list[InferenceChunk],
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
) -> AnswerQuestionReturn:
model_api_response = self._get_request_response(
query, context_docs, stream=False
@ -245,8 +250,7 @@ class RequestCompletionQA(QAModel):
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
return process_answer(model_output, context_docs)
def answer_question_stream(
self,

View File

@ -1,5 +1,6 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
import tiktoken
from langchain.prompts.base import StringPromptValue
@ -17,6 +18,16 @@ from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
_LLM_TOKENIZER: Callable[[str], Any] | None = None
def get_default_llm_tokenizer() -> Callable:
global _LLM_TOKENIZER
if _LLM_TOKENIZER is None:
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode
return _LLM_TOKENIZER
def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
if (
msg.message_type == MessageType.SYSTEM

View File

@ -1,4 +1,5 @@
import json
from collections.abc import Callable
from uuid import UUID
from nltk.corpus import stopwords # type:ignore
@ -9,6 +10,9 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import IndexFilter
from danswer.search.models import ChunkMetric
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@ -40,13 +44,31 @@ def retrieve_keyword_documents(
filters: list[IndexFilter] | None,
datastore: DocumentIndex,
num_hits: int = NUM_RETURNED_HITS,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
) -> list[InferenceChunk] | None:
edited_query = query_processing(query)
top_chunks = datastore.keyword_retrieval(edited_query, user_id, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
logger.warning(
f"Keyword search returned no results - Filters: {filters_log_msg}\tEdited Query: {edited_query}"
)
return None
if retrieval_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in top_chunks
]
retrieval_metrics_callback(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
)
return top_chunks

View File

@ -1,9 +1,16 @@
from enum import Enum
from pydantic import BaseModel
from danswer.chunking.models import DocAwareChunk
from danswer.chunking.models import IndexChunk
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
)
class SearchType(str, Enum):
KEYWORD = "keyword" # May be better to also try keyword search if Semantic (AI Search) is on
SEMANTIC = "semantic" # Really should try Semantic (AI Search) if keyword is on
@ -17,3 +24,22 @@ class QueryFlow(str, Enum):
class Embedder:
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
raise NotImplementedError
class ChunkMetric(BaseModel):
document_id: str
chunk_content_start: str
first_link: str | None
score: float
class RetrievalMetricsContainer(BaseModel):
keyword_search: bool # False for Vector Search
metrics: list[ChunkMetric] # This contains the scores for retrieval as well
class RerankMetricsContainer(BaseModel):
"""The score held by this is the un-boosted, averaged score of the ensemble cross-encoders"""
metrics: list[ChunkMetric]
raw_similarity_scores: list[float]

View File

@ -1,4 +1,5 @@
import json
from collections.abc import Callable
from uuid import UUID
import numpy
@ -18,7 +19,11 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.datastores.datastore_utils import translate_boost_count_to_multiplier
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import IndexFilter
from danswer.search.models import ChunkMetric
from danswer.search.models import Embedder
from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.search_utils import get_default_embedding_model
from danswer.search.search_utils import get_default_reranking_model_ensemble
from danswer.server.models import SearchDoc
@ -55,6 +60,7 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
def semantic_reranking(
query: str,
chunks: list[InferenceChunk],
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> list[InferenceChunk]:
model_max = 12 # These are just based on observations from model selection
model_min = -12
@ -64,6 +70,8 @@ def semantic_reranking(
for encoder in cross_encoders
]
raw_sim_scores = sum(sim_scores) / len(sim_scores)
cross_models_min = numpy.min(sim_scores)
shifted_sim_scores = sum(
@ -75,9 +83,9 @@ def semantic_reranking(
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
model_max - model_min
)
scored_results = list(zip(normalized_b_s_scores, chunks))
scored_results = list(zip(normalized_b_s_scores, raw_sim_scores, chunks))
scored_results.sort(key=lambda x: x[0], reverse=True)
ranked_sim_scores, ranked_chunks = zip(*scored_results)
ranked_sim_scores, ranked_raw_scores, ranked_chunks = zip(*scored_results)
logger.debug(f"Reranked similarity scores: {ranked_sim_scores}")
@ -86,6 +94,23 @@ def semantic_reranking(
for ind, chunk in enumerate(ranked_chunks):
chunk.score = ranked_sim_scores[ind]
if rerank_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in ranked_chunks
]
rerank_metrics_callback(
RerankMetricsContainer(
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores
)
)
return list(ranked_chunks)
@ -97,6 +122,9 @@ def retrieve_ranked_documents(
datastore: DocumentIndex,
num_hits: int = NUM_RETURNED_HITS,
num_rerank: int = NUM_RERANKED_RESULTS,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
"""Uses vector similarity to fetch the top num_hits document chunks with a distance cutoff.
Reranks the top num_rerank out of those (instead of all due to latency)"""
@ -108,7 +136,24 @@ def retrieve_ranked_documents(
)
return None, None
logger.debug(top_chunks)
ranked_chunks = semantic_reranking(query, top_chunks[:num_rerank])
if retrieval_metrics_callback is not None:
chunk_metrics = [
ChunkMetric(
document_id=chunk.document_id,
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
first_link=chunk.source_links[0] if chunk.source_links else None,
score=chunk.score if chunk.score is not None else 0,
)
for chunk in top_chunks
]
retrieval_metrics_callback(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
)
ranked_chunks = semantic_reranking(
query, top_chunks[:num_rerank], rerank_metrics_callback=rerank_metrics_callback
)
top_docs = [
ranked_chunk.source_links[0]

View File

@ -0,0 +1,12 @@
from typing import Generic
from typing import TypeVar
T = TypeVar("T")
class MetricsHander(Generic[T]):
def __init__(self) -> None:
self.metrics: T | None = None
def record_metric(self, metrics: T) -> None:
self.metrics = metrics

View File

@ -1,24 +1,42 @@
import argparse
import builtins
from contextlib import contextmanager
from datetime import datetime
from typing import Any
from typing import TextIO
import yaml
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.server.models import QuestionRequest
from danswer.utils.callbacks import MetricsHander
engine = get_sqlalchemy_engine()
@contextmanager
def redirect_print_to_file(file: TextIO) -> Any:
original_print = builtins.print
builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
try:
yield
finally:
builtins.print = original_print
def load_yaml(filepath: str) -> dict:
with open(filepath, "r") as file:
data = yaml.safe_load(file)
return data
def word_wrap(s: str, max_line_size: int = 120) -> str:
def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str:
words = s.split()
current_line: list[str] = []
@ -45,27 +63,85 @@ def word_wrap(s: str, max_line_size: int = 120) -> str:
if current_line:
result_lines.append(" ".join(current_line))
return "\n".join(result_lines)
return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines)
def get_answer_for_question(query: str, db_session: Session) -> str | None:
def get_answer_for_question(
query: str, db_session: Session
) -> tuple[
str | None,
RetrievalMetricsContainer | None,
RerankMetricsContainer | None,
LLMMetricsContainer | None,
]:
question = QuestionRequest(
query=query,
collection="danswer_index",
use_keyword=None,
use_keyword=False,
filters=None,
offset=None,
)
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
rerank_metrics = MetricsHander[RerankMetricsContainer]()
llm_metrics = MetricsHander[LLMMetricsContainer]()
answer = answer_qa_query(
question=question,
user=None,
db_session=db_session,
answer_generation_timeout=100,
real_time_flow=False,
enable_reflexion=False,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,
llm_metrics_callback=llm_metrics.record_metric,
)
return answer.answer
return (
answer.answer,
retrieval_metrics.metrics,
rerank_metrics.metrics,
llm_metrics.metrics,
)
def _print_retrieval_metrics(
metrics_container: RetrievalMetricsContainer, show_all: bool
) -> None:
for ind, metric in enumerate(metrics_container.metrics):
if not show_all and ind >= 10:
break
if ind != 0:
print() # for spacing purposes
print(f"\tDocument: {metric.document_id}")
print(f"\tLink: {metric.first_link or 'NA'}")
section_start = metric.chunk_content_start.replace("\n", " ")
print(f"\tSection Start: {section_start}")
print(f"\tSimilarity Distance Metric: {metric.score}")
def _print_reranking_metrics(
metrics_container: RerankMetricsContainer, show_all: bool
) -> None:
# Printing the raw scores as they're more informational than post-norm/boosting
for ind, metric in enumerate(metrics_container.metrics):
if not show_all and ind >= 10:
break
if ind != 0:
print() # for spacing purposes
print(f"\tDocument: {metric.document_id}")
print(f"\tLink: {metric.first_link or 'NA'}")
section_start = metric.chunk_content_start.replace("\n", " ")
print(f"\tSection Start: {section_start}")
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None:
print(f"\tPrompt Tokens: {metrics_container.prompt_tokens}")
print(f"\tResponse Tokens: {metrics_container.response_tokens}")
if __name__ == "__main__":
@ -80,6 +156,16 @@ if __name__ == "__main__":
parser.add_argument(
"--real-time", action="store_true", help="Set to use the real-time flow."
)
parser.add_argument(
"--discard-metrics",
action="store_true",
help="Set to not include metrics on search, rerank, and token counts.",
)
parser.add_argument(
"--all-results",
action="store_true",
help="Set to not include more than the 10 top sections for search and reranking metrics.",
)
parser.add_argument(
"--output",
type=str,
@ -91,27 +177,65 @@ if __name__ == "__main__":
questions_data = load_yaml(args.regression_yaml)
with open(args.output, "w") as outfile:
print("Running Question Answering Flow", file=outfile)
with redirect_print_to_file(outfile):
print("Running Question Answering Flow")
print(
"Note that running metrics requires tokenizing all "
"prompts/returns and slightly slows down inference."
)
print(
"Also note that the text embedding model (bi-encoder) currently used is trained for "
"relative distances, not absolute distances. Therefore cosine similarity values may all be > 0.5 "
"even for poor matches"
)
with Session(engine, expire_on_commit=False) as db_session:
for sample in questions_data["questions"]:
# This line goes to stdout to track progress
print(f"Running Test for Question {sample['id']}: {sample['question']}")
with Session(engine, expire_on_commit=False) as db_session:
for sample in questions_data["questions"]:
print(
f"Running Test for Question {sample['id']}: {sample['question']}"
)
start_time = datetime.now()
answer = get_answer_for_question(sample["question"], db_session)
end_time = datetime.now()
start_time = datetime.now()
(
answer,
retrieval_metrics,
rerank_metrics,
llm_metrics,
) = get_answer_for_question(sample["question"], db_session)
end_time = datetime.now()
print(f"====Duration: {end_time - start_time}====", file=outfile)
print(f"Question {sample['id']}:", file=outfile)
print(sample["question"], file=outfile)
print("\nExpected Answer:", file=outfile)
print(sample["expected_answer"], file=outfile)
print("\nActual Answer:", file=outfile)
print(
word_wrap(answer)
if answer
else "Failed, either crashed or refused to answer.",
file=outfile,
)
print("\n\n", file=outfile, flush=True)
print(f"====Duration: {end_time - start_time}====")
print(f"Question {sample['id']}:")
print(f'\t{sample["question"]}')
print("\nApproximate Expected Answer:")
print(f'\t{sample["expected_answer"]}')
print("\nActual Answer:")
print(
word_wrap(answer)
if answer
else "\tFailed, either crashed or refused to answer."
)
if not args.discard_metrics:
print("\nLLM Tokens Usage:")
if llm_metrics is None:
print("No LLM Metrics Available")
else:
_print_llm_metrics(llm_metrics)
print("\nRetrieval Metrics:")
if retrieval_metrics is None:
print("No Retrieval Metrics Available")
else:
_print_retrieval_metrics(
retrieval_metrics, show_all=args.all_results
)
print("\nReranking Metrics:")
if rerank_metrics is None:
print("No Reranking Metrics Available")
else:
_print_reranking_metrics(
rerank_metrics, show_all=args.all_results
)
print("\n\n", flush=True)