Fix document counts (#3671)

* Various fixes/improvements to document counting

* Add new column + index

* Avoid double scan

* comment fixes

* Fix revision history

* Fix IT

* Fix IT

* Fix migration

* Rebase
This commit is contained in:
Chris Weaver 2025-01-18 21:36:07 -08:00 committed by GitHub
parent b25668c83a
commit 342bb9f685
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 184 additions and 49 deletions

View File

@ -0,0 +1,48 @@
"""Add has_been_indexed to DocumentByConnectorCredentialPair
Revision ID: c7bf5721733e
Revises: fec3db967bf7
Create Date: 2025-01-13 12:39:05.831693
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "c7bf5721733e"
down_revision = "027381bce97c"
branch_labels = None
depends_on = None
def upgrade() -> None:
# assume all existing rows have been indexed, no better approach
op.add_column(
"document_by_connector_credential_pair",
sa.Column("has_been_indexed", sa.Boolean(), nullable=True),
)
op.execute(
"UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE"
)
op.alter_column(
"document_by_connector_credential_pair",
"has_been_indexed",
nullable=False,
)
# Add index to optimize get_document_counts_for_cc_pairs query pattern
op.create_index(
"idx_document_cc_pair_counts",
"document_by_connector_credential_pair",
["connector_id", "credential_id", "has_been_indexed"],
unique=False,
)
def downgrade() -> None:
# Remove the index first before removing the column
op.drop_index(
"idx_document_cc_pair_counts",
table_name="document_by_connector_credential_pair",
)
op.drop_column("document_by_connector_credential_pair", "has_been_indexed")

View File

@ -101,10 +101,17 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
for doc in doc_batch:
cleaned_doc = doc.model_copy()
# Postgres cannot handle NUL characters in text fields
if "\x00" in cleaned_doc.id:
logger.warning(f"NUL characters found in document ID: {cleaned_doc.id}")
cleaned_doc.id = cleaned_doc.id.replace("\x00", "")
if cleaned_doc.title and "\x00" in cleaned_doc.title:
logger.warning(
f"NUL characters found in document title: {cleaned_doc.title}"
)
cleaned_doc.title = cleaned_doc.title.replace("\x00", "")
if "\x00" in cleaned_doc.semantic_identifier:
logger.warning(
f"NUL characters found in document semantic identifier: {cleaned_doc.semantic_identifier}"
@ -120,6 +127,9 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
return cleaned_batch
@ -277,8 +287,6 @@ def _run_indexing(
tenant_id=tenant_id,
)
all_connector_doc_ids: set[str] = set()
tracer_counter = 0
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
@ -347,16 +355,15 @@ def _run_indexing(
index_attempt_md.batch_num = batch_num + 1 # use 1-index for this
# real work happens here!
new_docs, total_batch_chunks = indexing_pipeline(
index_pipeline_result = indexing_pipeline(
document_batch=doc_batch_cleaned,
index_attempt_metadata=index_attempt_md,
)
batch_num += 1
net_doc_change += new_docs
chunk_count += total_batch_chunks
document_count += len(doc_batch_cleaned)
all_connector_doc_ids.update(doc.id for doc in doc_batch_cleaned)
net_doc_change += index_pipeline_result.new_docs
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
@ -365,9 +372,6 @@ def _run_indexing(
# be inaccurate
db_session.commit()
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
update_docs_indexed(
@ -378,6 +382,9 @@ def _run_indexing(
docs_removed_from_index=0,
)
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0

View File

@ -1,6 +1,7 @@
import contextlib
import time
from collections.abc import Generator
from collections.abc import Iterable
from collections.abc import Sequence
from datetime import datetime
from datetime import timezone
@ -13,6 +14,7 @@ from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import tuple_
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.engine.util import TransactionalContext
from sqlalchemy.exc import OperationalError
@ -226,10 +228,13 @@ def get_document_counts_for_cc_pairs(
func.count(),
)
.where(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids)
and_(
tuple_(
DocumentByConnectorCredentialPair.connector_id,
DocumentByConnectorCredentialPair.credential_id,
).in_(cc_ids),
DocumentByConnectorCredentialPair.has_been_indexed.is_(True),
)
)
.group_by(
DocumentByConnectorCredentialPair.connector_id,
@ -382,18 +387,40 @@ def upsert_document_by_connector_credential_pair(
id=doc_id,
connector_id=connector_id,
credential_id=credential_id,
has_been_indexed=False,
)
)
for doc_id in document_ids
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
# this must be `on_conflict_do_nothing` rather than `on_conflict_do_update`
# since we don't want to update the `has_been_indexed` field for documents
# that already exist
on_conflict_stmt = insert_stmt.on_conflict_do_nothing()
db_session.execute(on_conflict_stmt)
db_session.commit()
def mark_document_as_indexed_for_cc_pair__no_commit(
db_session: Session,
connector_id: int,
credential_id: int,
document_ids: Iterable[str],
) -> None:
"""Should be called only after a successful index operation for a batch."""
db_session.execute(
update(DocumentByConnectorCredentialPair)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
DocumentByConnectorCredentialPair.id.in_(document_ids),
)
)
.values(has_been_indexed=True)
)
def update_docs_updated_at__no_commit(
ids_to_new_updated_at: dict[str, datetime],
db_session: Session,

View File

@ -941,6 +941,12 @@ class DocumentByConnectorCredentialPair(Base):
ForeignKey("credential.id"), primary_key=True
)
# used to better keep track of document counts at a connector level
# e.g. if a document is added as part of permission syncing, it should
# not be counted as part of the connector's document count until
# the actual indexing is complete
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector"
)
@ -955,6 +961,14 @@ class DocumentByConnectorCredentialPair(Base):
"credential_id",
unique=False,
),
# Index to optimize get_document_counts_for_cc_pairs query pattern
Index(
"idx_document_cc_pair_counts",
"connector_id",
"credential_id",
"has_been_indexed",
unique=False,
),
)

View File

@ -21,6 +21,7 @@ from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
from onyx.db.document import prepare_to_modify_documents
from onyx.db.document import update_docs_chunk_count__no_commit
from onyx.db.document import update_docs_last_modified__no_commit
@ -55,12 +56,23 @@ class DocumentBatchPrepareContext(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
class IndexingPipelineResult(BaseModel):
# number of documents that are completely new (e.g. did
# not exist as a part of this OR any other connector)
new_docs: int
# NOTE: need total_docs, since the pipeline can skip some docs
# (e.g. not even insert them into Postgres)
total_docs: int
# number of chunks that were inserted into Vespa
total_chunks: int
class IndexingPipelineProtocol(Protocol):
def __call__(
self,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
...
@ -147,10 +159,12 @@ def index_doc_batch_with_handler(
db_session: Session,
ignore_time_skip: bool = False,
tenant_id: str | None = None,
) -> tuple[int, int]:
r = (0, 0)
) -> IndexingPipelineResult:
index_pipeline_result = IndexingPipelineResult(
new_docs=0, total_docs=len(document_batch), total_chunks=0
)
try:
r = index_doc_batch(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
document_index=document_index,
@ -203,7 +217,7 @@ def index_doc_batch_with_handler(
else:
pass
return r
return index_pipeline_result
def index_doc_batch_prepare(
@ -227,6 +241,15 @@ def index_doc_batch_prepare(
if not ignore_time_skip
else documents
)
if len(updatable_docs) != len(documents):
updatable_doc_ids = [doc.id for doc in updatable_docs]
skipped_doc_ids = [
doc.id for doc in documents if doc.id not in updatable_doc_ids
]
logger.info(
f"Skipping {len(skipped_doc_ids)} documents "
f"because they are up to date. Skipped doc IDs: {skipped_doc_ids}"
)
# for all updatable docs, upsert into the DB
# Does not include doc_updated_at which is also used to indicate a successful update
@ -263,21 +286,6 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
# Remove any NUL characters from title/semantic_id
# This is a known issue with the Zendesk connector
# Postgres cannot handle NUL characters in text fields
if document.title:
document.title = document.title.replace("\x00", "")
if document.semantic_identifier:
document.semantic_identifier = document.semantic_identifier.replace(
"\x00", ""
)
# Remove NUL characters from all sections
for section in document.sections:
if section.text is not None:
section.text = section.text.replace("\x00", "")
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
@ -333,7 +341,7 @@ def index_doc_batch(
ignore_time_skip: bool = False,
tenant_id: str | None = None,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> tuple[int, int]:
) -> IndexingPipelineResult:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements
@ -359,7 +367,18 @@ def index_doc_batch(
db_session=db_session,
)
if not ctx:
return 0, 0
# even though we didn't actually index anything, we should still
# mark them as "completed" for the CC Pair in order to make the
# counts match
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
return IndexingPipelineResult(
new_docs=0, total_docs=len(filtered_documents), total_chunks=0
)
logger.debug("Starting chunking")
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
@ -425,7 +444,8 @@ def index_doc_batch(
]
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
"Indexing the following chunks: "
f"{[chunk.to_short_descriptor() for chunk in access_aware_chunks]}"
)
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
@ -440,14 +460,17 @@ def index_doc_batch(
),
)
successful_doc_ids = [record.document_id for record in insertion_records]
successful_docs = [
doc for doc in ctx.updatable_docs if doc.id in successful_doc_ids
]
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Successful IDs: {successful_doc_ids}"
)
last_modified_ids = []
ids_to_new_updated_at = {}
for doc in successful_docs:
for doc in ctx.updatable_docs:
last_modified_ids.append(doc.id)
# doc_updated_at is the source's idea (on the other end of the connector)
# of when the doc was last modified
@ -469,11 +492,24 @@ def index_doc_batch(
db_session=db_session,
)
# these documents can now be counted as part of the CC Pairs
# document count, so we need to mark them as indexed
# NOTE: even documents we skipped since they were already up
# to date should be counted here in order to maintain parity
# between CC Pair and index attempt counts
mark_document_as_indexed_for_cc_pair__no_commit(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_ids=[doc.id for doc in filtered_documents],
db_session=db_session,
)
db_session.commit()
result = (
len([r for r in insertion_records if r.already_existed is False]),
len(access_aware_chunks),
result = IndexingPipelineResult(
new_docs=len([r for r in insertion_records if r.already_existed is False]),
total_docs=len(filtered_documents),
total_chunks=len(access_aware_chunks),
)
return result

View File

@ -108,7 +108,7 @@ def upsert_ingestion_doc(
tenant_id=tenant_id,
)
new_doc, __chunk_count = indexing_pipeline(
indexing_pipeline_result = indexing_pipeline(
document_batch=[document],
index_attempt_metadata=IndexAttemptMetadata(
connector_id=cc_pair.connector_id,
@ -150,4 +150,7 @@ def upsert_ingestion_doc(
),
)
return IngestionResult(document_id=document.id, already_existed=not bool(new_doc))
return IngestionResult(
document_id=document.id,
already_existed=indexing_pipeline_result.new_docs > 0,
)