mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-10 21:09:51 +02:00
Fix empty / reverted embeddings (#1910)
This commit is contained in:
parent
6ff8e6c0ea
commit
48a0d29a5c
@ -331,12 +331,18 @@ def _index_vespa_chunk(
|
||||
document = chunk.source_document
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
|
||||
if chunk.embeddings.full_embedding is None:
|
||||
embeddings.full_embedding = chunk.title_embedding
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
if m_c_embed is None:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
|
||||
else:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
|
||||
|
@ -73,7 +73,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[IndexChunk]:
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, list[float]] = {}
|
||||
title_embed_dict: dict[str, list[float] | None] = {}
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
@ -168,4 +168,6 @@ def get_embedding_model_from_db_embedding_model(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
api_key=db_embedding_model.api_key,
|
||||
)
|
||||
|
@ -13,7 +13,7 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
Embedding = list[float] | None
|
||||
|
||||
|
||||
class ChunkEmbedding(BaseModel):
|
||||
|
@ -72,7 +72,7 @@ class EmbeddingModel:
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[list[float]]:
|
||||
) -> list[list[float] | None]:
|
||||
if not texts:
|
||||
logger.warning("No texts to be embedded")
|
||||
return []
|
||||
@ -112,11 +112,13 @@ class EmbeddingModel:
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
EmbedResponse(**response.json()).embeddings
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[list[float]] = []
|
||||
embeddings: list[list[float] | None] = []
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
@ -143,7 +145,6 @@ class EmbeddingModel:
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(EmbedResponse(**response.json()).embeddings)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@ -156,7 +157,7 @@ class CrossEncoderEnsembleModel:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float] | None]:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
|
||||
response = requests.post(
|
||||
|
@ -1,5 +1,6 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import nltk # type:ignore
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
@ -143,7 +144,9 @@ def doc_index_retrieval(
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
@ -152,7 +155,9 @@ def doc_index_retrieval(
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
|
@ -96,6 +96,8 @@ def upsert_ingestion_doc(
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@ -132,6 +134,8 @@ def upsert_ingestion_doc(
|
||||
normalize=sec_db_embedding_model.normalize,
|
||||
query_prefix=sec_db_embedding_model.query_prefix,
|
||||
passage_prefix=sec_db_embedding_model.passage_prefix,
|
||||
api_key=sec_db_embedding_model.api_key,
|
||||
provider_type=sec_db_embedding_model.provider_type,
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
|
@ -80,7 +80,9 @@ class CloudEmbedding:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[list[float]]:
|
||||
def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
@ -92,7 +94,7 @@ class CloudEmbedding:
|
||||
|
||||
def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
@ -108,7 +110,7 @@ class CloudEmbedding:
|
||||
|
||||
def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
@ -124,7 +126,7 @@ class CloudEmbedding:
|
||||
|
||||
def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float]]:
|
||||
) -> list[list[float] | None]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
@ -147,7 +149,7 @@ class CloudEmbedding:
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
) -> list[list[float]]:
|
||||
) -> list[list[float] | None]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
@ -235,9 +237,20 @@ def embed_text(
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
prefix: str | None,
|
||||
) -> list[list[float]]:
|
||||
) -> list[list[float] | None]:
|
||||
non_empty_texts = []
|
||||
empty_indices = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
if text.strip():
|
||||
non_empty_texts.append(text)
|
||||
else:
|
||||
empty_indices.append(idx)
|
||||
|
||||
# Third party API based embedding model
|
||||
if provider_type is not None:
|
||||
if not non_empty_texts:
|
||||
embeddings = []
|
||||
elif provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
raise RuntimeError("API key not provided for cloud model")
|
||||
@ -254,14 +267,17 @@ def embed_text(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=texts,
|
||||
texts=non_empty_texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
# Locally running model
|
||||
elif model_name is not None:
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
prefixed_texts = (
|
||||
[f"{prefix}{text}" for text in non_empty_texts]
|
||||
if prefix
|
||||
else non_empty_texts
|
||||
)
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
@ -277,14 +293,26 @@ def embed_text(
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Failed to create Embeddings")
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
embeddings_with_nulls: list[list[float] | None] = []
|
||||
current_embedding_index = 0
|
||||
|
||||
for idx in range(len(texts)):
|
||||
if idx in empty_indices:
|
||||
embeddings_with_nulls.append(None)
|
||||
else:
|
||||
embedding = embeddings[current_embedding_index]
|
||||
if isinstance(embedding, list) or embedding is None:
|
||||
embeddings_with_nulls.append(embedding)
|
||||
else:
|
||||
embeddings_with_nulls.append(embedding.tolist())
|
||||
current_embedding_index += 1
|
||||
|
||||
embeddings = embeddings_with_nulls
|
||||
return embeddings
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float] | None]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
|
@ -17,7 +17,7 @@ class EmbedRequest(BaseModel):
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
embeddings: list[list[float]]
|
||||
embeddings: list[list[float] | None]
|
||||
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
@ -26,7 +26,7 @@ class RerankRequest(BaseModel):
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
scores: list[list[float]]
|
||||
scores: list[list[float] | None]
|
||||
|
||||
|
||||
class IntentRequest(BaseModel):
|
||||
|
Loading…
x
Reference in New Issue
Block a user