diff --git a/backend/Dockerfile b/backend/Dockerfile index 59fdaa756..eb0dfd126 100644 --- a/backend/Dockerfile +++ b/backend/Dockerfile @@ -52,6 +52,9 @@ WORKDIR /app/danswer/datastores/vespa/app_config RUN zip -r /app/danswer/vespa-app.zip . WORKDIR /app +# TODO: remove this once all users have migrated +COPY ./danswer/scripts/migrate_vespa_to_acl.py /app/migrate_vespa_to_acl.py + ENV PYTHONPATH /app # By default this container does nothing, it is used by api server and background which specify their own CMD diff --git a/backend/danswer/access/access.py b/backend/danswer/access/access.py new file mode 100644 index 000000000..e2c32a0bf --- /dev/null +++ b/backend/danswer/access/access.py @@ -0,0 +1,36 @@ +from sqlalchemy.orm import Session + +from danswer.access.models import DocumentAccess +from danswer.db.document import get_acccess_info_for_documents +from danswer.db.engine import get_sqlalchemy_engine +from danswer.server.models import ConnectorCredentialPairIdentifier + + +def _get_access_for_documents( + document_ids: list[str], + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None, + db_session: Session, +) -> dict[str, DocumentAccess]: + document_access_info = get_acccess_info_for_documents( + db_session=db_session, + document_ids=document_ids, + cc_pair_to_delete=cc_pair_to_delete, + ) + return { + document_id: DocumentAccess.build(user_ids, is_public) + for document_id, user_ids, is_public in document_access_info + } + + +def get_access_for_documents( + document_ids: list[str], + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, + db_session: Session | None = None, +) -> dict[str, DocumentAccess]: + if db_session is None: + with Session(get_sqlalchemy_engine()) as db_session: + return _get_access_for_documents( + document_ids, cc_pair_to_delete, db_session + ) + + return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session) diff --git a/backend/danswer/access/models.py b/backend/danswer/access/models.py new file mode 100644 index 000000000..94a18528c --- /dev/null +++ b/backend/danswer/access/models.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass +from uuid import UUID + +from danswer.configs.constants import PUBLIC_DOC_PAT + + +@dataclass(frozen=True) +class DocumentAccess: + user_ids: set[str] # stringified UUIDs + is_public: bool + + def to_acl(self) -> list[str]: + return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else []) + + @classmethod + def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess": + return cls( + user_ids={str(user_id) for user_id in user_ids if user_id}, + is_public=is_public, + ) diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index 601acbf51..09312fc70 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -10,11 +10,9 @@ are multiple connector / credential pairs that have indexed it connector / credential pair from the access list (6) delete all relevant entries from postgres """ -from collections import defaultdict - from sqlalchemy.orm import Session -from danswer.configs.constants import PUBLIC_DOC_PAT +from danswer.access.access import get_access_for_documents from danswer.datastores.document_index import get_default_document_index from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import UpdateRequest @@ -22,24 +20,78 @@ from danswer.db.connector import fetch_connector_by_id from danswer.db.connector_credential_pair import delete_connector_credential_pair from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed -from danswer.db.document import ( - delete_document_by_connector_credential_pair_for_connector_credential_pair, -) +from danswer.db.document import delete_document_by_connector_credential_pair from danswer.db.document import delete_documents_complete -from danswer.db.document import ( - get_document_by_connector_credential_pairs_indexed_by_multiple, -) -from danswer.db.document import ( - get_documents_with_single_connector_credential_pair, -) +from danswer.db.document import get_document_connector_cnts +from danswer.db.document import get_documents_for_connector_credential_pair +from danswer.db.document import prepare_to_modify_documents from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import delete_index_attempts -from danswer.db.models import ConnectorCredentialPair -from danswer.db.models import Credential +from danswer.server.models import ConnectorCredentialPairIdentifier from danswer.utils.logger import setup_logger logger = setup_logger() +_DELETION_BATCH_SIZE = 1000 + + +def _delete_connector_credential_pair_batch( + document_ids: list[str], + connector_id: int, + credential_id: int, + document_index: DocumentIndex, +) -> None: + with Session(get_sqlalchemy_engine()) as db_session: + # acquire lock for all documents in this batch so that indexing can't + # override the deletion + prepare_to_modify_documents(db_session=db_session, document_ids=document_ids) + + document_connector_cnts = get_document_connector_cnts( + db_session=db_session, document_ids=document_ids + ) + + # figure out which docs need to be completely deleted + document_ids_to_delete = [ + document_id for document_id, cnt in document_connector_cnts if cnt == 1 + ] + logger.debug(f"Deleting documents: {document_ids_to_delete}") + document_index.delete(doc_ids=document_ids_to_delete) + delete_documents_complete( + db_session=db_session, + document_ids=document_ids_to_delete, + ) + + # figure out which docs need to be updated + document_ids_to_update = [ + document_id for document_id, cnt in document_connector_cnts if cnt > 1 + ] + access_for_documents = get_access_for_documents( + document_ids=document_ids_to_update, + db_session=db_session, + cc_pair_to_delete=ConnectorCredentialPairIdentifier( + connector_id=connector_id, + credential_id=credential_id, + ), + ) + update_requests = [ + UpdateRequest( + document_ids=[document_id], + access=access, + ) + for document_id, access in access_for_documents.items() + ] + logger.debug(f"Updating documents: {document_ids_to_update}") + document_index.update(update_requests=update_requests) + delete_document_by_connector_credential_pair( + db_session=db_session, + document_ids=document_ids_to_update, + connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( + connector_id=connector_id, + credential_id=credential_id, + ), + ) + db_session.commit() + def _delete_connector_credential_pair( db_session: Session, @@ -47,147 +99,45 @@ def _delete_connector_credential_pair( connector_id: int, credential_id: int, ) -> int: - # validate that the connector / credential pair is deletable - cc_pair = get_connector_credential_pair( + num_docs_deleted = 0 + while True: + documents = get_documents_for_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + limit=_DELETION_BATCH_SIZE, + ) + if not documents: + break + + _delete_connector_credential_pair_batch( + document_ids=[document.id for document in documents], + connector_id=connector_id, + credential_id=credential_id, + document_index=document_index, + ) + num_docs_deleted += len(documents) + + # cleanup everything else up + delete_index_attempts( db_session=db_session, connector_id=connector_id, credential_id=credential_id, ) - if not cc_pair or not check_deletion_attempt_is_allowed( - connector_credential_pair=cc_pair - ): - raise ValueError( - "Cannot run deletion attempt - connector_credential_pair is not deletable. " - "This is likely because there is an ongoing / planned indexing attempt OR the " - "connector is not disabled." - ) - - def _delete_singly_indexed_docs() -> int: - # if a document store entry is only indexed by this connector_credential_pair, delete it - docs_to_delete = get_documents_with_single_connector_credential_pair( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - - if docs_to_delete: - document_ids = [doc.id for doc in docs_to_delete] - document_index.delete(doc_ids=document_ids) - - # removes all `DocumentByConnectorCredentialPair`, and `Document` - # rows from the DB - delete_documents_complete( - db_session=db_session, - document_ids=list(document_ids), - ) - - return len(docs_to_delete) - - num_docs_deleted = _delete_singly_indexed_docs() - logger.info(f"Deleted {num_docs_deleted} documents from document stores") - - def _update_multi_indexed_docs( - connector_credential_pair: ConnectorCredentialPair, - ) -> None: - # if a document is indexed by multiple connector_credential_pairs, we should - # update its access rather than outright delete it - document_by_connector_credential_pairs_to_update = ( - get_document_by_connector_credential_pairs_indexed_by_multiple( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - ) - - def _get_user( - credential: Credential, - ) -> str: - if credential.public_doc or not credential.user: - return PUBLIC_DOC_PAT - - return str(credential.user.id) - - # find out which documents need to be updated and what their new allowed_users - # should be. This is a bit slow as it requires looping through all the documents - to_be_deleted_user = _get_user(connector_credential_pair.credential) - document_ids_not_needing_update: set[str] = set() - document_id_to_allowed_users: dict[str, list[str]] = defaultdict(list) - for ( - document_by_connector_credential_pair - ) in document_by_connector_credential_pairs_to_update: - document_id = document_by_connector_credential_pair.id - user = _get_user(document_by_connector_credential_pair.credential) - document_id_to_allowed_users[document_id].append(user) - - # if there's another connector / credential pair which has indexed this - # document with the same access, we don't need to update it since removing - # the access from this connector / credential pair won't change anything - if ( - document_by_connector_credential_pair.connector_id != connector_id - or document_by_connector_credential_pair.credential_id != credential_id - ) and user == to_be_deleted_user: - document_ids_not_needing_update.add(document_id) - - # categorize into groups of updates to try and batch them more efficiently - update_groups: dict[tuple[str, ...], list[str]] = {} - for document_id, allowed_users_lst in document_id_to_allowed_users.items(): - if document_id in document_ids_not_needing_update: - continue - - allowed_users_lst.remove(to_be_deleted_user) - allowed_users = tuple(sorted(set(allowed_users_lst))) - update_groups[allowed_users] = update_groups.get(allowed_users, []) + [ - document_id - ] - - # actually perform the updates in the document store - update_requests = [ - UpdateRequest(document_ids=document_ids, allowed_users=list(allowed_users)) - for allowed_users, document_ids in update_groups.items() - ] - document_index.update(update_requests=update_requests) - - # delete the rest of the `document_by_connector_credential_pair` rows for - # this connector / credential pair - delete_document_by_connector_credential_pair_for_connector_credential_pair( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - - _update_multi_indexed_docs(cc_pair) - - def _cleanup() -> None: - # cleanup everything else up - # we cannot undo the deletion of the document store entries if something - # goes wrong since they happen outside the postgres world. Best we can do - # is keep everything else around and mark the deletion attempt as failed. - # If it's a transient failure, re-deleting the connector / credential pair should - # fix the weird state. - # TODO: lock anything to do with this connector via transaction isolation - # NOTE: we have to delete index_attempts and deletion_attempts since they both - # have foreign key columns to the connector - delete_index_attempts( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - delete_connector_credential_pair( - db_session=db_session, - connector_id=connector_id, - credential_id=credential_id, - ) - # if there are no credentials left, delete the connector - connector = fetch_connector_by_id( - db_session=db_session, - connector_id=connector_id, - ) - if not connector or not len(connector.credentials): - logger.debug("Found no credentials left for connector, deleting connector") - db_session.delete(connector) - db_session.commit() - - _cleanup() + delete_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + # if there are no credentials left, delete the connector + connector = fetch_connector_by_id( + db_session=db_session, + connector_id=connector_id, + ) + if not connector or not len(connector.credentials): + logger.debug("Found no credentials left for connector, deleting connector") + db_session.delete(connector) + db_session.commit() logger.info( "Successfully deleted connector_credential_pair with connector_id:" @@ -199,6 +149,21 @@ def _delete_connector_credential_pair( def cleanup_connector_credential_pair(connector_id: int, credential_id: int) -> int: engine = get_sqlalchemy_engine() with Session(engine) as db_session: + # validate that the connector / credential pair is deletable + cc_pair = get_connector_credential_pair( + db_session=db_session, + connector_id=connector_id, + credential_id=credential_id, + ) + if not cc_pair or not check_deletion_attempt_is_allowed( + connector_credential_pair=cc_pair + ): + raise ValueError( + "Cannot run deletion attempt - connector_credential_pair is not deletable. " + "This is likely because there is an ongoing / planned indexing attempt OR the " + "connector is not disabled." + ) + try: return _delete_connector_credential_pair( db_session=db_session, diff --git a/backend/danswer/chunking/models.py b/backend/danswer/chunking/models.py index faee6333c..cf8337284 100644 --- a/backend/danswer/chunking/models.py +++ b/backend/danswer/chunking/models.py @@ -1,9 +1,11 @@ import inspect import json from dataclasses import dataclass +from dataclasses import fields from typing import Any from typing import cast +from danswer.access.models import DocumentAccess from danswer.configs.constants import BLURB from danswer.configs.constants import BOOST from danswer.configs.constants import MATCH_HIGHLIGHTS @@ -14,6 +16,7 @@ from danswer.configs.constants import SOURCE_LINKS from danswer.connectors.models import Document from danswer.utils.logger import setup_logger + logger = setup_logger() @@ -55,6 +58,34 @@ class IndexChunk(DocAwareChunk): embeddings: ChunkEmbedding +@dataclass +class DocMetadataAwareIndexChunk(IndexChunk): + """An `IndexChunk` that contains all necessary metadata to be indexed. This includes + the following: + + access: holds all information about which users should have access to the + source document for this chunk. + document_sets: all document sets the source document for this chunk is a part + of. This is used for filtering / personas. + """ + + access: "DocumentAccess" + document_sets: set[str] + + @classmethod + def from_index_chunk( + cls, index_chunk: IndexChunk, access: "DocumentAccess", document_sets: set[str] + ) -> "DocMetadataAwareIndexChunk": + return cls( + **{ + field.name: getattr(index_chunk, field.name) + for field in fields(index_chunk) + }, + access=access, + document_sets=document_sets, + ) + + @dataclass class InferenceChunk(BaseChunk): document_id: str diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 3dde92e5d..4ad6fbec8 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -11,7 +11,8 @@ SEMANTIC_IDENTIFIER = "semantic_identifier" SECTION_CONTINUATION = "section_continuation" EMBEDDINGS = "embeddings" ALLOWED_USERS = "allowed_users" -ALLOWED_GROUPS = "allowed_groups" +ACCESS_CONTROL_LIST = "access_control_list" +DOCUMENT_SETS = "document_sets" METADATA = "metadata" MATCH_HIGHLIGHTS = "match_highlights" # stored in the `metadata` of a chunk. Used to signify that this chunk should @@ -20,6 +21,7 @@ MATCH_HIGHLIGHTS = "match_highlights" IGNORE_FOR_QA = "ignore_for_qa" GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key" PUBLIC_DOC_PAT = "PUBLIC" +PUBLIC_DOCUMENT_SET = "__PUBLIC" QUOTE = "quote" BOOST = "boost" SCORE = "score" diff --git a/backend/danswer/datastores/datastore_utils.py b/backend/danswer/datastores/datastore_utils.py index d606f2898..a4fbb02f3 100644 --- a/backend/danswer/datastores/datastore_utils.py +++ b/backend/danswer/datastores/datastore_utils.py @@ -1,15 +1,8 @@ import math import uuid -from collections.abc import Callable -from copy import deepcopy -from typing import TypeVar - -from pydantic import BaseModel from danswer.chunking.models import IndexChunk from danswer.chunking.models import InferenceChunk -from danswer.configs.constants import PUBLIC_DOC_PAT -from danswer.connectors.models import IndexAttemptMetadata DEFAULT_BATCH_SIZE = 30 @@ -37,89 +30,3 @@ def get_uuid_from_chunk( [doc_str, str(chunk.chunk_id), str(mini_chunk_ind)] ) return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string) - - -class CrossConnectorDocumentMetadata(BaseModel): - """Represents metadata about a single document. This is needed since the - `Document` class represents a document from a single connector, but that same - document may be indexed by multiple connectors.""" - - allowed_users: list[str] - allowed_user_groups: list[str] - already_in_index: bool - - -# Takes the chunk identifier returns the existing metaddata about that chunk -CrossConnectorDocumentMetadataFetchCallable = Callable[ - [str], CrossConnectorDocumentMetadata | None -] - - -T = TypeVar("T") - - -def _add_if_not_exists(obj_list: list[T], item: T) -> list[T]: - if item in obj_list: - return obj_list - return obj_list + [item] - - -def update_cross_connector_document_metadata_map( - chunk: IndexChunk, - cross_connector_document_metadata_map: dict[str, CrossConnectorDocumentMetadata], - doc_store_cross_connector_document_metadata_fetch_fn: CrossConnectorDocumentMetadataFetchCallable, - index_attempt_metadata: IndexAttemptMetadata, -) -> tuple[dict[str, CrossConnectorDocumentMetadata], bool]: - """Returns an updated document_id -> CrossConnectorDocumentMetadata map and - if the document's chunks need to be wiped.""" - user_str = ( - PUBLIC_DOC_PAT - if index_attempt_metadata.user_id is None - else str(index_attempt_metadata.user_id) - ) - - cross_connector_document_metadata_map = deepcopy( - cross_connector_document_metadata_map - ) - first_chunk_uuid = str(get_uuid_from_chunk(chunk)) - document = chunk.source_document - if document.id not in cross_connector_document_metadata_map: - document_metadata_in_doc_store = ( - doc_store_cross_connector_document_metadata_fetch_fn(first_chunk_uuid) - ) - - if not document_metadata_in_doc_store: - cross_connector_document_metadata_map[ - document.id - ] = CrossConnectorDocumentMetadata( - allowed_users=[user_str], - allowed_user_groups=[], - already_in_index=False, - ) - # First chunk does not exist so document does not exist, no need for deletion - return cross_connector_document_metadata_map, False - else: - # TODO introduce groups logic here - cross_connector_document_metadata_map[ - document.id - ] = CrossConnectorDocumentMetadata( - allowed_users=_add_if_not_exists( - document_metadata_in_doc_store.allowed_users, user_str - ), - allowed_user_groups=document_metadata_in_doc_store.allowed_user_groups, - already_in_index=True, - ) - # First chunk exists, but with update, there may be less total chunks now - # Must delete rest of document chunks - return cross_connector_document_metadata_map, True - - existing_document_metadata = cross_connector_document_metadata_map[document.id] - cross_connector_document_metadata_map[document.id] = CrossConnectorDocumentMetadata( - allowed_users=_add_if_not_exists( - existing_document_metadata.allowed_users, user_str - ), - allowed_user_groups=existing_document_metadata.allowed_user_groups, - already_in_index=existing_document_metadata.already_in_index, - ) - # If document is already in the mapping, don't delete again - return cross_connector_document_metadata_map, False diff --git a/backend/danswer/datastores/document_index.py b/backend/danswer/datastores/document_index.py index 74e303a24..7af973c80 100644 --- a/backend/danswer/datastores/document_index.py +++ b/backend/danswer/datastores/document_index.py @@ -1,14 +1,13 @@ from typing import Type from uuid import UUID -from danswer.chunking.models import IndexChunk +from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.app_configs import DOCUMENT_INDEX_TYPE from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentIndexType from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF -from danswer.connectors.models import IndexAttemptMetadata from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import IndexFilter @@ -43,11 +42,10 @@ class SplitDocumentIndex(DocumentIndex): def index( self, - chunks: list[IndexChunk], - index_attempt_metadata: IndexAttemptMetadata, + chunks: list[DocMetadataAwareIndexChunk], ) -> set[DocumentInsertionRecord]: - keyword_index_result = self.keyword_index.index(chunks, index_attempt_metadata) - vector_index_result = self.vector_index.index(chunks, index_attempt_metadata) + keyword_index_result = self.keyword_index.index(chunks) + vector_index_result = self.vector_index.index(chunks) if keyword_index_result != vector_index_result: logger.error( f"Inconsistent document indexing:\n" diff --git a/backend/danswer/datastores/indexing_pipeline.py b/backend/danswer/datastores/indexing_pipeline.py index 6b693d446..c6b0cfbdd 100644 --- a/backend/danswer/datastores/indexing_pipeline.py +++ b/backend/danswer/datastores/indexing_pipeline.py @@ -4,16 +4,17 @@ from typing import Protocol from sqlalchemy.orm import Session +from danswer.access.access import get_access_for_documents from danswer.chunking.chunk import Chunker from danswer.chunking.chunk import DefaultChunker from danswer.chunking.models import DocAwareChunk -from danswer.chunking.models import IndexChunk +from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.connectors.models import Document from danswer.connectors.models import IndexAttemptMetadata from danswer.datastores.document_index import get_default_document_index from danswer.datastores.interfaces import DocumentIndex -from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import DocumentMetadata +from danswer.db.document import prepare_to_modify_documents from danswer.db.document import upsert_documents_complete from danswer.db.engine import get_sqlalchemy_engine from danswer.search.models import Embedder @@ -30,40 +31,25 @@ class IndexingPipelineProtocol(Protocol): ... -def _upsert_insertion_records( - insertion_records: set[DocumentInsertionRecord], +def _upsert_documents( + document_ids: list[str], index_attempt_metadata: IndexAttemptMetadata, doc_m_data_lookup: dict[str, tuple[str, str]], + db_session: Session, ) -> None: - with Session(get_sqlalchemy_engine()) as session: - upsert_documents_complete( - db_session=session, - document_metadata_batch=[ - DocumentMetadata( - connector_id=index_attempt_metadata.connector_id, - credential_id=index_attempt_metadata.credential_id, - document_id=i_r.document_id, - semantic_identifier=doc_m_data_lookup[i_r.document_id][0], - first_link=doc_m_data_lookup[i_r.document_id][1], - ) - for i_r in insertion_records - ], - ) - - -def _get_net_new_documents( - insertion_records: list[DocumentInsertionRecord], -) -> int: - net_new_documents = 0 - seen_documents: set[str] = set() - for insertion_record in insertion_records: - if insertion_record.already_existed: - continue - - if insertion_record.document_id not in seen_documents: - net_new_documents += 1 - seen_documents.add(insertion_record.document_id) - return net_new_documents + upsert_documents_complete( + db_session=db_session, + document_metadata_batch=[ + DocumentMetadata( + connector_id=index_attempt_metadata.connector_id, + credential_id=index_attempt_metadata.credential_id, + document_id=document_id, + semantic_identifier=doc_m_data_lookup[document_id][0], + first_link=doc_m_data_lookup[document_id][1], + ) + for document_id in document_ids + ], + ) def _extract_minimal_document_metadata(doc: Document) -> tuple[str, str]: @@ -82,60 +68,52 @@ def _indexing_pipeline( """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" - + document_ids = [document.id for document in documents] document_metadata_lookup = { doc.id: _extract_minimal_document_metadata(doc) for doc in documents } - chunks: list[DocAwareChunk] = list( - chain(*[chunker.chunk(document=document) for document in documents]) - ) - logger.debug( - f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}" - ) - chunks_with_embeddings = embedder.embed(chunks=chunks) + with Session(get_sqlalchemy_engine()) as db_session: + # acquires a lock on the documents so that no other process can modify them + prepare_to_modify_documents(db_session=db_session, document_ids=document_ids) - # if there are any empty chunks, remove them. This usually happens due to a - # bug in a connector. Handling here to prevent a bad connector from - # breaking retrieval completely. - final_chunks_with_embeddings: list[IndexChunk] = [] - for chunk in chunks_with_embeddings: - if chunk.content: - final_chunks_with_embeddings.append(chunk) - else: - bad_chunk_link = ( - chunk.source_document.sections[0].link - if chunk.source_document.sections - else "" - ) - logger.error( - f"Found empty chunk, skipping. Chunk ID: '{chunk.chunk_id}', " - f"Document ID: '{chunk.source_document.id}', " - f"Document Link: '{bad_chunk_link}'" - ) - - # A document will not be spread across different batches, so all the documents with chunks in this set, are fully - # represented by the chunks in this set - insertion_records = document_index.index( - chunks=final_chunks_with_embeddings, - index_attempt_metadata=index_attempt_metadata, - ) - - # TODO (chris): remove this try/except after issue with null document_id is resolved - try: - _upsert_insertion_records( - insertion_records=insertion_records, + # create records in the source of truth about these documents + _upsert_documents( + document_ids=document_ids, index_attempt_metadata=index_attempt_metadata, doc_m_data_lookup=document_metadata_lookup, + db_session=db_session, ) - except Exception as e: - logger.error( - f"Failed to upsert insertion records from vector index for documents: " - f"{[document.to_short_descriptor() for document in documents]}, " - f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks_with_embeddings]}" - f"for insertion records: {insertion_records}" + + chunks: list[DocAwareChunk] = list( + chain(*[chunker.chunk(document=document) for document in documents]) + ) + logger.debug( + f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}" + ) + chunks_with_embeddings = embedder.embed(chunks=chunks) + + # Attach the latest status from Postgres (source of truth for access) to each + # chunk. This access status will be attached to each chunk in the document index + # TODO: attach document sets to the chunk based on the status of Postgres as well + document_id_to_access_info = get_access_for_documents( + document_ids=document_ids, db_session=db_session + ) + access_aware_chunks = [ + DocMetadataAwareIndexChunk.from_index_chunk( + index_chunk=chunk, + access=document_id_to_access_info[chunk.source_document.id], + document_sets=set(), + ) + for chunk in chunks_with_embeddings + ] + + # A document will not be spread across different batches, so all the + # documents with chunks in this set, are fully represented by the chunks + # in this set + insertion_records = document_index.index( + chunks=access_aware_chunks, ) - raise e return len([r for r in insertion_records if r.already_existed is False]), len( chunks diff --git a/backend/danswer/datastores/interfaces.py b/backend/danswer/datastores/interfaces.py index e5452b6de..969cc1934 100644 --- a/backend/danswer/datastores/interfaces.py +++ b/backend/danswer/datastores/interfaces.py @@ -3,10 +3,10 @@ from dataclasses import dataclass from typing import Any from uuid import UUID -from danswer.chunking.models import IndexChunk +from danswer.access.models import DocumentAccess +from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF -from danswer.connectors.models import IndexAttemptMetadata IndexFilter = dict[str, str | list[str] | None] @@ -33,7 +33,8 @@ class UpdateRequest: document_ids: list[str] # all other fields will be left alone - allowed_users: list[str] | None = None + access: DocumentAccess | None = None + document_sets: set[str] | None = None boost: float | None = None @@ -51,7 +52,7 @@ class Verifiable(abc.ABC): class Indexable(abc.ABC): @abc.abstractmethod def index( - self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata + self, chunks: list[DocMetadataAwareIndexChunk] ) -> set[DocumentInsertionRecord]: """Indexes document chunks into the Document Index and return the IDs of all the documents indexed""" raise NotImplementedError diff --git a/backend/danswer/datastores/qdrant/indexing.py b/backend/danswer/datastores/qdrant/indexing.py index 994bcf7ed..c2019d96a 100644 --- a/backend/danswer/datastores/qdrant/indexing.py +++ b/backend/danswer/datastores/qdrant/indexing.py @@ -1,6 +1,4 @@ import json -from functools import partial -from typing import cast from qdrant_client import QdrantClient from qdrant_client.http import models @@ -8,8 +6,7 @@ from qdrant_client.http.exceptions import ResponseHandlingException from qdrant_client.http.models.models import UpdateResult from qdrant_client.models import PointStruct -from danswer.chunking.models import IndexChunk -from danswer.configs.constants import ALLOWED_GROUPS +from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.configs.constants import ALLOWED_USERS from danswer.configs.constants import BLURB from danswer.configs.constants import CHUNK_ID @@ -20,15 +17,9 @@ from danswer.configs.constants import SECTION_CONTINUATION from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_TYPE -from danswer.connectors.models import IndexAttemptMetadata -from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE from danswer.datastores.datastore_utils import get_uuid_from_chunk -from danswer.datastores.datastore_utils import ( - update_cross_connector_document_metadata_map, -) from danswer.datastores.interfaces import DocumentInsertionRecord -from danswer.datastores.qdrant.utils import get_payload_from_record from danswer.utils.clients import get_qdrant_client from danswer.utils.logger import setup_logger @@ -36,40 +27,18 @@ from danswer.utils.logger import setup_logger logger = setup_logger() -def get_qdrant_document_cross_connector_metadata( +def _does_document_exist( doc_chunk_id: str, collection_name: str, q_client: QdrantClient -) -> CrossConnectorDocumentMetadata | None: +) -> bool: """Get whether a document is found and the existing whitelists""" results = q_client.retrieve( collection_name=collection_name, ids=[doc_chunk_id], - with_payload=[ALLOWED_USERS, ALLOWED_GROUPS], ) if len(results) == 0: - return None - payload = get_payload_from_record(results[0]) - allowed_users = cast(list[str] | None, payload.get(ALLOWED_USERS)) - allowed_groups = cast(list[str] | None, payload.get(ALLOWED_GROUPS)) - if allowed_users is None: - allowed_users = [] - logger.error( - "Qdrant Index is corrupted, Document found with no user access lists." - f"Assuming no users have access to chunk with ID '{doc_chunk_id}'." - ) - if allowed_groups is None: - allowed_groups = [] - logger.error( - "Qdrant Index is corrupted, Document found with no groups access lists." - f"Assuming no groups have access to chunk with ID '{doc_chunk_id}'." - ) + return False - return CrossConnectorDocumentMetadata( - # if either `allowed_users` or `allowed_groups` are missing from the - # point, then assume that the document has no allowed users. - allowed_users=allowed_users, - allowed_user_groups=allowed_groups, - already_in_index=True, - ) + return True def delete_qdrant_doc_chunks( @@ -92,8 +61,7 @@ def delete_qdrant_doc_chunks( def index_qdrant_chunks( - chunks: list[IndexChunk], - index_attempt_metadata: IndexAttemptMetadata, + chunks: list[DocMetadataAwareIndexChunk], collection: str, client: QdrantClient | None = None, batch_upsert: bool = True, @@ -104,29 +72,15 @@ def index_qdrant_chunks( point_structs: list[PointStruct] = [] insertion_records: set[DocumentInsertionRecord] = set() - # Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings - cross_connector_document_metadata_map: dict[ - str, CrossConnectorDocumentMetadata - ] = {} # document ids of documents that existed BEFORE this indexing already_existing_documents: set[str] = set() for chunk in chunks: document = chunk.source_document - ( - cross_connector_document_metadata_map, - should_delete_doc, - ) = update_cross_connector_document_metadata_map( - chunk=chunk, - cross_connector_document_metadata_map=cross_connector_document_metadata_map, - doc_store_cross_connector_document_metadata_fetch_fn=partial( - get_qdrant_document_cross_connector_metadata, - collection_name=collection, - q_client=q_client, - ), - index_attempt_metadata=index_attempt_metadata, - ) - if should_delete_doc: + # Delete all chunks related to the document if (1) it already exists and + # (2) this is our first time running into it during this indexing attempt + document_exists = _does_document_exist(document.id, collection, q_client) + if document_exists and document.id not in already_existing_documents: # Processing the first chunk of the doc and the doc exists delete_qdrant_doc_chunks(document.id, collection, q_client) already_existing_documents.add(document.id) @@ -155,12 +109,7 @@ def index_qdrant_chunks( SOURCE_LINKS: chunk.source_links, SEMANTIC_IDENTIFIER: document.semantic_identifier, SECTION_CONTINUATION: chunk.section_continuation, - ALLOWED_USERS: cross_connector_document_metadata_map[ - document.id - ].allowed_users, - ALLOWED_GROUPS: cross_connector_document_metadata_map[ - document.id - ].allowed_user_groups, + ALLOWED_USERS: json.dumps(chunk.access.to_acl()), METADATA: json.dumps(document.metadata), }, vector=embedding, diff --git a/backend/danswer/datastores/qdrant/store.py b/backend/danswer/datastores/qdrant/store.py index c45ea951e..14505c685 100644 --- a/backend/danswer/datastores/qdrant/store.py +++ b/backend/danswer/datastores/qdrant/store.py @@ -8,7 +8,7 @@ from qdrant_client.http.models import Filter from qdrant_client.http.models import MatchAny from qdrant_client.http.models import MatchValue -from danswer.chunking.models import IndexChunk +from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.app_configs import NUM_RETURNED_HITS @@ -16,7 +16,6 @@ from danswer.configs.constants import ALLOWED_USERS from danswer.configs.constants import DOCUMENT_ID from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF -from danswer.connectors.models import IndexAttemptMetadata from danswer.datastores.datastore_utils import get_uuid_from_chunk from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import IndexFilter @@ -126,12 +125,10 @@ class QdrantIndex(VectorIndex): def index( self, - chunks: list[IndexChunk], - index_attempt_metadata: IndexAttemptMetadata, + chunks: list[DocMetadataAwareIndexChunk], ) -> set[DocumentInsertionRecord]: return index_qdrant_chunks( chunks=chunks, - index_attempt_metadata=index_attempt_metadata, collection=self.collection, client=self.client, ) @@ -145,12 +142,15 @@ class QdrantIndex(VectorIndex): items=update_request.document_ids, batch_size=_BATCH_SIZE, ): + if update_request.access is None: + continue + chunk_ids = _get_points_from_document_ids( doc_id_batch, self.collection, self.client ) self.client.set_payload( collection_name=self.collection, - payload={ALLOWED_USERS: update_request.allowed_users}, + payload={ALLOWED_USERS: update_request.access.to_acl()}, points=chunk_ids, ) diff --git a/backend/danswer/datastores/typesense/store.py b/backend/danswer/datastores/typesense/store.py index abc7be740..23947a7ca 100644 --- a/backend/danswer/datastores/typesense/store.py +++ b/backend/danswer/datastores/typesense/store.py @@ -1,17 +1,14 @@ import json -from functools import partial from typing import Any -from typing import cast from uuid import UUID import typesense # type: ignore from typesense.exceptions import ObjectNotFound # type: ignore -from danswer.chunking.models import IndexChunk +from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.app_configs import NUM_RETURNED_HITS -from danswer.configs.constants import ALLOWED_GROUPS from danswer.configs.constants import ALLOWED_USERS from danswer.configs.constants import BLURB from danswer.configs.constants import CHUNK_ID @@ -23,13 +20,8 @@ from danswer.configs.constants import SECTION_CONTINUATION from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_TYPE -from danswer.connectors.models import IndexAttemptMetadata -from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE from danswer.datastores.datastore_utils import get_uuid_from_chunk -from danswer.datastores.datastore_utils import ( - update_cross_connector_document_metadata_map, -) from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import IndexFilter from danswer.datastores.interfaces import KeywordIndex @@ -63,7 +55,6 @@ def create_typesense_collection( {"name": SEMANTIC_IDENTIFIER, "type": "string"}, {"name": SECTION_CONTINUATION, "type": "bool"}, {"name": ALLOWED_USERS, "type": "string[]"}, - {"name": ALLOWED_GROUPS, "type": "string[]"}, {"name": METADATA, "type": "string"}, ], } @@ -81,38 +72,16 @@ def _check_typesense_collection_exist( return True -def _get_typesense_document_cross_connector_metadata( +def _does_document_exist( doc_chunk_id: str, collection_name: str, ts_client: typesense.Client -) -> CrossConnectorDocumentMetadata | None: +) -> bool: """Returns whether the document already exists and the users/group whitelists""" try: - document = cast( - dict[str, Any], - ts_client.collections[collection_name].documents[doc_chunk_id].retrieve(), - ) + ts_client.collections[collection_name].documents[doc_chunk_id].retrieve() except ObjectNotFound: - return None + return False - allowed_users = cast(list[str] | None, document.get(ALLOWED_USERS)) - allowed_groups = cast(list[str] | None, document.get(ALLOWED_GROUPS)) - if allowed_users is None: - allowed_users = [] - logger.error( - "Typesense Index is corrupted, Document found with no user access lists." - f"Assuming no users have access to chunk with ID '{doc_chunk_id}'." - ) - if allowed_groups is None: - allowed_groups = [] - logger.error( - "Typesense Index is corrupted, Document found with no groups access lists." - f"Assuming no groups have access to chunk with ID '{doc_chunk_id}'." - ) - - return CrossConnectorDocumentMetadata( - allowed_users=allowed_users, - allowed_user_groups=allowed_groups, - already_in_index=True, - ) + return True def _delete_typesense_doc_chunks( @@ -127,8 +96,7 @@ def _delete_typesense_doc_chunks( def _index_typesense_chunks( - chunks: list[IndexChunk], - index_attempt_metadata: IndexAttemptMetadata, + chunks: list[DocMetadataAwareIndexChunk], collection: str, client: typesense.Client | None = None, batch_upsert: bool = True, @@ -137,33 +105,20 @@ def _index_typesense_chunks( insertion_records: set[DocumentInsertionRecord] = set() new_documents: list[dict[str, Any]] = [] - cross_connector_document_metadata_map: dict[ - str, CrossConnectorDocumentMetadata - ] = {} # document ids of documents that existed BEFORE this indexing already_existing_documents: set[str] = set() for chunk in chunks: document = chunk.source_document - ( - cross_connector_document_metadata_map, - should_delete_doc, - ) = update_cross_connector_document_metadata_map( - chunk=chunk, - cross_connector_document_metadata_map=cross_connector_document_metadata_map, - doc_store_cross_connector_document_metadata_fetch_fn=partial( - _get_typesense_document_cross_connector_metadata, - collection_name=collection, - ts_client=ts_client, - ), - index_attempt_metadata=index_attempt_metadata, - ) + typesense_id = str(get_uuid_from_chunk(chunk)) - if should_delete_doc: + # Delete all chunks related to the document if (1) it already exists and + # (2) this is our first time running into it during this indexing attempt + document_exists = _does_document_exist(typesense_id, collection, ts_client) + if document_exists and document.id not in already_existing_documents: # Processing the first chunk of the doc and the doc exists _delete_typesense_doc_chunks(document.id, collection, ts_client) already_existing_documents.add(document.id) - typesense_id = str(get_uuid_from_chunk(chunk)) insertion_records.add( DocumentInsertionRecord( document_id=document.id, @@ -181,12 +136,7 @@ def _index_typesense_chunks( SOURCE_LINKS: json.dumps(chunk.source_links), SEMANTIC_IDENTIFIER: document.semantic_identifier, SECTION_CONTINUATION: chunk.section_continuation, - ALLOWED_USERS: cross_connector_document_metadata_map[ - document.id - ].allowed_users, - ALLOWED_GROUPS: cross_connector_document_metadata_map[ - document.id - ].allowed_user_groups, + ALLOWED_USERS: json.dumps(chunk.access.to_acl()), METADATA: json.dumps(document.metadata), } ) @@ -260,11 +210,10 @@ class TypesenseIndex(KeywordIndex): create_typesense_collection(collection_name=self.collection) def index( - self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata + self, chunks: list[DocMetadataAwareIndexChunk] ) -> set[DocumentInsertionRecord]: return _index_typesense_chunks( chunks=chunks, - index_attempt_metadata=index_attempt_metadata, collection=self.collection, client=self.ts_client, ) @@ -277,8 +226,14 @@ class TypesenseIndex(KeywordIndex): for id_batch in batch_generator( items=update_request.document_ids, batch_size=_BATCH_SIZE ): + if update_request.access is None: + continue + typesense_updates = [ - {DOCUMENT_ID: doc_id, ALLOWED_USERS: update_request.allowed_users} + { + DOCUMENT_ID: doc_id, + ALLOWED_USERS: update_request.access.to_acl(), + } for doc_id in id_batch ] self.ts_client.collections[self.collection].documents.import_( diff --git a/backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd b/backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd index 6d960990a..228480632 100644 --- a/backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd @@ -56,11 +56,11 @@ schema danswer_chunk { distance-metric: angular } } - field allowed_users type array { + field access_control_list type weightedset { indexing: summary | attribute attribute: fast-search } - field allowed_groups type array { + field document_sets type weightedset { indexing: summary | attribute attribute: fast-search } diff --git a/backend/danswer/datastores/vespa/store.py b/backend/danswer/datastores/vespa/store.py index f2fb61eaf..0fc67c1d5 100644 --- a/backend/danswer/datastores/vespa/store.py +++ b/backend/danswer/datastores/vespa/store.py @@ -9,7 +9,7 @@ import requests from requests import HTTPError from requests import Response -from danswer.chunking.models import IndexChunk +from danswer.chunking.models import DocMetadataAwareIndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.app_configs import NUM_RETURNED_HITS @@ -17,14 +17,14 @@ from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP from danswer.configs.app_configs import VESPA_HOST from danswer.configs.app_configs import VESPA_PORT from danswer.configs.app_configs import VESPA_TENANT_PORT -from danswer.configs.constants import ALLOWED_GROUPS -from danswer.configs.constants import ALLOWED_USERS +from danswer.configs.constants import ACCESS_CONTROL_LIST from danswer.configs.constants import BLURB from danswer.configs.constants import BOOST from danswer.configs.constants import CHUNK_ID from danswer.configs.constants import CONTENT from danswer.configs.constants import DEFAULT_BOOST from danswer.configs.constants import DOCUMENT_ID +from danswer.configs.constants import DOCUMENT_SETS from danswer.configs.constants import EMBEDDINGS from danswer.configs.constants import MATCH_HIGHLIGHTS from danswer.configs.constants import METADATA @@ -35,12 +35,7 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_TYPE from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF -from danswer.connectors.models import IndexAttemptMetadata -from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata from danswer.datastores.datastore_utils import get_uuid_from_chunk -from danswer.datastores.datastore_utils import ( - update_cross_connector_document_metadata_map, -) from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import IndexFilter @@ -65,33 +60,20 @@ _BATCH_SIZE = 100 # Specific to Vespa CONTENT_SUMMARY = "content_summary" -def _get_vespa_document_cross_connector_metadata( +def _does_document_exist( doc_chunk_id: str, -) -> CrossConnectorDocumentMetadata | None: +) -> bool: """Returns whether the document already exists and the users/group whitelists""" doc_fetch_response = requests.get(f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}") if doc_fetch_response.status_code == 404: - return None + return False if doc_fetch_response.status_code != 200: raise RuntimeError( f"Unexpected fetch document by ID value from Vespa " f"with error {doc_fetch_response.status_code}" ) - - doc_fields = doc_fetch_response.json()["fields"] - allowed_users = doc_fields.get(ALLOWED_USERS) - allowed_groups = doc_fields.get(ALLOWED_GROUPS) - # Add group permission later, empty list gets saved/loaded as a null - if allowed_users is None: - raise RuntimeError( - "Vespa Index is corrupted, Document found with no access lists." - ) - return CrossConnectorDocumentMetadata( - allowed_users=allowed_users or [], - allowed_user_groups=allowed_groups or [], - already_in_index=True, - ) + return True def _get_vespa_chunk_ids_by_document_id( @@ -127,32 +109,23 @@ def _delete_vespa_doc_chunks(document_id: str) -> bool: def _index_vespa_chunks( - chunks: list[IndexChunk], - index_attempt_metadata: IndexAttemptMetadata, + chunks: list[DocMetadataAwareIndexChunk], ) -> set[DocumentInsertionRecord]: json_header = { "Content-Type": "application/json", } insertion_records: set[DocumentInsertionRecord] = set() - cross_connector_document_metadata_map: dict[ - str, CrossConnectorDocumentMetadata - ] = {} # document ids of documents that existed BEFORE this indexing already_existing_documents: set[str] = set() for chunk in chunks: document = chunk.source_document - ( - cross_connector_document_metadata_map, - should_delete_doc, - ) = update_cross_connector_document_metadata_map( - chunk=chunk, - cross_connector_document_metadata_map=cross_connector_document_metadata_map, - doc_store_cross_connector_document_metadata_fetch_fn=_get_vespa_document_cross_connector_metadata, - index_attempt_metadata=index_attempt_metadata, - ) + # No minichunk documents in vespa, minichunk vectors are stored in the chunk itself + vespa_chunk_id = str(get_uuid_from_chunk(chunk)) - if should_delete_doc: - # Processing the first chunk of the doc and the doc exists + # Delete all chunks related to the document if (1) it already exists and + # (2) this is our first time running into it during this indexing attempt + chunk_exists = _does_document_exist(vespa_chunk_id) + if chunk_exists and document.id not in already_existing_documents: deletion_success = _delete_vespa_doc_chunks(document.id) if not deletion_success: raise RuntimeError( @@ -160,9 +133,6 @@ def _index_vespa_chunks( ) already_existing_documents.add(document.id) - # No minichunk documents in vespa, minichunk vectors are stored in the chunk itself - vespa_chunk_id = str(get_uuid_from_chunk(chunk)) - embeddings = chunk.embeddings embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding} if embeddings.mini_chunk_embeddings: @@ -183,12 +153,10 @@ def _index_vespa_chunks( METADATA: json.dumps(document.metadata), EMBEDDINGS: embeddings_name_vector_map, BOOST: DEFAULT_BOOST, - ALLOWED_USERS: cross_connector_document_metadata_map[ - document.id - ].allowed_users, - ALLOWED_GROUPS: cross_connector_document_metadata_map[ - document.id - ].allowed_user_groups, + # the only `set` vespa has is `weightedset`, so we have to give each + # element an arbitrary weight + ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()}, + DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets}, } def _index_chunk( @@ -244,16 +212,13 @@ def _index_vespa_chunks( def _build_vespa_filters( user_id: UUID | None, filters: list[IndexFilter] | None ) -> str: - filter_str = "" - # Permissions filter - # TODO group permissioning + # Permissions filters + acl_filter_stmts = [f'{ACCESS_CONTROL_LIST} contains "{PUBLIC_DOC_PAT}"'] if user_id: - filter_str += ( - f'({ALLOWED_USERS} contains "{user_id}" or ' - f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}") and ' - ) - else: - filter_str += f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}" and ' + acl_filter_stmts.append(f'{ACCESS_CONTROL_LIST} contains "{user_id}"') + filter_str = "(" + " or ".join(acl_filter_stmts) + ") and" + + # TODO: have document sets passed in + add document set based filters # Provided query filters if filters: @@ -399,31 +364,38 @@ class VespaIndex(DocumentIndex): def index( self, - chunks: list[IndexChunk], - index_attempt_metadata: IndexAttemptMetadata, + chunks: list[DocMetadataAwareIndexChunk], ) -> set[DocumentInsertionRecord]: - return _index_vespa_chunks( - chunks=chunks, index_attempt_metadata=index_attempt_metadata - ) + return _index_vespa_chunks(chunks=chunks) def update(self, update_requests: list[UpdateRequest]) -> None: - logger.info( - f"Updating {len(update_requests)} documents' allowed_users in Vespa" - ) + logger.info(f"Updating {len(update_requests)} documents in Vespa") json_header = {"Content-Type": "application/json"} for update_request in update_requests: - if update_request.boost is None and update_request.allowed_users is None: + if ( + update_request.boost is None + and update_request.access is None + and update_request.document_sets is None + ): logger.error("Update request received but nothing to update") continue update_dict: dict[str, dict] = {"fields": {}} if update_request.boost is not None: update_dict["fields"][BOOST] = {"assign": update_request.boost} - if update_request.allowed_users is not None: - update_dict["fields"][ALLOWED_USERS] = { - "assign": update_request.allowed_users + if update_request.document_sets is not None: + update_dict["fields"][DOCUMENT_SETS] = { + "assign": { + document_set: 1 for document_set in update_request.document_sets + } + } + if update_request.access is not None: + update_dict["fields"][ACCESS_CONTROL_LIST] = { + "assign": { + acl_entry: 1 for acl_entry in update_request.access.to_acl() + } } for document_id in update_request.document_ids: diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 17592c40d..c88e9f945 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -1,4 +1,6 @@ +import time from collections.abc import Sequence +from uuid import UUID from sqlalchemy import and_ from sqlalchemy import delete @@ -10,18 +12,18 @@ from sqlalchemy.orm import Session from danswer.configs.constants import DEFAULT_BOOST from danswer.datastores.interfaces import DocumentMetadata from danswer.db.feedback import delete_document_feedback_for_documents +from danswer.db.models import Credential from danswer.db.models import Document as DbDocument from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.utils import model_to_dict +from danswer.server.models import ConnectorCredentialPairIdentifier from danswer.utils.logger import setup_logger logger = setup_logger() -def get_documents_with_single_connector_credential_pair( - db_session: Session, - connector_id: int, - credential_id: int, +def get_documents_for_connector_credential_pair( + db_session: Session, connector_id: int, credential_id: int, limit: int | None = None ) -> Sequence[DbDocument]: initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( and_( @@ -29,54 +31,63 @@ def get_documents_with_single_connector_credential_pair( DocumentByConnectorCredentialPair.credential_id == credential_id, ) ) - - # Filter it down to the documents with only a single connector/credential pair - # Meaning if this connector/credential pair is removed, this doc should be gone - trimmed_doc_ids_stmt = ( - select(DbDocument.id) - .join( - DocumentByConnectorCredentialPair, - DocumentByConnectorCredentialPair.id == DbDocument.id, - ) - .where(DbDocument.id.in_(initial_doc_ids_stmt)) - .group_by(DbDocument.id) - .having(func.count(DocumentByConnectorCredentialPair.id) == 1) - ) - - stmt = select(DbDocument).where(DbDocument.id.in_(trimmed_doc_ids_stmt)) + stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct() + if limit: + stmt = stmt.limit(limit) return db_session.scalars(stmt).all() -def get_document_by_connector_credential_pairs_indexed_by_multiple( +def get_document_connector_cnts( db_session: Session, - connector_id: int, - credential_id: int, -) -> Sequence[DocumentByConnectorCredentialPair]: - initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( - and_( - DocumentByConnectorCredentialPair.connector_id == connector_id, - DocumentByConnectorCredentialPair.credential_id == credential_id, + document_ids: list[str], +) -> Sequence[tuple[str, int]]: + stmt = ( + select( + DocumentByConnectorCredentialPair.id, + func.count(), ) + .where(DocumentByConnectorCredentialPair.id.in_(document_ids)) + .group_by(DocumentByConnectorCredentialPair.id) ) + return db_session.execute(stmt).all() # type: ignore - # Filter it down to the documents with more than 1 connector/credential pair - # Meaning if this connector/credential pair is removed, this doc is still accessible - trimmed_doc_ids_stmt = ( - select(DbDocument.id) - .join( - DocumentByConnectorCredentialPair, - DocumentByConnectorCredentialPair.id == DbDocument.id, + +def get_acccess_info_for_documents( + db_session: Session, + document_ids: list[str], + cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None, +) -> Sequence[tuple[str, list[UUID | None], bool]]: + """Gets back all relevant access info for the given documents. This includes + the user_ids for cc pairs that the document is associated with + whether any + of the associated cc pairs are intending to make the document globally public. + + If `cc_pair_to_delete` is specified, gets the above access info as if that + pair had been deleted. This is needed since we want to delete from the Vespa + before deleting from Postgres to ensure that the state of Postgres never "loses" + documents that still exist in Vespa. + """ + stmt = select( + DocumentByConnectorCredentialPair.id, + func.array_agg(Credential.user_id).label("user_ids"), + func.bool_or(Credential.public_doc).label("public_doc"), + ).where(DocumentByConnectorCredentialPair.id.in_(document_ids)) + + # pretend that the specified cc pair doesn't exist + if cc_pair_to_delete: + stmt = stmt.where( + and_( + DocumentByConnectorCredentialPair.connector_id + != cc_pair_to_delete.connector_id, + DocumentByConnectorCredentialPair.credential_id + != cc_pair_to_delete.credential_id, + ) ) - .where(DbDocument.id.in_(initial_doc_ids_stmt)) - .group_by(DbDocument.id) - .having(func.count(DocumentByConnectorCredentialPair.id) > 1) - ) - stmt = select(DocumentByConnectorCredentialPair).where( - DocumentByConnectorCredentialPair.id.in_(trimmed_doc_ids_stmt) - ) - - return db_session.execute(stmt).scalars().all() + stmt = stmt.join( + Credential, + DocumentByConnectorCredentialPair.credential_id == Credential.id, + ).group_by(DocumentByConnectorCredentialPair.id) + return db_session.execute(stmt).all() # type: ignore def upsert_documents( @@ -153,26 +164,24 @@ def upsert_documents_complete( def delete_document_by_connector_credential_pair( - db_session: Session, document_ids: list[str] + db_session: Session, + document_ids: list[str], + connector_credential_pair_identifier: ConnectorCredentialPairIdentifier + | None = None, ) -> None: - db_session.execute( - delete(DocumentByConnectorCredentialPair).where( - DocumentByConnectorCredentialPair.id.in_(document_ids) - ) + stmt = delete(DocumentByConnectorCredentialPair).where( + DocumentByConnectorCredentialPair.id.in_(document_ids) ) - - -def delete_document_by_connector_credential_pair_for_connector_credential_pair( - db_session: Session, connector_id: int, credential_id: int -) -> None: - db_session.execute( - delete(DocumentByConnectorCredentialPair).where( + if connector_credential_pair_identifier: + stmt = stmt.where( and_( - DocumentByConnectorCredentialPair.connector_id == connector_id, - DocumentByConnectorCredentialPair.credential_id == credential_id, + DocumentByConnectorCredentialPair.connector_id + == connector_credential_pair_identifier.connector_id, + DocumentByConnectorCredentialPair.credential_id + == connector_credential_pair_identifier.credential_id, ) ) - ) + db_session.execute(stmt) def delete_documents(db_session: Session, document_ids: list[str]) -> None: @@ -187,3 +196,48 @@ def delete_documents_complete(db_session: Session, document_ids: list[str]) -> N ) delete_documents(db_session, document_ids) db_session.commit() + + +def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool: + """Acquire locks for the specified documents. Ideally this shouldn't be + called with large list of document_ids (an exception could be made if the + length of holding the lock is very short). + + Will simply raise an exception if any of the documents are already locked. + This prevents deadlocks (assuming that the caller passes in all required + document IDs in a single call). + """ + stmt = ( + select(DbDocument) + .where(DbDocument.id.in_(document_ids)) + .with_for_update(nowait=True) + ) + # will raise exception if any of the documents are already locked + db_session.execute(stmt) + return True + + +_NUM_LOCK_ATTEMPTS = 10 +_LOCK_RETRY_DELAY = 30 + + +def prepare_to_modify_documents(db_session: Session, document_ids: list[str]) -> None: + """Try and acquire locks for the documents to prevent other jobs from + modifying them at the same time (e.g. avoid race conditions). This should be + called ahead of any modification to Vespa. Locks should be released by the + caller as soon as updates are complete by finishing the transaction.""" + lock_acquired = False + for _ in range(_NUM_LOCK_ATTEMPTS): + try: + lock_acquired = acquire_document_locks( + db_session=db_session, document_ids=document_ids + ) + except Exception as e: + logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}") + time.sleep(_LOCK_RETRY_DELAY) + + if not lock_acquired: + raise RuntimeError( + f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts " + f"for documents: {document_ids}" + ) diff --git a/backend/scripts/migrate_vespa_to_acl.py b/backend/scripts/migrate_vespa_to_acl.py new file mode 100644 index 000000000..4348eeac3 --- /dev/null +++ b/backend/scripts/migrate_vespa_to_acl.py @@ -0,0 +1,44 @@ +"""Script which updates Vespa to align with the access described in Postgres. +Should be run wehn a user who has docs already indexed switches over to the new +access control system. This allows them to not have to re-index all documents.""" +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.access.models import DocumentAccess +from danswer.datastores.document_index import get_default_document_index +from danswer.datastores.interfaces import UpdateRequest +from danswer.datastores.vespa.store import VespaIndex +from danswer.db.document import get_acccess_info_for_documents +from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.models import Document +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def _migrate_vespa_to_acl() -> None: + vespa_index = get_default_document_index() + if not isinstance(vespa_index, VespaIndex): + raise ValueError("This script is only for Vespa indexes") + + with Session(get_sqlalchemy_engine()) as db_session: + # for all documents, set the `access_control_list` field apporpriately + # based on the state of Postgres + documents = db_session.scalars(select(Document)).all() + document_access_info = get_acccess_info_for_documents( + db_session=db_session, + document_ids=[document.id for document in documents], + ) + vespa_index.update( + update_requests=[ + UpdateRequest( + document_ids=[document_id], + access=DocumentAccess.build(user_ids, is_public), + ) + for document_id, user_ids, is_public in document_access_info + ], + ) + + +if __name__ == "__main__": + _migrate_vespa_to_acl()