mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Transition to using access_control_list
to manage access in Vespa (#450)
This commit is contained in:
parent
c4e4e88301
commit
8594bac30b
@ -52,6 +52,9 @@ WORKDIR /app/danswer/datastores/vespa/app_config
|
||||
RUN zip -r /app/danswer/vespa-app.zip .
|
||||
WORKDIR /app
|
||||
|
||||
# TODO: remove this once all users have migrated
|
||||
COPY ./danswer/scripts/migrate_vespa_to_acl.py /app/migrate_vespa_to_acl.py
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
# By default this container does nothing, it is used by api server and background which specify their own CMD
|
||||
|
36
backend/danswer/access/access.py
Normal file
36
backend/danswer/access/access.py
Normal file
@ -0,0 +1,36 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
|
||||
|
||||
def _get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
|
||||
db_session: Session,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
cc_pair_to_delete=cc_pair_to_delete,
|
||||
)
|
||||
return {
|
||||
document_id: DocumentAccess.build(user_ids, is_public)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
}
|
||||
|
||||
|
||||
def get_access_for_documents(
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
db_session: Session | None = None,
|
||||
) -> dict[str, DocumentAccess]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
return _get_access_for_documents(
|
||||
document_ids, cc_pair_to_delete, db_session
|
||||
)
|
||||
|
||||
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)
|
20
backend/danswer/access/models.py
Normal file
20
backend/danswer/access/models.py
Normal file
@ -0,0 +1,20 @@
|
||||
from dataclasses import dataclass
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess:
|
||||
user_ids: set[str] # stringified UUIDs
|
||||
is_public: bool
|
||||
|
||||
def to_acl(self) -> list[str]:
|
||||
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
|
||||
@classmethod
|
||||
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
|
||||
return cls(
|
||||
user_ids={str(user_id) for user_id in user_ids if user_id},
|
||||
is_public=is_public,
|
||||
)
|
@ -10,11 +10,9 @@ are multiple connector / credential pairs that have indexed it
|
||||
connector / credential pair from the access list
|
||||
(6) delete all relevant entries from postgres
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
@ -22,24 +20,78 @@ from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import delete_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import (
|
||||
delete_document_by_connector_credential_pair_for_connector_credential_pair,
|
||||
)
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair
|
||||
from danswer.db.document import delete_documents_complete
|
||||
from danswer.db.document import (
|
||||
get_document_by_connector_credential_pairs_indexed_by_multiple,
|
||||
)
|
||||
from danswer.db.document import (
|
||||
get_documents_with_single_connector_credential_pair,
|
||||
)
|
||||
from danswer.db.document import get_document_connector_cnts
|
||||
from danswer.db.document import get_documents_for_connector_credential_pair
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import Credential
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_DELETION_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def _delete_connector_credential_pair_batch(
|
||||
document_ids: list[str],
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
document_index: DocumentIndex,
|
||||
) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquire lock for all documents in this batch so that indexing can't
|
||||
# override the deletion
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
|
||||
|
||||
document_connector_cnts = get_document_connector_cnts(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
|
||||
# figure out which docs need to be completely deleted
|
||||
document_ids_to_delete = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt == 1
|
||||
]
|
||||
logger.debug(f"Deleting documents: {document_ids_to_delete}")
|
||||
document_index.delete(doc_ids=document_ids_to_delete)
|
||||
delete_documents_complete(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_delete,
|
||||
)
|
||||
|
||||
# figure out which docs need to be updated
|
||||
document_ids_to_update = [
|
||||
document_id for document_id, cnt in document_connector_cnts if cnt > 1
|
||||
]
|
||||
access_for_documents = get_access_for_documents(
|
||||
document_ids=document_ids_to_update,
|
||||
db_session=db_session,
|
||||
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=access,
|
||||
)
|
||||
for document_id, access in access_for_documents.items()
|
||||
]
|
||||
logger.debug(f"Updating documents: {document_ids_to_update}")
|
||||
document_index.update(update_requests=update_requests)
|
||||
delete_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids_to_update,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def _delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
@ -47,147 +99,45 @@ def _delete_connector_credential_pair(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
num_docs_deleted = 0
|
||||
while True:
|
||||
documents = get_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
limit=_DELETION_BATCH_SIZE,
|
||||
)
|
||||
if not documents:
|
||||
break
|
||||
|
||||
_delete_connector_credential_pair_batch(
|
||||
document_ids=[document.id for document in documents],
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_index=document_index,
|
||||
)
|
||||
num_docs_deleted += len(documents)
|
||||
|
||||
# cleanup everything else up
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair or not check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
|
||||
"This is likely because there is an ongoing / planned indexing attempt OR the "
|
||||
"connector is not disabled."
|
||||
)
|
||||
|
||||
def _delete_singly_indexed_docs() -> int:
|
||||
# if a document store entry is only indexed by this connector_credential_pair, delete it
|
||||
docs_to_delete = get_documents_with_single_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
if docs_to_delete:
|
||||
document_ids = [doc.id for doc in docs_to_delete]
|
||||
document_index.delete(doc_ids=document_ids)
|
||||
|
||||
# removes all `DocumentByConnectorCredentialPair`, and `Document`
|
||||
# rows from the DB
|
||||
delete_documents_complete(
|
||||
db_session=db_session,
|
||||
document_ids=list(document_ids),
|
||||
)
|
||||
|
||||
return len(docs_to_delete)
|
||||
|
||||
num_docs_deleted = _delete_singly_indexed_docs()
|
||||
logger.info(f"Deleted {num_docs_deleted} documents from document stores")
|
||||
|
||||
def _update_multi_indexed_docs(
|
||||
connector_credential_pair: ConnectorCredentialPair,
|
||||
) -> None:
|
||||
# if a document is indexed by multiple connector_credential_pairs, we should
|
||||
# update its access rather than outright delete it
|
||||
document_by_connector_credential_pairs_to_update = (
|
||||
get_document_by_connector_credential_pairs_indexed_by_multiple(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
def _get_user(
|
||||
credential: Credential,
|
||||
) -> str:
|
||||
if credential.public_doc or not credential.user:
|
||||
return PUBLIC_DOC_PAT
|
||||
|
||||
return str(credential.user.id)
|
||||
|
||||
# find out which documents need to be updated and what their new allowed_users
|
||||
# should be. This is a bit slow as it requires looping through all the documents
|
||||
to_be_deleted_user = _get_user(connector_credential_pair.credential)
|
||||
document_ids_not_needing_update: set[str] = set()
|
||||
document_id_to_allowed_users: dict[str, list[str]] = defaultdict(list)
|
||||
for (
|
||||
document_by_connector_credential_pair
|
||||
) in document_by_connector_credential_pairs_to_update:
|
||||
document_id = document_by_connector_credential_pair.id
|
||||
user = _get_user(document_by_connector_credential_pair.credential)
|
||||
document_id_to_allowed_users[document_id].append(user)
|
||||
|
||||
# if there's another connector / credential pair which has indexed this
|
||||
# document with the same access, we don't need to update it since removing
|
||||
# the access from this connector / credential pair won't change anything
|
||||
if (
|
||||
document_by_connector_credential_pair.connector_id != connector_id
|
||||
or document_by_connector_credential_pair.credential_id != credential_id
|
||||
) and user == to_be_deleted_user:
|
||||
document_ids_not_needing_update.add(document_id)
|
||||
|
||||
# categorize into groups of updates to try and batch them more efficiently
|
||||
update_groups: dict[tuple[str, ...], list[str]] = {}
|
||||
for document_id, allowed_users_lst in document_id_to_allowed_users.items():
|
||||
if document_id in document_ids_not_needing_update:
|
||||
continue
|
||||
|
||||
allowed_users_lst.remove(to_be_deleted_user)
|
||||
allowed_users = tuple(sorted(set(allowed_users_lst)))
|
||||
update_groups[allowed_users] = update_groups.get(allowed_users, []) + [
|
||||
document_id
|
||||
]
|
||||
|
||||
# actually perform the updates in the document store
|
||||
update_requests = [
|
||||
UpdateRequest(document_ids=document_ids, allowed_users=list(allowed_users))
|
||||
for allowed_users, document_ids in update_groups.items()
|
||||
]
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# delete the rest of the `document_by_connector_credential_pair` rows for
|
||||
# this connector / credential pair
|
||||
delete_document_by_connector_credential_pair_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
_update_multi_indexed_docs(cc_pair)
|
||||
|
||||
def _cleanup() -> None:
|
||||
# cleanup everything else up
|
||||
# we cannot undo the deletion of the document store entries if something
|
||||
# goes wrong since they happen outside the postgres world. Best we can do
|
||||
# is keep everything else around and mark the deletion attempt as failed.
|
||||
# If it's a transient failure, re-deleting the connector / credential pair should
|
||||
# fix the weird state.
|
||||
# TODO: lock anything to do with this connector via transaction isolation
|
||||
# NOTE: we have to delete index_attempts and deletion_attempts since they both
|
||||
# have foreign key columns to the connector
|
||||
delete_index_attempts(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.debug("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
_cleanup()
|
||||
delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
logger.debug("Found no credentials left for connector, deleting connector")
|
||||
db_session.delete(connector)
|
||||
db_session.commit()
|
||||
|
||||
logger.info(
|
||||
"Successfully deleted connector_credential_pair with connector_id:"
|
||||
@ -199,6 +149,21 @@ def _delete_connector_credential_pair(
|
||||
def cleanup_connector_credential_pair(connector_id: int, credential_id: int) -> int:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
# validate that the connector / credential pair is deletable
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if not cc_pair or not check_deletion_attempt_is_allowed(
|
||||
connector_credential_pair=cc_pair
|
||||
):
|
||||
raise ValueError(
|
||||
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
|
||||
"This is likely because there is an ongoing / planned indexing attempt OR the "
|
||||
"connector is not disabled."
|
||||
)
|
||||
|
||||
try:
|
||||
return _delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
|
@ -1,9 +1,11 @@
|
||||
import inspect
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import fields
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import BOOST
|
||||
from danswer.configs.constants import MATCH_HIGHLIGHTS
|
||||
@ -14,6 +16,7 @@ from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@ -55,6 +58,34 @@ class IndexChunk(DocAwareChunk):
|
||||
embeddings: ChunkEmbedding
|
||||
|
||||
|
||||
@dataclass
|
||||
class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
"""An `IndexChunk` that contains all necessary metadata to be indexed. This includes
|
||||
the following:
|
||||
|
||||
access: holds all information about which users should have access to the
|
||||
source document for this chunk.
|
||||
document_sets: all document sets the source document for this chunk is a part
|
||||
of. This is used for filtering / personas.
|
||||
"""
|
||||
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
|
||||
@classmethod
|
||||
def from_index_chunk(
|
||||
cls, index_chunk: IndexChunk, access: "DocumentAccess", document_sets: set[str]
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
return cls(
|
||||
**{
|
||||
field.name: getattr(index_chunk, field.name)
|
||||
for field in fields(index_chunk)
|
||||
},
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InferenceChunk(BaseChunk):
|
||||
document_id: str
|
||||
|
@ -11,7 +11,8 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
|
||||
SECTION_CONTINUATION = "section_continuation"
|
||||
EMBEDDINGS = "embeddings"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ALLOWED_GROUPS = "allowed_groups"
|
||||
ACCESS_CONTROL_LIST = "access_control_list"
|
||||
DOCUMENT_SETS = "document_sets"
|
||||
METADATA = "metadata"
|
||||
MATCH_HIGHLIGHTS = "match_highlights"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
@ -20,6 +21,7 @@ MATCH_HIGHLIGHTS = "match_highlights"
|
||||
IGNORE_FOR_QA = "ignore_for_qa"
|
||||
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
PUBLIC_DOCUMENT_SET = "__PUBLIC"
|
||||
QUOTE = "quote"
|
||||
BOOST = "boost"
|
||||
SCORE = "score"
|
||||
|
@ -1,15 +1,8 @@
|
||||
import math
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from copy import deepcopy
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
|
||||
|
||||
DEFAULT_BATCH_SIZE = 30
|
||||
@ -37,89 +30,3 @@ def get_uuid_from_chunk(
|
||||
[doc_str, str(chunk.chunk_id), str(mini_chunk_ind)]
|
||||
)
|
||||
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
|
||||
class CrossConnectorDocumentMetadata(BaseModel):
|
||||
"""Represents metadata about a single document. This is needed since the
|
||||
`Document` class represents a document from a single connector, but that same
|
||||
document may be indexed by multiple connectors."""
|
||||
|
||||
allowed_users: list[str]
|
||||
allowed_user_groups: list[str]
|
||||
already_in_index: bool
|
||||
|
||||
|
||||
# Takes the chunk identifier returns the existing metaddata about that chunk
|
||||
CrossConnectorDocumentMetadataFetchCallable = Callable[
|
||||
[str], CrossConnectorDocumentMetadata | None
|
||||
]
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _add_if_not_exists(obj_list: list[T], item: T) -> list[T]:
|
||||
if item in obj_list:
|
||||
return obj_list
|
||||
return obj_list + [item]
|
||||
|
||||
|
||||
def update_cross_connector_document_metadata_map(
|
||||
chunk: IndexChunk,
|
||||
cross_connector_document_metadata_map: dict[str, CrossConnectorDocumentMetadata],
|
||||
doc_store_cross_connector_document_metadata_fetch_fn: CrossConnectorDocumentMetadataFetchCallable,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> tuple[dict[str, CrossConnectorDocumentMetadata], bool]:
|
||||
"""Returns an updated document_id -> CrossConnectorDocumentMetadata map and
|
||||
if the document's chunks need to be wiped."""
|
||||
user_str = (
|
||||
PUBLIC_DOC_PAT
|
||||
if index_attempt_metadata.user_id is None
|
||||
else str(index_attempt_metadata.user_id)
|
||||
)
|
||||
|
||||
cross_connector_document_metadata_map = deepcopy(
|
||||
cross_connector_document_metadata_map
|
||||
)
|
||||
first_chunk_uuid = str(get_uuid_from_chunk(chunk))
|
||||
document = chunk.source_document
|
||||
if document.id not in cross_connector_document_metadata_map:
|
||||
document_metadata_in_doc_store = (
|
||||
doc_store_cross_connector_document_metadata_fetch_fn(first_chunk_uuid)
|
||||
)
|
||||
|
||||
if not document_metadata_in_doc_store:
|
||||
cross_connector_document_metadata_map[
|
||||
document.id
|
||||
] = CrossConnectorDocumentMetadata(
|
||||
allowed_users=[user_str],
|
||||
allowed_user_groups=[],
|
||||
already_in_index=False,
|
||||
)
|
||||
# First chunk does not exist so document does not exist, no need for deletion
|
||||
return cross_connector_document_metadata_map, False
|
||||
else:
|
||||
# TODO introduce groups logic here
|
||||
cross_connector_document_metadata_map[
|
||||
document.id
|
||||
] = CrossConnectorDocumentMetadata(
|
||||
allowed_users=_add_if_not_exists(
|
||||
document_metadata_in_doc_store.allowed_users, user_str
|
||||
),
|
||||
allowed_user_groups=document_metadata_in_doc_store.allowed_user_groups,
|
||||
already_in_index=True,
|
||||
)
|
||||
# First chunk exists, but with update, there may be less total chunks now
|
||||
# Must delete rest of document chunks
|
||||
return cross_connector_document_metadata_map, True
|
||||
|
||||
existing_document_metadata = cross_connector_document_metadata_map[document.id]
|
||||
cross_connector_document_metadata_map[document.id] = CrossConnectorDocumentMetadata(
|
||||
allowed_users=_add_if_not_exists(
|
||||
existing_document_metadata.allowed_users, user_str
|
||||
),
|
||||
allowed_user_groups=existing_document_metadata.allowed_user_groups,
|
||||
already_in_index=existing_document_metadata.already_in_index,
|
||||
)
|
||||
# If document is already in the mapping, don't delete again
|
||||
return cross_connector_document_metadata_map, False
|
||||
|
@ -1,14 +1,13 @@
|
||||
from typing import Type
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_TYPE
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
@ -43,11 +42,10 @@ class SplitDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
keyword_index_result = self.keyword_index.index(chunks, index_attempt_metadata)
|
||||
vector_index_result = self.vector_index.index(chunks, index_attempt_metadata)
|
||||
keyword_index_result = self.keyword_index.index(chunks)
|
||||
vector_index_result = self.vector_index.index(chunks)
|
||||
if keyword_index_result != vector_index_result:
|
||||
logger.error(
|
||||
f"Inconsistent document indexing:\n"
|
||||
|
@ -4,16 +4,17 @@ from typing import Protocol
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.chunking.chunk import Chunker
|
||||
from danswer.chunking.chunk import DefaultChunker
|
||||
from danswer.chunking.models import DocAwareChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import DocumentMetadata
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.document import upsert_documents_complete
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.search.models import Embedder
|
||||
@ -30,40 +31,25 @@ class IndexingPipelineProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
def _upsert_insertion_records(
|
||||
insertion_records: set[DocumentInsertionRecord],
|
||||
def _upsert_documents(
|
||||
document_ids: list[str],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
doc_m_data_lookup: dict[str, tuple[str, str]],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as session:
|
||||
upsert_documents_complete(
|
||||
db_session=session,
|
||||
document_metadata_batch=[
|
||||
DocumentMetadata(
|
||||
connector_id=index_attempt_metadata.connector_id,
|
||||
credential_id=index_attempt_metadata.credential_id,
|
||||
document_id=i_r.document_id,
|
||||
semantic_identifier=doc_m_data_lookup[i_r.document_id][0],
|
||||
first_link=doc_m_data_lookup[i_r.document_id][1],
|
||||
)
|
||||
for i_r in insertion_records
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _get_net_new_documents(
|
||||
insertion_records: list[DocumentInsertionRecord],
|
||||
) -> int:
|
||||
net_new_documents = 0
|
||||
seen_documents: set[str] = set()
|
||||
for insertion_record in insertion_records:
|
||||
if insertion_record.already_existed:
|
||||
continue
|
||||
|
||||
if insertion_record.document_id not in seen_documents:
|
||||
net_new_documents += 1
|
||||
seen_documents.add(insertion_record.document_id)
|
||||
return net_new_documents
|
||||
upsert_documents_complete(
|
||||
db_session=db_session,
|
||||
document_metadata_batch=[
|
||||
DocumentMetadata(
|
||||
connector_id=index_attempt_metadata.connector_id,
|
||||
credential_id=index_attempt_metadata.credential_id,
|
||||
document_id=document_id,
|
||||
semantic_identifier=doc_m_data_lookup[document_id][0],
|
||||
first_link=doc_m_data_lookup[document_id][1],
|
||||
)
|
||||
for document_id in document_ids
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _extract_minimal_document_metadata(doc: Document) -> tuple[str, str]:
|
||||
@ -82,60 +68,52 @@ def _indexing_pipeline(
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
|
||||
document_ids = [document.id for document in documents]
|
||||
document_metadata_lookup = {
|
||||
doc.id: _extract_minimal_document_metadata(doc) for doc in documents
|
||||
}
|
||||
|
||||
chunks: list[DocAwareChunk] = list(
|
||||
chain(*[chunker.chunk(document=document) for document in documents])
|
||||
)
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed(chunks=chunks)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# acquires a lock on the documents so that no other process can modify them
|
||||
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
|
||||
|
||||
# if there are any empty chunks, remove them. This usually happens due to a
|
||||
# bug in a connector. Handling here to prevent a bad connector from
|
||||
# breaking retrieval completely.
|
||||
final_chunks_with_embeddings: list[IndexChunk] = []
|
||||
for chunk in chunks_with_embeddings:
|
||||
if chunk.content:
|
||||
final_chunks_with_embeddings.append(chunk)
|
||||
else:
|
||||
bad_chunk_link = (
|
||||
chunk.source_document.sections[0].link
|
||||
if chunk.source_document.sections
|
||||
else ""
|
||||
)
|
||||
logger.error(
|
||||
f"Found empty chunk, skipping. Chunk ID: '{chunk.chunk_id}', "
|
||||
f"Document ID: '{chunk.source_document.id}', "
|
||||
f"Document Link: '{bad_chunk_link}'"
|
||||
)
|
||||
|
||||
# A document will not be spread across different batches, so all the documents with chunks in this set, are fully
|
||||
# represented by the chunks in this set
|
||||
insertion_records = document_index.index(
|
||||
chunks=final_chunks_with_embeddings,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
)
|
||||
|
||||
# TODO (chris): remove this try/except after issue with null document_id is resolved
|
||||
try:
|
||||
_upsert_insertion_records(
|
||||
insertion_records=insertion_records,
|
||||
# create records in the source of truth about these documents
|
||||
_upsert_documents(
|
||||
document_ids=document_ids,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
doc_m_data_lookup=document_metadata_lookup,
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to upsert insertion records from vector index for documents: "
|
||||
f"{[document.to_short_descriptor() for document in documents]}, "
|
||||
f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks_with_embeddings]}"
|
||||
f"for insertion records: {insertion_records}"
|
||||
|
||||
chunks: list[DocAwareChunk] = list(
|
||||
chain(*[chunker.chunk(document=document) for document in documents])
|
||||
)
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
|
||||
)
|
||||
chunks_with_embeddings = embedder.embed(chunks=chunks)
|
||||
|
||||
# Attach the latest status from Postgres (source of truth for access) to each
|
||||
# chunk. This access status will be attached to each chunk in the document index
|
||||
# TODO: attach document sets to the chunk based on the status of Postgres as well
|
||||
document_id_to_access_info = get_access_for_documents(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
)
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=document_id_to_access_info[chunk.source_document.id],
|
||||
document_sets=set(),
|
||||
)
|
||||
for chunk in chunks_with_embeddings
|
||||
]
|
||||
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
insertion_records = document_index.index(
|
||||
chunks=access_aware_chunks,
|
||||
)
|
||||
raise e
|
||||
|
||||
return len([r for r in insertion_records if r.already_existed is False]), len(
|
||||
chunks
|
||||
|
@ -3,10 +3,10 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
|
||||
IndexFilter = dict[str, str | list[str] | None]
|
||||
|
||||
@ -33,7 +33,8 @@ class UpdateRequest:
|
||||
|
||||
document_ids: list[str]
|
||||
# all other fields will be left alone
|
||||
allowed_users: list[str] | None = None
|
||||
access: DocumentAccess | None = None
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
|
||||
|
||||
@ -51,7 +52,7 @@ class Verifiable(abc.ABC):
|
||||
class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
|
||||
self, chunks: list[DocMetadataAwareIndexChunk]
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""Indexes document chunks into the Document Index and return the IDs of all the documents indexed"""
|
||||
raise NotImplementedError
|
||||
|
@ -1,6 +1,4 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
@ -8,8 +6,7 @@ from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from qdrant_client.http.models.models import UpdateResult
|
||||
from qdrant_client.models import PointStruct
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.configs.constants import ALLOWED_GROUPS
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import CHUNK_ID
|
||||
@ -20,15 +17,9 @@ from danswer.configs.constants import SECTION_CONTINUATION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
|
||||
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
|
||||
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.datastore_utils import (
|
||||
update_cross_connector_document_metadata_map,
|
||||
)
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.qdrant.utils import get_payload_from_record
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -36,40 +27,18 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_qdrant_document_cross_connector_metadata(
|
||||
def _does_document_exist(
|
||||
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
|
||||
) -> CrossConnectorDocumentMetadata | None:
|
||||
) -> bool:
|
||||
"""Get whether a document is found and the existing whitelists"""
|
||||
results = q_client.retrieve(
|
||||
collection_name=collection_name,
|
||||
ids=[doc_chunk_id],
|
||||
with_payload=[ALLOWED_USERS, ALLOWED_GROUPS],
|
||||
)
|
||||
if len(results) == 0:
|
||||
return None
|
||||
payload = get_payload_from_record(results[0])
|
||||
allowed_users = cast(list[str] | None, payload.get(ALLOWED_USERS))
|
||||
allowed_groups = cast(list[str] | None, payload.get(ALLOWED_GROUPS))
|
||||
if allowed_users is None:
|
||||
allowed_users = []
|
||||
logger.error(
|
||||
"Qdrant Index is corrupted, Document found with no user access lists."
|
||||
f"Assuming no users have access to chunk with ID '{doc_chunk_id}'."
|
||||
)
|
||||
if allowed_groups is None:
|
||||
allowed_groups = []
|
||||
logger.error(
|
||||
"Qdrant Index is corrupted, Document found with no groups access lists."
|
||||
f"Assuming no groups have access to chunk with ID '{doc_chunk_id}'."
|
||||
)
|
||||
return False
|
||||
|
||||
return CrossConnectorDocumentMetadata(
|
||||
# if either `allowed_users` or `allowed_groups` are missing from the
|
||||
# point, then assume that the document has no allowed users.
|
||||
allowed_users=allowed_users,
|
||||
allowed_user_groups=allowed_groups,
|
||||
already_in_index=True,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def delete_qdrant_doc_chunks(
|
||||
@ -92,8 +61,7 @@ def delete_qdrant_doc_chunks(
|
||||
|
||||
|
||||
def index_qdrant_chunks(
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
collection: str,
|
||||
client: QdrantClient | None = None,
|
||||
batch_upsert: bool = True,
|
||||
@ -104,29 +72,15 @@ def index_qdrant_chunks(
|
||||
|
||||
point_structs: list[PointStruct] = []
|
||||
insertion_records: set[DocumentInsertionRecord] = set()
|
||||
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
|
||||
cross_connector_document_metadata_map: dict[
|
||||
str, CrossConnectorDocumentMetadata
|
||||
] = {}
|
||||
# document ids of documents that existed BEFORE this indexing
|
||||
already_existing_documents: set[str] = set()
|
||||
for chunk in chunks:
|
||||
document = chunk.source_document
|
||||
(
|
||||
cross_connector_document_metadata_map,
|
||||
should_delete_doc,
|
||||
) = update_cross_connector_document_metadata_map(
|
||||
chunk=chunk,
|
||||
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
|
||||
doc_store_cross_connector_document_metadata_fetch_fn=partial(
|
||||
get_qdrant_document_cross_connector_metadata,
|
||||
collection_name=collection,
|
||||
q_client=q_client,
|
||||
),
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
)
|
||||
|
||||
if should_delete_doc:
|
||||
# Delete all chunks related to the document if (1) it already exists and
|
||||
# (2) this is our first time running into it during this indexing attempt
|
||||
document_exists = _does_document_exist(document.id, collection, q_client)
|
||||
if document_exists and document.id not in already_existing_documents:
|
||||
# Processing the first chunk of the doc and the doc exists
|
||||
delete_qdrant_doc_chunks(document.id, collection, q_client)
|
||||
already_existing_documents.add(document.id)
|
||||
@ -155,12 +109,7 @@ def index_qdrant_chunks(
|
||||
SOURCE_LINKS: chunk.source_links,
|
||||
SEMANTIC_IDENTIFIER: document.semantic_identifier,
|
||||
SECTION_CONTINUATION: chunk.section_continuation,
|
||||
ALLOWED_USERS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_users,
|
||||
ALLOWED_GROUPS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_user_groups,
|
||||
ALLOWED_USERS: json.dumps(chunk.access.to_acl()),
|
||||
METADATA: json.dumps(document.metadata),
|
||||
},
|
||||
vector=embedding,
|
||||
|
@ -8,7 +8,7 @@ from qdrant_client.http.models import Filter
|
||||
from qdrant_client.http.models import MatchAny
|
||||
from qdrant_client.http.models import MatchValue
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
@ -16,7 +16,6 @@ from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
@ -126,12 +125,10 @@ class QdrantIndex(VectorIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
return index_qdrant_chunks(
|
||||
chunks=chunks,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
collection=self.collection,
|
||||
client=self.client,
|
||||
)
|
||||
@ -145,12 +142,15 @@ class QdrantIndex(VectorIndex):
|
||||
items=update_request.document_ids,
|
||||
batch_size=_BATCH_SIZE,
|
||||
):
|
||||
if update_request.access is None:
|
||||
continue
|
||||
|
||||
chunk_ids = _get_points_from_document_ids(
|
||||
doc_id_batch, self.collection, self.client
|
||||
)
|
||||
self.client.set_payload(
|
||||
collection_name=self.collection,
|
||||
payload={ALLOWED_USERS: update_request.allowed_users},
|
||||
payload={ALLOWED_USERS: update_request.access.to_acl()},
|
||||
points=chunk_ids,
|
||||
)
|
||||
|
||||
|
@ -1,17 +1,14 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
import typesense # type: ignore
|
||||
from typesense.exceptions import ObjectNotFound # type: ignore
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import ALLOWED_GROUPS
|
||||
from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import CHUNK_ID
|
||||
@ -23,13 +20,8 @@ from danswer.configs.constants import SECTION_CONTINUATION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
|
||||
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
|
||||
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.datastore_utils import (
|
||||
update_cross_connector_document_metadata_map,
|
||||
)
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import KeywordIndex
|
||||
@ -63,7 +55,6 @@ def create_typesense_collection(
|
||||
{"name": SEMANTIC_IDENTIFIER, "type": "string"},
|
||||
{"name": SECTION_CONTINUATION, "type": "bool"},
|
||||
{"name": ALLOWED_USERS, "type": "string[]"},
|
||||
{"name": ALLOWED_GROUPS, "type": "string[]"},
|
||||
{"name": METADATA, "type": "string"},
|
||||
],
|
||||
}
|
||||
@ -81,38 +72,16 @@ def _check_typesense_collection_exist(
|
||||
return True
|
||||
|
||||
|
||||
def _get_typesense_document_cross_connector_metadata(
|
||||
def _does_document_exist(
|
||||
doc_chunk_id: str, collection_name: str, ts_client: typesense.Client
|
||||
) -> CrossConnectorDocumentMetadata | None:
|
||||
) -> bool:
|
||||
"""Returns whether the document already exists and the users/group whitelists"""
|
||||
try:
|
||||
document = cast(
|
||||
dict[str, Any],
|
||||
ts_client.collections[collection_name].documents[doc_chunk_id].retrieve(),
|
||||
)
|
||||
ts_client.collections[collection_name].documents[doc_chunk_id].retrieve()
|
||||
except ObjectNotFound:
|
||||
return None
|
||||
return False
|
||||
|
||||
allowed_users = cast(list[str] | None, document.get(ALLOWED_USERS))
|
||||
allowed_groups = cast(list[str] | None, document.get(ALLOWED_GROUPS))
|
||||
if allowed_users is None:
|
||||
allowed_users = []
|
||||
logger.error(
|
||||
"Typesense Index is corrupted, Document found with no user access lists."
|
||||
f"Assuming no users have access to chunk with ID '{doc_chunk_id}'."
|
||||
)
|
||||
if allowed_groups is None:
|
||||
allowed_groups = []
|
||||
logger.error(
|
||||
"Typesense Index is corrupted, Document found with no groups access lists."
|
||||
f"Assuming no groups have access to chunk with ID '{doc_chunk_id}'."
|
||||
)
|
||||
|
||||
return CrossConnectorDocumentMetadata(
|
||||
allowed_users=allowed_users,
|
||||
allowed_user_groups=allowed_groups,
|
||||
already_in_index=True,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _delete_typesense_doc_chunks(
|
||||
@ -127,8 +96,7 @@ def _delete_typesense_doc_chunks(
|
||||
|
||||
|
||||
def _index_typesense_chunks(
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
collection: str,
|
||||
client: typesense.Client | None = None,
|
||||
batch_upsert: bool = True,
|
||||
@ -137,33 +105,20 @@ def _index_typesense_chunks(
|
||||
|
||||
insertion_records: set[DocumentInsertionRecord] = set()
|
||||
new_documents: list[dict[str, Any]] = []
|
||||
cross_connector_document_metadata_map: dict[
|
||||
str, CrossConnectorDocumentMetadata
|
||||
] = {}
|
||||
# document ids of documents that existed BEFORE this indexing
|
||||
already_existing_documents: set[str] = set()
|
||||
for chunk in chunks:
|
||||
document = chunk.source_document
|
||||
(
|
||||
cross_connector_document_metadata_map,
|
||||
should_delete_doc,
|
||||
) = update_cross_connector_document_metadata_map(
|
||||
chunk=chunk,
|
||||
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
|
||||
doc_store_cross_connector_document_metadata_fetch_fn=partial(
|
||||
_get_typesense_document_cross_connector_metadata,
|
||||
collection_name=collection,
|
||||
ts_client=ts_client,
|
||||
),
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
)
|
||||
typesense_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
if should_delete_doc:
|
||||
# Delete all chunks related to the document if (1) it already exists and
|
||||
# (2) this is our first time running into it during this indexing attempt
|
||||
document_exists = _does_document_exist(typesense_id, collection, ts_client)
|
||||
if document_exists and document.id not in already_existing_documents:
|
||||
# Processing the first chunk of the doc and the doc exists
|
||||
_delete_typesense_doc_chunks(document.id, collection, ts_client)
|
||||
already_existing_documents.add(document.id)
|
||||
|
||||
typesense_id = str(get_uuid_from_chunk(chunk))
|
||||
insertion_records.add(
|
||||
DocumentInsertionRecord(
|
||||
document_id=document.id,
|
||||
@ -181,12 +136,7 @@ def _index_typesense_chunks(
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: document.semantic_identifier,
|
||||
SECTION_CONTINUATION: chunk.section_continuation,
|
||||
ALLOWED_USERS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_users,
|
||||
ALLOWED_GROUPS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_user_groups,
|
||||
ALLOWED_USERS: json.dumps(chunk.access.to_acl()),
|
||||
METADATA: json.dumps(document.metadata),
|
||||
}
|
||||
)
|
||||
@ -260,11 +210,10 @@ class TypesenseIndex(KeywordIndex):
|
||||
create_typesense_collection(collection_name=self.collection)
|
||||
|
||||
def index(
|
||||
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
|
||||
self, chunks: list[DocMetadataAwareIndexChunk]
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
return _index_typesense_chunks(
|
||||
chunks=chunks,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
collection=self.collection,
|
||||
client=self.ts_client,
|
||||
)
|
||||
@ -277,8 +226,14 @@ class TypesenseIndex(KeywordIndex):
|
||||
for id_batch in batch_generator(
|
||||
items=update_request.document_ids, batch_size=_BATCH_SIZE
|
||||
):
|
||||
if update_request.access is None:
|
||||
continue
|
||||
|
||||
typesense_updates = [
|
||||
{DOCUMENT_ID: doc_id, ALLOWED_USERS: update_request.allowed_users}
|
||||
{
|
||||
DOCUMENT_ID: doc_id,
|
||||
ALLOWED_USERS: update_request.access.to_acl(),
|
||||
}
|
||||
for doc_id in id_batch
|
||||
]
|
||||
self.ts_client.collections[self.collection].documents.import_(
|
||||
|
@ -56,11 +56,11 @@ schema danswer_chunk {
|
||||
distance-metric: angular
|
||||
}
|
||||
}
|
||||
field allowed_users type array<string> {
|
||||
field access_control_list type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
field allowed_groups type array<string> {
|
||||
field document_sets type weightedset<string> {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ import requests
|
||||
from requests import HTTPError
|
||||
from requests import Response
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import DocMetadataAwareIndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
@ -17,14 +17,14 @@ from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
|
||||
from danswer.configs.app_configs import VESPA_HOST
|
||||
from danswer.configs.app_configs import VESPA_PORT
|
||||
from danswer.configs.app_configs import VESPA_TENANT_PORT
|
||||
from danswer.configs.constants import ALLOWED_GROUPS
|
||||
from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import ACCESS_CONTROL_LIST
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import BOOST
|
||||
from danswer.configs.constants import CHUNK_ID
|
||||
from danswer.configs.constants import CONTENT
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import DOCUMENT_SETS
|
||||
from danswer.configs.constants import EMBEDDINGS
|
||||
from danswer.configs.constants import MATCH_HIGHLIGHTS
|
||||
from danswer.configs.constants import METADATA
|
||||
@ -35,12 +35,7 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
|
||||
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.datastore_utils import (
|
||||
update_cross_connector_document_metadata_map,
|
||||
)
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
@ -65,33 +60,20 @@ _BATCH_SIZE = 100 # Specific to Vespa
|
||||
CONTENT_SUMMARY = "content_summary"
|
||||
|
||||
|
||||
def _get_vespa_document_cross_connector_metadata(
|
||||
def _does_document_exist(
|
||||
doc_chunk_id: str,
|
||||
) -> CrossConnectorDocumentMetadata | None:
|
||||
) -> bool:
|
||||
"""Returns whether the document already exists and the users/group whitelists"""
|
||||
doc_fetch_response = requests.get(f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}")
|
||||
if doc_fetch_response.status_code == 404:
|
||||
return None
|
||||
return False
|
||||
|
||||
if doc_fetch_response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Unexpected fetch document by ID value from Vespa "
|
||||
f"with error {doc_fetch_response.status_code}"
|
||||
)
|
||||
|
||||
doc_fields = doc_fetch_response.json()["fields"]
|
||||
allowed_users = doc_fields.get(ALLOWED_USERS)
|
||||
allowed_groups = doc_fields.get(ALLOWED_GROUPS)
|
||||
# Add group permission later, empty list gets saved/loaded as a null
|
||||
if allowed_users is None:
|
||||
raise RuntimeError(
|
||||
"Vespa Index is corrupted, Document found with no access lists."
|
||||
)
|
||||
return CrossConnectorDocumentMetadata(
|
||||
allowed_users=allowed_users or [],
|
||||
allowed_user_groups=allowed_groups or [],
|
||||
already_in_index=True,
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def _get_vespa_chunk_ids_by_document_id(
|
||||
@ -127,32 +109,23 @@ def _delete_vespa_doc_chunks(document_id: str) -> bool:
|
||||
|
||||
|
||||
def _index_vespa_chunks(
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
json_header = {
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
insertion_records: set[DocumentInsertionRecord] = set()
|
||||
cross_connector_document_metadata_map: dict[
|
||||
str, CrossConnectorDocumentMetadata
|
||||
] = {}
|
||||
# document ids of documents that existed BEFORE this indexing
|
||||
already_existing_documents: set[str] = set()
|
||||
for chunk in chunks:
|
||||
document = chunk.source_document
|
||||
(
|
||||
cross_connector_document_metadata_map,
|
||||
should_delete_doc,
|
||||
) = update_cross_connector_document_metadata_map(
|
||||
chunk=chunk,
|
||||
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
|
||||
doc_store_cross_connector_document_metadata_fetch_fn=_get_vespa_document_cross_connector_metadata,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
)
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
if should_delete_doc:
|
||||
# Processing the first chunk of the doc and the doc exists
|
||||
# Delete all chunks related to the document if (1) it already exists and
|
||||
# (2) this is our first time running into it during this indexing attempt
|
||||
chunk_exists = _does_document_exist(vespa_chunk_id)
|
||||
if chunk_exists and document.id not in already_existing_documents:
|
||||
deletion_success = _delete_vespa_doc_chunks(document.id)
|
||||
if not deletion_success:
|
||||
raise RuntimeError(
|
||||
@ -160,9 +133,6 @@ def _index_vespa_chunks(
|
||||
)
|
||||
already_existing_documents.add(document.id)
|
||||
|
||||
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
@ -183,12 +153,10 @@ def _index_vespa_chunks(
|
||||
METADATA: json.dumps(document.metadata),
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
BOOST: DEFAULT_BOOST,
|
||||
ALLOWED_USERS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_users,
|
||||
ALLOWED_GROUPS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_user_groups,
|
||||
# the only `set` vespa has is `weightedset`, so we have to give each
|
||||
# element an arbitrary weight
|
||||
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
}
|
||||
|
||||
def _index_chunk(
|
||||
@ -244,16 +212,13 @@ def _index_vespa_chunks(
|
||||
def _build_vespa_filters(
|
||||
user_id: UUID | None, filters: list[IndexFilter] | None
|
||||
) -> str:
|
||||
filter_str = ""
|
||||
# Permissions filter
|
||||
# TODO group permissioning
|
||||
# Permissions filters
|
||||
acl_filter_stmts = [f'{ACCESS_CONTROL_LIST} contains "{PUBLIC_DOC_PAT}"']
|
||||
if user_id:
|
||||
filter_str += (
|
||||
f'({ALLOWED_USERS} contains "{user_id}" or '
|
||||
f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}") and '
|
||||
)
|
||||
else:
|
||||
filter_str += f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}" and '
|
||||
acl_filter_stmts.append(f'{ACCESS_CONTROL_LIST} contains "{user_id}"')
|
||||
filter_str = "(" + " or ".join(acl_filter_stmts) + ") and"
|
||||
|
||||
# TODO: have document sets passed in + add document set based filters
|
||||
|
||||
# Provided query filters
|
||||
if filters:
|
||||
@ -399,31 +364,38 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
return _index_vespa_chunks(
|
||||
chunks=chunks, index_attempt_metadata=index_attempt_metadata
|
||||
)
|
||||
return _index_vespa_chunks(chunks=chunks)
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(
|
||||
f"Updating {len(update_requests)} documents' allowed_users in Vespa"
|
||||
)
|
||||
logger.info(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
json_header = {"Content-Type": "application/json"}
|
||||
|
||||
for update_request in update_requests:
|
||||
if update_request.boost is None and update_request.allowed_users is None:
|
||||
if (
|
||||
update_request.boost is None
|
||||
and update_request.access is None
|
||||
and update_request.document_sets is None
|
||||
):
|
||||
logger.error("Update request received but nothing to update")
|
||||
continue
|
||||
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
if update_request.boost is not None:
|
||||
update_dict["fields"][BOOST] = {"assign": update_request.boost}
|
||||
if update_request.allowed_users is not None:
|
||||
update_dict["fields"][ALLOWED_USERS] = {
|
||||
"assign": update_request.allowed_users
|
||||
if update_request.document_sets is not None:
|
||||
update_dict["fields"][DOCUMENT_SETS] = {
|
||||
"assign": {
|
||||
document_set: 1 for document_set in update_request.document_sets
|
||||
}
|
||||
}
|
||||
if update_request.access is not None:
|
||||
update_dict["fields"][ACCESS_CONTROL_LIST] = {
|
||||
"assign": {
|
||||
acl_entry: 1 for acl_entry in update_request.access.to_acl()
|
||||
}
|
||||
}
|
||||
|
||||
for document_id in update_request.document_ids:
|
||||
|
@ -1,4 +1,6 @@
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@ -10,18 +12,18 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.datastores.interfaces import DocumentMetadata
|
||||
from danswer.db.feedback import delete_document_feedback_for_documents
|
||||
from danswer.db.models import Credential
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.utils import model_to_dict
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_documents_with_single_connector_credential_pair(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
@ -29,54 +31,63 @@ def get_documents_with_single_connector_credential_pair(
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Filter it down to the documents with only a single connector/credential pair
|
||||
# Meaning if this connector/credential pair is removed, this doc should be gone
|
||||
trimmed_doc_ids_stmt = (
|
||||
select(DbDocument.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DocumentByConnectorCredentialPair.id == DbDocument.id,
|
||||
)
|
||||
.where(DbDocument.id.in_(initial_doc_ids_stmt))
|
||||
.group_by(DbDocument.id)
|
||||
.having(func.count(DocumentByConnectorCredentialPair.id) == 1)
|
||||
)
|
||||
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(trimmed_doc_ids_stmt))
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct()
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def get_document_by_connector_credential_pairs_indexed_by_multiple(
|
||||
def get_document_connector_cnts(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> Sequence[DocumentByConnectorCredentialPair]:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
document_ids: list[str],
|
||||
) -> Sequence[tuple[str, int]]:
|
||||
stmt = (
|
||||
select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.count(),
|
||||
)
|
||||
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
.group_by(DocumentByConnectorCredentialPair.id)
|
||||
)
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
# Filter it down to the documents with more than 1 connector/credential pair
|
||||
# Meaning if this connector/credential pair is removed, this doc is still accessible
|
||||
trimmed_doc_ids_stmt = (
|
||||
select(DbDocument.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DocumentByConnectorCredentialPair.id == DbDocument.id,
|
||||
|
||||
def get_acccess_info_for_documents(
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
|
||||
) -> Sequence[tuple[str, list[UUID | None], bool]]:
|
||||
"""Gets back all relevant access info for the given documents. This includes
|
||||
the user_ids for cc pairs that the document is associated with + whether any
|
||||
of the associated cc pairs are intending to make the document globally public.
|
||||
|
||||
If `cc_pair_to_delete` is specified, gets the above access info as if that
|
||||
pair had been deleted. This is needed since we want to delete from the Vespa
|
||||
before deleting from Postgres to ensure that the state of Postgres never "loses"
|
||||
documents that still exist in Vespa.
|
||||
"""
|
||||
stmt = select(
|
||||
DocumentByConnectorCredentialPair.id,
|
||||
func.array_agg(Credential.user_id).label("user_ids"),
|
||||
func.bool_or(Credential.public_doc).label("public_doc"),
|
||||
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
|
||||
|
||||
# pretend that the specified cc pair doesn't exist
|
||||
if cc_pair_to_delete:
|
||||
stmt = stmt.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
!= cc_pair_to_delete.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
!= cc_pair_to_delete.credential_id,
|
||||
)
|
||||
)
|
||||
.where(DbDocument.id.in_(initial_doc_ids_stmt))
|
||||
.group_by(DbDocument.id)
|
||||
.having(func.count(DocumentByConnectorCredentialPair.id) > 1)
|
||||
)
|
||||
|
||||
stmt = select(DocumentByConnectorCredentialPair).where(
|
||||
DocumentByConnectorCredentialPair.id.in_(trimmed_doc_ids_stmt)
|
||||
)
|
||||
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
stmt = stmt.join(
|
||||
Credential,
|
||||
DocumentByConnectorCredentialPair.credential_id == Credential.id,
|
||||
).group_by(DocumentByConnectorCredentialPair.id)
|
||||
return db_session.execute(stmt).all() # type: ignore
|
||||
|
||||
|
||||
def upsert_documents(
|
||||
@ -153,26 +164,24 @@ def upsert_documents_complete(
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair(
|
||||
db_session: Session, document_ids: list[str]
|
||||
db_session: Session,
|
||||
document_ids: list[str],
|
||||
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
|
||||
| None = None,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(DocumentByConnectorCredentialPair).where(
|
||||
DocumentByConnectorCredentialPair.id.in_(document_ids)
|
||||
)
|
||||
stmt = delete(DocumentByConnectorCredentialPair).where(
|
||||
DocumentByConnectorCredentialPair.id.in_(document_ids)
|
||||
)
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
delete(DocumentByConnectorCredentialPair).where(
|
||||
if connector_credential_pair_identifier:
|
||||
stmt = stmt.where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
DocumentByConnectorCredentialPair.connector_id
|
||||
== connector_credential_pair_identifier.connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id
|
||||
== connector_credential_pair_identifier.credential_id,
|
||||
)
|
||||
)
|
||||
)
|
||||
db_session.execute(stmt)
|
||||
|
||||
|
||||
def delete_documents(db_session: Session, document_ids: list[str]) -> None:
|
||||
@ -187,3 +196,48 @@ def delete_documents_complete(db_session: Session, document_ids: list[str]) -> N
|
||||
)
|
||||
delete_documents(db_session, document_ids)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
|
||||
"""Acquire locks for the specified documents. Ideally this shouldn't be
|
||||
called with large list of document_ids (an exception could be made if the
|
||||
length of holding the lock is very short).
|
||||
|
||||
Will simply raise an exception if any of the documents are already locked.
|
||||
This prevents deadlocks (assuming that the caller passes in all required
|
||||
document IDs in a single call).
|
||||
"""
|
||||
stmt = (
|
||||
select(DbDocument)
|
||||
.where(DbDocument.id.in_(document_ids))
|
||||
.with_for_update(nowait=True)
|
||||
)
|
||||
# will raise exception if any of the documents are already locked
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
|
||||
|
||||
_NUM_LOCK_ATTEMPTS = 10
|
||||
_LOCK_RETRY_DELAY = 30
|
||||
|
||||
|
||||
def prepare_to_modify_documents(db_session: Session, document_ids: list[str]) -> None:
|
||||
"""Try and acquire locks for the documents to prevent other jobs from
|
||||
modifying them at the same time (e.g. avoid race conditions). This should be
|
||||
called ahead of any modification to Vespa. Locks should be released by the
|
||||
caller as soon as updates are complete by finishing the transaction."""
|
||||
lock_acquired = False
|
||||
for _ in range(_NUM_LOCK_ATTEMPTS):
|
||||
try:
|
||||
lock_acquired = acquire_document_locks(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}")
|
||||
time.sleep(_LOCK_RETRY_DELAY)
|
||||
|
||||
if not lock_acquired:
|
||||
raise RuntimeError(
|
||||
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts "
|
||||
f"for documents: {document_ids}"
|
||||
)
|
||||
|
44
backend/scripts/migrate_vespa_to_acl.py
Normal file
44
backend/scripts/migrate_vespa_to_acl.py
Normal file
@ -0,0 +1,44 @@
|
||||
"""Script which updates Vespa to align with the access described in Postgres.
|
||||
Should be run wehn a user who has docs already indexed switches over to the new
|
||||
access control system. This allows them to not have to re-index all documents."""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.datastores.vespa.store import VespaIndex
|
||||
from danswer.db.document import get_acccess_info_for_documents
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import Document
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _migrate_vespa_to_acl() -> None:
|
||||
vespa_index = get_default_document_index()
|
||||
if not isinstance(vespa_index, VespaIndex):
|
||||
raise ValueError("This script is only for Vespa indexes")
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
# for all documents, set the `access_control_list` field apporpriately
|
||||
# based on the state of Postgres
|
||||
documents = db_session.scalars(select(Document)).all()
|
||||
document_access_info = get_acccess_info_for_documents(
|
||||
db_session=db_session,
|
||||
document_ids=[document.id for document in documents],
|
||||
)
|
||||
vespa_index.update(
|
||||
update_requests=[
|
||||
UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
access=DocumentAccess.build(user_ids, is_public),
|
||||
)
|
||||
for document_id, user_ids, is_public in document_access_info
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
_migrate_vespa_to_acl()
|
Loading…
x
Reference in New Issue
Block a user