Kg batch clustering (#4847)

* super genius kg_entity parent migration

* feat: batched clustering

* fix: nit
This commit is contained in:
Rei Meguro
2025-06-08 17:16:10 -04:00
committed by GitHub
parent c5adbe4180
commit 2b812b7d7d
5 changed files with 372 additions and 226 deletions

View File

@@ -0,0 +1,29 @@
"""kgentity_parent
Revision ID: cec7ec36c505
Revises: 495cb26ce93e
Create Date: 2025-06-07 20:07:46.400770
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "cec7ec36c505"
down_revision = "495cb26ce93e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"kg_entity",
sa.Column("parent_key", sa.String(), nullable=True, index=True),
)
# NOTE: you will have to reindex the KG after this migration as the parent_key will be null
def downgrade() -> None:
op.drop_column("kg_entity", "parent_key")

View File

@@ -130,6 +130,7 @@ def transfer_entity(
entity_class=entity.entity_class,
entity_subtype=entity.entity_subtype,
entity_key=entity.entity_key,
parent_key=entity.parent_key,
alternative_names=entity.alternative_names or [],
entity_type_id_name=entity.entity_type_id_name,
document_id=entity.document_id,

View File

@@ -849,6 +849,9 @@ class KGEntity(Base):
entity_key: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
parent_key: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)
entity_subtype: Mapped[str] = mapped_column(
NullFilteredString, nullable=True, index=True
)
@@ -1003,7 +1006,7 @@ class KGEntityExtractionStaging(Base):
)
# Basic entity information
parent_key: Mapped[str] = mapped_column(
parent_key: Mapped[str | None] = mapped_column(
NullFilteredString, nullable=True, index=True
)

View File

@@ -27,7 +27,7 @@ logger = setup_logger()
def upsert_staging_relationship(
db_session: Session,
relationship_id_name: str,
source_document_id: str,
source_document_id: str | None,
occurrences: int = 1,
) -> KGRelationshipExtractionStaging:
"""
@@ -99,6 +99,72 @@ def upsert_staging_relationship(
return result
def upsert_relationship(
db_session: Session,
relationship_id_name: str,
source_document_id: str | None,
occurrences: int = 1,
) -> KGRelationship:
"""
Upsert a new relationship directly to the database.
Args:
db_session: SQLAlchemy database session
relationship_id_name: The ID name of the relationship in format "source__relationship__target"
source_document_id: ID of the source document
occurrences: Number of times this relationship has been found
Returns:
The created or updated KGRelationship object
Raises:
sqlalchemy.exc.IntegrityError: If there's an error with the database operation
"""
# Generate a unique ID for the relationship
relationship_id_name = format_relationship_id(relationship_id_name)
(
source_entity_id_name,
relationship_string,
target_entity_id_name,
) = split_relationship_id(relationship_id_name)
source_entity_type = get_entity_type(source_entity_id_name)
target_entity_type = get_entity_type(target_entity_id_name)
relationship_type = extract_relationship_type_id(relationship_id_name)
# Insert the new relationship
stmt = (
postgresql.insert(KGRelationship)
.values(
{
"id_name": relationship_id_name,
"source_node": source_entity_id_name,
"target_node": target_entity_id_name,
"source_node_type": source_entity_type,
"target_node_type": target_entity_type,
"type": relationship_string.lower(),
"relationship_type_id_name": relationship_type,
"source_document": source_document_id,
"occurrences": occurrences,
}
)
.on_conflict_do_update(
index_elements=["id_name", "source_document"],
set_=dict(
occurrences=KGRelationship.occurrences + occurrences,
),
)
.returning(KGRelationship)
)
new_relationship = db_session.execute(stmt).scalar()
if new_relationship is None:
raise RuntimeError(
f"Failed to upsert relationship with id_name: {relationship_id_name}"
)
db_session.flush()
return new_relationship
def transfer_relationship(
db_session: Session,
relationship: KGRelationshipExtractionStaging,
@@ -108,12 +174,8 @@ def transfer_relationship(
Transfer a relationship from the staging table to the normalized table.
"""
# Translate the source and target nodes
source_node = entity_translations.get(
relationship.source_node, relationship.source_node
)
target_node = entity_translations.get(
relationship.target_node, relationship.target_node
)
source_node = entity_translations[relationship.source_node]
target_node = entity_translations[relationship.target_node]
relationship_id_name = make_relationship_id(
source_node, relationship.type, target_node
)
@@ -218,6 +280,65 @@ def upsert_staging_relationship_type(
return result
def upsert_relationship_type(
db_session: Session,
source_entity_type: str,
relationship_type: str,
target_entity_type: str,
definition: bool = False,
extraction_count: int = 1,
) -> KGRelationshipType:
"""
Upsert a new relationship type directly to the database.
Args:
db_session: SQLAlchemy session
source_entity_type: Type of the source entity
relationship_type: Type of relationship
target_entity_type: Type of the target entity
definition: Whether this relationship type represents a definition (default False)
Returns:
The created KGRelationshipType object
"""
id_name = make_relationship_type_id(
source_entity_type, relationship_type, target_entity_type
)
# Create new relationship type
stmt = (
postgresql.insert(KGRelationshipType)
.values(
{
"id_name": id_name,
"name": relationship_type,
"source_entity_type_id_name": source_entity_type.upper(),
"target_entity_type_id_name": target_entity_type.upper(),
"definition": definition,
"occurrences": extraction_count,
"type": relationship_type, # Using the relationship_type as the type
"active": True, # Setting as active by default
}
)
.on_conflict_do_update(
index_elements=["id_name"],
set_=dict(
occurrences=KGRelationshipType.occurrences + extraction_count,
),
)
.returning(KGRelationshipType)
)
new_relationship_type = db_session.execute(stmt).scalar()
if new_relationship_type is None:
raise RuntimeError(
f"Failed to upsert relationship type with id_name: {id_name}"
)
db_session.flush()
return new_relationship_type
def transfer_relationship_type(
db_session: Session,
relationship_type: KGRelationshipTypeExtractionStaging,
@@ -262,112 +383,6 @@ def transfer_relationship_type(
return new_relationship_type
def get_parent_child_relationships_and_types(
db_session: Session,
depth: int,
) -> tuple[
list[KGRelationshipExtractionStaging], list[KGRelationshipTypeExtractionStaging]
]:
"""
Create parent-child relationships and relationship types from staging entities with
a parent key, if the parent exists in the normalized entities table. Will create
relationships up to depth levels. E.g., if depth is 2, a relationship will be created
between the entity and its parent, and the entity and its grandparents (if any).
A relationship will not be created if the parent does not exist.
"""
relationship_types: dict[str, KGRelationshipTypeExtractionStaging] = {}
relationships: dict[tuple[str, str | None], KGRelationshipExtractionStaging] = {}
parented_entities = (
db_session.query(KGEntityExtractionStaging)
.filter(KGEntityExtractionStaging.parent_key.isnot(None))
.all()
)
# create has_subcomponent relationships and relationship types
for entity in parented_entities:
child = entity
if entity.transferred_id_name is None:
logger.warning(f"Entity {entity.id_name} has not yet been transferred")
continue
for i in range(depth):
if not child.parent_key:
break
# find the transferred parent entity
parent = (
db_session.query(KGEntity)
.filter(
KGEntity.entity_class == child.entity_class,
KGEntity.entity_key == child.parent_key,
)
.first()
)
if parent is None:
logger.warning(f"Parent entity not found for {entity.id_name}")
break
# create the relationship type
relationship_type = upsert_staging_relationship_type(
db_session=db_session,
source_entity_type=parent.entity_type_id_name,
relationship_type="has_subcomponent",
target_entity_type=entity.entity_type_id_name,
definition=False,
extraction_count=1,
)
relationship_types[relationship_type.id_name] = relationship_type
# create the relationship
# (don't add it to the table as we're using the transferred id, which breaks fk constraints)
relationship_id_name = make_relationship_id(
parent.id_name, "has_subcomponent", entity.transferred_id_name
)
if (parent.id_name, entity.document_id) not in relationships:
(
source_entity_id_name,
relationship_string,
target_entity_id_name,
) = split_relationship_id(relationship_id_name)
source_entity_type = get_entity_type(source_entity_id_name)
target_entity_type = get_entity_type(target_entity_id_name)
relationship_type_id_name = extract_relationship_type_id(
relationship_id_name
)
relationships[(relationship_id_name, entity.document_id)] = (
KGRelationshipExtractionStaging(
id_name=relationship_id_name,
source_node=source_entity_id_name,
target_node=target_entity_id_name,
source_node_type=source_entity_type,
target_node_type=target_entity_type,
type=relationship_string,
relationship_type_id_name=relationship_type_id_name,
source_document=entity.document_id,
occurrences=1,
)
)
else:
relationships[(parent.id_name, entity.document_id)].occurrences += 1
# set parent as the next child (unless we're at the max depth)
if i < depth - 1:
parent_staging = (
db_session.query(KGEntityExtractionStaging)
.filter(
KGEntityExtractionStaging.transferred_id_name == parent.id_name
)
.first()
)
if parent_staging is None:
break
child = parent_staging
return list(relationships.values()), list(relationship_types.values())
def delete_relationships_by_id_names(
db_session: Session, id_names: list[str], kg_stage: KGStage
) -> int:

View File

@@ -1,3 +1,4 @@
from collections.abc import Generator
from typing import cast
from rapidfuzz.fuzz import ratio
@@ -16,14 +17,17 @@ from onyx.db.models import Document
from onyx.db.models import KGEntityType
from onyx.db.models import KGRelationshipExtractionStaging
from onyx.db.models import KGRelationshipTypeExtractionStaging
from onyx.db.relationships import get_parent_child_relationships_and_types
from onyx.db.relationships import transfer_relationship
from onyx.db.relationships import transfer_relationship_type
from onyx.db.relationships import upsert_relationship
from onyx.db.relationships import upsert_relationship_type
from onyx.document_index.vespa.kg_interactions import (
get_kg_vespa_info_update_requests_for_document,
)
from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info
from onyx.kg.configuration import validate_kg_settings
from onyx.kg.models import KGGroundingType
from onyx.kg.utils.formatting_utils import make_relationship_id
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
@@ -31,6 +35,84 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
logger = setup_logger()
def _get_batch_untransferred_grounded_entities(
batch_size: int,
) -> Generator[list[KGEntityExtractionStaging], None, None]:
while True:
with get_session_with_current_tenant() as db_session:
batch = (
db_session.query(KGEntityExtractionStaging)
.join(
KGEntityType,
KGEntityExtractionStaging.entity_type_id_name
== KGEntityType.id_name,
)
.filter(
KGEntityType.grounding == KGGroundingType.GROUNDED,
KGEntityExtractionStaging.transferred_id_name.is_(None),
)
.limit(batch_size)
.all()
)
if not batch:
break
yield batch
def _get_batch_untransferred_relationship_types(
batch_size: int,
) -> Generator[list[KGRelationshipTypeExtractionStaging], None, None]:
while True:
with get_session_with_current_tenant() as db_session:
batch = (
db_session.query(KGRelationshipTypeExtractionStaging)
.filter(KGRelationshipTypeExtractionStaging.transferred.is_(False))
.limit(batch_size)
.all()
)
if not batch:
break
yield batch
def _get_batch_untransferred_relationships(
batch_size: int,
) -> Generator[list[KGRelationshipExtractionStaging], None, None]:
while True:
with get_session_with_current_tenant() as db_session:
batch = (
db_session.query(KGRelationshipExtractionStaging)
.filter(KGRelationshipExtractionStaging.transferred.is_(False))
.limit(batch_size)
.all()
)
if not batch:
break
yield batch
def _get_batch_entities_with_parent(
batch_size: int,
) -> Generator[list[KGEntityExtractionStaging], None, None]:
offset = 0
while True:
with get_session_with_current_tenant() as db_session:
batch = (
db_session.query(KGEntityExtractionStaging)
.filter(KGEntityExtractionStaging.parent_key.isnot(None))
.order_by(KGEntityExtractionStaging.id_name)
.offset(offset)
.limit(batch_size)
.all()
)
if not batch:
break
# we can't filter out ""s earlier as it will mess up the pagination
yield [entity for entity in batch if entity.parent_key != ""]
offset += batch_size
def _cluster_one_grounded_entity(
entity: KGEntityExtractionStaging,
) -> tuple[KGEntity, bool]:
@@ -105,14 +187,75 @@ def _cluster_one_grounded_entity(
return transferred_entity, update_vespa
def _transfer_batch_relationship(
def _create_one_parent_child_relationship(entity: KGEntityExtractionStaging) -> None:
"""
Creates a relationship between the entity and its parent, if it exists.
Then, updates the entity's parent to the next ancestor.
"""
with get_session_with_current_tenant() as db_session:
# find the next ancestor
parent = (
db_session.query(KGEntity)
.filter(KGEntity.entity_key == entity.parent_key)
.first()
)
if parent is not None:
# create parent child relationship and relationship type
upsert_relationship_type(
db_session=db_session,
source_entity_type=parent.entity_type_id_name,
relationship_type="has_subcomponent",
target_entity_type=entity.entity_type_id_name,
)
relationship_id_name = make_relationship_id(
parent.id_name,
"has_subcomponent",
cast(str, entity.transferred_id_name),
)
upsert_relationship(
db_session=db_session,
relationship_id_name=relationship_id_name,
source_document_id=entity.document_id,
)
next_ancestor = parent.parent_key or ""
else:
next_ancestor = ""
# set the staging entity's parent to the next ancestor
# if there is no parent or next ancestor, set to "" to differentiate from None
# None will mess up the pagination in _get_batch_entities_with_parent
db_session.query(KGEntityExtractionStaging).filter(
KGEntityExtractionStaging.id_name == entity.id_name
).update({"parent_key": next_ancestor})
db_session.commit()
def _transfer_batch_relationship_and_update_vespa(
relationships: list[KGRelationshipExtractionStaging],
entity_translations: dict[str, str],
) -> set[str]:
updated_documents: set[str] = set()
index_name: str,
tenant_id: str,
) -> None:
docs_to_update: set[str] = set()
with get_session_with_current_tenant() as db_session:
entity_id_names: set[str] = set()
# get the translations
staging_entity_id_names: set[str] = set()
for relationship in relationships:
staging_entity_id_names.add(relationship.source_node)
staging_entity_id_names.add(relationship.target_node)
entity_translations: dict[str, str] = {
entity.id_name: entity.transferred_id_name
for entity in db_session.query(KGEntityExtractionStaging)
.filter(KGEntityExtractionStaging.id_name.in_(staging_entity_id_names))
.all()
if entity.transferred_id_name is not None
}
# transfer the relationships
for relationship in relationships:
transferred_relationship = transfer_relationship(
db_session=db_session,
@@ -121,19 +264,29 @@ def _transfer_batch_relationship(
)
entity_id_names.add(transferred_relationship.source_node)
entity_id_names.add(transferred_relationship.target_node)
updated_documents.update(
(
res[0]
for res in db_session.query(KGEntity.document_id)
.filter(KGEntity.id_name.in_(entity_id_names))
.all()
if res[0] is not None
)
)
db_session.commit()
return updated_documents
# get all documents that require a vespa update
docs_to_update |= {
entity.document_id
for entity in db_session.query(KGEntity)
.filter(KGEntity.id_name.in_(entity_id_names))
.all()
if entity.document_id is not None
}
# update vespa in parallel
batch_update_requests = run_functions_tuples_in_parallel(
[
(
get_kg_vespa_info_update_requests_for_document,
(document_id, index_name, tenant_id),
)
for document_id in docs_to_update
]
)
for update_requests in batch_update_requests:
update_kg_chunks_vespa_info(update_requests, index_name, tenant_id)
def kg_clustering(
@@ -151,110 +304,55 @@ def kg_clustering(
This will change with deep extraction, where grounded-sourceless entities
can be extracted and then need to be clustered.
"""
# TODO: revisit splitting into batches
logger.info(f"Starting kg clustering for tenant {tenant_id}")
with get_session_with_current_tenant() as db_session:
kg_config_settings = get_kg_config_settings(db_session)
validate_kg_settings(kg_config_settings)
# Retrieve staging data
with get_session_with_current_tenant() as db_session:
untransferred_relationship_types = (
db_session.query(KGRelationshipTypeExtractionStaging)
.filter(KGRelationshipTypeExtractionStaging.transferred.is_(False))
.all()
)
untransferred_relationships = (
db_session.query(KGRelationshipExtractionStaging)
.filter(KGRelationshipExtractionStaging.transferred.is_(False))
.all()
)
grounded_entities = (
db_session.query(KGEntityExtractionStaging)
.join(
KGEntityType,
KGEntityExtractionStaging.entity_type_id_name == KGEntityType.id_name,
)
.filter(KGEntityType.grounding == KGGroundingType.GROUNDED)
.all()
)
# Cluster and transfer grounded entities sequentially
for untransferred_grounded_entities in _get_batch_untransferred_grounded_entities(
batch_size=processing_chunk_batch_size
):
for entity in untransferred_grounded_entities:
_cluster_one_grounded_entity(entity)
# NOTE: we assume every entity is transferred, as we currently only have grounded entities
logger.info("Finished transferring all entities")
# Cluster and transfer grounded entities
untransferred_grounded_entities = [
entity for entity in grounded_entities if entity.transferred_id_name is None
]
entity_translations: dict[str, str] = {
entity.id_name: entity.transferred_id_name
for entity in grounded_entities
if entity.transferred_id_name is not None
}
vespa_update_documents: set[str] = set()
for entity in untransferred_grounded_entities:
added_entity, update_vespa = _cluster_one_grounded_entity(entity)
entity_translations[entity.id_name] = added_entity.id_name
if update_vespa and added_entity.document_id is not None:
vespa_update_documents.add(added_entity.document_id)
logger.info(f"Transferred {len(untransferred_grounded_entities)} entities")
# Add parent-child relationships and relationship types
with get_session_with_current_tenant() as db_session:
parent_child_relationships, parent_child_relationship_types = (
get_parent_child_relationships_and_types(
db_session, depth=kg_config_settings.KG_MAX_PARENT_RECURSION_DEPTH
)
)
untransferred_relationship_types.extend(parent_child_relationship_types)
untransferred_relationships.extend(parent_child_relationships)
db_session.commit()
# Transfer the relationship types
for relationship_type in untransferred_relationship_types:
with get_session_with_current_tenant() as db_session:
transfer_relationship_type(db_session, relationship_type=relationship_type)
db_session.commit()
logger.info(
f"Transferred {len(untransferred_relationship_types)} relationship types"
)
# Transfer relationships in parallel
updated_documents_batch: list[set[str]] = run_functions_tuples_in_parallel(
[
(
_transfer_batch_relationship,
(
untransferred_relationships[
batch_i : batch_i + processing_chunk_batch_size
],
entity_translations,
),
)
for batch_i in range(
0, len(untransferred_relationships), processing_chunk_batch_size
)
]
)
for updated_documents in updated_documents_batch:
vespa_update_documents.update(updated_documents)
logger.info(f"Transferred {len(untransferred_relationships)} relationships")
# Update vespa for documents that had their kg info updated in parallel
for i in range(0, len(vespa_update_documents), processing_chunk_batch_size):
batch_update_requests = run_functions_tuples_in_parallel(
[
(
get_kg_vespa_info_update_requests_for_document,
(document_id, index_name, tenant_id),
)
for document_id in list(vespa_update_documents)[
i : i + processing_chunk_batch_size
# Create parent-child relationships in parallel
for _ in range(kg_config_settings.KG_MAX_PARENT_RECURSION_DEPTH):
for root_entities in _get_batch_entities_with_parent(
batch_size=processing_chunk_batch_size
):
run_functions_tuples_in_parallel(
[
(_create_one_parent_child_relationship, (root_entity,))
for root_entity in root_entities
]
]
)
logger.info("Finished creating all parent-child relationships")
# Transfer the relationship types (no need to do in parallel as there's only a few)
for relationship_types in _get_batch_untransferred_relationship_types(
batch_size=processing_chunk_batch_size
):
with get_session_with_current_tenant() as db_session:
for relationship_type in relationship_types:
transfer_relationship_type(db_session, relationship_type)
db_session.commit()
logger.info("Finished transferring all relationship types")
# Transfer the relationships and update vespa in parallel
# NOTE we assume there are no entities that aren't part of any relationships
for untransferred_relationships in _get_batch_untransferred_relationships(
batch_size=processing_chunk_batch_size
):
_transfer_batch_relationship_and_update_vespa(
relationships=untransferred_relationships,
index_name=index_name,
tenant_id=tenant_id,
)
for update_requests in batch_update_requests:
update_kg_chunks_vespa_info(update_requests, index_name, tenant_id)
logger.info("Finished transferring all relationships")
# Delete the transferred objects from the staging tables
try: