mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-15 07:33:35 +02:00
No Null Embeddings (#1982)
This commit is contained in:
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "08a1eda20fe1"
|
||||
down_revision = "8a87bd6ec550"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -13,8 +13,8 @@ from sqlalchemy.dialects import postgresql
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "473a1a7ca408"
|
||||
down_revision = "325975216eb3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
default_models_by_provider = {
|
||||
"openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"],
|
||||
|
@ -12,8 +12,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4ea2c93919c1"
|
||||
down_revision = "473a1a7ca408"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8a87bd6ec550"
|
||||
down_revision = "4ea2c93919c1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -7,6 +7,7 @@ from danswer.access.models import DocumentAccess
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import InferenceChunkUncleaned
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -257,7 +258,7 @@ class VectorCapable(abc.ABC):
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str, # Needed for matching purposes
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
@ -292,7 +293,7 @@ class HybridCapable(abc.ABC):
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
|
@ -69,6 +69,7 @@ from danswer.search.retrieval.search_runner import query_processing
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -329,19 +330,15 @@ def _index_vespa_chunk(
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
document = chunk.source_document
|
||||
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
embeddings = chunk.embeddings
|
||||
|
||||
if chunk.embeddings.full_embedding is None:
|
||||
embeddings.full_embedding = chunk.title_embedding
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
if m_c_embed is None:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = chunk.title_embedding
|
||||
else:
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
|
||||
title = document.get_title_for_document_index()
|
||||
@ -1035,7 +1032,7 @@ class VespaIndex(DocumentIndex):
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
@ -1077,7 +1074,7 @@ class VespaIndex(DocumentIndex):
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
query_embedding: list[float],
|
||||
query_embedding: Embedding,
|
||||
filters: IndexFilters,
|
||||
time_decay_multiplier: float,
|
||||
num_to_retrieve: int,
|
||||
|
@ -1,5 +1,6 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.configs.app_configs import BLURB_SIZE
|
||||
@ -50,9 +51,10 @@ def chunk_large_section(
|
||||
start_chunk_id: int,
|
||||
blurb: str,
|
||||
chunk_splitter: "SentenceSplitter",
|
||||
title_prefix: str = "",
|
||||
metadata_suffix_semantic: str = "",
|
||||
metadata_suffix_keyword: str = "",
|
||||
mini_chunk_splitter: Optional["SentenceSplitter"],
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
) -> list[DocAwareChunk]:
|
||||
split_texts = chunk_splitter.split_text(section_text)
|
||||
|
||||
@ -61,14 +63,17 @@ def chunk_large_section(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
content=chunk_str,
|
||||
content=chunk_text,
|
||||
source_links={0: section_link_text},
|
||||
section_continuation=(chunk_ind != 0),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if mini_chunk_splitter and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
for chunk_ind, chunk_str in enumerate(split_texts)
|
||||
for chunk_ind, chunk_text in enumerate(split_texts)
|
||||
]
|
||||
return chunks
|
||||
|
||||
@ -114,49 +119,6 @@ def _get_metadata_suffix_for_document_index(
|
||||
return metadata_semantic, metadata_keyword
|
||||
|
||||
|
||||
def _split_chunk_text_into_mini_chunks(
|
||||
chunk_text: str, embedder: IndexingEmbedder, mini_chunk_size: int = MINI_CHUNK_SIZE
|
||||
) -> list[str]:
|
||||
"""The minichunks won't all have the title prefix or metadata suffix
|
||||
It could be a significant percentage of every minichunk so better to not include it
|
||||
"""
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
token_count_func = get_tokenizer(
|
||||
model_name=embedder.model_name,
|
||||
provider_type=embedder.provider_type,
|
||||
).tokenize
|
||||
sentence_aware_splitter = SentenceSplitter(
|
||||
tokenizer=token_count_func, chunk_size=mini_chunk_size, chunk_overlap=0
|
||||
)
|
||||
|
||||
return sentence_aware_splitter.split_text(chunk_text)
|
||||
|
||||
|
||||
def _extract_chunk_texts_from_doc_aware_chunk(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[DocAwareChunk]:
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunks_with_texts: list[DocAwareChunk] = []
|
||||
for chunk in chunks:
|
||||
# The whole chunk including the prefix/suffix is included in the overall vector representation
|
||||
mini_chunk_texts = (
|
||||
_split_chunk_text_into_mini_chunks(
|
||||
chunk.content,
|
||||
embedder=embedder,
|
||||
)
|
||||
if enable_mini_chunk
|
||||
else []
|
||||
)
|
||||
chunk.mini_chunk_texts = mini_chunk_texts
|
||||
chunks_with_texts.append(chunk)
|
||||
|
||||
return chunks_with_texts
|
||||
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
embedder: IndexingEmbedder,
|
||||
@ -164,6 +126,8 @@ def chunk_document(
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE, # Used for both title and content
|
||||
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
|
||||
mini_chunk_size: int = MINI_CHUNK_SIZE,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[DocAwareChunk]:
|
||||
from llama_index.text_splitter import SentenceSplitter
|
||||
|
||||
@ -182,6 +146,12 @@ def chunk_document(
|
||||
chunk_overlap=subsection_overlap,
|
||||
)
|
||||
|
||||
mini_chunk_splitter = SentenceSplitter(
|
||||
tokenizer=tokenizer.tokenize,
|
||||
chunk_size=mini_chunk_size,
|
||||
chunk_overlap=0,
|
||||
)
|
||||
|
||||
title = extract_blurb(document.get_title_for_document_index() or "", blurb_splitter)
|
||||
title_prefix = title + RETURN_SEPARATOR if title else ""
|
||||
title_tokens = len(tokenizer.tokenize(title_prefix))
|
||||
@ -238,6 +208,9 @@ def chunk_document(
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
link_offsets = {}
|
||||
@ -249,6 +222,9 @@ def chunk_document(
|
||||
document=document,
|
||||
start_chunk_id=len(chunks),
|
||||
chunk_splitter=chunk_splitter,
|
||||
mini_chunk_splitter=mini_chunk_splitter
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
blurb=extract_blurb(section_text, blurb_splitter),
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
@ -280,14 +256,17 @@ def chunk_document(
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
link_offsets = {0: section_link_text}
|
||||
chunk_text = section_text
|
||||
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have
|
||||
# NOTE: if it's just whitespace, ignore it.
|
||||
if chunk_text.strip():
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have. If there is only whitespace left
|
||||
# then don't include it. If there are no chunks at all from the doc, we can just create a single chunk with the title.
|
||||
if chunk_text.strip() or not chunks:
|
||||
chunks.append(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
@ -299,13 +278,14 @@ def chunk_document(
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
metadata_suffix_keyword=metadata_suffix_keyword,
|
||||
mini_chunk_texts=mini_chunk_splitter.split_text(chunk_text)
|
||||
if enable_mini_chunk and chunk_text.strip()
|
||||
else None,
|
||||
)
|
||||
)
|
||||
|
||||
chunks_with_texts: list[DocAwareChunk] = _extract_chunk_texts_from_doc_aware_chunk(
|
||||
chunks=chunks, embedder=embedder
|
||||
)
|
||||
return chunks_with_texts
|
||||
# If the chunk does not have any useable content, it will not be indexed
|
||||
return chunks
|
||||
|
||||
|
||||
class Chunker:
|
||||
|
@ -16,6 +16,7 @@ from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -78,12 +79,21 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
self,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> list[IndexChunk]:
|
||||
# All chunks at this point must have some non-empty content
|
||||
flat_chunk_texts: list[str] = []
|
||||
for chunk in chunks:
|
||||
chunk_text = (
|
||||
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
|
||||
)
|
||||
) or chunk.source_document.get_title_for_document_index()
|
||||
|
||||
if not chunk_text:
|
||||
# This should never happen, the document would have been dropped
|
||||
# before getting to this point
|
||||
raise ValueError(f"Chunk has no content: {chunk.to_short_descriptor()}")
|
||||
|
||||
flat_chunk_texts.append(chunk_text)
|
||||
|
||||
if chunk.mini_chunk_texts:
|
||||
flat_chunk_texts.extend(chunk.mini_chunk_texts)
|
||||
|
||||
embeddings = self.embedding_model.encode(
|
||||
@ -95,10 +105,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
}
|
||||
|
||||
# Drop any None or empty strings
|
||||
# If there is no title or the title is empty, the title embedding field will be null
|
||||
# which is ok, it just won't contribute at all to the scoring.
|
||||
chunk_titles_list = [title for title in chunk_titles if title]
|
||||
|
||||
# Cache the Title embeddings to only have to do it once
|
||||
title_embed_dict: dict[str, list[float] | None] = {}
|
||||
title_embed_dict: dict[str, Embedding] = {}
|
||||
if chunk_titles_list:
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
chunk_titles_list, text_type=EmbedTextType.PASSAGE
|
||||
@ -114,7 +126,9 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
embedding_ind_start = 0
|
||||
for chunk in chunks:
|
||||
num_embeddings = 1 + len(chunk.mini_chunk_texts)
|
||||
num_embeddings = 1 + (
|
||||
len(chunk.mini_chunk_texts) if chunk.mini_chunk_texts else 0
|
||||
)
|
||||
chunk_embeddings = embeddings[
|
||||
embedding_ind_start : embedding_ind_start + num_embeddings
|
||||
]
|
||||
|
@ -124,10 +124,15 @@ def index_doc_batch(
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
# Skip documents that have neither title nor content
|
||||
# If the document doesn't have either, then there is no useful information in it
|
||||
# This is again verified later in the pipeline after chunking but at that point there should
|
||||
# already be no documents that are empty.
|
||||
documents_to_process = []
|
||||
for document in documents:
|
||||
if not document.title and not any(
|
||||
section.text.strip() for section in document.sections
|
||||
if (
|
||||
not document.title
|
||||
or not document.title.strip()
|
||||
and not any(section.text.strip() for section in document.sections)
|
||||
):
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content"
|
||||
|
@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import EmbeddingModel
|
||||
@ -13,9 +14,6 @@ if TYPE_CHECKING:
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float] | None
|
||||
|
||||
|
||||
class ChunkEmbedding(BaseModel):
|
||||
full_embedding: Embedding
|
||||
mini_chunk_embeddings: list[Embedding]
|
||||
@ -36,6 +34,8 @@ class DocAwareChunk(BaseChunk):
|
||||
# During inference we only have access to the document id and do not reconstruct the Document
|
||||
source_document: Document
|
||||
|
||||
# This could be an empty string if the title is too long and taking up too much of the chunk
|
||||
# This does not mean necessarily that the document does not have a title
|
||||
title_prefix: str
|
||||
|
||||
# During indexing we also (optionally) build a metadata string from the metadata dict
|
||||
@ -44,8 +44,7 @@ class DocAwareChunk(BaseChunk):
|
||||
metadata_suffix_semantic: str
|
||||
metadata_suffix_keyword: str
|
||||
|
||||
# give these default values so they can be set after the rest of the chunk is created
|
||||
mini_chunk_texts: list[str] = []
|
||||
mini_chunk_texts: list[str] | None
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a chunk"""
|
||||
|
@ -13,6 +13,7 @@ from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
@ -73,10 +74,9 @@ class EmbeddingModel:
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
) -> list[list[float] | None]:
|
||||
if not texts:
|
||||
logger.warning("No texts to be embedded")
|
||||
return []
|
||||
) -> list[Embedding]:
|
||||
if not texts or not all(texts):
|
||||
raise ValueError(f"Empty or missing text for embedding: {texts}")
|
||||
|
||||
if self.retrim_content:
|
||||
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
|
||||
@ -116,13 +116,12 @@ class EmbeddingModel:
|
||||
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
||||
except requests.RequestException as e:
|
||||
raise HTTPError(f"Request failed: {str(e)}") from e
|
||||
EmbedResponse(**response.json()).embeddings
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
# Batching for local embedding
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
embeddings: list[list[float] | None] = []
|
||||
embeddings: list[Embedding] = []
|
||||
logger.debug(
|
||||
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
||||
)
|
||||
|
@ -126,6 +126,7 @@ class InferenceChunk(BaseChunk):
|
||||
document_id: str
|
||||
source_type: DocumentSource
|
||||
semantic_identifier: str
|
||||
title: str | None # Separate from Semantic Identifier though often same
|
||||
boost: int
|
||||
recency_bias: float
|
||||
score: float | None
|
||||
@ -193,16 +194,16 @@ class InferenceChunk(BaseChunk):
|
||||
|
||||
|
||||
class InferenceChunkUncleaned(InferenceChunk):
|
||||
title: str | None # Separate from Semantic Identifier though often same
|
||||
metadata_suffix: str | None
|
||||
|
||||
def to_inference_chunk(self) -> InferenceChunk:
|
||||
# Create a dict of all fields except 'title' and 'metadata_suffix'
|
||||
# Create a dict of all fields except 'metadata_suffix'
|
||||
# Assumes the cleaning has already been applied and just needs to translate to the right type
|
||||
inference_chunk_data = {
|
||||
k: v
|
||||
for k, v in self.dict().items()
|
||||
if k not in ["title", "metadata_suffix"]
|
||||
if k
|
||||
not in ["metadata_suffix"] # May be other fields to throw out in the future
|
||||
}
|
||||
return InferenceChunk(**inference_chunk_data)
|
||||
|
||||
|
@ -99,7 +99,11 @@ def semantic_reranking(
|
||||
Note: this updates the chunks in place, it updates the chunk scores which came from retrieval
|
||||
"""
|
||||
cross_encoders = CrossEncoderEnsembleModel()
|
||||
passages = [chunk.content for chunk in chunks]
|
||||
|
||||
passages = [
|
||||
f"{chunk.semantic_identifier or chunk.title or ''}\n{chunk.content}"
|
||||
for chunk in chunks
|
||||
]
|
||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||
|
||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||
|
@ -1,6 +1,5 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import nltk # type:ignore
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
@ -144,9 +143,7 @@ def doc_index_retrieval(
|
||||
if query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
query_embedding=query_embedding,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
@ -155,9 +152,7 @@ def doc_index_retrieval(
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=cast(
|
||||
list[float], query_embedding
|
||||
), # query embeddings should always have vector representations
|
||||
query_embedding=query_embedding,
|
||||
filters=query.filters,
|
||||
time_decay_multiplier=query.recency_bias_multiplier,
|
||||
num_to_retrieve=query.num_hits,
|
||||
|
@ -29,6 +29,7 @@ from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
@ -80,9 +81,7 @@ class CloudEmbedding:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
||||
self.client = _initialize_client(api_key, self.provider, model)
|
||||
|
||||
def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[list[float] | None]:
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
if model is None:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
@ -104,7 +103,7 @@ class CloudEmbedding:
|
||||
|
||||
def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
) -> list[Embedding]:
|
||||
if model is None:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
@ -120,7 +119,7 @@ class CloudEmbedding:
|
||||
|
||||
def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
) -> list[Embedding]:
|
||||
if model is None:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
@ -136,7 +135,7 @@ class CloudEmbedding:
|
||||
|
||||
def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[list[float] | None]:
|
||||
) -> list[Embedding]:
|
||||
if model is None:
|
||||
model = DEFAULT_VERTEX_MODEL
|
||||
|
||||
@ -159,7 +158,7 @@ class CloudEmbedding:
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None = None,
|
||||
) -> list[list[float] | None]:
|
||||
) -> list[Embedding]:
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return self._embed_openai(texts, model_name)
|
||||
@ -247,19 +246,13 @@ def embed_text(
|
||||
api_key: str | None,
|
||||
provider_type: str | None,
|
||||
prefix: str | None,
|
||||
) -> list[list[float] | None]:
|
||||
non_empty_texts = []
|
||||
empty_indices = []
|
||||
|
||||
for idx, text in enumerate(texts):
|
||||
if text.strip():
|
||||
non_empty_texts.append(text)
|
||||
else:
|
||||
empty_indices.append(idx)
|
||||
) -> list[Embedding]:
|
||||
if not all(texts):
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
# Third party API based embedding model
|
||||
if not non_empty_texts:
|
||||
embeddings = []
|
||||
if not texts:
|
||||
raise ValueError("No texts provided for embedding.")
|
||||
elif provider_type is not None:
|
||||
logger.debug(f"Embedding text with provider: {provider_type}")
|
||||
if api_key is None:
|
||||
@ -277,47 +270,36 @@ def embed_text(
|
||||
api_key=api_key, provider=provider_type, model=model_name
|
||||
)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=non_empty_texts,
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
# Check for None values in embeddings
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
error_message += "Corresponding texts:\n"
|
||||
error_message += "\n".join(texts)
|
||||
raise ValueError(error_message)
|
||||
|
||||
elif model_name is not None:
|
||||
prefixed_texts = (
|
||||
[f"{prefix}{text}" for text in non_empty_texts]
|
||||
if prefix
|
||||
else non_empty_texts
|
||||
)
|
||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = local_model.encode(
|
||||
embeddings_vectors = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
embeddings = [
|
||||
embedding if isinstance(embedding, list) else embedding.tolist()
|
||||
for embedding in embeddings_vectors
|
||||
]
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either model name or provider must be provided to run embeddings."
|
||||
)
|
||||
|
||||
if embeddings is None:
|
||||
raise RuntimeError("Failed to create Embeddings")
|
||||
|
||||
embeddings_with_nulls: list[list[float] | None] = []
|
||||
current_embedding_index = 0
|
||||
|
||||
for idx in range(len(texts)):
|
||||
if idx in empty_indices:
|
||||
embeddings_with_nulls.append(None)
|
||||
else:
|
||||
embedding = embeddings[current_embedding_index]
|
||||
if isinstance(embedding, list) or embedding is None:
|
||||
embeddings_with_nulls.append(embedding)
|
||||
else:
|
||||
embeddings_with_nulls.append(embedding.tolist())
|
||||
current_embedding_index += 1
|
||||
|
||||
embeddings = embeddings_with_nulls
|
||||
return embeddings
|
||||
|
||||
|
||||
@ -337,6 +319,8 @@ async def process_embed_request(
|
||||
) -> EmbedResponse:
|
||||
if not embed_request.texts:
|
||||
raise HTTPException(status_code=400, detail="No texts to be embedded")
|
||||
elif not all(embed_request.texts):
|
||||
raise ValueError("Empty strings are not allowed for embedding.")
|
||||
|
||||
try:
|
||||
if embed_request.text_type == EmbedTextType.QUERY:
|
||||
@ -371,8 +355,10 @@ async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse
|
||||
|
||||
if not embed_request.documents or not embed_request.query:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No documents or query to be reranked"
|
||||
status_code=400, detail="Missing documents or query for reranking"
|
||||
)
|
||||
if not all(embed_request.documents):
|
||||
raise ValueError("Empty documents cannot be reranked.")
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
|
@ -2,6 +2,8 @@ from pydantic import BaseModel
|
||||
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
texts: list[str]
|
||||
@ -17,7 +19,7 @@ class EmbedRequest(BaseModel):
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
embeddings: list[list[float] | None]
|
||||
embeddings: list[Embedding]
|
||||
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
|
@ -114,6 +114,7 @@ def test_fuzzy_match_quotes_to_docs() -> None:
|
||||
},
|
||||
blurb="anything",
|
||||
semantic_identifier="anything",
|
||||
title="whatever",
|
||||
section_continuation=False,
|
||||
recency_bias=1,
|
||||
boost=0,
|
||||
@ -131,6 +132,7 @@ def test_fuzzy_match_quotes_to_docs() -> None:
|
||||
source_links={0: "doc 1 base", 36: "2nd line link", 82: "last link"},
|
||||
blurb="whatever",
|
||||
semantic_identifier="whatever",
|
||||
title="whatever",
|
||||
section_continuation=False,
|
||||
recency_bias=1,
|
||||
boost=0,
|
||||
|
@ -24,6 +24,7 @@ def create_inference_chunk(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
semantic_identifier=f"{document_id}_{chunk_id}",
|
||||
title="whatever",
|
||||
blurb=f"{document_id}_{chunk_id}",
|
||||
content=content,
|
||||
source_links={0: "fake_link"},
|
||||
|
Reference in New Issue
Block a user