No Null Embeddings (#1982)

This commit is contained in:
Yuhong Sun
2024-07-30 19:54:49 -07:00
committed by GitHub
parent 60a87d9472
commit 036d5c737e
18 changed files with 132 additions and 146 deletions

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "08a1eda20fe1"
down_revision = "8a87bd6ec550"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "473a1a7ca408"
down_revision = "325975216eb3"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
default_models_by_provider = {
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4ea2c93919c1"
down_revision = "473a1a7ca408"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8a87bd6ec550"
down_revision = "4ea2c93919c1"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -7,6 +7,7 @@ from danswer.access.models import DocumentAccess
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned
from shared_configs.model_server_models import Embedding
@dataclass(frozen=True)
@ -257,7 +258,7 @@ class VectorCapable(abc.ABC):
def semantic_retrieval(
self,
query: str, # Needed for matching purposes
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,
@ -292,7 +293,7 @@ class HybridCapable(abc.ABC):
def hybrid_retrieval(
self,
query: str,
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,

View File

@ -69,6 +69,7 @@ from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@ -329,19 +330,15 @@ def _index_vespa_chunk(
"Content-Type": "application/json",
}
document = chunk.source_document
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
embeddings = chunk.embeddings
if chunk.embeddings.full_embedding is None:
embeddings.full_embedding = chunk.title_embedding
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
if embeddings.mini_chunk_embeddings:
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
if m_c_embed is None:
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
else:
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
title = document.get_title_for_document_index()
@ -1035,7 +1032,7 @@ class VespaIndex(DocumentIndex):
def semantic_retrieval(
self,
query: str,
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int = NUM_RETURNED_HITS,
@ -1077,7 +1074,7 @@ class VespaIndex(DocumentIndex):
def hybrid_retrieval(
self,
query: str,
query_embedding: list[float],
query_embedding: Embedding,
filters: IndexFilters,
time_decay_multiplier: float,
num_to_retrieve: int,

View File

@ -1,5 +1,6 @@
import abc
from collections.abc import Callable
from typing import Optional
from typing import TYPE_CHECKING
from danswer.configs.app_configs import BLURB_SIZE
@ -50,9 +51,10 @@ def chunk_large_section(
start_chunk_id: int,
blurb: str,
chunk_splitter: "SentenceSplitter",
title_prefix: str = "",
metadata_suffix_semantic: str = "",
metadata_suffix_keyword: str = "",
mini_chunk_splitter: Optional["SentenceSplitter"],
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
) -> list[DocAwareChunk]:
split_texts = chunk_splitter.split_text(section_text)
@ -61,14 +63,17 @@ def chunk_large_section(
source_document=document,
chunk_id=start_chunk_id + chunk_ind,
blurb=blurb,
content=chunk_str,
content=chunk_text,
source_links={0: section_link_text},
section_continuation=(chunk_ind != 0),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if mini_chunk_splitter and chunk_text.strip()
else None,
)
for chunk_ind, chunk_str in enumerate(split_texts)
for chunk_ind, chunk_text in enumerate(split_texts)
]
return chunks
@ -114,49 +119,6 @@ def _get_metadata_suffix_for_document_index(
return metadata_semantic, metadata_keyword
def _split_chunk_text_into_mini_chunks(
chunk_text: str, embedder: IndexingEmbedder, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]:
"""The minichunks won't all have the title prefix or metadata suffix
It could be a significant percentage of every minichunk so better to not include it
"""
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_tokenizer(
model_name=embedder.model_name,
provider_type=embedder.provider_type,
).tokenize
sentence_aware_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
)
return sentence_aware_splitter.split_text(chunk_text)
def _extract_chunk_texts_from_doc_aware_chunk(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[DocAwareChunk]:
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunks_with_texts: list[DocAwareChunk] = []
for chunk in chunks:
# The whole chunk including the prefix/suffix is included in the overall vector representation
mini_chunk_texts = (
_split_chunk_text_into_mini_chunks(
chunk.content,
embedder=embedder,
)
if enable_mini_chunk
else []
)
chunk.mini_chunk_texts = mini_chunk_texts
chunks_with_texts.append(chunk)
return chunks_with_texts
def chunk_document(
document: Document,
embedder: IndexingEmbedder,
@ -164,6 +126,8 @@ def chunk_document(
subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE, # Used for both title and content
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
mini_chunk_size: int = MINI_CHUNK_SIZE,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter
@ -182,6 +146,12 @@ def chunk_document(
chunk_overlap=subsection_overlap,
)
mini_chunk_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=mini_chunk_size,
chunk_overlap=0,
)
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(tokenizer.tokenize(title_prefix))
@ -238,6 +208,9 @@ def chunk_document(
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
)
)
link_offsets = {}
@ -249,6 +222,9 @@ def chunk_document(
document=document,
start_chunk_id=len(chunks),
chunk_splitter=chunk_splitter,
mini_chunk_splitter=mini_chunk_splitter
if enable_mini_chunk and chunk_text.strip()
else None,
blurb=extract_blurb(section_text, blurb_splitter),
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
@ -280,14 +256,17 @@ def chunk_document(
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
)
)
link_offsets = {0: section_link_text}
chunk_text = section_text
# Once we hit the end, if we're still in the process of building a chunk, add what we have
# NOTE: if it's just whitespace, ignore it.
if chunk_text.strip():
# Once we hit the end, if we're still in the process of building a chunk, add what we have. If there is only whitespace left
# then don't include it. If there are no chunks at all from the doc, we can just create a single chunk with the title.
if chunk_text.strip() or not chunks:
chunks.append(
DocAwareChunk(
source_document=document,
@ -299,13 +278,14 @@ def chunk_document(
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
)
)
chunks_with_texts: list[DocAwareChunk] = _extract_chunk_texts_from_doc_aware_chunk(
chunks=chunks, embedder=embedder
)
return chunks_with_texts
# If the chunk does not have any useable content, it will not be indexed
return chunks
class Chunker:

View File

@ -16,6 +16,7 @@ from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@ -78,12 +79,21 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
self,
chunks: list[DocAwareChunk],
) -> list[IndexChunk]:
# All chunks at this point must have some non-empty content
flat_chunk_texts: list[str] = []
for chunk in chunks:
chunk_text = (
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
)
) or chunk.source_document.get_title_for_document_index()
if not chunk_text:
# This should never happen, the document would have been dropped
# before getting to this point
raise ValueError(f"Chunk has no content: {chunk.to_short_descriptor()}")
flat_chunk_texts.append(chunk_text)
if chunk.mini_chunk_texts:
flat_chunk_texts.extend(chunk.mini_chunk_texts)
embeddings = self.embedding_model.encode(
@ -95,10 +105,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
}
# Drop any None or empty strings
# If there is no title or the title is empty, the title embedding field will be null
# which is ok, it just won't contribute at all to the scoring.
chunk_titles_list = [title for title in chunk_titles if title]
# Cache the Title embeddings to only have to do it once
title_embed_dict: dict[str, list[float] | None] = {}
title_embed_dict: dict[str, Embedding] = {}
if chunk_titles_list:
title_embeddings = self.embedding_model.encode(
chunk_titles_list, text_type=EmbedTextType.PASSAGE
@ -114,7 +126,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
embedded_chunks: list[IndexChunk] = []
embedding_ind_start = 0
for chunk in chunks:
num_embeddings = 1 + len(chunk.mini_chunk_texts)
num_embeddings = 1 + (
len(chunk.mini_chunk_texts) if chunk.mini_chunk_texts else 0
)
chunk_embeddings = embeddings[
embedding_ind_start : embedding_ind_start + num_embeddings
]

View File

@ -124,10 +124,15 @@ def index_doc_batch(
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
# Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
documents_to_process = []
for document in documents:
if not document.title and not any(
section.text.strip() for section in document.sections
if (
not document.title
or not document.title.strip()
and not any(section.text.strip() for section in document.sections)
):
logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content"

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
from danswer.access.models import DocumentAccess
from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
if TYPE_CHECKING:
from danswer.db.models import EmbeddingModel
@ -13,9 +14,6 @@ if TYPE_CHECKING:
logger = setup_logger()
Embedding = list[float] | None
class ChunkEmbedding(BaseModel):
full_embedding: Embedding
mini_chunk_embeddings: list[Embedding]
@ -36,6 +34,8 @@ class DocAwareChunk(BaseChunk):
# During inference we only have access to the document id and do not reconstruct the Document
source_document: Document
# This could be an empty string if the title is too long and taking up too much of the chunk
# This does not mean necessarily that the document does not have a title
title_prefix: str
# During indexing we also (optionally) build a metadata string from the metadata dict
@ -44,8 +44,7 @@ class DocAwareChunk(BaseChunk):
metadata_suffix_semantic: str
metadata_suffix_keyword: str
# give these default values so they can be set after the rest of the chunk is created
mini_chunk_texts: list[str] = []
mini_chunk_texts: list[str] | None
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk"""

View File

@ -13,6 +13,7 @@ from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest
@ -73,10 +74,9 @@ class EmbeddingModel:
texts: list[str],
text_type: EmbedTextType,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[list[float] | None]:
if not texts:
logger.warning("No texts to be embedded")
return []
) -> list[Embedding]:
if not texts or not all(texts):
raise ValueError(f"Empty or missing text for embedding: {texts}")
if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
@ -116,13 +116,12 @@ class EmbeddingModel:
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
EmbedResponse(**response.json()).embeddings
return EmbedResponse(**response.json()).embeddings
# Batching for local embedding
text_batches = batch_list(texts, batch_size)
embeddings: list[list[float] | None] = []
embeddings: list[Embedding] = []
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
)

View File

@ -126,6 +126,7 @@ class InferenceChunk(BaseChunk):
document_id: str
source_type: DocumentSource
semantic_identifier: str
title: str | None # Separate from Semantic Identifier though often same
boost: int
recency_bias: float
score: float | None
@ -193,16 +194,16 @@ class InferenceChunk(BaseChunk):
class InferenceChunkUncleaned(InferenceChunk):
title: str | None # Separate from Semantic Identifier though often same
metadata_suffix: str | None
def to_inference_chunk(self) -> InferenceChunk:
# Create a dict of all fields except 'title' and 'metadata_suffix'
# Create a dict of all fields except 'metadata_suffix'
# Assumes the cleaning has already been applied and just needs to translate to the right type
inference_chunk_data = {
k: v
for k, v in self.dict().items()
if k not in ["title", "metadata_suffix"]
if k
not in ["metadata_suffix"] # May be other fields to throw out in the future
}
return InferenceChunk(**inference_chunk_data)

View File

@ -99,7 +99,11 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
"""
cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks]
passages = [
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks
]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]

View File

@ -1,6 +1,5 @@
import string
from collections.abc import Callable
from typing import cast
import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore
@ -144,9 +143,7 @@ def doc_index_retrieval(
if query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query=query.query,
query_embedding=cast(
list[float], query_embedding
), # query embeddings should always have vector representations
query_embedding=query_embedding,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,
@ -155,9 +152,7 @@ def doc_index_retrieval(
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=cast(
list[float], query_embedding
), # query embeddings should always have vector representations
query_embedding=query_embedding,
filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits,

View File

@ -29,6 +29,7 @@ from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
@ -80,9 +81,7 @@ class CloudEmbedding:
raise ValueError(f"Unsupported provider: {provider}")
self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai(
self, texts: list[str], model: str | None
) -> list[list[float] | None]:
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
if model is None:
model = DEFAULT_OPENAI_MODEL
@ -104,7 +103,7 @@ class CloudEmbedding:
def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
) -> list[Embedding]:
if model is None:
model = DEFAULT_COHERE_MODEL
@ -120,7 +119,7 @@ class CloudEmbedding:
def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
) -> list[Embedding]:
if model is None:
model = DEFAULT_VOYAGE_MODEL
@ -136,7 +135,7 @@ class CloudEmbedding:
def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]:
) -> list[Embedding]:
if model is None:
model = DEFAULT_VERTEX_MODEL
@ -159,7 +158,7 @@ class CloudEmbedding:
texts: list[str],
text_type: EmbedTextType,
model_name: str | None = None,
) -> list[list[float] | None]:
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name)
@ -247,19 +246,13 @@ def embed_text(
api_key: str | None,
provider_type: str | None,
prefix: str | None,
) -> list[list[float] | None]:
non_empty_texts = []
empty_indices = []
for idx, text in enumerate(texts):
if text.strip():
non_empty_texts.append(text)
else:
empty_indices.append(idx)
) -> list[Embedding]:
if not all(texts):
raise ValueError("Empty strings are not allowed for embedding.")
# Third party API based embedding model
if not non_empty_texts:
embeddings = []
if not texts:
raise ValueError("No texts provided for embedding.")
elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None:
@ -277,47 +270,36 @@ def embed_text(
api_key=api_key, provider=provider_type, model=model_name
)
embeddings = cloud_model.embed(
texts=non_empty_texts,
texts=texts,
model_name=model_name,
text_type=text_type,
)
# Check for None values in embeddings
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(texts)
raise ValueError(error_message)
elif model_name is not None:
prefixed_texts = (
[f"{prefix}{text}" for text in non_empty_texts]
if prefix
else non_empty_texts
)
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = local_model.encode(
embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings
)
embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist()
for embedding in embeddings_vectors
]
else:
raise ValueError(
"Either model name or provider must be provided to run embeddings."
)
if embeddings is None:
raise RuntimeError("Failed to create Embeddings")
embeddings_with_nulls: list[list[float] | None] = []
current_embedding_index = 0
for idx in range(len(texts)):
if idx in empty_indices:
embeddings_with_nulls.append(None)
else:
embedding = embeddings[current_embedding_index]
if isinstance(embedding, list) or embedding is None:
embeddings_with_nulls.append(embedding)
else:
embeddings_with_nulls.append(embedding.tolist())
current_embedding_index += 1
embeddings = embeddings_with_nulls
return embeddings
@ -337,6 +319,8 @@ async def process_embed_request(
) -> EmbedResponse:
if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded")
elif not all(embed_request.texts):
raise ValueError("Empty strings are not allowed for embedding.")
try:
if embed_request.text_type == EmbedTextType.QUERY:
@ -371,8 +355,10 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
if not embed_request.documents or not embed_request.query:
raise HTTPException(
status_code=400, detail="No documents or query to be reranked"
status_code=400, detail="Missing documents or query for reranking"
)
if not all(embed_request.documents):
raise ValueError("Empty documents cannot be reranked.")
try:
sim_scores = calc_sim_scores(

View File

@ -2,6 +2,8 @@ from pydantic import BaseModel
from shared_configs.enums import EmbedTextType
Embedding = list[float]
class EmbedRequest(BaseModel):
texts: list[str]
@ -17,7 +19,7 @@ class EmbedRequest(BaseModel):
class EmbedResponse(BaseModel):
embeddings: list[list[float] | None]
embeddings: list[Embedding]
class RerankRequest(BaseModel):

View File

@ -114,6 +114,7 @@ def test_fuzzy_match_quotes_to_docs() -> None:
},
blurb="anything",
semantic_identifier="anything",
title="whatever",
section_continuation=False,
recency_bias=1,
boost=0,
@ -131,6 +132,7 @@ def test_fuzzy_match_quotes_to_docs() -> None:
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
blurb="whatever",
semantic_identifier="whatever",
title="whatever",
section_continuation=False,
recency_bias=1,
boost=0,

View File

@ -24,6 +24,7 @@ def create_inference_chunk(
chunk_id=chunk_id,
document_id=document_id,
semantic_identifier=f"{document_id}_{chunk_id}",
title="whatever",
blurb=f"{document_id}_{chunk_id}",
content=content,
source_links={0: "fake_link"},