No Null Embeddings (#1982)

This commit is contained in:
Yuhong Sun
2024-07-30 19:54:49 -07:00
committed by GitHub
parent 60a87d9472
commit 036d5c737e
18 changed files with 132 additions and 146 deletions

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "08a1eda20fe1" revision = "08a1eda20fe1"
down_revision = "8a87bd6ec550" down_revision = "8a87bd6ec550"
branch_labels = None branch_labels: None = None
depends_on = None depends_on: None = None
def upgrade() -> None: def upgrade() -> None:

View File

@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "473a1a7ca408" revision = "473a1a7ca408"
down_revision = "325975216eb3" down_revision = "325975216eb3"
branch_labels = None branch_labels: None = None
depends_on = None depends_on: None = None
default_models_by_provider = { default_models_by_provider = {
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"], "openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],

View File

@ -12,8 +12,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "4ea2c93919c1" revision = "4ea2c93919c1"
down_revision = "473a1a7ca408" down_revision = "473a1a7ca408"
branch_labels = None branch_labels: None = None
depends_on = None depends_on: None = None
def upgrade() -> None: def upgrade() -> None:

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "8a87bd6ec550" revision = "8a87bd6ec550"
down_revision = "4ea2c93919c1" down_revision = "4ea2c93919c1"
branch_labels = None branch_labels: None = None
depends_on = None depends_on: None = None
def upgrade() -> None: def upgrade() -> None:

View File

@ -7,6 +7,7 @@ from danswer.access.models import DocumentAccess
from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.models import IndexFilters from danswer.search.models import IndexFilters
from danswer.search.models import InferenceChunkUncleaned from danswer.search.models import InferenceChunkUncleaned
from shared_configs.model_server_models import Embedding
@dataclass(frozen=True) @dataclass(frozen=True)
@ -257,7 +258,7 @@ class VectorCapable(abc.ABC):
def semantic_retrieval( def semantic_retrieval(
self, self,
query: str, # Needed for matching purposes query: str, # Needed for matching purposes
query_embedding: list[float], query_embedding: Embedding,
filters: IndexFilters, filters: IndexFilters,
time_decay_multiplier: float, time_decay_multiplier: float,
num_to_retrieve: int, num_to_retrieve: int,
@ -292,7 +293,7 @@ class HybridCapable(abc.ABC):
def hybrid_retrieval( def hybrid_retrieval(
self, self,
query: str, query: str,
query_embedding: list[float], query_embedding: Embedding,
filters: IndexFilters, filters: IndexFilters,
time_decay_multiplier: float, time_decay_multiplier: float,
num_to_retrieve: int, num_to_retrieve: int,

View File

@ -69,6 +69,7 @@ from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
logger = setup_logger() logger = setup_logger()
@ -329,20 +330,16 @@ def _index_vespa_chunk(
"Content-Type": "application/json", "Content-Type": "application/json",
} }
document = chunk.source_document document = chunk.source_document
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself # No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
vespa_chunk_id = str(get_uuid_from_chunk(chunk)) vespa_chunk_id = str(get_uuid_from_chunk(chunk))
embeddings = chunk.embeddings embeddings = chunk.embeddings
if chunk.embeddings.full_embedding is None:
embeddings.full_embedding = chunk.title_embedding
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding} embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
if embeddings.mini_chunk_embeddings: if embeddings.mini_chunk_embeddings:
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings): for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
if m_c_embed is None: embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
else:
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
title = document.get_title_for_document_index() title = document.get_title_for_document_index()
@ -1035,7 +1032,7 @@ class VespaIndex(DocumentIndex):
def semantic_retrieval( def semantic_retrieval(
self, self,
query: str, query: str,
query_embedding: list[float], query_embedding: Embedding,
filters: IndexFilters, filters: IndexFilters,
time_decay_multiplier: float, time_decay_multiplier: float,
num_to_retrieve: int = NUM_RETURNED_HITS, num_to_retrieve: int = NUM_RETURNED_HITS,
@ -1077,7 +1074,7 @@ class VespaIndex(DocumentIndex):
def hybrid_retrieval( def hybrid_retrieval(
self, self,
query: str, query: str,
query_embedding: list[float], query_embedding: Embedding,
filters: IndexFilters, filters: IndexFilters,
time_decay_multiplier: float, time_decay_multiplier: float,
num_to_retrieve: int, num_to_retrieve: int,

View File

@ -1,5 +1,6 @@
import abc import abc
from collections.abc import Callable from collections.abc import Callable
from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from danswer.configs.app_configs import BLURB_SIZE from danswer.configs.app_configs import BLURB_SIZE
@ -50,9 +51,10 @@ def chunk_large_section(
start_chunk_id: int, start_chunk_id: int,
blurb: str, blurb: str,
chunk_splitter: "SentenceSplitter", chunk_splitter: "SentenceSplitter",
title_prefix: str = "", mini_chunk_splitter: Optional["SentenceSplitter"],
metadata_suffix_semantic: str = "", title_prefix: str,
metadata_suffix_keyword: str = "", metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
) -> list[DocAwareChunk]: ) -> list[DocAwareChunk]:
split_texts = chunk_splitter.split_text(section_text) split_texts = chunk_splitter.split_text(section_text)
@ -61,14 +63,17 @@ def chunk_large_section(
source_document=document, source_document=document,
chunk_id=start_chunk_id + chunk_ind, chunk_id=start_chunk_id + chunk_ind,
blurb=blurb, blurb=blurb,
content=chunk_str, content=chunk_text,
source_links={0: section_link_text}, source_links={0: section_link_text},
section_continuation=(chunk_ind != 0), section_continuation=(chunk_ind != 0),
title_prefix=title_prefix, title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword, metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if mini_chunk_splitter and chunk_text.strip()
else None,
) )
for chunk_ind, chunk_str in enumerate(split_texts) for chunk_ind, chunk_text in enumerate(split_texts)
] ]
return chunks return chunks
@ -114,49 +119,6 @@ def _get_metadata_suffix_for_document_index(
return metadata_semantic, metadata_keyword return metadata_semantic, metadata_keyword
def _split_chunk_text_into_mini_chunks(
chunk_text: str, embedder: IndexingEmbedder, mini_chunk_size: int = MINI_CHUNK_SIZE
) -> list[str]:
"""The minichunks won't all have the title prefix or metadata suffix
It could be a significant percentage of every minichunk so better to not include it
"""
from llama_index.text_splitter import SentenceSplitter
token_count_func = get_tokenizer(
model_name=embedder.model_name,
provider_type=embedder.provider_type,
).tokenize
sentence_aware_splitter = SentenceSplitter(
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
)
return sentence_aware_splitter.split_text(chunk_text)
def _extract_chunk_texts_from_doc_aware_chunk(
chunks: list[DocAwareChunk],
embedder: IndexingEmbedder,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[DocAwareChunk]:
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunks_with_texts: list[DocAwareChunk] = []
for chunk in chunks:
# The whole chunk including the prefix/suffix is included in the overall vector representation
mini_chunk_texts = (
_split_chunk_text_into_mini_chunks(
chunk.content,
embedder=embedder,
)
if enable_mini_chunk
else []
)
chunk.mini_chunk_texts = mini_chunk_texts
chunks_with_texts.append(chunk)
return chunks_with_texts
def chunk_document( def chunk_document(
document: Document, document: Document,
embedder: IndexingEmbedder, embedder: IndexingEmbedder,
@ -164,6 +126,8 @@ def chunk_document(
subsection_overlap: int = CHUNK_OVERLAP, subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE, # Used for both title and content blurb_size: int = BLURB_SIZE, # Used for both title and content
include_metadata: bool = not SKIP_METADATA_IN_CHUNK, include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
mini_chunk_size: int = MINI_CHUNK_SIZE,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[DocAwareChunk]: ) -> list[DocAwareChunk]:
from llama_index.text_splitter import SentenceSplitter from llama_index.text_splitter import SentenceSplitter
@ -182,6 +146,12 @@ def chunk_document(
chunk_overlap=subsection_overlap, chunk_overlap=subsection_overlap,
) )
mini_chunk_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=mini_chunk_size,
chunk_overlap=0,
)
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter) title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
title_prefix = title + RETURN_SEPARATOR if title else "" title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(tokenizer.tokenize(title_prefix)) title_tokens = len(tokenizer.tokenize(title_prefix))
@ -238,6 +208,9 @@ def chunk_document(
title_prefix=title_prefix, title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword, metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
) )
) )
link_offsets = {} link_offsets = {}
@ -249,6 +222,9 @@ def chunk_document(
document=document, document=document,
start_chunk_id=len(chunks), start_chunk_id=len(chunks),
chunk_splitter=chunk_splitter, chunk_splitter=chunk_splitter,
mini_chunk_splitter=mini_chunk_splitter
if enable_mini_chunk and chunk_text.strip()
else None,
blurb=extract_blurb(section_text, blurb_splitter), blurb=extract_blurb(section_text, blurb_splitter),
title_prefix=title_prefix, title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_semantic=metadata_suffix_semantic,
@ -280,14 +256,17 @@ def chunk_document(
title_prefix=title_prefix, title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword, metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
) )
) )
link_offsets = {0: section_link_text} link_offsets = {0: section_link_text}
chunk_text = section_text chunk_text = section_text
# Once we hit the end, if we're still in the process of building a chunk, add what we have # Once we hit the end, if we're still in the process of building a chunk, add what we have. If there is only whitespace left
# NOTE: if it's just whitespace, ignore it. # then don't include it. If there are no chunks at all from the doc, we can just create a single chunk with the title.
if chunk_text.strip(): if chunk_text.strip() or not chunks:
chunks.append( chunks.append(
DocAwareChunk( DocAwareChunk(
source_document=document, source_document=document,
@ -299,13 +278,14 @@ def chunk_document(
title_prefix=title_prefix, title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_semantic=metadata_suffix_semantic,
metadata_suffix_keyword=metadata_suffix_keyword, metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
if enable_mini_chunk and chunk_text.strip()
else None,
) )
) )
chunks_with_texts: list[DocAwareChunk] = _extract_chunk_texts_from_doc_aware_chunk( # If the chunk does not have any useable content, it will not be indexed
chunks=chunks, embedder=embedder return chunks
)
return chunks_with_texts
class Chunker: class Chunker:

View File

@ -16,6 +16,7 @@ from danswer.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger() logger = setup_logger()
@ -78,13 +79,22 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
self, self,
chunks: list[DocAwareChunk], chunks: list[DocAwareChunk],
) -> list[IndexChunk]: ) -> list[IndexChunk]:
# All chunks at this point must have some non-empty content
flat_chunk_texts: list[str] = [] flat_chunk_texts: list[str] = []
for chunk in chunks: for chunk in chunks:
chunk_text = ( chunk_text = (
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}" f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
) ) or chunk.source_document.get_title_for_document_index()
if not chunk_text:
# This should never happen, the document would have been dropped
# before getting to this point
raise ValueError(f"Chunk has no content: {chunk.to_short_descriptor()}")
flat_chunk_texts.append(chunk_text) flat_chunk_texts.append(chunk_text)
flat_chunk_texts.extend(chunk.mini_chunk_texts)
if chunk.mini_chunk_texts:
flat_chunk_texts.extend(chunk.mini_chunk_texts)
embeddings = self.embedding_model.encode( embeddings = self.embedding_model.encode(
flat_chunk_texts, text_type=EmbedTextType.PASSAGE flat_chunk_texts, text_type=EmbedTextType.PASSAGE
@ -95,10 +105,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
} }
# Drop any None or empty strings # Drop any None or empty strings
# If there is no title or the title is empty, the title embedding field will be null
# which is ok, it just won't contribute at all to the scoring.
chunk_titles_list = [title for title in chunk_titles if title] chunk_titles_list = [title for title in chunk_titles if title]
# Cache the Title embeddings to only have to do it once # Cache the Title embeddings to only have to do it once
title_embed_dict: dict[str, list[float] | None] = {} title_embed_dict: dict[str, Embedding] = {}
if chunk_titles_list: if chunk_titles_list:
title_embeddings = self.embedding_model.encode( title_embeddings = self.embedding_model.encode(
chunk_titles_list, text_type=EmbedTextType.PASSAGE chunk_titles_list, text_type=EmbedTextType.PASSAGE
@ -114,7 +126,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
embedded_chunks: list[IndexChunk] = [] embedded_chunks: list[IndexChunk] = []
embedding_ind_start = 0 embedding_ind_start = 0
for chunk in chunks: for chunk in chunks:
num_embeddings = 1 + len(chunk.mini_chunk_texts) num_embeddings = 1 + (
len(chunk.mini_chunk_texts) if chunk.mini_chunk_texts else 0
)
chunk_embeddings = embeddings[ chunk_embeddings = embeddings[
embedding_ind_start : embedding_ind_start + num_embeddings embedding_ind_start : embedding_ind_start + num_embeddings
] ]

View File

@ -124,10 +124,15 @@ def index_doc_batch(
Note that the documents should already be batched at this point so that it does not inflate the Note that the documents should already be batched at this point so that it does not inflate the
memory requirements""" memory requirements"""
# Skip documents that have neither title nor content # Skip documents that have neither title nor content
# If the document doesn't have either, then there is no useful information in it
# This is again verified later in the pipeline after chunking but at that point there should
# already be no documents that are empty.
documents_to_process = [] documents_to_process = []
for document in documents: for document in documents:
if not document.title and not any( if (
section.text.strip() for section in document.sections not document.title
or not document.title.strip()
and not any(section.text.strip() for section in document.sections)
): ):
logger.warning( logger.warning(
f"Skipping document with ID {document.id} as it has neither title nor content" f"Skipping document with ID {document.id} as it has neither title nor content"

View File

@ -5,6 +5,7 @@ from pydantic import BaseModel
from danswer.access.models import DocumentAccess from danswer.access.models import DocumentAccess
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_configs.model_server_models import Embedding
if TYPE_CHECKING: if TYPE_CHECKING:
from danswer.db.models import EmbeddingModel from danswer.db.models import EmbeddingModel
@ -13,9 +14,6 @@ if TYPE_CHECKING:
logger = setup_logger() logger = setup_logger()
Embedding = list[float] | None
class ChunkEmbedding(BaseModel): class ChunkEmbedding(BaseModel):
full_embedding: Embedding full_embedding: Embedding
mini_chunk_embeddings: list[Embedding] mini_chunk_embeddings: list[Embedding]
@ -36,6 +34,8 @@ class DocAwareChunk(BaseChunk):
# During inference we only have access to the document id and do not reconstruct the Document # During inference we only have access to the document id and do not reconstruct the Document
source_document: Document source_document: Document
# This could be an empty string if the title is too long and taking up too much of the chunk
# This does not mean necessarily that the document does not have a title
title_prefix: str title_prefix: str
# During indexing we also (optionally) build a metadata string from the metadata dict # During indexing we also (optionally) build a metadata string from the metadata dict
@ -44,8 +44,7 @@ class DocAwareChunk(BaseChunk):
metadata_suffix_semantic: str metadata_suffix_semantic: str
metadata_suffix_keyword: str metadata_suffix_keyword: str
# give these default values so they can be set after the rest of the chunk is created mini_chunk_texts: list[str] | None
mini_chunk_texts: list[str] = []
def to_short_descriptor(self) -> str: def to_short_descriptor(self) -> str:
"""Used when logging the identity of a chunk""" """Used when logging the identity of a chunk"""

View File

@ -13,6 +13,7 @@ from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentRequest
@ -73,10 +74,9 @@ class EmbeddingModel:
texts: list[str], texts: list[str],
text_type: EmbedTextType, text_type: EmbedTextType,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
) -> list[list[float] | None]: ) -> list[Embedding]:
if not texts: if not texts or not all(texts):
logger.warning("No texts to be embedded") raise ValueError(f"Empty or missing text for embedding: {texts}")
return []
if self.retrim_content: if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields) # This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
@ -116,13 +116,12 @@ class EmbeddingModel:
raise HTTPError(f"HTTP error occurred: {error_detail}") from e raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e: except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e raise HTTPError(f"Request failed: {str(e)}") from e
EmbedResponse(**response.json()).embeddings
return EmbedResponse(**response.json()).embeddings return EmbedResponse(**response.json()).embeddings
# Batching for local embedding # Batching for local embedding
text_batches = batch_list(texts, batch_size) text_batches = batch_list(texts, batch_size)
embeddings: list[list[float] | None] = [] embeddings: list[Embedding] = []
logger.debug( logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model" f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
) )

View File

@ -126,6 +126,7 @@ class InferenceChunk(BaseChunk):
document_id: str document_id: str
source_type: DocumentSource source_type: DocumentSource
semantic_identifier: str semantic_identifier: str
title: str | None # Separate from Semantic Identifier though often same
boost: int boost: int
recency_bias: float recency_bias: float
score: float | None score: float | None
@ -193,16 +194,16 @@ class InferenceChunk(BaseChunk):
class InferenceChunkUncleaned(InferenceChunk): class InferenceChunkUncleaned(InferenceChunk):
title: str | None # Separate from Semantic Identifier though often same
metadata_suffix: str | None metadata_suffix: str | None
def to_inference_chunk(self) -> InferenceChunk: def to_inference_chunk(self) -> InferenceChunk:
# Create a dict of all fields except 'title' and 'metadata_suffix' # Create a dict of all fields except 'metadata_suffix'
# Assumes the cleaning has already been applied and just needs to translate to the right type # Assumes the cleaning has already been applied and just needs to translate to the right type
inference_chunk_data = { inference_chunk_data = {
k: v k: v
for k, v in self.dict().items() for k, v in self.dict().items()
if k not in ["title", "metadata_suffix"] if k
not in ["metadata_suffix"] # May be other fields to throw out in the future
} }
return InferenceChunk(**inference_chunk_data) return InferenceChunk(**inference_chunk_data)

View File

@ -99,7 +99,11 @@ def semantic_reranking(
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
""" """
cross_encoders = CrossEncoderEnsembleModel() cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks]
passages = [
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
for chunk in chunks
]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages) sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
sim_scores = [numpy.array(scores) for scores in sim_scores_floats] sim_scores = [numpy.array(scores) for scores in sim_scores_floats]

View File

@ -1,6 +1,5 @@
import string import string
from collections.abc import Callable from collections.abc import Callable
from typing import cast
import nltk # type:ignore import nltk # type:ignore
from nltk.corpus import stopwords # type:ignore from nltk.corpus import stopwords # type:ignore
@ -144,9 +143,7 @@ def doc_index_retrieval(
if query.search_type == SearchType.SEMANTIC: if query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval( top_chunks = document_index.semantic_retrieval(
query=query.query, query=query.query,
query_embedding=cast( query_embedding=query_embedding,
list[float], query_embedding
), # query embeddings should always have vector representations
filters=query.filters, filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier, time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits, num_to_retrieve=query.num_hits,
@ -155,9 +152,7 @@ def doc_index_retrieval(
elif query.search_type == SearchType.HYBRID: elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval( top_chunks = document_index.hybrid_retrieval(
query=query.query, query=query.query,
query_embedding=cast( query_embedding=query_embedding,
list[float], query_embedding
), # query embeddings should always have vector representations
filters=query.filters, filters=query.filters,
time_decay_multiplier=query.recency_bias_multiplier, time_decay_multiplier=query.recency_bias_multiplier,
num_to_retrieve=query.num_hits, num_to_retrieve=query.num_hits,

View File

@ -29,6 +29,7 @@ from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import INDEXING_ONLY
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest from shared_configs.model_server_models import RerankRequest
@ -80,9 +81,7 @@ class CloudEmbedding:
raise ValueError(f"Unsupported provider: {provider}") raise ValueError(f"Unsupported provider: {provider}")
self.client = _initialize_client(api_key, self.provider, model) self.client = _initialize_client(api_key, self.provider, model)
def _embed_openai( def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
self, texts: list[str], model: str | None
) -> list[list[float] | None]:
if model is None: if model is None:
model = DEFAULT_OPENAI_MODEL model = DEFAULT_OPENAI_MODEL
@ -104,7 +103,7 @@ class CloudEmbedding:
def _embed_cohere( def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]: ) -> list[Embedding]:
if model is None: if model is None:
model = DEFAULT_COHERE_MODEL model = DEFAULT_COHERE_MODEL
@ -120,7 +119,7 @@ class CloudEmbedding:
def _embed_voyage( def _embed_voyage(
self, texts: list[str], model: str | None, embedding_type: str self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]: ) -> list[Embedding]:
if model is None: if model is None:
model = DEFAULT_VOYAGE_MODEL model = DEFAULT_VOYAGE_MODEL
@ -136,7 +135,7 @@ class CloudEmbedding:
def _embed_vertex( def _embed_vertex(
self, texts: list[str], model: str | None, embedding_type: str self, texts: list[str], model: str | None, embedding_type: str
) -> list[list[float] | None]: ) -> list[Embedding]:
if model is None: if model is None:
model = DEFAULT_VERTEX_MODEL model = DEFAULT_VERTEX_MODEL
@ -159,7 +158,7 @@ class CloudEmbedding:
texts: list[str], texts: list[str],
text_type: EmbedTextType, text_type: EmbedTextType,
model_name: str | None = None, model_name: str | None = None,
) -> list[list[float] | None]: ) -> list[Embedding]:
try: try:
if self.provider == EmbeddingProvider.OPENAI: if self.provider == EmbeddingProvider.OPENAI:
return self._embed_openai(texts, model_name) return self._embed_openai(texts, model_name)
@ -247,19 +246,13 @@ def embed_text(
api_key: str | None, api_key: str | None,
provider_type: str | None, provider_type: str | None,
prefix: str | None, prefix: str | None,
) -> list[list[float] | None]: ) -> list[Embedding]:
non_empty_texts = [] if not all(texts):
empty_indices = [] raise ValueError("Empty strings are not allowed for embedding.")
for idx, text in enumerate(texts):
if text.strip():
non_empty_texts.append(text)
else:
empty_indices.append(idx)
# Third party API based embedding model # Third party API based embedding model
if not non_empty_texts: if not texts:
embeddings = [] raise ValueError("No texts provided for embedding.")
elif provider_type is not None: elif provider_type is not None:
logger.debug(f"Embedding text with provider: {provider_type}") logger.debug(f"Embedding text with provider: {provider_type}")
if api_key is None: if api_key is None:
@ -277,47 +270,36 @@ def embed_text(
api_key=api_key, provider=provider_type, model=model_name api_key=api_key, provider=provider_type, model=model_name
) )
embeddings = cloud_model.embed( embeddings = cloud_model.embed(
texts=non_empty_texts, texts=texts,
model_name=model_name, model_name=model_name,
text_type=text_type, text_type=text_type,
) )
# Check for None values in embeddings
if any(embedding is None for embedding in embeddings):
error_message = "Embeddings contain None values\n"
error_message += "Corresponding texts:\n"
error_message += "\n".join(texts)
raise ValueError(error_message)
elif model_name is not None: elif model_name is not None:
prefixed_texts = ( prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
[f"{prefix}{text}" for text in non_empty_texts]
if prefix
else non_empty_texts
)
local_model = get_embedding_model( local_model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length model_name=model_name, max_context_length=max_context_length
) )
embeddings = local_model.encode( embeddings_vectors = local_model.encode(
prefixed_texts, normalize_embeddings=normalize_embeddings prefixed_texts, normalize_embeddings=normalize_embeddings
) )
embeddings = [
embedding if isinstance(embedding, list) else embedding.tolist()
for embedding in embeddings_vectors
]
else: else:
raise ValueError( raise ValueError(
"Either model name or provider must be provided to run embeddings." "Either model name or provider must be provided to run embeddings."
) )
if embeddings is None:
raise RuntimeError("Failed to create Embeddings")
embeddings_with_nulls: list[list[float] | None] = []
current_embedding_index = 0
for idx in range(len(texts)):
if idx in empty_indices:
embeddings_with_nulls.append(None)
else:
embedding = embeddings[current_embedding_index]
if isinstance(embedding, list) or embedding is None:
embeddings_with_nulls.append(embedding)
else:
embeddings_with_nulls.append(embedding.tolist())
current_embedding_index += 1
embeddings = embeddings_with_nulls
return embeddings return embeddings
@ -337,6 +319,8 @@ async def process_embed_request(
) -> EmbedResponse: ) -> EmbedResponse:
if not embed_request.texts: if not embed_request.texts:
raise HTTPException(status_code=400, detail="No texts to be embedded") raise HTTPException(status_code=400, detail="No texts to be embedded")
elif not all(embed_request.texts):
raise ValueError("Empty strings are not allowed for embedding.")
try: try:
if embed_request.text_type == EmbedTextType.QUERY: if embed_request.text_type == EmbedTextType.QUERY:
@ -371,8 +355,10 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
if not embed_request.documents or not embed_request.query: if not embed_request.documents or not embed_request.query:
raise HTTPException( raise HTTPException(
status_code=400, detail="No documents or query to be reranked" status_code=400, detail="Missing documents or query for reranking"
) )
if not all(embed_request.documents):
raise ValueError("Empty documents cannot be reranked.")
try: try:
sim_scores = calc_sim_scores( sim_scores = calc_sim_scores(

View File

@ -2,6 +2,8 @@ from pydantic import BaseModel
from shared_configs.enums import EmbedTextType from shared_configs.enums import EmbedTextType
Embedding = list[float]
class EmbedRequest(BaseModel): class EmbedRequest(BaseModel):
texts: list[str] texts: list[str]
@ -17,7 +19,7 @@ class EmbedRequest(BaseModel):
class EmbedResponse(BaseModel): class EmbedResponse(BaseModel):
embeddings: list[list[float] | None] embeddings: list[Embedding]
class RerankRequest(BaseModel): class RerankRequest(BaseModel):

View File

@ -114,6 +114,7 @@ def test_fuzzy_match_quotes_to_docs() -> None:
}, },
blurb="anything", blurb="anything",
semantic_identifier="anything", semantic_identifier="anything",
title="whatever",
section_continuation=False, section_continuation=False,
recency_bias=1, recency_bias=1,
boost=0, boost=0,
@ -131,6 +132,7 @@ def test_fuzzy_match_quotes_to_docs() -> None:
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"}, source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
blurb="whatever", blurb="whatever",
semantic_identifier="whatever", semantic_identifier="whatever",
title="whatever",
section_continuation=False, section_continuation=False,
recency_bias=1, recency_bias=1,
boost=0, boost=0,

View File

@ -24,6 +24,7 @@ def create_inference_chunk(
chunk_id=chunk_id, chunk_id=chunk_id,
document_id=document_id, document_id=document_id,
semantic_identifier=f"{document_id}_{chunk_id}", semantic_identifier=f"{document_id}_{chunk_id}",
title="whatever",
blurb=f"{document_id}_{chunk_id}", blurb=f"{document_id}_{chunk_id}",
content=content, content=content,
source_links={0: "fake_link"}, source_links={0: "fake_link"},