danswer/backend/tests/unit/onyx/indexing/test_embedder.py
pablonyx ddec239fef
Improved indexing (#3594)
* nit

* k

* add steps

* main util functions

* functioning fully

* quick nit

* k

* typing fix

* k

* address comments
2025-01-05 23:31:53 +00:00

92 lines
2.8 KiB
Python

from collections.abc import Generator
from unittest.mock import Mock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import IndexChunk
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@pytest.fixture
def mock_embedding_model() -> Generator[Mock, None, None]:
with patch("onyx.indexing.embedder.EmbeddingModel") as mock:
yield mock
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
# Setup
embedder = DefaultIndexingEmbedder(
model_name="test-model",
normalize=True,
query_prefix=None,
passage_prefix=None,
provider_type=EmbeddingProvider.OPENAI,
)
# Mock the encode method of the embedding model
mock_embedding_model.return_value.encode.side_effect = [
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # Main chunk embeddings
[[7.0, 8.0, 9.0]], # Title embedding
]
# Create test input
source_doc = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier="Test Document",
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
Section(text="This is a short section.", link="link1"),
],
)
chunks: list[DocAwareChunk] = [
DocAwareChunk(
chunk_id=1,
blurb="This is a short section.",
content="Test chunk",
source_links={0: "link1"},
section_continuation=False,
source_document=source_doc,
title_prefix="Title: ",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
large_chunk_id=None,
)
]
# Execute
result: list[IndexChunk] = embedder.embed_chunks(chunks)
# Assert
assert len(result) == 1
assert isinstance(result[0], IndexChunk)
assert result[0].content == "Test chunk"
assert result[0].embeddings == ChunkEmbedding(
full_embedding=[1.0, 2.0, 3.0],
mini_chunk_embeddings=[],
)
assert result[0].title_embedding == [7.0, 8.0, 9.0]
# Verify the embedding model was called correctly
mock_embedding_model.return_value.encode.assert_any_call(
texts=["Title: Test chunk"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
)
# title only embedding call
mock_embedding_model.return_value.encode.assert_any_call(
["Test Document"],
text_type=EmbedTextType.PASSAGE,
)