mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-11 21:39:31 +02:00
Make Cross Encoders Optional (#476)
This commit is contained in:
parent
3c65317538
commit
8b95e2631d
@ -66,7 +66,7 @@ def mark_run_failed(
|
||||
connector_credential_pair` to reflect that the run failed"""
|
||||
logger.warning(
|
||||
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
|
||||
f"credential: {index_attempt.credential_id}' as failed"
|
||||
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt=index_attempt,
|
||||
|
@ -3,36 +3,50 @@ import os
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
|
||||
|
||||
#####
|
||||
# Embedding/Reranking Model Configs
|
||||
#####
|
||||
# Important considerations when choosing models
|
||||
# Max tokens count needs to be high considering use case (at least 512)
|
||||
# Models used must be MIT or Apache license
|
||||
# Inference/Indexing speed
|
||||
|
||||
# https://huggingface.co/thenlper/gte-small
|
||||
DOCUMENT_ENCODER_MODEL = "thenlper/gte-small"
|
||||
DOC_EMBEDDING_DIM = 384 # Depends on the document encoder model
|
||||
NORMALIZE_EMBEDDINGS = False
|
||||
# Certain models like BGE use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
ASYMMETRIC_PREFIX = ""
|
||||
# https://huggingface.co/DOCUMENT_ENCODER_MODEL
|
||||
# The useable models configured as below must be SentenceTransformer compatible
|
||||
DOCUMENT_ENCODER_MODEL = (
|
||||
os.environ.get("DOCUMENT_ENCODER_MODEL") or "thenlper/gte-small"
|
||||
)
|
||||
# If the below is changed, Vespa deployment must also be changed
|
||||
DOC_EMBEDDING_DIM = 384
|
||||
# Model should be chosen with 512 context size, ideally don't change this
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
NORMALIZE_EMBEDDINGS = (os.environ.get("SKIP_RERANKING") or "False").lower() == "true"
|
||||
# These are only used if reranking is turned off, to normalize the direct retrieval scores for display
|
||||
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
|
||||
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
|
||||
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
|
||||
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
# Cross Encoder Settings
|
||||
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
|
||||
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
|
||||
CROSS_ENCODER_MODEL_ENSEMBLE = [
|
||||
"cross-encoder/ms-marco-MiniLM-L-4-v2",
|
||||
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
|
||||
]
|
||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
|
||||
|
||||
# Better to keep it loose, surfacing more results better than missing results
|
||||
# Currently unused by Vespa
|
||||
SEARCH_DISTANCE_CUTOFF = 0.1 # Cosine similarity (currently), range of -1 to 1 with -1 being completely opposite
|
||||
|
||||
# Intent model max context size
|
||||
QUERY_MAX_CONTEXT_SIZE = 256
|
||||
# The below is correlated with CHUNK_SIZE in app_configs but not strictly calculated
|
||||
# To avoid extra overhead of tokenizing for chunking during indexing.
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
|
||||
#####
|
||||
|
@ -243,6 +243,9 @@ class QABlock(QAModel):
|
||||
"""This is called during server start up to load the models into memory
|
||||
in case the chosen LLM is not accessed via API"""
|
||||
if self._llm.requires_warm_up:
|
||||
logger.info(
|
||||
"Warming up LLM, this should only run for in memory LLMs like GPT4All"
|
||||
)
|
||||
self._llm.invoke("Ignore this!")
|
||||
|
||||
def answer_question(
|
||||
|
@ -26,8 +26,12 @@ from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
@ -179,6 +183,13 @@ def get_application() -> FastAPI:
|
||||
else:
|
||||
logger.debug("OAuth is turned on")
|
||||
|
||||
if SKIP_RERANKING:
|
||||
logger.info("Reranking step of search flow is disabled")
|
||||
|
||||
logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"')
|
||||
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
|
||||
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
|
||||
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
qa_model = get_default_qa_model()
|
||||
|
@ -9,6 +9,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
||||
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
|
||||
|
||||
_TOKENIZER: None | AutoTokenizer = None
|
||||
@ -61,7 +62,9 @@ def get_default_intent_model() -> TFDistilBertForSequenceClassification:
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def warm_up_models(indexer_only: bool = False) -> None:
|
||||
def warm_up_models(
|
||||
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
|
||||
) -> None:
|
||||
warm_up_str = "Danswer is amazing"
|
||||
get_default_tokenizer()(warm_up_str)
|
||||
get_default_embedding_model().encode(warm_up_str)
|
||||
@ -69,11 +72,13 @@ def warm_up_models(indexer_only: bool = False) -> None:
|
||||
if indexer_only:
|
||||
return
|
||||
|
||||
cross_encoders = get_default_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((warm_up_str, warm_up_str))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
if not skip_cross_encoders:
|
||||
cross_encoders = get_default_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((warm_up_str, warm_up_str))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
|
||||
intent_tokenizer = get_default_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||
|
@ -13,9 +13,13 @@ from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import ASYMMETRIC_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
|
||||
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.datastores.datastore_utils import translate_boost_count_to_multiplier
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
@ -114,6 +118,52 @@ def semantic_reranking(
|
||||
return list(ranked_chunks)
|
||||
|
||||
|
||||
def apply_boost(
|
||||
chunks: list[InferenceChunk],
|
||||
norm_min: float = SIM_SCORE_RANGE_LOW,
|
||||
norm_max: float = SIM_SCORE_RANGE_HIGH,
|
||||
) -> list[InferenceChunk]:
|
||||
scores = [chunk.score or 0 for chunk in chunks]
|
||||
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
|
||||
|
||||
logger.debug(f"Raw similarity scores: {scores}")
|
||||
|
||||
score_min = min(scores)
|
||||
score_max = max(scores)
|
||||
score_range = score_max - score_min
|
||||
|
||||
boosted_scores = [
|
||||
((score - score_min) / score_range) * boost
|
||||
for score, boost in zip(scores, boosts)
|
||||
]
|
||||
|
||||
unnormed_boosted_scores = [
|
||||
score * score_range + score_min for score in boosted_scores
|
||||
]
|
||||
|
||||
norm_min = min(norm_min, min(scores))
|
||||
norm_max = max(norm_max, max(scores))
|
||||
|
||||
# For score display purposes
|
||||
re_normed_scores = [
|
||||
((score - norm_min) / (norm_max - norm_min))
|
||||
for score in unnormed_boosted_scores
|
||||
]
|
||||
|
||||
rescored_chunks = list(zip(re_normed_scores, chunks))
|
||||
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
|
||||
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
|
||||
|
||||
final_chunks = list(boost_sorted_chunks)
|
||||
final_scores = list(sorted_boosted_scores)
|
||||
for ind, chunk in enumerate(final_chunks):
|
||||
chunk.score = final_scores[ind]
|
||||
|
||||
logger.debug(f"Boost sorted similary scores: {list(final_scores)}")
|
||||
|
||||
return final_chunks
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def retrieve_ranked_documents(
|
||||
query: str,
|
||||
@ -122,12 +172,20 @@ def retrieve_ranked_documents(
|
||||
datastore: DocumentIndex,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
num_rerank: int = NUM_RERANKED_RESULTS,
|
||||
skip_rerank: bool = SKIP_RERANKING,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
||||
"""Uses vector similarity to fetch the top num_hits document chunks with a distance cutoff.
|
||||
Reranks the top num_rerank out of those (instead of all due to latency)"""
|
||||
|
||||
def _log_top_chunk_links(chunks: list[InferenceChunk]) -> None:
|
||||
doc_links = [c.source_links[0] for c in chunks if c.source_links is not None]
|
||||
|
||||
files_log_msg = f"Top links from semantic search: {', '.join(doc_links)}"
|
||||
logger.info(files_log_msg)
|
||||
|
||||
top_chunks = datastore.semantic_retrieval(query, user_id, filters, num_hits)
|
||||
if not top_chunks:
|
||||
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||
@ -151,20 +209,23 @@ def retrieve_ranked_documents(
|
||||
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
|
||||
)
|
||||
|
||||
ranked_chunks = semantic_reranking(
|
||||
query, top_chunks[:num_rerank], rerank_metrics_callback=rerank_metrics_callback
|
||||
if skip_rerank:
|
||||
# Need the range of values to not be too spread out for applying boost
|
||||
boosted_chunks = apply_boost(top_chunks[:num_rerank])
|
||||
_log_top_chunk_links(boosted_chunks)
|
||||
return boosted_chunks, top_chunks[num_rerank:]
|
||||
|
||||
ranked_chunks = (
|
||||
semantic_reranking(
|
||||
query,
|
||||
top_chunks[:num_rerank],
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
if not skip_rerank
|
||||
else []
|
||||
)
|
||||
|
||||
top_docs = [
|
||||
ranked_chunk.source_links[0]
|
||||
for ranked_chunk in ranked_chunks
|
||||
if ranked_chunk.source_links is not None
|
||||
]
|
||||
|
||||
files_log_msg = (
|
||||
f"Top links from semantic search: {', '.join(list(dict.fromkeys(top_docs)))}"
|
||||
)
|
||||
logger.info(files_log_msg)
|
||||
_log_top_chunk_links(ranked_chunks)
|
||||
|
||||
return ranked_chunks, top_chunks[num_rerank:]
|
||||
|
||||
@ -175,6 +236,7 @@ def encode_chunks(
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
passage_prefix: str = ASYM_PASSAGE_PREFIX,
|
||||
) -> list[IndexChunk]:
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
if embedding_model is None:
|
||||
@ -183,14 +245,15 @@ def encode_chunks(
|
||||
chunk_texts = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
chunk_texts.append(chunk.content)
|
||||
chunk_texts.append(passage_prefix + chunk.content)
|
||||
mini_chunk_texts = (
|
||||
split_chunk_text_into_mini_chunks(chunk.content)
|
||||
if enable_mini_chunk
|
||||
else []
|
||||
)
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
prefixed_mini_chunk_texts = [passage_prefix + text for text in mini_chunk_texts]
|
||||
chunk_texts.extend(prefixed_mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(prefixed_mini_chunk_texts)
|
||||
|
||||
text_batches = [
|
||||
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
|
||||
@ -228,7 +291,7 @@ def encode_chunks(
|
||||
def embed_query(
|
||||
query: str,
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
prefix: str = ASYMMETRIC_PREFIX,
|
||||
prefix: str = ASYM_QUERY_PREFIX,
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
) -> list[float]:
|
||||
model = embedding_model or get_default_embedding_model()
|
||||
|
@ -24,7 +24,6 @@ services:
|
||||
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
|
||||
- POSTGRES_HOST=relational_db
|
||||
- VESPA_HOST=index
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
- DISABLE_AUTH=${DISABLE_AUTH:-True}
|
||||
- QA_TIMEOUT=${QA_TIMEOUT:-}
|
||||
- VALID_EMAIL_DOMAINS=${VALID_EMAIL_DOMAINS:-}
|
||||
@ -37,6 +36,16 @@ services:
|
||||
- API_TYPE_OPENAI=${API_TYPE_OPENAI:-}
|
||||
- API_VERSION_OPENAI=${API_VERSION_OPENAI:-}
|
||||
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- SIM_SCORE_RANGE_LOW=${SIM_SCORE_RANGE_LOW:-}
|
||||
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
@ -71,6 +80,15 @@ services:
|
||||
- DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-}
|
||||
- DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-}
|
||||
- DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-}
|
||||
# Don't change the NLP model configs unless you know what you're doing
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- SIM_SCORE_RANGE_LOW=${SIM_SCORE_RANGE_LOW:-}
|
||||
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
|
Loading…
x
Reference in New Issue
Block a user