mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 11:58:34 +02:00
Add Metrics to Regression Test (#470)
This commit is contained in:
parent
4912beb283
commit
b416c85f0f
@ -1,3 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
@ -12,10 +14,13 @@ from danswer.db.models import User
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.semantic_search import chunks_to_search_docs
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
@ -37,6 +42,10 @@ def answer_qa_query(
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
real_time_flow: bool = True,
|
||||
enable_reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
llm_metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> QAResponse:
|
||||
query = question.query
|
||||
filters = question.filters
|
||||
@ -59,12 +68,21 @@ def answer_qa_query(
|
||||
user_id = None if user is None else user.id
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query, user_id, filters, get_default_document_index()
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
else:
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query, user_id, filters, get_default_document_index()
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
get_default_document_index(),
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return QAResponse(
|
||||
@ -126,11 +144,13 @@ def answer_qa_query(
|
||||
|
||||
error_msg = None
|
||||
try:
|
||||
d_answer, quotes = qa_model.answer_question(query, usable_chunks)
|
||||
d_answer, quotes = qa_model.answer_question(
|
||||
query, usable_chunks, metrics_callback=llm_metrics_callback
|
||||
)
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
d_answer, quotes = None, None
|
||||
error_msg = f"Error occurred in call to LLM - {e}"
|
||||
error_msg = f"Error occurred in call to LLM - {e}" # Used in the QAResponse
|
||||
|
||||
if not real_time_flow and enable_reflexion and d_answer is not None:
|
||||
valid = False
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
@ -6,6 +7,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
|
||||
@ -83,7 +85,10 @@ class GPT4AllCompletionQA(QAModel):
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
@ -100,8 +105,7 @@ class GPT4AllCompletionQA(QAModel):
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes = process_answer(model_output, context_docs)
|
||||
return answer, quotes
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
@ -148,7 +152,10 @@ class GPT4AllChatCompletionQA(QAModel):
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
@ -171,8 +178,7 @@ class GPT4AllChatCompletionQA(QAModel):
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
|
@ -1,3 +1,4 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from huggingface_hub import InferenceClient # type:ignore
|
||||
@ -9,6 +10,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import FreeformProcessor
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
@ -49,7 +51,10 @@ class HuggingFaceCompletionQA(QAModel):
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
@ -62,8 +67,7 @@ class HuggingFaceCompletionQA(QAModel):
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
@ -163,7 +167,10 @@ class HuggingFaceChatCompletionQA(QAModel):
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
model_output = self._get_hf_model_output(query, context_docs)
|
||||
|
||||
|
@ -1,8 +1,10 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -44,6 +46,7 @@ class DanswerQuotes:
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
# Final int is for number of output tokens
|
||||
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
|
||||
AnswerQuestionStreamReturn = Generator[
|
||||
DanswerAnswerPiece | DanswerQuotes | None, None, None
|
||||
@ -66,6 +69,7 @@ class QAModel:
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from transformers import pipeline # type:ignore
|
||||
from transformers import QuestionAnsweringPipeline # type:ignore
|
||||
@ -12,6 +13,7 @@ from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
@ -104,7 +106,10 @@ class TransformerQA(QAModel):
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
danswer_quotes: list[DanswerQuote] = []
|
||||
d_answers: list[str] = []
|
||||
|
6
backend/danswer/direct_qa/models.py
Normal file
6
backend/danswer/direct_qa/models.py
Normal file
@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
@ -24,6 +24,7 @@ from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
@ -142,7 +143,10 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
@ -168,8 +172,7 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
|
@ -1,6 +1,7 @@
|
||||
import abc
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from copy import copy
|
||||
|
||||
@ -21,13 +22,16 @@ from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.llm.llm import LLM
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import str_prompt_to_langchain_prompt
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_up_code_blocks
|
||||
@ -245,11 +249,32 @@ class QABlock(QAModel):
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None,
|
||||
) -> AnswerQuestionReturn:
|
||||
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
||||
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||
model_out = self._llm.invoke(prompt)
|
||||
|
||||
if metrics_callback is not None:
|
||||
prompt_tokens = sum(
|
||||
[
|
||||
check_number_of_tokens(
|
||||
text=p.content, encode_fn=get_default_llm_tokenizer()
|
||||
)
|
||||
for p in prompt
|
||||
]
|
||||
)
|
||||
|
||||
response_tokens = check_number_of_tokens(
|
||||
text=model_out, encode_fn=get_default_llm_tokenizer()
|
||||
)
|
||||
|
||||
metrics_callback(
|
||||
LLMMetricsContainer(
|
||||
prompt_tokens=prompt_tokens, response_tokens=response_tokens
|
||||
)
|
||||
)
|
||||
|
||||
return self._qa_handler.process_llm_output(model_out, trimmed_context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
|
@ -1,5 +1,6 @@
|
||||
import abc
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
import requests
|
||||
@ -15,6 +16,7 @@ from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
@ -234,7 +236,10 @@ class RequestCompletionQA(QAModel):
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
metrics_callback: Callable[[LLMMetricsContainer], None] | None = None, # Unused
|
||||
) -> AnswerQuestionReturn:
|
||||
model_api_response = self._get_request_response(
|
||||
query, context_docs, stream=False
|
||||
@ -245,8 +250,7 @@ class RequestCompletionQA(QAModel):
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
return process_answer(model_output, context_docs)
|
||||
|
||||
def answer_question_stream(
|
||||
self,
|
||||
|
@ -1,5 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
@ -17,6 +18,16 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
|
||||
|
||||
_LLM_TOKENIZER: Callable[[str], Any] | None = None
|
||||
|
||||
|
||||
def get_default_llm_tokenizer() -> Callable:
|
||||
global _LLM_TOKENIZER
|
||||
if _LLM_TOKENIZER is None:
|
||||
_LLM_TOKENIZER = tiktoken.get_encoding("cl100k_base").encode
|
||||
return _LLM_TOKENIZER
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(msg: ChatMessage) -> BaseMessage:
|
||||
if (
|
||||
msg.message_type == MessageType.SYSTEM
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
@ -9,6 +10,9 @@ from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
@ -40,13 +44,31 @@ def retrieve_keyword_documents(
|
||||
filters: list[IndexFilter] | None,
|
||||
datastore: DocumentIndex,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
) -> list[InferenceChunk] | None:
|
||||
edited_query = query_processing(query)
|
||||
top_chunks = datastore.keyword_retrieval(edited_query, user_id, filters, num_hits)
|
||||
|
||||
if not top_chunks:
|
||||
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||
logger.warning(
|
||||
f"Keyword search returned no results - Filters: {filters_log_msg}\tEdited Query: {edited_query}"
|
||||
)
|
||||
return None
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in top_chunks
|
||||
]
|
||||
retrieval_metrics_callback(
|
||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||
)
|
||||
|
||||
return top_chunks
|
||||
|
@ -1,9 +1,16 @@
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chunking.models import DocAwareChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
)
|
||||
|
||||
|
||||
class SearchType(str, Enum):
|
||||
KEYWORD = "keyword" # May be better to also try keyword search if Semantic (AI Search) is on
|
||||
SEMANTIC = "semantic" # Really should try Semantic (AI Search) if keyword is on
|
||||
@ -17,3 +24,22 @@ class QueryFlow(str, Enum):
|
||||
class Embedder:
|
||||
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ChunkMetric(BaseModel):
|
||||
document_id: str
|
||||
chunk_content_start: str
|
||||
first_link: str | None
|
||||
score: float
|
||||
|
||||
|
||||
class RetrievalMetricsContainer(BaseModel):
|
||||
keyword_search: bool # False for Vector Search
|
||||
metrics: list[ChunkMetric] # This contains the scores for retrieval as well
|
||||
|
||||
|
||||
class RerankMetricsContainer(BaseModel):
|
||||
"""The score held by this is the un-boosted, averaged score of the ensemble cross-encoders"""
|
||||
|
||||
metrics: list[ChunkMetric]
|
||||
raw_similarity_scores: list[float]
|
||||
|
@ -1,4 +1,5 @@
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from uuid import UUID
|
||||
|
||||
import numpy
|
||||
@ -18,7 +19,11 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.datastores.datastore_utils import translate_boost_count_to_multiplier
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import Embedder
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.search_utils import get_default_embedding_model
|
||||
from danswer.search.search_utils import get_default_reranking_model_ensemble
|
||||
from danswer.server.models import SearchDoc
|
||||
@ -55,6 +60,7 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
chunks: list[InferenceChunk],
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
model_max = 12 # These are just based on observations from model selection
|
||||
model_min = -12
|
||||
@ -64,6 +70,8 @@ def semantic_reranking(
|
||||
for encoder in cross_encoders
|
||||
]
|
||||
|
||||
raw_sim_scores = sum(sim_scores) / len(sim_scores)
|
||||
|
||||
cross_models_min = numpy.min(sim_scores)
|
||||
|
||||
shifted_sim_scores = sum(
|
||||
@ -75,9 +83,9 @@ def semantic_reranking(
|
||||
normalized_b_s_scores = (boosted_sim_scores + cross_models_min - model_min) / (
|
||||
model_max - model_min
|
||||
)
|
||||
scored_results = list(zip(normalized_b_s_scores, chunks))
|
||||
scored_results = list(zip(normalized_b_s_scores, raw_sim_scores, chunks))
|
||||
scored_results.sort(key=lambda x: x[0], reverse=True)
|
||||
ranked_sim_scores, ranked_chunks = zip(*scored_results)
|
||||
ranked_sim_scores, ranked_raw_scores, ranked_chunks = zip(*scored_results)
|
||||
|
||||
logger.debug(f"Reranked similarity scores: {ranked_sim_scores}")
|
||||
|
||||
@ -86,6 +94,23 @@ def semantic_reranking(
|
||||
for ind, chunk in enumerate(ranked_chunks):
|
||||
chunk.score = ranked_sim_scores[ind]
|
||||
|
||||
if rerank_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in ranked_chunks
|
||||
]
|
||||
|
||||
rerank_metrics_callback(
|
||||
RerankMetricsContainer(
|
||||
metrics=chunk_metrics, raw_similarity_scores=ranked_raw_scores
|
||||
)
|
||||
)
|
||||
|
||||
return list(ranked_chunks)
|
||||
|
||||
|
||||
@ -97,6 +122,9 @@ def retrieve_ranked_documents(
|
||||
datastore: DocumentIndex,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
num_rerank: int = NUM_RERANKED_RESULTS,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
||||
"""Uses vector similarity to fetch the top num_hits document chunks with a distance cutoff.
|
||||
Reranks the top num_rerank out of those (instead of all due to latency)"""
|
||||
@ -108,7 +136,24 @@ def retrieve_ranked_documents(
|
||||
)
|
||||
return None, None
|
||||
logger.debug(top_chunks)
|
||||
ranked_chunks = semantic_reranking(query, top_chunks[:num_rerank])
|
||||
|
||||
if retrieval_metrics_callback is not None:
|
||||
chunk_metrics = [
|
||||
ChunkMetric(
|
||||
document_id=chunk.document_id,
|
||||
chunk_content_start=chunk.content[:MAX_METRICS_CONTENT],
|
||||
first_link=chunk.source_links[0] if chunk.source_links else None,
|
||||
score=chunk.score if chunk.score is not None else 0,
|
||||
)
|
||||
for chunk in top_chunks
|
||||
]
|
||||
retrieval_metrics_callback(
|
||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||
)
|
||||
|
||||
ranked_chunks = semantic_reranking(
|
||||
query, top_chunks[:num_rerank], rerank_metrics_callback=rerank_metrics_callback
|
||||
)
|
||||
|
||||
top_docs = [
|
||||
ranked_chunk.source_links[0]
|
||||
|
12
backend/danswer/utils/callbacks.py
Normal file
12
backend/danswer/utils/callbacks.py
Normal file
@ -0,0 +1,12 @@
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MetricsHander(Generic[T]):
|
||||
def __init__(self) -> None:
|
||||
self.metrics: T | None = None
|
||||
|
||||
def record_metric(self, metrics: T) -> None:
|
||||
self.metrics = metrics
|
@ -1,24 +1,42 @@
|
||||
import argparse
|
||||
import builtins
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TextIO
|
||||
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.utils.callbacks import MetricsHander
|
||||
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def redirect_print_to_file(file: TextIO) -> Any:
|
||||
original_print = builtins.print
|
||||
builtins.print = lambda *args, **kwargs: original_print(*args, file=file, **kwargs)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
builtins.print = original_print
|
||||
|
||||
|
||||
def load_yaml(filepath: str) -> dict:
|
||||
with open(filepath, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
return data
|
||||
|
||||
|
||||
def word_wrap(s: str, max_line_size: int = 120) -> str:
|
||||
def word_wrap(s: str, max_line_size: int = 100, prepend_tab: bool = True) -> str:
|
||||
words = s.split()
|
||||
|
||||
current_line: list[str] = []
|
||||
@ -45,27 +63,85 @@ def word_wrap(s: str, max_line_size: int = 120) -> str:
|
||||
if current_line:
|
||||
result_lines.append(" ".join(current_line))
|
||||
|
||||
return "\n".join(result_lines)
|
||||
return "\t" + "\n\t".join(result_lines) if prepend_tab else "\n".join(result_lines)
|
||||
|
||||
|
||||
def get_answer_for_question(query: str, db_session: Session) -> str | None:
|
||||
def get_answer_for_question(
|
||||
query: str, db_session: Session
|
||||
) -> tuple[
|
||||
str | None,
|
||||
RetrievalMetricsContainer | None,
|
||||
RerankMetricsContainer | None,
|
||||
LLMMetricsContainer | None,
|
||||
]:
|
||||
question = QuestionRequest(
|
||||
query=query,
|
||||
collection="danswer_index",
|
||||
use_keyword=None,
|
||||
use_keyword=False,
|
||||
filters=None,
|
||||
offset=None,
|
||||
)
|
||||
|
||||
retrieval_metrics = MetricsHander[RetrievalMetricsContainer]()
|
||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||
llm_metrics = MetricsHander[LLMMetricsContainer]()
|
||||
|
||||
answer = answer_qa_query(
|
||||
question=question,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=100,
|
||||
real_time_flow=False,
|
||||
enable_reflexion=False,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
llm_metrics_callback=llm_metrics.record_metric,
|
||||
)
|
||||
|
||||
return answer.answer
|
||||
return (
|
||||
answer.answer,
|
||||
retrieval_metrics.metrics,
|
||||
rerank_metrics.metrics,
|
||||
llm_metrics.metrics,
|
||||
)
|
||||
|
||||
|
||||
def _print_retrieval_metrics(
|
||||
metrics_container: RetrievalMetricsContainer, show_all: bool
|
||||
) -> None:
|
||||
for ind, metric in enumerate(metrics_container.metrics):
|
||||
if not show_all and ind >= 10:
|
||||
break
|
||||
|
||||
if ind != 0:
|
||||
print() # for spacing purposes
|
||||
print(f"\tDocument: {metric.document_id}")
|
||||
print(f"\tLink: {metric.first_link or 'NA'}")
|
||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||
print(f"\tSection Start: {section_start}")
|
||||
print(f"\tSimilarity Distance Metric: {metric.score}")
|
||||
|
||||
|
||||
def _print_reranking_metrics(
|
||||
metrics_container: RerankMetricsContainer, show_all: bool
|
||||
) -> None:
|
||||
# Printing the raw scores as they're more informational than post-norm/boosting
|
||||
for ind, metric in enumerate(metrics_container.metrics):
|
||||
if not show_all and ind >= 10:
|
||||
break
|
||||
|
||||
if ind != 0:
|
||||
print() # for spacing purposes
|
||||
print(f"\tDocument: {metric.document_id}")
|
||||
print(f"\tLink: {metric.first_link or 'NA'}")
|
||||
section_start = metric.chunk_content_start.replace("\n", " ")
|
||||
print(f"\tSection Start: {section_start}")
|
||||
print(f"\tSimilarity Score: {metrics_container.raw_similarity_scores[ind]}")
|
||||
|
||||
|
||||
def _print_llm_metrics(metrics_container: LLMMetricsContainer) -> None:
|
||||
print(f"\tPrompt Tokens: {metrics_container.prompt_tokens}")
|
||||
print(f"\tResponse Tokens: {metrics_container.response_tokens}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@ -80,6 +156,16 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--real-time", action="store_true", help="Set to use the real-time flow."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--discard-metrics",
|
||||
action="store_true",
|
||||
help="Set to not include metrics on search, rerank, and token counts.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--all-results",
|
||||
action="store_true",
|
||||
help="Set to not include more than the 10 top sections for search and reranking metrics.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
@ -91,27 +177,65 @@ if __name__ == "__main__":
|
||||
questions_data = load_yaml(args.regression_yaml)
|
||||
|
||||
with open(args.output, "w") as outfile:
|
||||
print("Running Question Answering Flow", file=outfile)
|
||||
with redirect_print_to_file(outfile):
|
||||
print("Running Question Answering Flow")
|
||||
print(
|
||||
"Note that running metrics requires tokenizing all "
|
||||
"prompts/returns and slightly slows down inference."
|
||||
)
|
||||
print(
|
||||
"Also note that the text embedding model (bi-encoder) currently used is trained for "
|
||||
"relative distances, not absolute distances. Therefore cosine similarity values may all be > 0.5 "
|
||||
"even for poor matches"
|
||||
)
|
||||
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
for sample in questions_data["questions"]:
|
||||
# This line goes to stdout to track progress
|
||||
print(f"Running Test for Question {sample['id']}: {sample['question']}")
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
for sample in questions_data["questions"]:
|
||||
print(
|
||||
f"Running Test for Question {sample['id']}: {sample['question']}"
|
||||
)
|
||||
|
||||
start_time = datetime.now()
|
||||
answer = get_answer_for_question(sample["question"], db_session)
|
||||
end_time = datetime.now()
|
||||
start_time = datetime.now()
|
||||
(
|
||||
answer,
|
||||
retrieval_metrics,
|
||||
rerank_metrics,
|
||||
llm_metrics,
|
||||
) = get_answer_for_question(sample["question"], db_session)
|
||||
end_time = datetime.now()
|
||||
|
||||
print(f"====Duration: {end_time - start_time}====", file=outfile)
|
||||
print(f"Question {sample['id']}:", file=outfile)
|
||||
print(sample["question"], file=outfile)
|
||||
print("\nExpected Answer:", file=outfile)
|
||||
print(sample["expected_answer"], file=outfile)
|
||||
print("\nActual Answer:", file=outfile)
|
||||
print(
|
||||
word_wrap(answer)
|
||||
if answer
|
||||
else "Failed, either crashed or refused to answer.",
|
||||
file=outfile,
|
||||
)
|
||||
print("\n\n", file=outfile, flush=True)
|
||||
print(f"====Duration: {end_time - start_time}====")
|
||||
print(f"Question {sample['id']}:")
|
||||
print(f'\t{sample["question"]}')
|
||||
print("\nApproximate Expected Answer:")
|
||||
print(f'\t{sample["expected_answer"]}')
|
||||
print("\nActual Answer:")
|
||||
print(
|
||||
word_wrap(answer)
|
||||
if answer
|
||||
else "\tFailed, either crashed or refused to answer."
|
||||
)
|
||||
if not args.discard_metrics:
|
||||
print("\nLLM Tokens Usage:")
|
||||
if llm_metrics is None:
|
||||
print("No LLM Metrics Available")
|
||||
else:
|
||||
_print_llm_metrics(llm_metrics)
|
||||
|
||||
print("\nRetrieval Metrics:")
|
||||
if retrieval_metrics is None:
|
||||
print("No Retrieval Metrics Available")
|
||||
else:
|
||||
_print_retrieval_metrics(
|
||||
retrieval_metrics, show_all=args.all_results
|
||||
)
|
||||
|
||||
print("\nReranking Metrics:")
|
||||
if rerank_metrics is None:
|
||||
print("No Reranking Metrics Available")
|
||||
else:
|
||||
_print_reranking_metrics(
|
||||
rerank_metrics, show_all=args.all_results
|
||||
)
|
||||
|
||||
print("\n\n", flush=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user