Merge pull request #2631 from danswer-ai/hotfix/v0.6-heartbeat

Hotfix/v0.6 heartbeat
This commit is contained in:
rkuo-danswer 2024-09-30 12:25:48 -07:00 committed by GitHub
commit 3cafedcf22
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 324 additions and 54 deletions

View File

@ -29,6 +29,7 @@ from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import IndexAttemptSingleton
from danswer.utils.logger import setup_logger
@ -103,15 +104,24 @@ def _run_indexing(
)
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
search_settings=search_settings,
heartbeat=IndexingHeartbeat(
index_attempt_id=index_attempt.id,
db_session=db_session,
# let the world know we're still making progress after
# every 10 batches
freq=10,
),
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt.id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE),
ignore_time_skip=(
index_attempt.from_beginning
or (search_settings.status == IndexModelStatus.FUTURE)
),
db_session=db_session,
)

View File

@ -10,6 +10,7 @@ from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from danswer.connectors.models import Document
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.natural_language_processing.utils import BaseTokenizer
from danswer.utils.logger import setup_logger
@ -123,6 +124,7 @@ class Chunker:
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
mini_chunk_size: int = MINI_CHUNK_SIZE,
heartbeat: Heartbeat | None = None,
) -> None:
from llama_index.text_splitter import SentenceSplitter
@ -131,6 +133,7 @@ class Chunker:
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.tokenizer = tokenizer
self.heartbeat = heartbeat
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
@ -255,7 +258,7 @@ class Chunker:
# If the chunk does not have any useable content, it will not be indexed
return chunks
def chunk(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@ -302,3 +305,13 @@ class Chunker:
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
final_chunks: list[DocAwareChunk] = []
for document in documents:
final_chunks.extend(self._handle_single_document(document))
if self.heartbeat:
self.heartbeat.heartbeat()
return final_chunks

View File

@ -1,12 +1,8 @@
from abc import ABC
from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@ -24,6 +20,9 @@ logger = setup_logger()
class IndexingEmbedder(ABC):
"""Converts chunks into chunks with embeddings. Note that one chunk may have
multiple embeddings associated with it."""
def __init__(
self,
model_name: str,
@ -33,6 +32,7 @@ class IndexingEmbedder(ABC):
provider_type: EmbeddingProvider | None,
api_key: str | None,
api_url: str | None,
heartbeat: Heartbeat | None,
):
self.model_name = model_name
self.normalize = normalize
@ -54,6 +54,7 @@ class IndexingEmbedder(ABC):
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=INDEXING_MODEL_SERVER_PORT,
retrim_content=True,
heartbeat=heartbeat,
)
@abstractmethod
@ -74,6 +75,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type: EmbeddingProvider | None = None,
api_key: str | None = None,
api_url: str | None = None,
heartbeat: Heartbeat | None = None,
):
super().__init__(
model_name,
@ -83,6 +85,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type,
api_key,
api_url,
heartbeat,
)
@log_function_time()
@ -166,7 +169,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict[title] = title_embedding
new_embedded_chunk = IndexChunk(
**chunk.dict(),
**chunk.model_dump(),
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
@ -180,7 +183,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
@classmethod
def from_db_search_settings(
cls, search_settings: SearchSettings
cls, search_settings: SearchSettings, heartbeat: Heartbeat | None = None
) -> "DefaultIndexingEmbedder":
return cls(
model_name=search_settings.model_name,
@ -190,28 +193,5 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
heartbeat=heartbeat,
)
def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")
return DefaultIndexingEmbedder(
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
api_url=search_settings.api_url,
)

View File

@ -0,0 +1,41 @@
import abc
from typing import Any
from sqlalchemy import func
from sqlalchemy.orm import Session
from danswer.db.index_attempt import get_index_attempt
from danswer.utils.logger import setup_logger
logger = setup_logger()
class Heartbeat(abc.ABC):
"""Useful for any long-running work that goes through a bunch of items
and needs to occasionally give updates on progress.
e.g. chunking, embedding, updating vespa, etc."""
@abc.abstractmethod
def heartbeat(self, metadata: Any = None) -> None:
raise NotImplementedError
class IndexingHeartbeat(Heartbeat):
def __init__(self, index_attempt_id: int, db_session: Session, freq: int):
self.cnt = 0
self.index_attempt_id = index_attempt_id
self.db_session = db_session
self.freq = freq
def heartbeat(self, metadata: Any = None) -> None:
self.cnt += 1
if self.cnt % self.freq == 0:
index_attempt = get_index_attempt(
db_session=self.db_session, index_attempt_id=self.index_attempt_id
)
if index_attempt:
index_attempt.time_updated = func.now()
self.db_session.commit()
else:
logger.error("Index attempt not found, this should not happen!")

View File

@ -31,6 +31,7 @@ from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import DocumentMetadata
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.utils.logger import setup_logger
@ -283,18 +284,10 @@ def index_doc_batch(
return 0, 0
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = []
for document in ctx.updatable_docs:
chunks.extend(chunker.chunk(document=document))
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings = (
embedder.embed_chunks(
chunks=chunks,
)
if chunks
else []
)
chunks_with_embeddings = embedder.embed_chunks(chunks) if chunks else []
updatable_ids = [doc.id for doc in ctx.updatable_docs]
@ -406,6 +399,13 @@ def build_indexing_pipeline(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass,
enable_large_chunks=enable_large_chunks,
# after every doc, update status in case there are a bunch of
# really long docs
heartbeat=IndexingHeartbeat(
index_attempt_id=attempt_id, db_session=db_session, freq=1
)
if attempt_id
else None,
)
return partial(

View File

@ -16,6 +16,7 @@ from danswer.configs.model_configs import (
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.indexing.indexing_heartbeat import Heartbeat
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@ -95,6 +96,7 @@ class EmbeddingModel:
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
heartbeat: Heartbeat | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
@ -107,6 +109,7 @@ class EmbeddingModel:
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.heartbeat = heartbeat
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
@ -166,6 +169,9 @@ class EmbeddingModel:
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.heartbeat:
self.heartbeat.heartbeat()
return embeddings
def encode(

View File

@ -0,0 +1,18 @@
from typing import Any
import pytest
from danswer.indexing.indexing_heartbeat import Heartbeat
class MockHeartbeat(Heartbeat):
def __init__(self) -> None:
self.call_count = 0
def heartbeat(self, metadata: Any = None) -> None:
self.call_count += 1
@pytest.fixture
def mock_heartbeat() -> MockHeartbeat:
return MockHeartbeat()

View File

@ -1,11 +1,24 @@
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import DefaultIndexingEmbedder
from tests.unit.danswer.indexing.conftest import MockHeartbeat
def test_chunk_document() -> None:
@pytest.fixture
def embedder() -> DefaultIndexingEmbedder:
return DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
@ -30,18 +43,11 @@ def test_chunk_document() -> None:
],
)
embedder = DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
)
chunks = chunker.chunk(document)
chunks = chunker.chunk([document])
assert len(chunks) == 5
assert short_section_1 in chunks[0].content
@ -49,3 +55,29 @@ def test_chunk_document() -> None:
assert short_section_4 in chunks[-1].content
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
def test_chunker_heartbeat(
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
) -> None:
document = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier="Test Document",
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
Section(text="This is a short section.", link="link1"),
],
)
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
heartbeat=mock_heartbeat,
)
chunks = chunker.chunk([document])
assert mock_heartbeat.call_count == 1
assert len(chunks) > 0

View File

@ -0,0 +1,90 @@
from collections.abc import Generator
from unittest.mock import Mock
from unittest.mock import patch
import pytest
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.indexing.embedder import DefaultIndexingEmbedder
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@pytest.fixture
def mock_embedding_model() -> Generator[Mock, None, None]:
with patch("danswer.indexing.embedder.EmbeddingModel") as mock:
yield mock
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
# Setup
embedder = DefaultIndexingEmbedder(
model_name="test-model",
normalize=True,
query_prefix=None,
passage_prefix=None,
provider_type=EmbeddingProvider.OPENAI,
)
# Mock the encode method of the embedding model
mock_embedding_model.return_value.encode.side_effect = [
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], # Main chunk embeddings
[[7.0, 8.0, 9.0]], # Title embedding
]
# Create test input
source_doc = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier="Test Document",
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
Section(text="This is a short section.", link="link1"),
],
)
chunks: list[DocAwareChunk] = [
DocAwareChunk(
chunk_id=1,
blurb="This is a short section.",
content="Test chunk",
source_links={0: "link1"},
section_continuation=False,
source_document=source_doc,
title_prefix="Title: ",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
)
]
# Execute
result: list[IndexChunk] = embedder.embed_chunks(chunks)
# Assert
assert len(result) == 1
assert isinstance(result[0], IndexChunk)
assert result[0].content == "Test chunk"
assert result[0].embeddings == ChunkEmbedding(
full_embedding=[1.0, 2.0, 3.0],
mini_chunk_embeddings=[],
)
assert result[0].title_embedding == [7.0, 8.0, 9.0]
# Verify the embedding model was called correctly
mock_embedding_model.return_value.encode.assert_any_call(
texts=["Title: Test chunk"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
)
# title only embedding call
mock_embedding_model.return_value.encode.assert_any_call(
["Test Document"],
text_type=EmbedTextType.PASSAGE,
)

View File

@ -0,0 +1,80 @@
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from sqlalchemy.orm import Session
from danswer.db.index_attempt import IndexAttempt
from danswer.indexing.indexing_heartbeat import IndexingHeartbeat
@pytest.fixture
def mock_db_session() -> MagicMock:
return MagicMock(spec=Session)
@pytest.fixture
def mock_index_attempt() -> MagicMock:
return MagicMock(spec=IndexAttempt)
def test_indexing_heartbeat(
mock_db_session: MagicMock, mock_index_attempt: MagicMock
) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt:
mock_get_index_attempt.return_value = mock_index_attempt
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=5
)
# Test that heartbeat doesn't update before freq is reached
for _ in range(4):
heartbeat.heartbeat()
mock_db_session.commit.assert_not_called()
# Test that heartbeat updates when freq is reached
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
assert mock_index_attempt.time_updated is not None
mock_db_session.commit.assert_called_once()
# Reset mock calls
mock_db_session.reset_mock()
mock_get_index_attempt.reset_mock()
# Test that heartbeat updates again after freq more calls
for _ in range(5):
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once()
mock_db_session.commit.assert_called_once()
def test_indexing_heartbeat_not_found(mock_db_session: MagicMock) -> None:
with patch(
"danswer.indexing.indexing_heartbeat.get_index_attempt"
) as mock_get_index_attempt, patch(
"danswer.indexing.indexing_heartbeat.logger"
) as mock_logger:
mock_get_index_attempt.return_value = None
heartbeat = IndexingHeartbeat(
index_attempt_id=1, db_session=mock_db_session, freq=1
)
heartbeat.heartbeat()
mock_get_index_attempt.assert_called_once_with(
db_session=mock_db_session, index_attempt_id=1
)
mock_logger.error.assert_called_once_with(
"Index attempt not found, this should not happen!"
)
mock_db_session.commit.assert_not_called()