Transition to using access_control_list to manage access in Vespa (#450)

This commit is contained in:
Chris Weaver 2023-09-26 12:26:39 -07:00 committed by GitHub
parent c4e4e88301
commit 8594bac30b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 513 additions and 598 deletions

View File

@ -52,6 +52,9 @@ WORKDIR /app/danswer/datastores/vespa/app_config
RUN zip -r /app/danswer/vespa-app.zip .
WORKDIR /app
# TODO: remove this once all users have migrated
COPY ./danswer/scripts/migrate_vespa_to_acl.py /app/migrate_vespa_to_acl.py
ENV PYTHONPATH /app
# By default this container does nothing, it is used by api server and background which specify their own CMD

View File

@ -0,0 +1,36 @@
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.server.models import ConnectorCredentialPairIdentifier
def _get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None,
db_session: Session,
) -> dict[str, DocumentAccess]:
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
document_ids=document_ids,
cc_pair_to_delete=cc_pair_to_delete,
)
return {
document_id: DocumentAccess.build(user_ids, is_public)
for document_id, user_ids, is_public in document_access_info
}
def get_access_for_documents(
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
db_session: Session | None = None,
) -> dict[str, DocumentAccess]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
return _get_access_for_documents(
document_ids, cc_pair_to_delete, db_session
)
return _get_access_for_documents(document_ids, cc_pair_to_delete, db_session)

View File

@ -0,0 +1,20 @@
from dataclasses import dataclass
from uuid import UUID
from danswer.configs.constants import PUBLIC_DOC_PAT
@dataclass(frozen=True)
class DocumentAccess:
user_ids: set[str] # stringified UUIDs
is_public: bool
def to_acl(self) -> list[str]:
return list(self.user_ids) + ([PUBLIC_DOC_PAT] if self.is_public else [])
@classmethod
def build(cls, user_ids: list[UUID | None], is_public: bool) -> "DocumentAccess":
return cls(
user_ids={str(user_id) for user_id in user_ids if user_id},
is_public=is_public,
)

View File

@ -10,11 +10,9 @@ are multiple connector / credential pairs that have indexed it
connector / credential pair from the access list
(6) delete all relevant entries from postgres
"""
from collections import defaultdict
from sqlalchemy.orm import Session
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.access.access import get_access_for_documents
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
@ -22,24 +20,78 @@ from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import delete_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pair
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import (
delete_document_by_connector_credential_pair_for_connector_credential_pair,
)
from danswer.db.document import delete_document_by_connector_credential_pair
from danswer.db.document import delete_documents_complete
from danswer.db.document import (
get_document_by_connector_credential_pairs_indexed_by_multiple,
)
from danswer.db.document import (
get_documents_with_single_connector_credential_pair,
)
from danswer.db.document import get_document_connector_cnts
from danswer.db.document import get_documents_for_connector_credential_pair
from danswer.db.document import prepare_to_modify_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import Credential
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
logger = setup_logger()
_DELETION_BATCH_SIZE = 1000
def _delete_connector_credential_pair_batch(
document_ids: list[str],
connector_id: int,
credential_id: int,
document_index: DocumentIndex,
) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
# acquire lock for all documents in this batch so that indexing can't
# override the deletion
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
document_connector_cnts = get_document_connector_cnts(
db_session=db_session, document_ids=document_ids
)
# figure out which docs need to be completely deleted
document_ids_to_delete = [
document_id for document_id, cnt in document_connector_cnts if cnt == 1
]
logger.debug(f"Deleting documents: {document_ids_to_delete}")
document_index.delete(doc_ids=document_ids_to_delete)
delete_documents_complete(
db_session=db_session,
document_ids=document_ids_to_delete,
)
# figure out which docs need to be updated
document_ids_to_update = [
document_id for document_id, cnt in document_connector_cnts if cnt > 1
]
access_for_documents = get_access_for_documents(
document_ids=document_ids_to_update,
db_session=db_session,
cc_pair_to_delete=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
update_requests = [
UpdateRequest(
document_ids=[document_id],
access=access,
)
for document_id, access in access_for_documents.items()
]
logger.debug(f"Updating documents: {document_ids_to_update}")
document_index.update(update_requests=update_requests)
delete_document_by_connector_credential_pair(
db_session=db_session,
document_ids=document_ids_to_update,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
db_session.commit()
def _delete_connector_credential_pair(
db_session: Session,
@ -47,147 +99,45 @@ def _delete_connector_credential_pair(
connector_id: int,
credential_id: int,
) -> int:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
num_docs_deleted = 0
while True:
documents = get_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
limit=_DELETION_BATCH_SIZE,
)
if not documents:
break
_delete_connector_credential_pair_batch(
document_ids=[document.id for document in documents],
connector_id=connector_id,
credential_id=credential_id,
document_index=document_index,
)
num_docs_deleted += len(documents)
# cleanup everything else up
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
def _delete_singly_indexed_docs() -> int:
# if a document store entry is only indexed by this connector_credential_pair, delete it
docs_to_delete = get_documents_with_single_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if docs_to_delete:
document_ids = [doc.id for doc in docs_to_delete]
document_index.delete(doc_ids=document_ids)
# removes all `DocumentByConnectorCredentialPair`, and `Document`
# rows from the DB
delete_documents_complete(
db_session=db_session,
document_ids=list(document_ids),
)
return len(docs_to_delete)
num_docs_deleted = _delete_singly_indexed_docs()
logger.info(f"Deleted {num_docs_deleted} documents from document stores")
def _update_multi_indexed_docs(
connector_credential_pair: ConnectorCredentialPair,
) -> None:
# if a document is indexed by multiple connector_credential_pairs, we should
# update its access rather than outright delete it
document_by_connector_credential_pairs_to_update = (
get_document_by_connector_credential_pairs_indexed_by_multiple(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
)
def _get_user(
credential: Credential,
) -> str:
if credential.public_doc or not credential.user:
return PUBLIC_DOC_PAT
return str(credential.user.id)
# find out which documents need to be updated and what their new allowed_users
# should be. This is a bit slow as it requires looping through all the documents
to_be_deleted_user = _get_user(connector_credential_pair.credential)
document_ids_not_needing_update: set[str] = set()
document_id_to_allowed_users: dict[str, list[str]] = defaultdict(list)
for (
document_by_connector_credential_pair
) in document_by_connector_credential_pairs_to_update:
document_id = document_by_connector_credential_pair.id
user = _get_user(document_by_connector_credential_pair.credential)
document_id_to_allowed_users[document_id].append(user)
# if there's another connector / credential pair which has indexed this
# document with the same access, we don't need to update it since removing
# the access from this connector / credential pair won't change anything
if (
document_by_connector_credential_pair.connector_id != connector_id
or document_by_connector_credential_pair.credential_id != credential_id
) and user == to_be_deleted_user:
document_ids_not_needing_update.add(document_id)
# categorize into groups of updates to try and batch them more efficiently
update_groups: dict[tuple[str, ...], list[str]] = {}
for document_id, allowed_users_lst in document_id_to_allowed_users.items():
if document_id in document_ids_not_needing_update:
continue
allowed_users_lst.remove(to_be_deleted_user)
allowed_users = tuple(sorted(set(allowed_users_lst)))
update_groups[allowed_users] = update_groups.get(allowed_users, []) + [
document_id
]
# actually perform the updates in the document store
update_requests = [
UpdateRequest(document_ids=document_ids, allowed_users=list(allowed_users))
for allowed_users, document_ids in update_groups.items()
]
document_index.update(update_requests=update_requests)
# delete the rest of the `document_by_connector_credential_pair` rows for
# this connector / credential pair
delete_document_by_connector_credential_pair_for_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
_update_multi_indexed_docs(cc_pair)
def _cleanup() -> None:
# cleanup everything else up
# we cannot undo the deletion of the document store entries if something
# goes wrong since they happen outside the postgres world. Best we can do
# is keep everything else around and mark the deletion attempt as failed.
# If it's a transient failure, re-deleting the connector / credential pair should
# fix the weird state.
# TODO: lock anything to do with this connector via transaction isolation
# NOTE: we have to delete index_attempts and deletion_attempts since they both
# have foreign key columns to the connector
delete_index_attempts(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
delete_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.debug("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
_cleanup()
delete_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id,
)
if not connector or not len(connector.credentials):
logger.debug("Found no credentials left for connector, deleting connector")
db_session.delete(connector)
db_session.commit()
logger.info(
"Successfully deleted connector_credential_pair with connector_id:"
@ -199,6 +149,21 @@ def _delete_connector_credential_pair(
def cleanup_connector_credential_pair(connector_id: int, credential_id: int) -> int:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# validate that the connector / credential pair is deletable
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if not cc_pair or not check_deletion_attempt_is_allowed(
connector_credential_pair=cc_pair
):
raise ValueError(
"Cannot run deletion attempt - connector_credential_pair is not deletable. "
"This is likely because there is an ongoing / planned indexing attempt OR the "
"connector is not disabled."
)
try:
return _delete_connector_credential_pair(
db_session=db_session,

View File

@ -1,9 +1,11 @@
import inspect
import json
from dataclasses import dataclass
from dataclasses import fields
from typing import Any
from typing import cast
from danswer.access.models import DocumentAccess
from danswer.configs.constants import BLURB
from danswer.configs.constants import BOOST
from danswer.configs.constants import MATCH_HIGHLIGHTS
@ -14,6 +16,7 @@ from danswer.configs.constants import SOURCE_LINKS
from danswer.connectors.models import Document
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -55,6 +58,34 @@ class IndexChunk(DocAwareChunk):
embeddings: ChunkEmbedding
@dataclass
class DocMetadataAwareIndexChunk(IndexChunk):
"""An `IndexChunk` that contains all necessary metadata to be indexed. This includes
the following:
access: holds all information about which users should have access to the
source document for this chunk.
document_sets: all document sets the source document for this chunk is a part
of. This is used for filtering / personas.
"""
access: "DocumentAccess"
document_sets: set[str]
@classmethod
def from_index_chunk(
cls, index_chunk: IndexChunk, access: "DocumentAccess", document_sets: set[str]
) -> "DocMetadataAwareIndexChunk":
return cls(
**{
field.name: getattr(index_chunk, field.name)
for field in fields(index_chunk)
},
access=access,
document_sets=document_sets,
)
@dataclass
class InferenceChunk(BaseChunk):
document_id: str

View File

@ -11,7 +11,8 @@ SEMANTIC_IDENTIFIER = "semantic_identifier"
SECTION_CONTINUATION = "section_continuation"
EMBEDDINGS = "embeddings"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"
ACCESS_CONTROL_LIST = "access_control_list"
DOCUMENT_SETS = "document_sets"
METADATA = "metadata"
MATCH_HIGHLIGHTS = "match_highlights"
# stored in the `metadata` of a chunk. Used to signify that this chunk should
@ -20,6 +21,7 @@ MATCH_HIGHLIGHTS = "match_highlights"
IGNORE_FOR_QA = "ignore_for_qa"
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
PUBLIC_DOC_PAT = "PUBLIC"
PUBLIC_DOCUMENT_SET = "__PUBLIC"
QUOTE = "quote"
BOOST = "boost"
SCORE = "score"

View File

@ -1,15 +1,8 @@
import math
import uuid
from collections.abc import Callable
from copy import deepcopy
from typing import TypeVar
from pydantic import BaseModel
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.connectors.models import IndexAttemptMetadata
DEFAULT_BATCH_SIZE = 30
@ -37,89 +30,3 @@ def get_uuid_from_chunk(
[doc_str, str(chunk.chunk_id), str(mini_chunk_ind)]
)
return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
class CrossConnectorDocumentMetadata(BaseModel):
"""Represents metadata about a single document. This is needed since the
`Document` class represents a document from a single connector, but that same
document may be indexed by multiple connectors."""
allowed_users: list[str]
allowed_user_groups: list[str]
already_in_index: bool
# Takes the chunk identifier returns the existing metaddata about that chunk
CrossConnectorDocumentMetadataFetchCallable = Callable[
[str], CrossConnectorDocumentMetadata | None
]
T = TypeVar("T")
def _add_if_not_exists(obj_list: list[T], item: T) -> list[T]:
if item in obj_list:
return obj_list
return obj_list + [item]
def update_cross_connector_document_metadata_map(
chunk: IndexChunk,
cross_connector_document_metadata_map: dict[str, CrossConnectorDocumentMetadata],
doc_store_cross_connector_document_metadata_fetch_fn: CrossConnectorDocumentMetadataFetchCallable,
index_attempt_metadata: IndexAttemptMetadata,
) -> tuple[dict[str, CrossConnectorDocumentMetadata], bool]:
"""Returns an updated document_id -> CrossConnectorDocumentMetadata map and
if the document's chunks need to be wiped."""
user_str = (
PUBLIC_DOC_PAT
if index_attempt_metadata.user_id is None
else str(index_attempt_metadata.user_id)
)
cross_connector_document_metadata_map = deepcopy(
cross_connector_document_metadata_map
)
first_chunk_uuid = str(get_uuid_from_chunk(chunk))
document = chunk.source_document
if document.id not in cross_connector_document_metadata_map:
document_metadata_in_doc_store = (
doc_store_cross_connector_document_metadata_fetch_fn(first_chunk_uuid)
)
if not document_metadata_in_doc_store:
cross_connector_document_metadata_map[
document.id
] = CrossConnectorDocumentMetadata(
allowed_users=[user_str],
allowed_user_groups=[],
already_in_index=False,
)
# First chunk does not exist so document does not exist, no need for deletion
return cross_connector_document_metadata_map, False
else:
# TODO introduce groups logic here
cross_connector_document_metadata_map[
document.id
] = CrossConnectorDocumentMetadata(
allowed_users=_add_if_not_exists(
document_metadata_in_doc_store.allowed_users, user_str
),
allowed_user_groups=document_metadata_in_doc_store.allowed_user_groups,
already_in_index=True,
)
# First chunk exists, but with update, there may be less total chunks now
# Must delete rest of document chunks
return cross_connector_document_metadata_map, True
existing_document_metadata = cross_connector_document_metadata_map[document.id]
cross_connector_document_metadata_map[document.id] = CrossConnectorDocumentMetadata(
allowed_users=_add_if_not_exists(
existing_document_metadata.allowed_users, user_str
),
allowed_user_groups=existing_document_metadata.allowed_user_groups,
already_in_index=existing_document_metadata.already_in_index,
)
# If document is already in the mapping, don't delete again
return cross_connector_document_metadata_map, False

View File

@ -1,14 +1,13 @@
from typing import Type
from uuid import UUID
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import DOCUMENT_INDEX_TYPE
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentIndexType
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
@ -43,11 +42,10 @@ class SplitDocumentIndex(DocumentIndex):
def index(
self,
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
chunks: list[DocMetadataAwareIndexChunk],
) -> set[DocumentInsertionRecord]:
keyword_index_result = self.keyword_index.index(chunks, index_attempt_metadata)
vector_index_result = self.vector_index.index(chunks, index_attempt_metadata)
keyword_index_result = self.keyword_index.index(chunks)
vector_index_result = self.vector_index.index(chunks)
if keyword_index_result != vector_index_result:
logger.error(
f"Inconsistent document indexing:\n"

View File

@ -4,16 +4,17 @@ from typing import Protocol
from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.chunking.chunk import Chunker
from danswer.chunking.chunk import DefaultChunker
from danswer.chunking.models import DocAwareChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.connectors.models import Document
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.document import prepare_to_modify_documents
from danswer.db.document import upsert_documents_complete
from danswer.db.engine import get_sqlalchemy_engine
from danswer.search.models import Embedder
@ -30,40 +31,25 @@ class IndexingPipelineProtocol(Protocol):
...
def _upsert_insertion_records(
insertion_records: set[DocumentInsertionRecord],
def _upsert_documents(
document_ids: list[str],
index_attempt_metadata: IndexAttemptMetadata,
doc_m_data_lookup: dict[str, tuple[str, str]],
db_session: Session,
) -> None:
with Session(get_sqlalchemy_engine()) as session:
upsert_documents_complete(
db_session=session,
document_metadata_batch=[
DocumentMetadata(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_id=i_r.document_id,
semantic_identifier=doc_m_data_lookup[i_r.document_id][0],
first_link=doc_m_data_lookup[i_r.document_id][1],
)
for i_r in insertion_records
],
)
def _get_net_new_documents(
insertion_records: list[DocumentInsertionRecord],
) -> int:
net_new_documents = 0
seen_documents: set[str] = set()
for insertion_record in insertion_records:
if insertion_record.already_existed:
continue
if insertion_record.document_id not in seen_documents:
net_new_documents += 1
seen_documents.add(insertion_record.document_id)
return net_new_documents
upsert_documents_complete(
db_session=db_session,
document_metadata_batch=[
DocumentMetadata(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_id=document_id,
semantic_identifier=doc_m_data_lookup[document_id][0],
first_link=doc_m_data_lookup[document_id][1],
)
for document_id in document_ids
],
)
def _extract_minimal_document_metadata(doc: Document) -> tuple[str, str]:
@ -82,60 +68,52 @@ def _indexing_pipeline(
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
document_ids = [document.id for document in documents]
document_metadata_lookup = {
doc.id: _extract_minimal_document_metadata(doc) for doc in documents
}
chunks: list[DocAwareChunk] = list(
chain(*[chunker.chunk(document=document) for document in documents])
)
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
)
chunks_with_embeddings = embedder.embed(chunks=chunks)
with Session(get_sqlalchemy_engine()) as db_session:
# acquires a lock on the documents so that no other process can modify them
prepare_to_modify_documents(db_session=db_session, document_ids=document_ids)
# if there are any empty chunks, remove them. This usually happens due to a
# bug in a connector. Handling here to prevent a bad connector from
# breaking retrieval completely.
final_chunks_with_embeddings: list[IndexChunk] = []
for chunk in chunks_with_embeddings:
if chunk.content:
final_chunks_with_embeddings.append(chunk)
else:
bad_chunk_link = (
chunk.source_document.sections[0].link
if chunk.source_document.sections
else ""
)
logger.error(
f"Found empty chunk, skipping. Chunk ID: '{chunk.chunk_id}', "
f"Document ID: '{chunk.source_document.id}', "
f"Document Link: '{bad_chunk_link}'"
)
# A document will not be spread across different batches, so all the documents with chunks in this set, are fully
# represented by the chunks in this set
insertion_records = document_index.index(
chunks=final_chunks_with_embeddings,
index_attempt_metadata=index_attempt_metadata,
)
# TODO (chris): remove this try/except after issue with null document_id is resolved
try:
_upsert_insertion_records(
insertion_records=insertion_records,
# create records in the source of truth about these documents
_upsert_documents(
document_ids=document_ids,
index_attempt_metadata=index_attempt_metadata,
doc_m_data_lookup=document_metadata_lookup,
db_session=db_session,
)
except Exception as e:
logger.error(
f"Failed to upsert insertion records from vector index for documents: "
f"{[document.to_short_descriptor() for document in documents]}, "
f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks_with_embeddings]}"
f"for insertion records: {insertion_records}"
chunks: list[DocAwareChunk] = list(
chain(*[chunker.chunk(document=document) for document in documents])
)
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
)
chunks_with_embeddings = embedder.embed(chunks=chunks)
# Attach the latest status from Postgres (source of truth for access) to each
# chunk. This access status will be attached to each chunk in the document index
# TODO: attach document sets to the chunk based on the status of Postgres as well
document_id_to_access_info = get_access_for_documents(
document_ids=document_ids, db_session=db_session
)
access_aware_chunks = [
DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=document_id_to_access_info[chunk.source_document.id],
document_sets=set(),
)
for chunk in chunks_with_embeddings
]
# A document will not be spread across different batches, so all the
# documents with chunks in this set, are fully represented by the chunks
# in this set
insertion_records = document_index.index(
chunks=access_aware_chunks,
)
raise e
return len([r for r in insertion_records if r.already_existed is False]), len(
chunks

View File

@ -3,10 +3,10 @@ from dataclasses import dataclass
from typing import Any
from uuid import UUID
from danswer.chunking.models import IndexChunk
from danswer.access.models import DocumentAccess
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
IndexFilter = dict[str, str | list[str] | None]
@ -33,7 +33,8 @@ class UpdateRequest:
document_ids: list[str]
# all other fields will be left alone
allowed_users: list[str] | None = None
access: DocumentAccess | None = None
document_sets: set[str] | None = None
boost: float | None = None
@ -51,7 +52,7 @@ class Verifiable(abc.ABC):
class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
self, chunks: list[DocMetadataAwareIndexChunk]
) -> set[DocumentInsertionRecord]:
"""Indexes document chunks into the Document Index and return the IDs of all the documents indexed"""
raise NotImplementedError

View File

@ -1,6 +1,4 @@
import json
from functools import partial
from typing import cast
from qdrant_client import QdrantClient
from qdrant_client.http import models
@ -8,8 +6,7 @@ from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.models.models import UpdateResult
from qdrant_client.models import PointStruct
from danswer.chunking.models import IndexChunk
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import BLURB
from danswer.configs.constants import CHUNK_ID
@ -20,15 +17,9 @@ from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import (
update_cross_connector_document_metadata_map,
)
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.qdrant.utils import get_payload_from_record
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logger import setup_logger
@ -36,40 +27,18 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_qdrant_document_cross_connector_metadata(
def _does_document_exist(
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
) -> CrossConnectorDocumentMetadata | None:
) -> bool:
"""Get whether a document is found and the existing whitelists"""
results = q_client.retrieve(
collection_name=collection_name,
ids=[doc_chunk_id],
with_payload=[ALLOWED_USERS, ALLOWED_GROUPS],
)
if len(results) == 0:
return None
payload = get_payload_from_record(results[0])
allowed_users = cast(list[str] | None, payload.get(ALLOWED_USERS))
allowed_groups = cast(list[str] | None, payload.get(ALLOWED_GROUPS))
if allowed_users is None:
allowed_users = []
logger.error(
"Qdrant Index is corrupted, Document found with no user access lists."
f"Assuming no users have access to chunk with ID '{doc_chunk_id}'."
)
if allowed_groups is None:
allowed_groups = []
logger.error(
"Qdrant Index is corrupted, Document found with no groups access lists."
f"Assuming no groups have access to chunk with ID '{doc_chunk_id}'."
)
return False
return CrossConnectorDocumentMetadata(
# if either `allowed_users` or `allowed_groups` are missing from the
# point, then assume that the document has no allowed users.
allowed_users=allowed_users,
allowed_user_groups=allowed_groups,
already_in_index=True,
)
return True
def delete_qdrant_doc_chunks(
@ -92,8 +61,7 @@ def delete_qdrant_doc_chunks(
def index_qdrant_chunks(
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
chunks: list[DocMetadataAwareIndexChunk],
collection: str,
client: QdrantClient | None = None,
batch_upsert: bool = True,
@ -104,29 +72,15 @@ def index_qdrant_chunks(
point_structs: list[PointStruct] = []
insertion_records: set[DocumentInsertionRecord] = set()
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
cross_connector_document_metadata_map: dict[
str, CrossConnectorDocumentMetadata
] = {}
# document ids of documents that existed BEFORE this indexing
already_existing_documents: set[str] = set()
for chunk in chunks:
document = chunk.source_document
(
cross_connector_document_metadata_map,
should_delete_doc,
) = update_cross_connector_document_metadata_map(
chunk=chunk,
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
doc_store_cross_connector_document_metadata_fetch_fn=partial(
get_qdrant_document_cross_connector_metadata,
collection_name=collection,
q_client=q_client,
),
index_attempt_metadata=index_attempt_metadata,
)
if should_delete_doc:
# Delete all chunks related to the document if (1) it already exists and
# (2) this is our first time running into it during this indexing attempt
document_exists = _does_document_exist(document.id, collection, q_client)
if document_exists and document.id not in already_existing_documents:
# Processing the first chunk of the doc and the doc exists
delete_qdrant_doc_chunks(document.id, collection, q_client)
already_existing_documents.add(document.id)
@ -155,12 +109,7 @@ def index_qdrant_chunks(
SOURCE_LINKS: chunk.source_links,
SEMANTIC_IDENTIFIER: document.semantic_identifier,
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: cross_connector_document_metadata_map[
document.id
].allowed_users,
ALLOWED_GROUPS: cross_connector_document_metadata_map[
document.id
].allowed_user_groups,
ALLOWED_USERS: json.dumps(chunk.access.to_acl()),
METADATA: json.dumps(document.metadata),
},
vector=embedding,

View File

@ -8,7 +8,7 @@ from qdrant_client.http.models import Filter
from qdrant_client.http.models import MatchAny
from qdrant_client.http.models import MatchValue
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import NUM_RETURNED_HITS
@ -16,7 +16,6 @@ from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
@ -126,12 +125,10 @@ class QdrantIndex(VectorIndex):
def index(
self,
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
chunks: list[DocMetadataAwareIndexChunk],
) -> set[DocumentInsertionRecord]:
return index_qdrant_chunks(
chunks=chunks,
index_attempt_metadata=index_attempt_metadata,
collection=self.collection,
client=self.client,
)
@ -145,12 +142,15 @@ class QdrantIndex(VectorIndex):
items=update_request.document_ids,
batch_size=_BATCH_SIZE,
):
if update_request.access is None:
continue
chunk_ids = _get_points_from_document_ids(
doc_id_batch, self.collection, self.client
)
self.client.set_payload(
collection_name=self.collection,
payload={ALLOWED_USERS: update_request.allowed_users},
payload={ALLOWED_USERS: update_request.access.to_acl()},
points=chunk_ids,
)

View File

@ -1,17 +1,14 @@
import json
from functools import partial
from typing import Any
from typing import cast
from uuid import UUID
import typesense # type: ignore
from typesense.exceptions import ObjectNotFound # type: ignore
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import BLURB
from danswer.configs.constants import CHUNK_ID
@ -23,13 +20,8 @@ from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import (
update_cross_connector_document_metadata_map,
)
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import KeywordIndex
@ -63,7 +55,6 @@ def create_typesense_collection(
{"name": SEMANTIC_IDENTIFIER, "type": "string"},
{"name": SECTION_CONTINUATION, "type": "bool"},
{"name": ALLOWED_USERS, "type": "string[]"},
{"name": ALLOWED_GROUPS, "type": "string[]"},
{"name": METADATA, "type": "string"},
],
}
@ -81,38 +72,16 @@ def _check_typesense_collection_exist(
return True
def _get_typesense_document_cross_connector_metadata(
def _does_document_exist(
doc_chunk_id: str, collection_name: str, ts_client: typesense.Client
) -> CrossConnectorDocumentMetadata | None:
) -> bool:
"""Returns whether the document already exists and the users/group whitelists"""
try:
document = cast(
dict[str, Any],
ts_client.collections[collection_name].documents[doc_chunk_id].retrieve(),
)
ts_client.collections[collection_name].documents[doc_chunk_id].retrieve()
except ObjectNotFound:
return None
return False
allowed_users = cast(list[str] | None, document.get(ALLOWED_USERS))
allowed_groups = cast(list[str] | None, document.get(ALLOWED_GROUPS))
if allowed_users is None:
allowed_users = []
logger.error(
"Typesense Index is corrupted, Document found with no user access lists."
f"Assuming no users have access to chunk with ID '{doc_chunk_id}'."
)
if allowed_groups is None:
allowed_groups = []
logger.error(
"Typesense Index is corrupted, Document found with no groups access lists."
f"Assuming no groups have access to chunk with ID '{doc_chunk_id}'."
)
return CrossConnectorDocumentMetadata(
allowed_users=allowed_users,
allowed_user_groups=allowed_groups,
already_in_index=True,
)
return True
def _delete_typesense_doc_chunks(
@ -127,8 +96,7 @@ def _delete_typesense_doc_chunks(
def _index_typesense_chunks(
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
chunks: list[DocMetadataAwareIndexChunk],
collection: str,
client: typesense.Client | None = None,
batch_upsert: bool = True,
@ -137,33 +105,20 @@ def _index_typesense_chunks(
insertion_records: set[DocumentInsertionRecord] = set()
new_documents: list[dict[str, Any]] = []
cross_connector_document_metadata_map: dict[
str, CrossConnectorDocumentMetadata
] = {}
# document ids of documents that existed BEFORE this indexing
already_existing_documents: set[str] = set()
for chunk in chunks:
document = chunk.source_document
(
cross_connector_document_metadata_map,
should_delete_doc,
) = update_cross_connector_document_metadata_map(
chunk=chunk,
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
doc_store_cross_connector_document_metadata_fetch_fn=partial(
_get_typesense_document_cross_connector_metadata,
collection_name=collection,
ts_client=ts_client,
),
index_attempt_metadata=index_attempt_metadata,
)
typesense_id = str(get_uuid_from_chunk(chunk))
if should_delete_doc:
# Delete all chunks related to the document if (1) it already exists and
# (2) this is our first time running into it during this indexing attempt
document_exists = _does_document_exist(typesense_id, collection, ts_client)
if document_exists and document.id not in already_existing_documents:
# Processing the first chunk of the doc and the doc exists
_delete_typesense_doc_chunks(document.id, collection, ts_client)
already_existing_documents.add(document.id)
typesense_id = str(get_uuid_from_chunk(chunk))
insertion_records.add(
DocumentInsertionRecord(
document_id=document.id,
@ -181,12 +136,7 @@ def _index_typesense_chunks(
SOURCE_LINKS: json.dumps(chunk.source_links),
SEMANTIC_IDENTIFIER: document.semantic_identifier,
SECTION_CONTINUATION: chunk.section_continuation,
ALLOWED_USERS: cross_connector_document_metadata_map[
document.id
].allowed_users,
ALLOWED_GROUPS: cross_connector_document_metadata_map[
document.id
].allowed_user_groups,
ALLOWED_USERS: json.dumps(chunk.access.to_acl()),
METADATA: json.dumps(document.metadata),
}
)
@ -260,11 +210,10 @@ class TypesenseIndex(KeywordIndex):
create_typesense_collection(collection_name=self.collection)
def index(
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
self, chunks: list[DocMetadataAwareIndexChunk]
) -> set[DocumentInsertionRecord]:
return _index_typesense_chunks(
chunks=chunks,
index_attempt_metadata=index_attempt_metadata,
collection=self.collection,
client=self.ts_client,
)
@ -277,8 +226,14 @@ class TypesenseIndex(KeywordIndex):
for id_batch in batch_generator(
items=update_request.document_ids, batch_size=_BATCH_SIZE
):
if update_request.access is None:
continue
typesense_updates = [
{DOCUMENT_ID: doc_id, ALLOWED_USERS: update_request.allowed_users}
{
DOCUMENT_ID: doc_id,
ALLOWED_USERS: update_request.access.to_acl(),
}
for doc_id in id_batch
]
self.ts_client.collections[self.collection].documents.import_(

View File

@ -56,11 +56,11 @@ schema danswer_chunk {
distance-metric: angular
}
}
field allowed_users type array<string> {
field access_control_list type weightedset<string> {
indexing: summary | attribute
attribute: fast-search
}
field allowed_groups type array<string> {
field document_sets type weightedset<string> {
indexing: summary | attribute
attribute: fast-search
}

View File

@ -9,7 +9,7 @@ import requests
from requests import HTTPError
from requests import Response
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import DocMetadataAwareIndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import NUM_RETURNED_HITS
@ -17,14 +17,14 @@ from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
from danswer.configs.app_configs import VESPA_HOST
from danswer.configs.app_configs import VESPA_PORT
from danswer.configs.app_configs import VESPA_TENANT_PORT
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import ACCESS_CONTROL_LIST
from danswer.configs.constants import BLURB
from danswer.configs.constants import BOOST
from danswer.configs.constants import CHUNK_ID
from danswer.configs.constants import CONTENT
from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import DOCUMENT_SETS
from danswer.configs.constants import EMBEDDINGS
from danswer.configs.constants import MATCH_HIGHLIGHTS
from danswer.configs.constants import METADATA
@ -35,12 +35,7 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import (
update_cross_connector_document_metadata_map,
)
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
@ -65,33 +60,20 @@ _BATCH_SIZE = 100 # Specific to Vespa
CONTENT_SUMMARY = "content_summary"
def _get_vespa_document_cross_connector_metadata(
def _does_document_exist(
doc_chunk_id: str,
) -> CrossConnectorDocumentMetadata | None:
) -> bool:
"""Returns whether the document already exists and the users/group whitelists"""
doc_fetch_response = requests.get(f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}")
if doc_fetch_response.status_code == 404:
return None
return False
if doc_fetch_response.status_code != 200:
raise RuntimeError(
f"Unexpected fetch document by ID value from Vespa "
f"with error {doc_fetch_response.status_code}"
)
doc_fields = doc_fetch_response.json()["fields"]
allowed_users = doc_fields.get(ALLOWED_USERS)
allowed_groups = doc_fields.get(ALLOWED_GROUPS)
# Add group permission later, empty list gets saved/loaded as a null
if allowed_users is None:
raise RuntimeError(
"Vespa Index is corrupted, Document found with no access lists."
)
return CrossConnectorDocumentMetadata(
allowed_users=allowed_users or [],
allowed_user_groups=allowed_groups or [],
already_in_index=True,
)
return True
def _get_vespa_chunk_ids_by_document_id(
@ -127,32 +109,23 @@ def _delete_vespa_doc_chunks(document_id: str) -> bool:
def _index_vespa_chunks(
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
chunks: list[DocMetadataAwareIndexChunk],
) -> set[DocumentInsertionRecord]:
json_header = {
"Content-Type": "application/json",
}
insertion_records: set[DocumentInsertionRecord] = set()
cross_connector_document_metadata_map: dict[
str, CrossConnectorDocumentMetadata
] = {}
# document ids of documents that existed BEFORE this indexing
already_existing_documents: set[str] = set()
for chunk in chunks:
document = chunk.source_document
(
cross_connector_document_metadata_map,
should_delete_doc,
) = update_cross_connector_document_metadata_map(
chunk=chunk,
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
doc_store_cross_connector_document_metadata_fetch_fn=_get_vespa_document_cross_connector_metadata,
index_attempt_metadata=index_attempt_metadata,
)
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
if should_delete_doc:
# Processing the first chunk of the doc and the doc exists
# Delete all chunks related to the document if (1) it already exists and
# (2) this is our first time running into it during this indexing attempt
chunk_exists = _does_document_exist(vespa_chunk_id)
if chunk_exists and document.id not in already_existing_documents:
deletion_success = _delete_vespa_doc_chunks(document.id)
if not deletion_success:
raise RuntimeError(
@ -160,9 +133,6 @@ def _index_vespa_chunks(
)
already_existing_documents.add(document.id)
# No minichunk documents in vespa, minichunk vectors are stored in the chunk itself
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
embeddings = chunk.embeddings
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
if embeddings.mini_chunk_embeddings:
@ -183,12 +153,10 @@ def _index_vespa_chunks(
METADATA: json.dumps(document.metadata),
EMBEDDINGS: embeddings_name_vector_map,
BOOST: DEFAULT_BOOST,
ALLOWED_USERS: cross_connector_document_metadata_map[
document.id
].allowed_users,
ALLOWED_GROUPS: cross_connector_document_metadata_map[
document.id
].allowed_user_groups,
# the only `set` vespa has is `weightedset`, so we have to give each
# element an arbitrary weight
ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()},
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
}
def _index_chunk(
@ -244,16 +212,13 @@ def _index_vespa_chunks(
def _build_vespa_filters(
user_id: UUID | None, filters: list[IndexFilter] | None
) -> str:
filter_str = ""
# Permissions filter
# TODO group permissioning
# Permissions filters
acl_filter_stmts = [f'{ACCESS_CONTROL_LIST} contains "{PUBLIC_DOC_PAT}"']
if user_id:
filter_str += (
f'({ALLOWED_USERS} contains "{user_id}" or '
f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}") and '
)
else:
filter_str += f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}" and '
acl_filter_stmts.append(f'{ACCESS_CONTROL_LIST} contains "{user_id}"')
filter_str = "(" + " or ".join(acl_filter_stmts) + ") and"
# TODO: have document sets passed in + add document set based filters
# Provided query filters
if filters:
@ -399,31 +364,38 @@ class VespaIndex(DocumentIndex):
def index(
self,
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
chunks: list[DocMetadataAwareIndexChunk],
) -> set[DocumentInsertionRecord]:
return _index_vespa_chunks(
chunks=chunks, index_attempt_metadata=index_attempt_metadata
)
return _index_vespa_chunks(chunks=chunks)
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(
f"Updating {len(update_requests)} documents' allowed_users in Vespa"
)
logger.info(f"Updating {len(update_requests)} documents in Vespa")
json_header = {"Content-Type": "application/json"}
for update_request in update_requests:
if update_request.boost is None and update_request.allowed_users is None:
if (
update_request.boost is None
and update_request.access is None
and update_request.document_sets is None
):
logger.error("Update request received but nothing to update")
continue
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost is not None:
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.allowed_users is not None:
update_dict["fields"][ALLOWED_USERS] = {
"assign": update_request.allowed_users
if update_request.document_sets is not None:
update_dict["fields"][DOCUMENT_SETS] = {
"assign": {
document_set: 1 for document_set in update_request.document_sets
}
}
if update_request.access is not None:
update_dict["fields"][ACCESS_CONTROL_LIST] = {
"assign": {
acl_entry: 1 for acl_entry in update_request.access.to_acl()
}
}
for document_id in update_request.document_ids:

View File

@ -1,4 +1,6 @@
import time
from collections.abc import Sequence
from uuid import UUID
from sqlalchemy import and_
from sqlalchemy import delete
@ -10,18 +12,18 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import DEFAULT_BOOST
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.feedback import delete_document_feedback_for_documents
from danswer.db.models import Credential
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.utils import model_to_dict
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_documents_with_single_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
@ -29,54 +31,63 @@ def get_documents_with_single_connector_credential_pair(
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
# Filter it down to the documents with only a single connector/credential pair
# Meaning if this connector/credential pair is removed, this doc should be gone
trimmed_doc_ids_stmt = (
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == DbDocument.id,
)
.where(DbDocument.id.in_(initial_doc_ids_stmt))
.group_by(DbDocument.id)
.having(func.count(DocumentByConnectorCredentialPair.id) == 1)
)
stmt = select(DbDocument).where(DbDocument.id.in_(trimmed_doc_ids_stmt))
stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct()
if limit:
stmt = stmt.limit(limit)
return db_session.scalars(stmt).all()
def get_document_by_connector_credential_pairs_indexed_by_multiple(
def get_document_connector_cnts(
db_session: Session,
connector_id: int,
credential_id: int,
) -> Sequence[DocumentByConnectorCredentialPair]:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
document_ids: list[str],
) -> Sequence[tuple[str, int]]:
stmt = (
select(
DocumentByConnectorCredentialPair.id,
func.count(),
)
.where(DocumentByConnectorCredentialPair.id.in_(document_ids))
.group_by(DocumentByConnectorCredentialPair.id)
)
return db_session.execute(stmt).all() # type: ignore
# Filter it down to the documents with more than 1 connector/credential pair
# Meaning if this connector/credential pair is removed, this doc is still accessible
trimmed_doc_ids_stmt = (
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == DbDocument.id,
def get_acccess_info_for_documents(
db_session: Session,
document_ids: list[str],
cc_pair_to_delete: ConnectorCredentialPairIdentifier | None = None,
) -> Sequence[tuple[str, list[UUID | None], bool]]:
"""Gets back all relevant access info for the given documents. This includes
the user_ids for cc pairs that the document is associated with + whether any
of the associated cc pairs are intending to make the document globally public.
If `cc_pair_to_delete` is specified, gets the above access info as if that
pair had been deleted. This is needed since we want to delete from the Vespa
before deleting from Postgres to ensure that the state of Postgres never "loses"
documents that still exist in Vespa.
"""
stmt = select(
DocumentByConnectorCredentialPair.id,
func.array_agg(Credential.user_id).label("user_ids"),
func.bool_or(Credential.public_doc).label("public_doc"),
).where(DocumentByConnectorCredentialPair.id.in_(document_ids))
# pretend that the specified cc pair doesn't exist
if cc_pair_to_delete:
stmt = stmt.where(
and_(
DocumentByConnectorCredentialPair.connector_id
!= cc_pair_to_delete.connector_id,
DocumentByConnectorCredentialPair.credential_id
!= cc_pair_to_delete.credential_id,
)
)
.where(DbDocument.id.in_(initial_doc_ids_stmt))
.group_by(DbDocument.id)
.having(func.count(DocumentByConnectorCredentialPair.id) > 1)
)
stmt = select(DocumentByConnectorCredentialPair).where(
DocumentByConnectorCredentialPair.id.in_(trimmed_doc_ids_stmt)
)
return db_session.execute(stmt).scalars().all()
stmt = stmt.join(
Credential,
DocumentByConnectorCredentialPair.credential_id == Credential.id,
).group_by(DocumentByConnectorCredentialPair.id)
return db_session.execute(stmt).all() # type: ignore
def upsert_documents(
@ -153,26 +164,24 @@ def upsert_documents_complete(
def delete_document_by_connector_credential_pair(
db_session: Session, document_ids: list[str]
db_session: Session,
document_ids: list[str],
connector_credential_pair_identifier: ConnectorCredentialPairIdentifier
| None = None,
) -> None:
db_session.execute(
delete(DocumentByConnectorCredentialPair).where(
DocumentByConnectorCredentialPair.id.in_(document_ids)
)
stmt = delete(DocumentByConnectorCredentialPair).where(
DocumentByConnectorCredentialPair.id.in_(document_ids)
)
def delete_document_by_connector_credential_pair_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int
) -> None:
db_session.execute(
delete(DocumentByConnectorCredentialPair).where(
if connector_credential_pair_identifier:
stmt = stmt.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
DocumentByConnectorCredentialPair.connector_id
== connector_credential_pair_identifier.connector_id,
DocumentByConnectorCredentialPair.credential_id
== connector_credential_pair_identifier.credential_id,
)
)
)
db_session.execute(stmt)
def delete_documents(db_session: Session, document_ids: list[str]) -> None:
@ -187,3 +196,48 @@ def delete_documents_complete(db_session: Session, document_ids: list[str]) -> N
)
delete_documents(db_session, document_ids)
db_session.commit()
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
"""Acquire locks for the specified documents. Ideally this shouldn't be
called with large list of document_ids (an exception could be made if the
length of holding the lock is very short).
Will simply raise an exception if any of the documents are already locked.
This prevents deadlocks (assuming that the caller passes in all required
document IDs in a single call).
"""
stmt = (
select(DbDocument)
.where(DbDocument.id.in_(document_ids))
.with_for_update(nowait=True)
)
# will raise exception if any of the documents are already locked
db_session.execute(stmt)
return True
_NUM_LOCK_ATTEMPTS = 10
_LOCK_RETRY_DELAY = 30
def prepare_to_modify_documents(db_session: Session, document_ids: list[str]) -> None:
"""Try and acquire locks for the documents to prevent other jobs from
modifying them at the same time (e.g. avoid race conditions). This should be
called ahead of any modification to Vespa. Locks should be released by the
caller as soon as updates are complete by finishing the transaction."""
lock_acquired = False
for _ in range(_NUM_LOCK_ATTEMPTS):
try:
lock_acquired = acquire_document_locks(
db_session=db_session, document_ids=document_ids
)
except Exception as e:
logger.info(f"Failed to acquire locks for documents, retrying. Error: {e}")
time.sleep(_LOCK_RETRY_DELAY)
if not lock_acquired:
raise RuntimeError(
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts "
f"for documents: {document_ids}"
)

View File

@ -0,0 +1,44 @@
"""Script which updates Vespa to align with the access described in Postgres.
Should be run wehn a user who has docs already indexed switches over to the new
access control system. This allows them to not have to re-index all documents."""
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.access.models import DocumentAccess
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import UpdateRequest
from danswer.datastores.vespa.store import VespaIndex
from danswer.db.document import get_acccess_info_for_documents
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import Document
from danswer.utils.logger import setup_logger
logger = setup_logger()
def _migrate_vespa_to_acl() -> None:
vespa_index = get_default_document_index()
if not isinstance(vespa_index, VespaIndex):
raise ValueError("This script is only for Vespa indexes")
with Session(get_sqlalchemy_engine()) as db_session:
# for all documents, set the `access_control_list` field apporpriately
# based on the state of Postgres
documents = db_session.scalars(select(Document)).all()
document_access_info = get_acccess_info_for_documents(
db_session=db_session,
document_ids=[document.id for document in documents],
)
vespa_index.update(
update_requests=[
UpdateRequest(
document_ids=[document_id],
access=DocumentAccess.build(user_ids, is_public),
)
for document_id, user_ids, is_public in document_access_info
],
)
if __name__ == "__main__":
_migrate_vespa_to_acl()