Make Cross Encoders Optional (#476)

This commit is contained in:
Yuhong Sun 2023-09-23 17:17:54 -07:00 committed by GitHub
parent 3c65317538
commit 8b95e2631d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 152 additions and 38 deletions

View File

@ -66,7 +66,7 @@ def mark_run_failed(
connector_credential_pair` to reflect that the run failed"""
logger.warning(
f"Marking in-progress attempt 'connector: {index_attempt.connector_id}, "
f"credential: {index_attempt.credential_id}' as failed"
f"credential: {index_attempt.credential_id}' as failed due to {failure_reason}"
)
mark_attempt_failed(
index_attempt=index_attempt,

View File

@ -3,36 +3,50 @@ import os
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
#####
# Embedding/Reranking Model Configs
#####
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license
# Inference/Indexing speed
# https://huggingface.co/thenlper/gte-small
DOCUMENT_ENCODER_MODEL = "thenlper/gte-small"
DOC_EMBEDDING_DIM = 384 # Depends on the document encoder model
NORMALIZE_EMBEDDINGS = False
# Certain models like BGE use a prefix for asymmetric retrievals (query generally shorter than docs)
ASYMMETRIC_PREFIX = ""
# https://huggingface.co/DOCUMENT_ENCODER_MODEL
# The useable models configured as below must be SentenceTransformer compatible
DOCUMENT_ENCODER_MODEL = (
os.environ.get("DOCUMENT_ENCODER_MODEL") or "thenlper/gte-small"
)
# If the below is changed, Vespa deployment must also be changed
DOC_EMBEDDING_DIM = 384
# Model should be chosen with 512 context size, ideally don't change this
DOC_EMBEDDING_CONTEXT_SIZE = 512
NORMALIZE_EMBEDDINGS = (os.environ.get("SKIP_RERANKING") or "False").lower() == "true"
# These are only used if reranking is turned off, to normalize the direct retrieval scores for display
SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0)
SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0)
# Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs)
ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# Cross Encoder Settings
SKIP_RERANKING = os.environ.get("SKIP_RERANKING", "").lower() == "true"
# https://www.sbert.net/docs/pretrained-models/ce-msmarco.html
CROSS_ENCODER_MODEL_ENSEMBLE = [
"cross-encoder/ms-marco-MiniLM-L-4-v2",
"cross-encoder/ms-marco-TinyBERT-L-2-v2",
]
CROSS_EMBED_CONTEXT_SIZE = 512
# Better to keep it loose, surfacing more results better than missing results
# Currently unused by Vespa
SEARCH_DISTANCE_CUTOFF = 0.1 # Cosine similarity (currently), range of -1 to 1 with -1 being completely opposite
# Intent model max context size
QUERY_MAX_CONTEXT_SIZE = 256
# The below is correlated with CHUNK_SIZE in app_configs but not strictly calculated
# To avoid extra overhead of tokenizing for chunking during indexing.
DOC_EMBEDDING_CONTEXT_SIZE = 512
CROSS_EMBED_CONTEXT_SIZE = 512
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
#####

View File

@ -243,6 +243,9 @@ class QABlock(QAModel):
"""This is called during server start up to load the models into memory
in case the chosen LLM is not accessed via API"""
if self._llm.requires_warm_up:
logger.info(
"Warming up LLM, this should only run for in memory LLMs like GPT4All"
)
self._llm.invoke("Ignore this!")
def answer_question(

View File

@ -26,8 +26,12 @@ from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.datastores.document_index import get_default_document_index
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.llm_utils import get_default_qa_model
@ -179,6 +183,13 @@ def get_application() -> FastAPI:
else:
logger.debug("OAuth is turned on")
if SKIP_RERANKING:
logger.info("Reranking step of search flow is disabled")
logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"')
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
logger.info("Warming up local NLP models.")
warm_up_models()
qa_model = get_default_qa_model()

View File

@ -9,6 +9,7 @@ from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.configs.model_configs import SKIP_RERANKING
_TOKENIZER: None | AutoTokenizer = None
@ -61,7 +62,9 @@ def get_default_intent_model() -> TFDistilBertForSequenceClassification:
return _INTENT_MODEL
def warm_up_models(indexer_only: bool = False) -> None:
def warm_up_models(
indexer_only: bool = False, skip_cross_encoders: bool = SKIP_RERANKING
) -> None:
warm_up_str = "Danswer is amazing"
get_default_tokenizer()(warm_up_str)
get_default_embedding_model().encode(warm_up_str)
@ -69,11 +72,13 @@ def warm_up_models(indexer_only: bool = False) -> None:
if indexer_only:
return
cross_encoders = get_default_reranking_model_ensemble()
[
cross_encoder.predict((warm_up_str, warm_up_str))
for cross_encoder in cross_encoders
]
if not skip_cross_encoders:
cross_encoders = get_default_reranking_model_ensemble()
[
cross_encoder.predict((warm_up_str, warm_up_str))
for cross_encoder in cross_encoders
]
intent_tokenizer = get_default_intent_model_tokenizer()
inputs = intent_tokenizer(
warm_up_str, return_tensors="tf", truncation=True, padding=True

View File

@ -13,9 +13,13 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import ASYMMETRIC_PREFIX
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH
from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.datastores.datastore_utils import translate_boost_count_to_multiplier
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import IndexFilter
@ -114,6 +118,52 @@ def semantic_reranking(
return list(ranked_chunks)
def apply_boost(
chunks: list[InferenceChunk],
norm_min: float = SIM_SCORE_RANGE_LOW,
norm_max: float = SIM_SCORE_RANGE_HIGH,
) -> list[InferenceChunk]:
scores = [chunk.score or 0 for chunk in chunks]
boosts = [translate_boost_count_to_multiplier(chunk.boost) for chunk in chunks]
logger.debug(f"Raw similarity scores: {scores}")
score_min = min(scores)
score_max = max(scores)
score_range = score_max - score_min
boosted_scores = [
((score - score_min) / score_range) * boost
for score, boost in zip(scores, boosts)
]
unnormed_boosted_scores = [
score * score_range + score_min for score in boosted_scores
]
norm_min = min(norm_min, min(scores))
norm_max = max(norm_max, max(scores))
# For score display purposes
re_normed_scores = [
((score - norm_min) / (norm_max - norm_min))
for score in unnormed_boosted_scores
]
rescored_chunks = list(zip(re_normed_scores, chunks))
rescored_chunks.sort(key=lambda x: x[0], reverse=True)
sorted_boosted_scores, boost_sorted_chunks = zip(*rescored_chunks)
final_chunks = list(boost_sorted_chunks)
final_scores = list(sorted_boosted_scores)
for ind, chunk in enumerate(final_chunks):
chunk.score = final_scores[ind]
logger.debug(f"Boost sorted similary scores: {list(final_scores)}")
return final_chunks
@log_function_time()
def retrieve_ranked_documents(
query: str,
@ -122,12 +172,20 @@ def retrieve_ranked_documents(
datastore: DocumentIndex,
num_hits: int = NUM_RETURNED_HITS,
num_rerank: int = NUM_RERANKED_RESULTS,
skip_rerank: bool = SKIP_RERANKING,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
"""Uses vector similarity to fetch the top num_hits document chunks with a distance cutoff.
Reranks the top num_rerank out of those (instead of all due to latency)"""
def _log_top_chunk_links(chunks: list[InferenceChunk]) -> None:
doc_links = [c.source_links[0] for c in chunks if c.source_links is not None]
files_log_msg = f"Top links from semantic search: {', '.join(doc_links)}"
logger.info(files_log_msg)
top_chunks = datastore.semantic_retrieval(query, user_id, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
@ -151,20 +209,23 @@ def retrieve_ranked_documents(
RetrievalMetricsContainer(keyword_search=True, metrics=chunk_metrics)
)
ranked_chunks = semantic_reranking(
query, top_chunks[:num_rerank], rerank_metrics_callback=rerank_metrics_callback
if skip_rerank:
# Need the range of values to not be too spread out for applying boost
boosted_chunks = apply_boost(top_chunks[:num_rerank])
_log_top_chunk_links(boosted_chunks)
return boosted_chunks, top_chunks[num_rerank:]
ranked_chunks = (
semantic_reranking(
query,
top_chunks[:num_rerank],
rerank_metrics_callback=rerank_metrics_callback,
)
if not skip_rerank
else []
)
top_docs = [
ranked_chunk.source_links[0]
for ranked_chunk in ranked_chunks
if ranked_chunk.source_links is not None
]
files_log_msg = (
f"Top links from semantic search: {', '.join(list(dict.fromkeys(top_docs)))}"
)
logger.info(files_log_msg)
_log_top_chunk_links(ranked_chunks)
return ranked_chunks, top_chunks[num_rerank:]
@ -175,6 +236,7 @@ def encode_chunks(
embedding_model: SentenceTransformer | None = None,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
passage_prefix: str = ASYM_PASSAGE_PREFIX,
) -> list[IndexChunk]:
embedded_chunks: list[IndexChunk] = []
if embedding_model is None:
@ -183,14 +245,15 @@ def encode_chunks(
chunk_texts = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
chunk_texts.append(chunk.content)
chunk_texts.append(passage_prefix + chunk.content)
mini_chunk_texts = (
split_chunk_text_into_mini_chunks(chunk.content)
if enable_mini_chunk
else []
)
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
prefixed_mini_chunk_texts = [passage_prefix + text for text in mini_chunk_texts]
chunk_texts.extend(prefixed_mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(prefixed_mini_chunk_texts)
text_batches = [
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
@ -228,7 +291,7 @@ def encode_chunks(
def embed_query(
query: str,
embedding_model: SentenceTransformer | None = None,
prefix: str = ASYMMETRIC_PREFIX,
prefix: str = ASYM_QUERY_PREFIX,
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[float]:
model = embedding_model or get_default_embedding_model()

View File

@ -24,7 +24,6 @@ services:
- NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL=${NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL:-}
- POSTGRES_HOST=relational_db
- VESPA_HOST=index
- LOG_LEVEL=${LOG_LEVEL:-info}
- DISABLE_AUTH=${DISABLE_AUTH:-True}
- QA_TIMEOUT=${QA_TIMEOUT:-}
- VALID_EMAIL_DOMAINS=${VALID_EMAIL_DOMAINS:-}
@ -37,6 +36,16 @@ services:
- API_TYPE_OPENAI=${API_TYPE_OPENAI:-}
- API_VERSION_OPENAI=${API_VERSION_OPENAI:-}
- AZURE_DEPLOYMENT_ID=${AZURE_DEPLOYMENT_ID:-}
# Don't change the NLP model configs unless you know what you're doing
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- SIM_SCORE_RANGE_LOW=${SIM_SCORE_RANGE_LOW:-}
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- SKIP_RERANKING=${SKIP_RERANKING:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
@ -71,6 +80,15 @@ services:
- DANSWER_BOT_SLACK_BOT_TOKEN=${DANSWER_BOT_SLACK_BOT_TOKEN:-}
- DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER=${DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER:-}
- DANSWER_BOT_DISPLAY_ERROR_MSGS=${DANSWER_BOT_DISPLAY_ERROR_MSGS:-}
# Don't change the NLP model configs unless you know what you're doing
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- SIM_SCORE_RANGE_LOW=${SIM_SCORE_RANGE_LOW:-}
- SIM_SCORE_RANGE_HIGH=${SIM_SCORE_RANGE_HIGH:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- SKIP_RERANKING=${SKIP_RERANKING:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- local_dynamic_storage:/home/storage