From 8159fdcdce3c5c8160aa1b41d1000409be3ae3e7 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Thu, 24 Aug 2023 08:46:28 -0700 Subject: [PATCH] Add Vespa and rework Document Indices (#317) --- backend/.gitignore | 3 +- ...abb57f3b49_restructure_document_indices.py | 39 ++ .../danswer/background/connector_deletion.py | 57 +-- backend/danswer/chunking/chunk.py | 22 +- backend/danswer/chunking/models.py | 15 +- backend/danswer/configs/app_configs.py | 22 +- backend/danswer/configs/constants.py | 7 + .../connectors/google_drive/connector.py | 2 +- backend/danswer/datastores/datastore_utils.py | 5 +- backend/danswer/datastores/document_index.py | 111 +++++ .../danswer/datastores/indexing_pipeline.py | 102 ++--- backend/danswer/datastores/interfaces.py | 94 ++-- backend/danswer/datastores/qdrant/indexing.py | 42 +- backend/danswer/datastores/qdrant/store.py | 129 +++--- backend/danswer/datastores/qdrant/utils.py | 33 ++ backend/danswer/datastores/typesense/store.py | 89 ++-- backend/danswer/datastores/vespa/__init__.py | 0 .../vespa/app_config/schemas/danswer_chunk.sd | 89 ++++ .../datastores/vespa/app_config/services.xml | 19 + backend/danswer/datastores/vespa/store.py | 402 ++++++++++++++++++ backend/danswer/db/document.py | 57 +-- backend/danswer/db/models.py | 32 +- backend/danswer/direct_qa/answer_question.py | 7 +- backend/danswer/listeners/slack_listener.py | 4 +- backend/danswer/main.py | 22 +- backend/danswer/search/keyword_search.py | 6 +- backend/danswer/search/models.py | 4 +- backend/danswer/search/semantic_search.py | 22 +- backend/danswer/server/search_backend.py | 19 +- .../utils.py => utils/batching.py} | 0 backend/scripts/list_typesense_docs.py | 6 +- backend/scripts/save_load_state.py | 27 +- backend/scripts/simulate_frontend.py | 4 +- 33 files changed, 1059 insertions(+), 433 deletions(-) create mode 100644 backend/alembic/versions/8aabb57f3b49_restructure_document_indices.py create mode 100644 backend/danswer/datastores/document_index.py create mode 100644 backend/danswer/datastores/vespa/__init__.py create mode 100644 backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd create mode 100644 backend/danswer/datastores/vespa/app_config/services.xml create mode 100644 backend/danswer/datastores/vespa/store.py rename backend/danswer/{connectors/utils.py => utils/batching.py} (100%) diff --git a/backend/.gitignore b/backend/.gitignore index 8fcbadd18a9..b75d7f4d758 100644 --- a/backend/.gitignore +++ b/backend/.gitignore @@ -6,4 +6,5 @@ api_keys.py *ipynb qdrant-data/ typesense-data/ -.env \ No newline at end of file +.env +vespa-app.zip diff --git a/backend/alembic/versions/8aabb57f3b49_restructure_document_indices.py b/backend/alembic/versions/8aabb57f3b49_restructure_document_indices.py new file mode 100644 index 00000000000..3ff454ba7e9 --- /dev/null +++ b/backend/alembic/versions/8aabb57f3b49_restructure_document_indices.py @@ -0,0 +1,39 @@ +"""Restructure Document Indices + +Revision ID: 8aabb57f3b49 +Revises: 5e84129c8be3 +Create Date: 2023-08-18 21:15:57.629515 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "8aabb57f3b49" +down_revision = "5e84129c8be3" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.drop_table("chunk") + op.execute("DROP TYPE IF EXISTS documentstoretype") + + +def downgrade() -> None: + op.create_table( + "chunk", + sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.Column( + "document_store_type", + postgresql.ENUM("VECTOR", "KEYWORD", name="documentstoretype"), + autoincrement=False, + nullable=False, + ), + sa.Column("document_id", sa.VARCHAR(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint( + ["document_id"], ["document.id"], name="chunk_document_id_fkey" + ), + sa.PrimaryKeyConstraint("id", "document_store_type", name="chunk_pkey"), + ) diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index d39d6ba3ab0..f3b27c2c191 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -17,12 +17,9 @@ from datetime import datetime from sqlalchemy.orm import Session from danswer.configs.constants import PUBLIC_DOC_PAT -from danswer.datastores.interfaces import KeywordIndex -from danswer.datastores.interfaces import StoreType +from danswer.datastores.document_index import get_default_document_index +from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import UpdateRequest -from danswer.datastores.interfaces import VectorIndex -from danswer.datastores.qdrant.store import QdrantIndex -from danswer.datastores.typesense.store import TypesenseIndex from danswer.db.connector import fetch_connector_by_id from danswer.db.connector_credential_pair import delete_connector_credential_pair from danswer.db.connector_credential_pair import get_connector_credential_pair @@ -31,13 +28,12 @@ from danswer.db.deletion_attempt import delete_deletion_attempts from danswer.db.deletion_attempt import get_deletion_attempts from danswer.db.document import delete_document_by_connector_credential_pair from danswer.db.document import delete_documents_complete -from danswer.db.document import get_chunk_ids_for_document_ids -from danswer.db.document import ( - get_chunks_with_single_connector_credential_pair, -) from danswer.db.document import ( get_document_by_connector_credential_pairs_indexed_by_multiple, ) +from danswer.db.document import ( + get_documents_with_single_connector_credential_pair, +) from danswer.db.engine import get_sqlalchemy_engine from danswer.db.index_attempt import delete_index_attempts from danswer.db.models import Credential @@ -50,8 +46,7 @@ logger = setup_logger() def _delete_connector_credential_pair( db_session: Session, - vector_index: VectorIndex, - keyword_index: KeywordIndex, + document_index: DocumentIndex, deletion_attempt: DeletionAttempt, ) -> int: connector_id = deletion_attempt.connector_id @@ -59,33 +54,24 @@ def _delete_connector_credential_pair( def _delete_singly_indexed_docs() -> int: # if a document store entry is only indexed by this connector_credential_pair, delete it - num_docs_deleted = 0 - chunks_to_delete = get_chunks_with_single_connector_credential_pair( + docs_to_delete = get_documents_with_single_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, ) - if chunks_to_delete: - document_ids: set[str] = set() - vector_chunk_ids_to_delete: list[str] = [] - keyword_chunk_ids_to_delete: list[str] = [] - for chunk in chunks_to_delete: - document_ids.add(chunk.document_id) - if chunk.document_store_type == StoreType.KEYWORD: - keyword_chunk_ids_to_delete.append(chunk.id) - else: - vector_chunk_ids_to_delete.append(chunk.id) - vector_index.delete(ids=vector_chunk_ids_to_delete) - keyword_index.delete(ids=keyword_chunk_ids_to_delete) - # removes all `Chunk`, `DocumentByConnectorCredentialPair`, and `Document` + if docs_to_delete: + document_ids = [doc.id for doc in docs_to_delete] + document_index.delete(doc_ids=document_ids) + + # removes all `DocumentByConnectorCredentialPair`, and `Document` # rows from the DB delete_documents_complete( db_session=db_session, document_ids=list(document_ids), ) - num_docs_deleted += len(document_ids) - return num_docs_deleted + + return len(docs_to_delete) num_docs_deleted = _delete_singly_indexed_docs() logger.info(f"Deleted {num_docs_deleted} documents from document stores") @@ -144,18 +130,10 @@ def _delete_connector_credential_pair( # actually perform the updates in the document store update_requests = [ - UpdateRequest( - ids=list( - get_chunk_ids_for_document_ids( - db_session=db_session, document_ids=document_ids - ) - ), - allowed_users=list(allowed_users), - ) + UpdateRequest(document_ids=document_ids, allowed_users=list(allowed_users)) for allowed_users, document_ids in update_groups.items() ] - vector_index.update(update_requests=update_requests) - keyword_index.update(update_requests=update_requests) + document_index.update(update_requests=update_requests) # delete the `document_by_connector_credential_pair` rows for the connector / credential pair delete_document_by_connector_credential_pair( @@ -246,8 +224,7 @@ def _run_deletion(db_session: Session) -> None: try: num_docs_deleted = _delete_connector_credential_pair( db_session=db_session, - vector_index=QdrantIndex(), - keyword_index=TypesenseIndex(), + document_index=get_default_document_index(), deletion_attempt=deletion_attempt, ) except Exception as e: diff --git a/backend/danswer/chunking/chunk.py b/backend/danswer/chunking/chunk.py index faa3196b59f..d9e8fea34c6 100644 --- a/backend/danswer/chunking/chunk.py +++ b/backend/danswer/chunking/chunk.py @@ -2,7 +2,7 @@ import abc import re from collections.abc import Callable -from danswer.chunking.models import IndexChunk +from danswer.chunking.models import DocAwareChunk from danswer.configs.app_configs import BLURB_LENGTH from danswer.configs.app_configs import CHUNK_MAX_CHAR_OVERLAP from danswer.configs.app_configs import CHUNK_SIZE @@ -12,7 +12,7 @@ from danswer.connectors.models import Section from danswer.utils.text_processing import shared_precompare_cleanup SECTION_SEPARATOR = "\n\n" -ChunkFunc = Callable[[Document], list[IndexChunk]] +ChunkFunc = Callable[[Document], list[DocAwareChunk]] def extract_blurb(text: str, blurb_len: int) -> str: @@ -51,7 +51,7 @@ def chunk_large_section( word_overlap: int = CHUNK_WORD_OVERLAP, blurb_len: int = BLURB_LENGTH, chunk_overflow_max: int = CHUNK_MAX_CHAR_OVERLAP, -) -> list[IndexChunk]: +) -> list[DocAwareChunk]: """Split large sections into multiple chunks with the final chunk having as much previous overlap as possible. Backtracks word_overlap words, delimited by whitespace, backtrack up to chunk_overflow_max characters max When chunk is finished in forward direction, attempt to finish the word, but only up to chunk_overflow_max @@ -129,7 +129,7 @@ def chunk_large_section( chunks = [] for chunk_ind, chunk_str in enumerate(chunk_strs): chunks.append( - IndexChunk( + DocAwareChunk( source_document=document, chunk_id=start_chunk_id + chunk_ind, blurb=blurb, @@ -146,8 +146,8 @@ def chunk_document( chunk_size: int = CHUNK_SIZE, subsection_overlap: int = CHUNK_WORD_OVERLAP, blurb_len: int = BLURB_LENGTH, -) -> list[IndexChunk]: - chunks: list[IndexChunk] = [] +) -> list[DocAwareChunk]: + chunks: list[DocAwareChunk] = [] link_offsets: dict[int, str] = {} chunk_text = "" for section in document.sections: @@ -160,7 +160,7 @@ def chunk_document( if section_length > chunk_size: if chunk_text: chunks.append( - IndexChunk( + DocAwareChunk( source_document=document, chunk_id=len(chunks), blurb=extract_blurb(chunk_text, blurb_len), @@ -191,7 +191,7 @@ def chunk_document( link_offsets[curr_offset_len] = section.link else: chunks.append( - IndexChunk( + DocAwareChunk( source_document=document, chunk_id=len(chunks), blurb=extract_blurb(chunk_text, blurb_len), @@ -206,7 +206,7 @@ def chunk_document( # Once we hit the end, if we're still in the process of building a chunk, add what we have if chunk_text: chunks.append( - IndexChunk( + DocAwareChunk( source_document=document, chunk_id=len(chunks), blurb=extract_blurb(chunk_text, blurb_len), @@ -220,10 +220,10 @@ def chunk_document( class Chunker: @abc.abstractmethod - def chunk(self, document: Document) -> list[IndexChunk]: + def chunk(self, document: Document) -> list[DocAwareChunk]: raise NotImplementedError class DefaultChunker(Chunker): - def chunk(self, document: Document) -> list[IndexChunk]: + def chunk(self, document: Document) -> list[DocAwareChunk]: return chunk_document(document) diff --git a/backend/danswer/chunking/models.py b/backend/danswer/chunking/models.py index 55c5d2e31ca..0ead152b779 100644 --- a/backend/danswer/chunking/models.py +++ b/backend/danswer/chunking/models.py @@ -14,6 +14,15 @@ from danswer.utils.logger import setup_logger logger = setup_logger() +Embedding = list[float] + + +@dataclass +class ChunkEmbedding: + full_embedding: Embedding + mini_chunk_embeddings: list[Embedding] + + @dataclass class BaseChunk: chunk_id: int @@ -26,7 +35,7 @@ class BaseChunk: @dataclass -class IndexChunk(BaseChunk): +class DocAwareChunk(BaseChunk): # During indexing flow, we have access to a complete "Document" # During inference we only have access to the document id and do not reconstruct the Document source_document: Document @@ -39,8 +48,8 @@ class IndexChunk(BaseChunk): @dataclass -class EmbeddedIndexChunk(IndexChunk): - embeddings: list[list[float]] +class IndexChunk(DocAwareChunk): + embeddings: ChunkEmbedding @dataclass diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 4aa0a95196a..fdb3d11fffb 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -1,5 +1,7 @@ import os +from danswer.configs.constants import DocumentIndexType + ##### # App Configs ##### @@ -62,20 +64,28 @@ MASK_CREDENTIAL_PREFIX = ( ##### # DB Configs ##### +DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices +# Vespa is now the default document index store for both keyword and vector +DOCUMENT_INDEX_TYPE = os.environ.get( + "DOCUMENT_INDEX_TYPE", DocumentIndexType.SPLIT.value +) +VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost" +VESPA_PORT = os.environ.get("VESPA_PORT") or "8081" +VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071" +# The default below is for dockerized deployment +VESPA_DEPLOYMENT_ZIP = ( + os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip" +) # Qdrant is Semantic Search Vector DB # Url / Key are used to connect to a remote Qdrant instance QDRANT_URL = os.environ.get("QDRANT_URL", "") QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "") # Host / Port are used for connecting to local Qdrant instance -QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost") +QDRANT_HOST = os.environ.get("QDRANT_HOST") or "localhost" QDRANT_PORT = 6333 -QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_DEFAULT_COLLECTION", "danswer_index") # Typesense is the Keyword Search Engine -TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST", "localhost") +TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST") or "localhost" TYPESENSE_PORT = 8108 -TYPESENSE_DEFAULT_COLLECTION = os.environ.get( - "TYPESENSE_DEFAULT_COLLECTION", "danswer_index" -) TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "") # Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder) INDEX_BATCH_SIZE = 16 diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index e7a2a5e5a76..467e7e7a178 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -9,6 +9,7 @@ SOURCE_LINKS = "source_links" SOURCE_LINK = "link" SEMANTIC_IDENTIFIER = "semantic_identifier" SECTION_CONTINUATION = "section_continuation" +EMBEDDINGS = "embeddings" ALLOWED_USERS = "allowed_users" ALLOWED_GROUPS = "allowed_groups" METADATA = "metadata" @@ -16,6 +17,7 @@ GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key" HTML_SEPARATOR = "\n" PUBLIC_DOC_PAT = "PUBLIC" QUOTE = "quote" +BOOST = "boost" class DocumentSource(str, Enum): @@ -35,6 +37,11 @@ class DocumentSource(str, Enum): LINEAR = "linear" +class DocumentIndexType(str, Enum): + COMBINED = "combined" # Vespa + SPLIT = "split" # Typesense + Qdrant + + class DanswerGenAIModel(str, Enum): """This represents the internal Danswer GenAI model which determines the class that is used to generate responses to the user query. Different models/services require different internal diff --git a/backend/danswer/connectors/google_drive/connector.py b/backend/danswer/connectors/google_drive/connector.py index 88b4681b4d3..1d462531cd2 100644 --- a/backend/danswer/connectors/google_drive/connector.py +++ b/backend/danswer/connectors/google_drive/connector.py @@ -24,7 +24,7 @@ from danswer.connectors.interfaces import PollConnector from danswer.connectors.interfaces import SecondsSinceUnixEpoch from danswer.connectors.models import Document from danswer.connectors.models import Section -from danswer.connectors.utils import batch_generator +from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger logger = setup_logger() diff --git a/backend/danswer/datastores/datastore_utils.py b/backend/danswer/datastores/datastore_utils.py index 9359562abeb..9b8a4aab252 100644 --- a/backend/danswer/datastores/datastore_utils.py +++ b/backend/danswer/datastores/datastore_utils.py @@ -5,7 +5,6 @@ from typing import TypeVar from pydantic import BaseModel -from danswer.chunking.models import EmbeddedIndexChunk from danswer.chunking.models import IndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.constants import PUBLIC_DOC_PAT @@ -16,7 +15,7 @@ DEFAULT_BATCH_SIZE = 30 def get_uuid_from_chunk( - chunk: IndexChunk | EmbeddedIndexChunk | InferenceChunk, mini_chunk_ind: int = 0 + chunk: IndexChunk | InferenceChunk, mini_chunk_ind: int = 0 ) -> uuid.UUID: doc_str = ( chunk.document_id @@ -58,7 +57,7 @@ def _add_if_not_exists(l: list[T], item: T) -> list[T]: def update_cross_connector_document_metadata_map( - chunk: IndexChunk | EmbeddedIndexChunk, + chunk: IndexChunk, cross_connector_document_metadata_map: dict[str, CrossConnectorDocumentMetadata], doc_store_cross_connector_document_metadata_fetch_fn: CrossConnectorDocumentMetadataFetchCallable, index_attempt_metadata: IndexAttemptMetadata, diff --git a/backend/danswer/datastores/document_index.py b/backend/danswer/datastores/document_index.py new file mode 100644 index 00000000000..74e303a2498 --- /dev/null +++ b/backend/danswer/datastores/document_index.py @@ -0,0 +1,111 @@ +from typing import Type +from uuid import UUID + +from danswer.chunking.models import IndexChunk +from danswer.chunking.models import InferenceChunk +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME +from danswer.configs.app_configs import DOCUMENT_INDEX_TYPE +from danswer.configs.app_configs import NUM_RETURNED_HITS +from danswer.configs.constants import DocumentIndexType +from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF +from danswer.connectors.models import IndexAttemptMetadata +from danswer.datastores.interfaces import DocumentIndex +from danswer.datastores.interfaces import DocumentInsertionRecord +from danswer.datastores.interfaces import IndexFilter +from danswer.datastores.interfaces import KeywordIndex +from danswer.datastores.interfaces import UpdateRequest +from danswer.datastores.interfaces import VectorIndex +from danswer.datastores.qdrant.store import QdrantIndex +from danswer.datastores.typesense.store import TypesenseIndex +from danswer.datastores.vespa.store import VespaIndex +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +class SplitDocumentIndex(DocumentIndex): + """A Document index that uses 2 separate storages + one for keyword index and one for vector index""" + + def __init__( + self, + index_name: str | None = DOCUMENT_INDEX_NAME, + keyword_index_cls: Type[KeywordIndex] = TypesenseIndex, + vector_index_cls: Type[VectorIndex] = QdrantIndex, + ) -> None: + index_name = index_name or DOCUMENT_INDEX_NAME + self.keyword_index = keyword_index_cls(index_name) + self.vector_index = vector_index_cls(index_name) + + def ensure_indices_exist(self) -> None: + self.keyword_index.ensure_indices_exist() + self.vector_index.ensure_indices_exist() + + def index( + self, + chunks: list[IndexChunk], + index_attempt_metadata: IndexAttemptMetadata, + ) -> set[DocumentInsertionRecord]: + keyword_index_result = self.keyword_index.index(chunks, index_attempt_metadata) + vector_index_result = self.vector_index.index(chunks, index_attempt_metadata) + if keyword_index_result != vector_index_result: + logger.error( + f"Inconsistent document indexing:\n" + f"Keyword: {keyword_index_result}\n" + f"Vector: {vector_index_result}" + ) + return keyword_index_result.union(vector_index_result) + + def update(self, update_requests: list[UpdateRequest]) -> None: + self.keyword_index.update(update_requests) + self.vector_index.update(update_requests) + + def delete(self, doc_ids: list[str]) -> None: + self.keyword_index.delete(doc_ids) + self.vector_index.delete(doc_ids) + + def keyword_retrieval( + self, + query: str, + user_id: UUID | None, + filters: list[IndexFilter] | None, + num_to_retrieve: int = NUM_RETURNED_HITS, + ) -> list[InferenceChunk]: + return self.keyword_index.keyword_retrieval( + query, user_id, filters, num_to_retrieve + ) + + def semantic_retrieval( + self, + query: str, + user_id: UUID | None, + filters: list[IndexFilter] | None, + num_to_retrieve: int = NUM_RETURNED_HITS, + distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF, + ) -> list[InferenceChunk]: + return self.vector_index.semantic_retrieval( + query, user_id, filters, num_to_retrieve, distance_cutoff + ) + + def hybrid_retrieval( + self, + query: str, + user_id: UUID | None, + filters: list[IndexFilter] | None, + num_to_retrieve: int, + ) -> list[InferenceChunk]: + """Currently results from vector and keyword indices are not combined post retrieval. + This may change in the future but for now, the default behavior is to use semantic search + which should be a more flexible/powerful search option""" + return self.semantic_retrieval(query, user_id, filters, num_to_retrieve) + + +def get_default_document_index( + collection: str | None = DOCUMENT_INDEX_NAME, index_type: str = DOCUMENT_INDEX_TYPE +) -> DocumentIndex: + if index_type == DocumentIndexType.COMBINED.value: + return VespaIndex() # Can't specify collection without modifying the deployment + elif index_type == DocumentIndexType.SPLIT.value: + return SplitDocumentIndex(index_name=collection) + else: + raise ValueError("Invalid document index type selected") diff --git a/backend/danswer/datastores/indexing_pipeline.py b/backend/danswer/datastores/indexing_pipeline.py index d8a2faf5aa1..11d00f7142a 100644 --- a/backend/danswer/datastores/indexing_pipeline.py +++ b/backend/danswer/datastores/indexing_pipeline.py @@ -6,15 +6,13 @@ from sqlalchemy.orm import Session from danswer.chunking.chunk import Chunker from danswer.chunking.chunk import DefaultChunker +from danswer.chunking.models import DocAwareChunk from danswer.connectors.models import Document from danswer.connectors.models import IndexAttemptMetadata -from danswer.datastores.interfaces import ChunkInsertionRecord -from danswer.datastores.interfaces import ChunkMetadata -from danswer.datastores.interfaces import KeywordIndex -from danswer.datastores.interfaces import StoreType -from danswer.datastores.interfaces import VectorIndex -from danswer.datastores.qdrant.store import QdrantIndex -from danswer.datastores.typesense.store import TypesenseIndex +from danswer.datastores.document_index import get_default_document_index +from danswer.datastores.interfaces import DocumentIndex +from danswer.datastores.interfaces import DocumentInsertionRecord +from danswer.datastores.interfaces import DocumentMetadata from danswer.db.document import upsert_documents_complete from danswer.db.engine import get_sqlalchemy_engine from danswer.search.models import Embedder @@ -32,20 +30,17 @@ class IndexingPipelineProtocol(Protocol): def _upsert_insertion_records( - insertion_records: list[ChunkInsertionRecord], + insertion_records: set[DocumentInsertionRecord], index_attempt_metadata: IndexAttemptMetadata, - document_store_type: StoreType, ) -> None: with Session(get_sqlalchemy_engine()) as session: upsert_documents_complete( db_session=session, document_metadata_batch=[ - ChunkMetadata( + DocumentMetadata( connector_id=index_attempt_metadata.connector_id, credential_id=index_attempt_metadata.credential_id, document_id=insertion_record.document_id, - store_id=insertion_record.store_id, - document_store_type=document_store_type, ) for insertion_record in insertion_records ], @@ -53,7 +48,7 @@ def _upsert_insertion_records( def _get_net_new_documents( - insertion_records: list[ChunkInsertionRecord], + insertion_records: list[DocumentInsertionRecord], ) -> int: net_new_documents = 0 seen_documents: set[str] = set() @@ -71,106 +66,61 @@ def _indexing_pipeline( *, chunker: Chunker, embedder: Embedder, - vector_index: VectorIndex, - keyword_index: KeywordIndex, + document_index: DocumentIndex, documents: list[Document], index_attempt_metadata: IndexAttemptMetadata, ) -> tuple[int, int]: """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" - # Chunk the documents into reasonably-sized chunks so they can fit into the - # context-sizes of our embedding models - chunks = list(chain(*[chunker.chunk(document=document) for document in documents])) + chunks: list[DocAwareChunk] = list( + chain(*[chunker.chunk(document=document) for document in documents]) + ) logger.debug( f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}" ) - - # Insert the chunks into our Keyword document store + store records of these - # documents / chunks into our database - # TODO keyword indexing can occur at same time as embedding - keyword_store_insertion_records = keyword_index.index( - chunks=chunks, index_attempt_metadata=index_attempt_metadata - ) - logger.debug(f"Keyword store insertion records: {keyword_store_insertion_records}") - # TODO (chris): remove this try/except after issue with null document_id is resolved - try: - _upsert_insertion_records( - insertion_records=keyword_store_insertion_records, - index_attempt_metadata=index_attempt_metadata, - document_store_type=StoreType.KEYWORD, - ) - except Exception as e: - logger.error( - f"Failed to upsert insertion records from keyword index for documents: " - f"{[document.to_short_descriptor() for document in documents]}, " - f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks]}," - f"for insertion records: {keyword_store_insertion_records}" - ) - raise e - net_doc_count_keyword = _get_net_new_documents( - insertion_records=keyword_store_insertion_records - ) - - # Embed the chunks and then insert them into our Vector document store - # + store records of these documents / chunks into our database chunks_with_embeddings = embedder.embed(chunks=chunks) - vector_store_insertion_records = vector_index.index( + + # A document will not be spread across different batches, so all the documents with chunks in this set, are fully + # represented by the chunks in this set + insertion_records = document_index.index( chunks=chunks_with_embeddings, index_attempt_metadata=index_attempt_metadata ) - logger.debug(f"Vector store insertion records: {keyword_store_insertion_records}") + # TODO (chris): remove this try/except after issue with null document_id is resolved try: _upsert_insertion_records( - insertion_records=vector_store_insertion_records, + insertion_records=insertion_records, index_attempt_metadata=index_attempt_metadata, - document_store_type=StoreType.VECTOR, ) except Exception as e: logger.error( f"Failed to upsert insertion records from vector index for documents: " f"{[document.to_short_descriptor() for document in documents]}, " f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks_with_embeddings]}" - f"for insertion records: {vector_store_insertion_records}" + f"for insertion records: {insertion_records}" ) raise e - net_doc_count_vector = _get_net_new_documents( - insertion_records=vector_store_insertion_records - ) - if net_doc_count_vector != net_doc_count_keyword: - logger.warning("Document count change from keyword/vector indices don't align") - net_new_docs = max(net_doc_count_keyword, net_doc_count_vector) - logger.info(f"Indexed {net_new_docs} new documents") - return net_new_docs, len(chunks) + return len(insertion_records), len(chunks) def build_indexing_pipeline( *, chunker: Chunker | None = None, embedder: Embedder | None = None, - vector_index: VectorIndex | None = None, - keyword_index: KeywordIndex | None = None, + document_index: DocumentIndex | None = None, ) -> IndexingPipelineProtocol: - """Builds a pipeline which takes in a list (batch) of docs and indexes them. + """Builds a pipline which takes in a list (batch) of docs and indexes them.""" + chunker = chunker or DefaultChunker() - Default uses _ chunker, _ embedder, and qdrant for the datastore""" - if chunker is None: - chunker = DefaultChunker() + embedder = embedder or DefaultEmbedder() - if embedder is None: - embedder = DefaultEmbedder() - - if vector_index is None: - vector_index = QdrantIndex() - - if keyword_index is None: - keyword_index = TypesenseIndex() + document_index = document_index or get_default_document_index() return partial( _indexing_pipeline, chunker=chunker, embedder=embedder, - vector_index=vector_index, - keyword_index=keyword_index, + document_index=document_index, ) diff --git a/backend/danswer/datastores/interfaces.py b/backend/danswer/datastores/interfaces.py index 382a6278476..0ccddc81a84 100644 --- a/backend/danswer/datastores/interfaces.py +++ b/backend/danswer/datastores/interfaces.py @@ -1,62 +1,63 @@ import abc from dataclasses import dataclass -from enum import Enum from typing import Any -from typing import Generic -from typing import TypeVar from uuid import UUID -from danswer.chunking.models import BaseChunk -from danswer.chunking.models import EmbeddedIndexChunk from danswer.chunking.models import IndexChunk from danswer.chunking.models import InferenceChunk +from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF from danswer.connectors.models import IndexAttemptMetadata - -T = TypeVar("T", bound=BaseChunk) IndexFilter = dict[str, str | list[str] | None] -class StoreType(str, Enum): - VECTOR = "vector" - KEYWORD = "keyword" - - -@dataclass -class ChunkInsertionRecord: +@dataclass(frozen=True) +class DocumentInsertionRecord: document_id: str - store_id: str already_existed: bool @dataclass -class ChunkMetadata: +class DocumentMetadata: connector_id: int credential_id: int document_id: str - store_id: str - document_store_type: StoreType @dataclass class UpdateRequest: - ids: list[str] + """For all document_ids, update the allowed_users and the boost to the new value + ignore if None""" + + document_ids: list[str] # all other fields will be left alone - allowed_users: list[str] + allowed_users: list[str] | None = None + boost: int | None = None -class Indexable(Generic[T], abc.ABC): +class Verifiable(abc.ABC): + @abc.abstractmethod + def __init__(self, index_name: str, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.index_name = index_name + + @abc.abstractmethod + def ensure_indices_exist(self) -> None: + raise NotImplementedError + + +class Indexable(abc.ABC): @abc.abstractmethod def index( - self, chunks: list[T], index_attempt_metadata: IndexAttemptMetadata - ) -> list[ChunkInsertionRecord]: + self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata + ) -> set[DocumentInsertionRecord]: """Indexes document chunks into the Document Index and return the IDs of all the documents indexed""" raise NotImplementedError class Deletable(abc.ABC): @abc.abstractmethod - def delete(self, ids: list[str]) -> None: + def delete(self, doc_ids: list[str]) -> None: """Removes the specified documents from the Index""" raise NotImplementedError @@ -64,11 +65,23 @@ class Deletable(abc.ABC): class Updatable(abc.ABC): @abc.abstractmethod def update(self, update_requests: list[UpdateRequest]) -> None: - """Updates metadata for the specified documents in the Index""" + """Updates metadata for the specified documents sets in the Index""" raise NotImplementedError -class VectorIndex(Indexable[EmbeddedIndexChunk], Deletable, Updatable, abc.ABC): +class KeywordCapable(abc.ABC): + @abc.abstractmethod + def keyword_retrieval( + self, + query: str, + user_id: UUID | None, + filters: list[IndexFilter] | None, + num_to_retrieve: int, + ) -> list[InferenceChunk]: + raise NotImplementedError + + +class VectorCapable(abc.ABC): @abc.abstractmethod def semantic_retrieval( self, @@ -76,13 +89,14 @@ class VectorIndex(Indexable[EmbeddedIndexChunk], Deletable, Updatable, abc.ABC): user_id: UUID | None, filters: list[IndexFilter] | None, num_to_retrieve: int, + distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF, ) -> list[InferenceChunk]: raise NotImplementedError -class KeywordIndex(Indexable[IndexChunk], Deletable, Updatable, abc.ABC): +class HybridCapable(abc.ABC): @abc.abstractmethod - def keyword_search( + def hybrid_retrieval( self, query: str, user_id: UUID | None, @@ -90,3 +104,27 @@ class KeywordIndex(Indexable[IndexChunk], Deletable, Updatable, abc.ABC): num_to_retrieve: int, ) -> list[InferenceChunk]: raise NotImplementedError + + +class BaseIndex(Verifiable, Indexable, Updatable, Deletable, abc.ABC): + """All basic functionalities excluding a specific retrieval approach + Indices need to be able to + - Check that the index exists with a schema definition + - Can index documents + - Can delete documents + - Can update document metadata (such as access permissions and document specific boost) + """ + + pass + + +class KeywordIndex(KeywordCapable, BaseIndex, abc.ABC): + pass + + +class VectorIndex(VectorCapable, BaseIndex, abc.ABC): + pass + + +class DocumentIndex(KeywordCapable, VectorCapable, HybridCapable, BaseIndex, abc.ABC): + pass diff --git a/backend/danswer/datastores/qdrant/indexing.py b/backend/danswer/datastores/qdrant/indexing.py index 2f4d45e4784..994bcf7ede0 100644 --- a/backend/danswer/datastores/qdrant/indexing.py +++ b/backend/danswer/datastores/qdrant/indexing.py @@ -1,18 +1,14 @@ import json from functools import partial from typing import cast -from uuid import UUID from qdrant_client import QdrantClient from qdrant_client.http import models from qdrant_client.http.exceptions import ResponseHandlingException from qdrant_client.http.models.models import UpdateResult -from qdrant_client.models import CollectionsResponse -from qdrant_client.models import Distance from qdrant_client.models import PointStruct -from qdrant_client.models import VectorParams -from danswer.chunking.models import EmbeddedIndexChunk +from danswer.chunking.models import IndexChunk from danswer.configs.constants import ALLOWED_GROUPS from danswer.configs.constants import ALLOWED_USERS from danswer.configs.constants import BLURB @@ -24,7 +20,6 @@ from danswer.configs.constants import SECTION_CONTINUATION from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_TYPE -from danswer.configs.model_configs import DOC_EMBEDDING_DIM from danswer.connectors.models import IndexAttemptMetadata from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE @@ -32,7 +27,7 @@ from danswer.datastores.datastore_utils import get_uuid_from_chunk from danswer.datastores.datastore_utils import ( update_cross_connector_document_metadata_map, ) -from danswer.datastores.interfaces import ChunkInsertionRecord +from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.qdrant.utils import get_payload_from_record from danswer.utils.clients import get_qdrant_client from danswer.utils.logger import setup_logger @@ -41,22 +36,6 @@ from danswer.utils.logger import setup_logger logger = setup_logger() -def list_qdrant_collections() -> CollectionsResponse: - return get_qdrant_client().get_collections() - - -def create_qdrant_collection( - collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM -) -> None: - logger.info(f"Attempting to create collection {collection_name}") - result = get_qdrant_client().create_collection( - collection_name=collection_name, - vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE), - ) - if not result: - raise RuntimeError("Could not create Qdrant collection") - - def get_qdrant_document_cross_connector_metadata( doc_chunk_id: str, collection_name: str, q_client: QdrantClient ) -> CrossConnectorDocumentMetadata | None: @@ -113,18 +92,18 @@ def delete_qdrant_doc_chunks( def index_qdrant_chunks( - chunks: list[EmbeddedIndexChunk], + chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata, collection: str, client: QdrantClient | None = None, batch_upsert: bool = True, -) -> list[ChunkInsertionRecord]: +) -> set[DocumentInsertionRecord]: # Public documents will have the PUBLIC string in ALLOWED_USERS # If credential that kicked this off has no user associated, either Auth is off or the doc is public q_client: QdrantClient = client if client else get_qdrant_client() point_structs: list[PointStruct] = [] - insertion_records: list[ChunkInsertionRecord] = [] + insertion_records: set[DocumentInsertionRecord] = set() # Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings cross_connector_document_metadata_map: dict[ str, CrossConnectorDocumentMetadata @@ -152,12 +131,15 @@ def index_qdrant_chunks( delete_qdrant_doc_chunks(document.id, collection, q_client) already_existing_documents.add(document.id) - for minichunk_ind, embedding in enumerate(chunk.embeddings): + embeddings = chunk.embeddings + vector_list = [embeddings.full_embedding] + vector_list.extend(embeddings.mini_chunk_embeddings) + + for minichunk_ind, embedding in enumerate(vector_list): qdrant_id = str(get_uuid_from_chunk(chunk, minichunk_ind)) - insertion_records.append( - ChunkInsertionRecord( + insertion_records.add( + DocumentInsertionRecord( document_id=document.id, - store_id=qdrant_id, already_existed=document.id in already_existing_documents, ) ) diff --git a/backend/danswer/datastores/qdrant/store.py b/backend/danswer/datastores/qdrant/store.py index 68ea1b6d19b..b2e0c403b1c 100644 --- a/backend/danswer/datastores/qdrant/store.py +++ b/backend/danswer/datastores/qdrant/store.py @@ -1,6 +1,6 @@ -from typing import Any from uuid import UUID +from qdrant_client import QdrantClient from qdrant_client.http.exceptions import ResponseHandlingException from qdrant_client.http.exceptions import UnexpectedResponse from qdrant_client.http.models import FieldCondition @@ -8,22 +8,25 @@ from qdrant_client.http.models import Filter from qdrant_client.http.models import MatchAny from qdrant_client.http.models import MatchValue -from danswer.chunking.models import EmbeddedIndexChunk +from danswer.chunking.models import IndexChunk from danswer.chunking.models import InferenceChunk +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.app_configs import NUM_RETURNED_HITS -from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.constants import ALLOWED_USERS +from danswer.configs.constants import DOCUMENT_ID from danswer.configs.constants import PUBLIC_DOC_PAT from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF from danswer.connectors.models import IndexAttemptMetadata -from danswer.connectors.utils import batch_generator from danswer.datastores.datastore_utils import get_uuid_from_chunk -from danswer.datastores.interfaces import ChunkInsertionRecord +from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import IndexFilter from danswer.datastores.interfaces import UpdateRequest from danswer.datastores.interfaces import VectorIndex from danswer.datastores.qdrant.indexing import index_qdrant_chunks +from danswer.datastores.qdrant.utils import create_qdrant_collection +from danswer.datastores.qdrant.utils import list_qdrant_collections from danswer.search.search_utils import get_default_embedding_model +from danswer.utils.batching import batch_generator from danswer.utils.clients import get_qdrant_client from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -81,16 +84,51 @@ def _build_qdrant_filters( return filter_conditions +def _get_points_from_document_ids( + document_ids: list[str], + collection: str, + client: QdrantClient, +) -> list[int | str]: + offset: int | str | None = 0 + chunk_ids = [] + while offset is not None: + matches, offset = client.scroll( + collection_name=collection, + scroll_filter=Filter( + must=[FieldCondition(key=DOCUMENT_ID, match=MatchAny(any=document_ids))] + ), + limit=_BATCH_SIZE, + with_payload=False, + with_vectors=False, + offset=offset, + ) + for match in matches: + chunk_ids.append(match.id) + + return chunk_ids + + class QdrantIndex(VectorIndex): - def __init__(self, collection: str = QDRANT_DEFAULT_COLLECTION) -> None: - self.collection = collection + def __init__(self, index_name: str = DOCUMENT_INDEX_NAME) -> None: + # In Qdrant, the vector index is referred to as a collection + self.collection = index_name self.client = get_qdrant_client() + def ensure_indices_exist(self) -> None: + if self.collection not in { + collection.name + for collection in list_qdrant_collections(self.client).collections + }: + logger.info(f"Creating Qdrant collection with name: {self.collection}") + create_qdrant_collection( + collection_name=self.collection, q_client=self.client + ) + def index( self, - chunks: list[EmbeddedIndexChunk], + chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata, - ) -> list[ChunkInsertionRecord]: + ) -> set[DocumentInsertionRecord]: return index_qdrant_chunks( chunks=chunks, index_attempt_metadata=index_attempt_metadata, @@ -98,6 +136,35 @@ class QdrantIndex(VectorIndex): client=self.client, ) + def update(self, update_requests: list[UpdateRequest]) -> None: + logger.info( + f"Updating {len(update_requests)} documents' allowed_users in Qdrant" + ) + for update_request in update_requests: + for doc_id_batch in batch_generator( + items=update_request.document_ids, + batch_size=_BATCH_SIZE, + ): + chunk_ids = _get_points_from_document_ids( + doc_id_batch, self.collection, self.client + ) + self.client.set_payload( + collection_name=self.collection, + payload={ALLOWED_USERS: update_request.allowed_users}, + points=chunk_ids, + ) + + def delete(self, doc_ids: list[str]) -> None: + logger.info(f"Deleting {len(doc_ids)} documents from Qdrant") + for doc_id_batch in batch_generator(items=doc_ids, batch_size=_BATCH_SIZE): + chunk_ids = _get_points_from_document_ids( + doc_id_batch, self.collection, self.client + ) + self.client.delete( + collection_name=self.collection, + points_selector=chunk_ids, + ) + @log_function_time() def semantic_retrieval( self, @@ -105,8 +172,8 @@ class QdrantIndex(VectorIndex): user_id: UUID | None, filters: list[IndexFilter] | None, num_to_retrieve: int = NUM_RETURNED_HITS, - page_size: int = NUM_RETURNED_HITS, distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF, + page_size: int = NUM_RETURNED_HITS, ) -> list[InferenceChunk]: query_embedding = get_default_embedding_model().encode( query @@ -156,45 +223,3 @@ class QdrantIndex(VectorIndex): found_chunk_uuids.add(inf_chunk_id) return found_inference_chunks - - def delete(self, ids: list[str]) -> None: - logger.info(f"Deleting {len(ids)} documents from Qdrant") - for id_batch in batch_generator(items=ids, batch_size=_BATCH_SIZE): - self.client.delete( - collection_name=self.collection, - points_selector=id_batch, - ) - - def update(self, update_requests: list[UpdateRequest]) -> None: - logger.info( - f"Updating {len(update_requests)} documents' allowed_users in Qdrant" - ) - for update_request in update_requests: - for id_batch in batch_generator( - items=update_request.ids, - batch_size=_BATCH_SIZE, - ): - self.client.set_payload( - collection_name=self.collection, - payload={ALLOWED_USERS: update_request.allowed_users}, - points=id_batch, - ) - - def get_from_id(self, object_id: str) -> InferenceChunk | None: - matches, _ = self.client.scroll( - collection_name=self.collection, - scroll_filter=Filter( - must=[FieldCondition(key="id", match=MatchValue(value=object_id))] - ), - ) - if not matches: - return None - - if len(matches) > 1: - logger.error(f"Found multiple matches for {logger}: {matches}") - - match = matches[0] - if not match.payload: - return None - - return InferenceChunk.from_dict(match.payload) diff --git a/backend/danswer/datastores/qdrant/utils.py b/backend/danswer/datastores/qdrant/utils.py index 6d4e7b7eaa9..2c2bf95452b 100644 --- a/backend/danswer/datastores/qdrant/utils.py +++ b/backend/danswer/datastores/qdrant/utils.py @@ -1,6 +1,39 @@ from typing import Any +from qdrant_client import QdrantClient from qdrant_client.http.models import Record +from qdrant_client.models import CollectionsResponse +from qdrant_client.models import Distance +from qdrant_client.models import VectorParams + +from danswer.configs.model_configs import DOC_EMBEDDING_DIM +from danswer.utils.clients import get_qdrant_client +from danswer.utils.logger import setup_logger + + +logger = setup_logger() + + +def list_qdrant_collections( + q_client: QdrantClient | None = None, +) -> CollectionsResponse: + q_client = q_client or get_qdrant_client() + return q_client.get_collections() + + +def create_qdrant_collection( + collection_name: str, + embedding_dim: int = DOC_EMBEDDING_DIM, + q_client: QdrantClient | None = None, +) -> None: + q_client = q_client or get_qdrant_client() + logger.info(f"Attempting to create collection {collection_name}") + result = q_client.create_collection( + collection_name=collection_name, + vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE), + ) + if not result: + raise RuntimeError("Could not create Qdrant collection") def get_payload_from_record(record: Record, is_required: bool = True) -> dict[str, Any]: diff --git a/backend/danswer/datastores/typesense/store.py b/backend/danswer/datastores/typesense/store.py index 59f657dae82..abc7be74005 100644 --- a/backend/danswer/datastores/typesense/store.py +++ b/backend/danswer/datastores/typesense/store.py @@ -7,10 +7,10 @@ from uuid import UUID import typesense # type: ignore from typesense.exceptions import ObjectNotFound # type: ignore -from danswer.chunking.models import EmbeddedIndexChunk from danswer.chunking.models import IndexChunk from danswer.chunking.models import InferenceChunk -from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME +from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.constants import ALLOWED_GROUPS from danswer.configs.constants import ALLOWED_USERS from danswer.configs.constants import BLURB @@ -24,17 +24,17 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER from danswer.configs.constants import SOURCE_LINKS from danswer.configs.constants import SOURCE_TYPE from danswer.connectors.models import IndexAttemptMetadata -from danswer.connectors.utils import batch_generator from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE from danswer.datastores.datastore_utils import get_uuid_from_chunk from danswer.datastores.datastore_utils import ( update_cross_connector_document_metadata_map, ) -from danswer.datastores.interfaces import ChunkInsertionRecord +from danswer.datastores.interfaces import DocumentInsertionRecord from danswer.datastores.interfaces import IndexFilter from danswer.datastores.interfaces import KeywordIndex from danswer.datastores.interfaces import UpdateRequest +from danswer.utils.batching import batch_generator from danswer.utils.clients import get_typesense_client from danswer.utils.logger import setup_logger @@ -45,19 +45,8 @@ logger = setup_logger() _BATCH_SIZE = 200 -def check_typesense_collection_exist( - collection_name: str = TYPESENSE_DEFAULT_COLLECTION, -) -> bool: - client = get_typesense_client() - try: - client.collections[collection_name].retrieve() - except ObjectNotFound: - return False - return True - - def create_typesense_collection( - collection_name: str = TYPESENSE_DEFAULT_COLLECTION, + collection_name: str = DOCUMENT_INDEX_NAME, ) -> None: ts_client = get_typesense_client() collection_schema = { @@ -81,7 +70,18 @@ def create_typesense_collection( ts_client.collections.create(collection_schema) -def get_typesense_document_cross_connector_metadata( +def _check_typesense_collection_exist( + collection_name: str = DOCUMENT_INDEX_NAME, +) -> bool: + client = get_typesense_client() + try: + client.collections[collection_name].retrieve() + except ObjectNotFound: + return False + return True + + +def _get_typesense_document_cross_connector_metadata( doc_chunk_id: str, collection_name: str, ts_client: typesense.Client ) -> CrossConnectorDocumentMetadata | None: """Returns whether the document already exists and the users/group whitelists""" @@ -115,7 +115,7 @@ def get_typesense_document_cross_connector_metadata( ) -def delete_typesense_doc_chunks( +def _delete_typesense_doc_chunks( document_id: str, collection_name: str, ts_client: typesense.Client ) -> bool: doc_id_filter = {"filter_by": f"{DOCUMENT_ID}:'{document_id}'"} @@ -126,16 +126,16 @@ def delete_typesense_doc_chunks( return del_result["num_deleted"] != 0 -def index_typesense_chunks( - chunks: list[IndexChunk | EmbeddedIndexChunk], +def _index_typesense_chunks( + chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata, collection: str, client: typesense.Client | None = None, batch_upsert: bool = True, -) -> list[ChunkInsertionRecord]: +) -> set[DocumentInsertionRecord]: ts_client: typesense.Client = client if client else get_typesense_client() - insertion_records: list[ChunkInsertionRecord] = [] + insertion_records: set[DocumentInsertionRecord] = set() new_documents: list[dict[str, Any]] = [] cross_connector_document_metadata_map: dict[ str, CrossConnectorDocumentMetadata @@ -151,7 +151,7 @@ def index_typesense_chunks( chunk=chunk, cross_connector_document_metadata_map=cross_connector_document_metadata_map, doc_store_cross_connector_document_metadata_fetch_fn=partial( - get_typesense_document_cross_connector_metadata, + _get_typesense_document_cross_connector_metadata, collection_name=collection, ts_client=ts_client, ), @@ -160,14 +160,13 @@ def index_typesense_chunks( if should_delete_doc: # Processing the first chunk of the doc and the doc exists - delete_typesense_doc_chunks(document.id, collection, ts_client) + _delete_typesense_doc_chunks(document.id, collection, ts_client) already_existing_documents.add(document.id) typesense_id = str(get_uuid_from_chunk(chunk)) - insertion_records.append( - ChunkInsertionRecord( + insertion_records.add( + DocumentInsertionRecord( document_id=document.id, - store_id=typesense_id, already_existed=document.id in already_existing_documents, ) ) @@ -250,49 +249,55 @@ def _build_typesense_filters( class TypesenseIndex(KeywordIndex): - def __init__(self, collection: str = TYPESENSE_DEFAULT_COLLECTION) -> None: - self.collection = collection + def __init__(self, index_name: str = DOCUMENT_INDEX_NAME) -> None: + # In Typesense, the document index is referred to as a collection + self.collection = index_name self.ts_client = get_typesense_client() + def ensure_indices_exist(self) -> None: + if not _check_typesense_collection_exist(self.collection): + logger.info(f"Creating Typesense collection with name: {self.collection}") + create_typesense_collection(collection_name=self.collection) + def index( self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata - ) -> list[ChunkInsertionRecord]: - return index_typesense_chunks( + ) -> set[DocumentInsertionRecord]: + return _index_typesense_chunks( chunks=chunks, index_attempt_metadata=index_attempt_metadata, collection=self.collection, client=self.ts_client, ) - def delete(self, ids: list[str]) -> None: - logger.info(f"Deleting {len(ids)} documents from Typesense") - for id_batch in batch_generator(items=ids, batch_size=_BATCH_SIZE): - self.ts_client.collections[self.collection].documents.delete( - {"filter_by": f'id:[{",".join(id_batch)}]'} - ) - def update(self, update_requests: list[UpdateRequest]) -> None: logger.info( f"Updating {len(update_requests)} documents' allowed_users in Typesense" ) for update_request in update_requests: for id_batch in batch_generator( - items=update_request.ids, batch_size=_BATCH_SIZE + items=update_request.document_ids, batch_size=_BATCH_SIZE ): typesense_updates = [ - {"id": doc_id, ALLOWED_USERS: update_request.allowed_users} + {DOCUMENT_ID: doc_id, ALLOWED_USERS: update_request.allowed_users} for doc_id in id_batch ] self.ts_client.collections[self.collection].documents.import_( typesense_updates, {"action": "update"} ) - def keyword_search( + def delete(self, doc_ids: list[str]) -> None: + logger.info(f"Deleting {len(doc_ids)} documents from Typesense") + for id_batch in batch_generator(items=doc_ids, batch_size=_BATCH_SIZE): + self.ts_client.collections[self.collection].documents.delete( + {"filter_by": f'{DOCUMENT_ID}:[{",".join(id_batch)}]'} + ) + + def keyword_retrieval( self, query: str, user_id: UUID | None, filters: list[IndexFilter] | None, - num_to_retrieve: int, + num_to_retrieve: int = NUM_RETURNED_HITS, ) -> list[InferenceChunk]: filters_str = _build_typesense_filters(user_id, filters) diff --git a/backend/danswer/datastores/vespa/__init__.py b/backend/danswer/datastores/vespa/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd b/backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd new file mode 100644 index 00000000000..43afc4da19d --- /dev/null +++ b/backend/danswer/datastores/vespa/app_config/schemas/danswer_chunk.sd @@ -0,0 +1,89 @@ +schema danswer_chunk { + document danswer_chunk { + # Not to be confused with the UUID generated for this chunk which is called documentid by default + field document_id type string { + indexing: summary | attribute + } + field chunk_id type int { + indexing: summary | attribute + } + field blurb type string { + indexing: summary | attribute + } + # Can separate out title in the future and give heavier bm-25 weighting + # Need to consider that not every doc has a separable title (ie. slack message) + # Set summary options to enable bolding + field content type string { + indexing: summary | attribute | index + index: enable-bm25 + } + # https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it + field source_type type string { + indexing: summary | attribute + } + # Can also index links https://docs.vespa.ai/en/reference/schema-reference.html#attribute + # URL type matching + field source_links type string { + indexing: summary | attribute + } + field semantic_identifier type string { + indexing: summary | attribute + } + field section_continuation type bool { + indexing: summary | attribute + } + field boost type float { + indexing: summary | attribute + } + field metadata type string { + indexing: summary | attribute + } + field embeddings type tensor(t{},x[768]) { + indexing: attribute + attribute { + distance-metric: angular + } + } + field allowed_users type array { + indexing: summary | attribute + attribute: fast-search + } + field allowed_groups type array { + indexing: summary | attribute + attribute: fast-search + } + } + + fieldset default { + fields: content + } + + rank-profile keyword_search inherits default { + first-phase { + expression: bm25(content) * attribute(boost) + } + } + + rank-profile semantic_search inherits default { + inputs { + query(query_embedding) tensor(x[768]) + } + first-phase { + expression: closeness(field, embeddings) * attribute(boost) + } + match-features: closest(embeddings) + } + + rank-profile hybrid_search inherits default { + inputs { + query(query_embedding) tensor(x[768]) + } + first-phase { + expression: bm25(content) + } + second-phase { + expression: closeness(field, embeddings) * attribute(boost) + } + match-features: closest(embeddings) + } +} diff --git a/backend/danswer/datastores/vespa/app_config/services.xml b/backend/danswer/datastores/vespa/app_config/services.xml new file mode 100644 index 00000000000..1eab58c241d --- /dev/null +++ b/backend/danswer/datastores/vespa/app_config/services.xml @@ -0,0 +1,19 @@ + + + + + + + + + + + 2 + + + + + + + + diff --git a/backend/danswer/datastores/vespa/store.py b/backend/danswer/datastores/vespa/store.py new file mode 100644 index 00000000000..41c6c9a3a8f --- /dev/null +++ b/backend/danswer/datastores/vespa/store.py @@ -0,0 +1,402 @@ +import json +from collections.abc import Mapping +from functools import partial +from typing import cast +from uuid import UUID + +import requests + +from danswer.chunking.models import IndexChunk +from danswer.chunking.models import InferenceChunk +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME +from danswer.configs.app_configs import NUM_RETURNED_HITS +from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP +from danswer.configs.app_configs import VESPA_HOST +from danswer.configs.app_configs import VESPA_PORT +from danswer.configs.app_configs import VESPA_TENANT_PORT +from danswer.configs.constants import ALLOWED_GROUPS +from danswer.configs.constants import ALLOWED_USERS +from danswer.configs.constants import BLURB +from danswer.configs.constants import BOOST +from danswer.configs.constants import CHUNK_ID +from danswer.configs.constants import CONTENT +from danswer.configs.constants import DOCUMENT_ID +from danswer.configs.constants import EMBEDDINGS +from danswer.configs.constants import METADATA +from danswer.configs.constants import PUBLIC_DOC_PAT +from danswer.configs.constants import SECTION_CONTINUATION +from danswer.configs.constants import SEMANTIC_IDENTIFIER +from danswer.configs.constants import SOURCE_LINKS +from danswer.configs.constants import SOURCE_TYPE +from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF +from danswer.connectors.models import IndexAttemptMetadata +from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata +from danswer.datastores.datastore_utils import get_uuid_from_chunk +from danswer.datastores.datastore_utils import ( + update_cross_connector_document_metadata_map, +) +from danswer.datastores.interfaces import DocumentIndex +from danswer.datastores.interfaces import DocumentInsertionRecord +from danswer.datastores.interfaces import IndexFilter +from danswer.datastores.interfaces import UpdateRequest +from danswer.search.search_utils import get_default_embedding_model +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +VESPA_CONFIG_SERVER_URL = f"http://{VESPA_HOST}:{VESPA_TENANT_PORT}" +VESPA_APP_CONTAINER_URL = f"http://{VESPA_HOST}:{VESPA_PORT}" +VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2" +# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd +DOCUMENT_ID_ENDPOINT = ( + f"{VESPA_APP_CONTAINER_URL}/document/v1/default/danswer_chunk/docid" +) +SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/" +_BATCH_SIZE = 100 # Specific to Vespa + + +def _get_vespa_document_cross_connector_metadata( + doc_chunk_id: str, +) -> CrossConnectorDocumentMetadata | None: + """Returns whether the document already exists and the users/group whitelists""" + doc_fetch_response = requests.get(f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}") + if doc_fetch_response.status_code == 404: + return None + + if doc_fetch_response.status_code != 200: + raise RuntimeError( + f"Unexpected fetch document by ID value from Vespa " + f"with error {doc_fetch_response.status_code}" + ) + + doc_fields = doc_fetch_response.json()["fields"] + allowed_users = doc_fields.get(ALLOWED_USERS) + allowed_groups = doc_fields.get(ALLOWED_GROUPS) + # Add group permission later, empty list gets saved/loaded as a null + if allowed_users is None: + raise RuntimeError( + "Vespa Index is corrupted, Document found with no access lists." + ) + return CrossConnectorDocumentMetadata( + allowed_users=allowed_users or [], + allowed_user_groups=allowed_groups or [], + already_in_index=True, + ) + + +def _get_vespa_chunk_ids_by_document_id( + document_id: str, hits_per_page: int = _BATCH_SIZE +) -> list[str]: + offset = 0 + doc_chunk_ids = [] + params: dict[str, int | str] = { + "yql": f"select documentid from {DOCUMENT_INDEX_NAME} where document_id contains '{document_id}'", + "timeout": "10s", + "offset": offset, + "hits": hits_per_page, + } + while True: + results = requests.get(SEARCH_ENDPOINT, params=params).json() + hits = results["root"].get("children", []) + doc_chunk_ids.extend( + [hit.get("fields", {}).get("documentid").split("::")[1] for hit in hits] + ) + params["offset"] += hits_per_page # type: ignore + + if len(hits) < hits_per_page: + break + return doc_chunk_ids + + +def _delete_vespa_doc_chunks(document_id: str) -> bool: + doc_chunk_ids = _get_vespa_chunk_ids_by_document_id(document_id) + + failures = [ + requests.delete(f"{DOCUMENT_ID_ENDPOINT}/{doc}").status_code != 200 + for doc in doc_chunk_ids + ] + return not any(failures) + + +def _index_vespa_chunks( + chunks: list[IndexChunk], + index_attempt_metadata: IndexAttemptMetadata, +) -> set[DocumentInsertionRecord]: + json_header = {"Content-Type": "application/json"} + insertion_records: set[DocumentInsertionRecord] = set() + cross_connector_document_metadata_map: dict[ + str, CrossConnectorDocumentMetadata + ] = {} + # document ids of documents that existed BEFORE this indexing + already_existing_documents: set[str] = set() + for chunk in chunks: + document = chunk.source_document + ( + cross_connector_document_metadata_map, + should_delete_doc, + ) = update_cross_connector_document_metadata_map( + chunk=chunk, + cross_connector_document_metadata_map=cross_connector_document_metadata_map, + doc_store_cross_connector_document_metadata_fetch_fn=partial( + _get_vespa_document_cross_connector_metadata, + ), + index_attempt_metadata=index_attempt_metadata, + ) + + if should_delete_doc: + # Processing the first chunk of the doc and the doc exists + deletion_success = _delete_vespa_doc_chunks(document.id) + if not deletion_success: + logger.error( + f"Failed to delete pre-existing chunks for with document with id: {document.id}" + ) + already_existing_documents.add(document.id) + + vespa_chunk_id = str(get_uuid_from_chunk(chunk)) + + embeddings = chunk.embeddings + embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding} + if embeddings.mini_chunk_embeddings: + for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings): + embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed + + vespa_document = { + "fields": { + DOCUMENT_ID: document.id, + CHUNK_ID: chunk.chunk_id, + BLURB: chunk.blurb, + CONTENT: chunk.content, + SOURCE_TYPE: str(document.source.value), + SOURCE_LINKS: json.dumps(chunk.source_links), + SEMANTIC_IDENTIFIER: document.semantic_identifier, + SECTION_CONTINUATION: chunk.section_continuation, + METADATA: json.dumps(document.metadata), + EMBEDDINGS: embeddings_name_vector_map, + BOOST: 1, # Boost value always starts at 1 for 0 impact on weight + ALLOWED_USERS: cross_connector_document_metadata_map[ + document.id + ].allowed_users, + ALLOWED_GROUPS: cross_connector_document_metadata_map[ + document.id + ].allowed_user_groups, + } + } + + url = f"{DOCUMENT_ID_ENDPOINT}/{vespa_chunk_id}" + + res = requests.post(url, headers=json_header, json=vespa_document) + res.raise_for_status() + + insertion_records.add( + DocumentInsertionRecord( + document_id=document.id, + already_existed=document.id in already_existing_documents, + ) + ) + + return insertion_records + + +def _build_vespa_filters( + user_id: UUID | None, filters: list[IndexFilter] | None +) -> str: + filter_str = "" + # Permissions filter + # TODO group permissioning + if user_id: + filter_str += ( + f'({ALLOWED_USERS} contains "{user_id}" or ' + f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}") and ' + ) + else: + filter_str += f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}" and ' + + # Provided query filters + if filters: + for filter_dict in filters: + valid_filters = { + key: value for key, value in filter_dict.items() if value is not None + } + for filter_key, filter_val in valid_filters.items(): + if isinstance(filter_val, str): + filter_str += f'{filter_key} = "{filter_val}" and ' + elif isinstance(filter_val, list): + quoted_elems = [f'"{elem}"' for elem in filter_val] + filters_or = ",".join(quoted_elems) + filter_str += f"{filter_key} in [{filters_or}] and " + else: + raise ValueError("Invalid filters provided") + return filter_str + + +def _query_vespa(query_params: Mapping[str, str | int]) -> list[InferenceChunk]: + if "query" in query_params and not cast(str, query_params["query"]).strip(): + raise ValueError( + "Query only consisted of stopwords, should not use Keyword Search" + ) + response = requests.get(SEARCH_ENDPOINT, params=query_params) + response.raise_for_status() + + hits = response.json()["root"].get("children", []) + inference_chunks = [InferenceChunk.from_dict(hit["fields"]) for hit in hits] + + return inference_chunks + + +class VespaIndex(DocumentIndex): + yql_base = ( + f"select " + f"documentid, " + f"{DOCUMENT_ID}, " + f"{CHUNK_ID}, " + f"{BLURB}, " + f"{CONTENT}, " + f"{SOURCE_TYPE}, " + f"{SOURCE_LINKS}, " + f"{SEMANTIC_IDENTIFIER}, " + f"{SECTION_CONTINUATION}, " + f"{BOOST}, " + f"{METADATA} " + f"from {DOCUMENT_INDEX_NAME} where " + ) + + def __init__(self, deployment_zip: str = VESPA_DEPLOYMENT_ZIP) -> None: + # Vespa index name isn't configurable via code alone because of the config .sd file that needs + # to be updated + zipped + deployed, not supporting the option for simplicity + self.deployment_zip = deployment_zip + + def ensure_indices_exist(self) -> None: + """Verifying indices is more involved as there is no good way to + verify the deployed app against the zip locally. But deploying the latest app.zip will ensure that + the index is up-to-date with the expected schema and this does not erase the existing index. + If the changes cannot be applied without conflict with existing data, it will fail with a non 200 + """ + deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate" + headers = {"Content-Type": "application/zip"} + with open(self.deployment_zip, "rb") as f: + response = requests.post(deploy_url, headers=headers, data=f) + if response.status_code != 200: + raise RuntimeError( + f"Failed to prepare Vespa Danswer Index. Response: {response.text}" + ) + + def index( + self, + chunks: list[IndexChunk], + index_attempt_metadata: IndexAttemptMetadata, + ) -> set[DocumentInsertionRecord]: + return _index_vespa_chunks( + chunks=chunks, index_attempt_metadata=index_attempt_metadata + ) + + def update(self, update_requests: list[UpdateRequest]) -> None: + logger.info( + f"Updating {len(update_requests)} documents' allowed_users in Vespa" + ) + + json_header = {"Content-Type": "application/json"} + + for update_request in update_requests: + if update_request.boost is None and update_request.allowed_users is None: + logger.error("Update request received but nothing to update") + + update_dict: dict[str, dict[str, list[str] | int]] = {"fields": {}} + if update_request.boost: + update_dict["fields"][BOOST] = update_request.boost + if update_request.allowed_users: + update_dict["fields"][ALLOWED_USERS] = update_request.allowed_users + + for document_id in update_request.document_ids: + for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(document_id): + url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}" + res = requests.put(url, headers=json_header, json=update_dict) + + if res.status_code != 200: + logger.error(f"Failed to update document: {document_id}") + + def delete(self, doc_ids: list[str]) -> None: + logger.info(f"Deleting {len(doc_ids)} documents from Vespa") + for doc_id in doc_ids: + for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(doc_id): + url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}" + logger.debug("Deleting: " + url) + requests.delete(url) + + def keyword_retrieval( + self, + query: str, + user_id: UUID | None, + filters: list[IndexFilter] | None, + num_to_retrieve: int = NUM_RETURNED_HITS, + ) -> list[InferenceChunk]: + vespa_where_clauses = _build_vespa_filters(user_id, filters) + yql = ( + VespaIndex.yql_base + + vespa_where_clauses + + '({grammar: "weakAnd"}userInput(@query))' + ) + + params: dict[str, str | int] = { + "yql": yql, + "query": query, + "hits": num_to_retrieve, + "num_to_rerank": 10 * num_to_retrieve, + "ranking.profile": "keyword_search", + } + + return _query_vespa(params) + + def semantic_retrieval( + self, + query: str, + user_id: UUID | None, + filters: list[IndexFilter] | None, + num_to_retrieve: int, + distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF, + ) -> list[InferenceChunk]: + vespa_where_clauses = _build_vespa_filters(user_id, filters) + yql = ( + VespaIndex.yql_base + + vespa_where_clauses + + f"({{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding))" + ) + + query_embedding = get_default_embedding_model().encode(query) + if not isinstance(query_embedding, list): + query_embedding = query_embedding.tolist() + + params = { + "yql": yql, + "input.query(query_embedding)": str(query_embedding), + "ranking.profile": "semantic_search", + } + + return _query_vespa(params) + + def hybrid_retrieval( + self, + query: str, + user_id: UUID | None, + filters: list[IndexFilter] | None, + num_to_retrieve: int, + ) -> list[InferenceChunk]: + vespa_where_clauses = _build_vespa_filters(user_id, filters) + yql = ( + VespaIndex.yql_base + + vespa_where_clauses + + f'{{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding) or {{grammar: "weakAnd"}}userInput(@query)' + ) + + query_embedding = get_default_embedding_model().encode(query) + if not isinstance(query_embedding, list): + query_embedding = query_embedding.tolist() + + params = { + "yql": yql, + "query": query, + "input.query(query_embedding)": str(query_embedding), + "ranking.profile": "hybrid_search", + } + + return _query_vespa(params) diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index 4e8c4980477..c35db0c8682 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from dataclasses import dataclass from sqlalchemy import and_ from sqlalchemy import delete @@ -8,8 +7,7 @@ from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session -from danswer.datastores.interfaces import ChunkMetadata -from danswer.db.models import Chunk +from danswer.datastores.interfaces import DocumentMetadata from danswer.db.models import Document from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.utils import model_to_dict @@ -18,11 +16,11 @@ from danswer.utils.logger import setup_logger logger = setup_logger() -def get_chunks_with_single_connector_credential_pair( +def get_documents_with_single_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, -) -> Sequence[Chunk]: +) -> Sequence[Document]: initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, @@ -30,6 +28,8 @@ def get_chunks_with_single_connector_credential_pair( ) ) + # Filter it down to the documents with only a single connector/credential pair + # Meaning if this connector/credential pair is removed, this doc should be gone trimmed_doc_ids_stmt = ( select(Document.id) .join( @@ -41,7 +41,7 @@ def get_chunks_with_single_connector_credential_pair( .having(func.count(DocumentByConnectorCredentialPair.id) == 1) ) - stmt = select(Chunk).where(Chunk.document_id.in_(trimmed_doc_ids_stmt)) + stmt = select(Document).where(Document.id.in_(trimmed_doc_ids_stmt)) return db_session.scalars(stmt).all() @@ -57,6 +57,8 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple( ) ) + # Filter it down to the documents with more than 1 connector/credential pair + # Meaning if this connector/credential pair is removed, this doc is still accessible trimmed_doc_ids_stmt = ( select(Document.id) .join( @@ -75,15 +77,8 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple( return db_session.execute(stmt).scalars().all() -def get_chunk_ids_for_document_ids( - db_session: Session, document_ids: list[str] -) -> Sequence[str]: - stmt = select(Chunk.id).where(Chunk.document_id.in_(document_ids)) - return db_session.execute(stmt).scalars().all() - - def upsert_documents( - db_session: Session, document_metadata_batch: list[ChunkMetadata] + db_session: Session, document_metadata_batch: list[DocumentMetadata] ) -> None: """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.""" seen_document_ids: set[str] = set() @@ -102,7 +97,7 @@ def upsert_documents( def upsert_document_by_connector_credential_pair( - db_session: Session, document_metadata_batch: list[ChunkMetadata] + db_session: Session, document_metadata_batch: list[DocumentMetadata] ) -> None: """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.""" insert_stmt = insert(DocumentByConnectorCredentialPair).values( @@ -124,45 +119,16 @@ def upsert_document_by_connector_credential_pair( db_session.commit() -def upsert_chunks( - db_session: Session, document_metadata_batch: list[ChunkMetadata] -) -> None: - """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.""" - insert_stmt = insert(Chunk).values( - [ - model_to_dict( - Chunk( - id=document_metadata.store_id, - document_id=document_metadata.document_id, - document_store_type=document_metadata.document_store_type, - ) - ) - for document_metadata in document_metadata_batch - ] - ) - on_conflict_stmt = insert_stmt.on_conflict_do_update( - index_elements=["id", "document_store_type"], - set_=dict(document_id=insert_stmt.excluded.document_id), - ) - db_session.execute(on_conflict_stmt) - db_session.commit() - - def upsert_documents_complete( - db_session: Session, document_metadata_batch: list[ChunkMetadata] + db_session: Session, document_metadata_batch: list[DocumentMetadata] ) -> None: upsert_documents(db_session, document_metadata_batch) upsert_document_by_connector_credential_pair(db_session, document_metadata_batch) - upsert_chunks(db_session, document_metadata_batch) logger.info( f"Upserted {len(document_metadata_batch)} document store entries into DB" ) -def delete_document_store_entries(db_session: Session, document_ids: list[str]) -> None: - db_session.execute(delete(Chunk).where(Chunk.document_id.in_(document_ids))) - - def delete_document_by_connector_credential_pair( db_session: Session, document_ids: list[str] ) -> None: @@ -179,7 +145,6 @@ def delete_documents(db_session: Session, document_ids: list[str]) -> None: def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None: logger.info(f"Deleting {len(document_ids)} documents from the DB") - delete_document_store_entries(db_session, document_ids) delete_document_by_connector_credential_pair(db_session, document_ids) delete_documents(db_session, document_ids) db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index e38f1793529..e8b42be56f7 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -25,7 +25,6 @@ from sqlalchemy.orm import relationship from danswer.auth.schemas import UserRole from danswer.configs.constants import DocumentSource from danswer.connectors.models import InputType -from danswer.datastores.interfaces import StoreType class IndexingStatus(str, PyEnum): @@ -185,7 +184,7 @@ class IndexAttempt(Base): nullable=True, ) status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus)) - num_docs_indexed: Mapped[int] = mapped_column(Integer, default=0) + num_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) error_msg: Mapped[str | None] = mapped_column( String(), default=None ) # only filled if status = "failed" @@ -195,7 +194,7 @@ class IndexAttempt(Base): ) # when the actual indexing run began # NOTE: will use the api_server clock rather than DB server clock - time_started: Mapped[datetime.datetime] = mapped_column( + time_started: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) time_updated: Mapped[datetime.datetime] = mapped_column( @@ -261,10 +260,7 @@ class DeletionAttempt(Base): class Document(Base): """Represents a single documents from a source. This is used to store - document level metadata so we don't need to duplicate it in a bunch of - DocumentByConnectorCredentialPair's/Chunk's for documents - that are split into many chunks and/or indexed by many connector / credential - pairs.""" + document level metadata, but currently nothing is stored""" __tablename__ = "document" @@ -272,10 +268,6 @@ class Document(Base): # in Danswer) id: Mapped[str] = mapped_column(String, primary_key=True) - document_store_entries: Mapped["Chunk"] = relationship( - "Chunk", back_populates="document" - ) - class DocumentByConnectorCredentialPair(Base): """Represents an indexing of a document by a specific connector / credential @@ -297,21 +289,3 @@ class DocumentByConnectorCredentialPair(Base): credential: Mapped[Credential] = relationship( "Credential", back_populates="documents_by_credential" ) - - -class Chunk(Base): - """A row represents a single entry in a document store (e.g. a single chunk - in Qdrant/Typesense)""" - - __tablename__ = "chunk" - - # this should correspond to the ID in the document store - id: Mapped[str] = mapped_column(String, primary_key=True) - document_store_type: Mapped[StoreType] = mapped_column( - Enum(StoreType), primary_key=True - ) - document_id: Mapped[str] = mapped_column(ForeignKey("document.id")) - - document: Mapped[Document] = relationship( - "Document", back_populates="document_store_entries" - ) diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 3494bdac4bc..6f9fa6e737d 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -2,8 +2,7 @@ from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.configs.app_configs import QA_TIMEOUT -from danswer.datastores.qdrant.store import QdrantIndex -from danswer.datastores.typesense.store import TypesenseIndex +from danswer.datastores.document_index import get_default_document_index from danswer.db.models import User from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError @@ -43,12 +42,12 @@ def answer_question( user_id = None if user is None else user.id if use_keyword: ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents( - query, user_id, filters, TypesenseIndex(collection) + query, user_id, filters, get_default_document_index(collection=collection) ) unranked_chunks: list[InferenceChunk] | None = [] else: ranked_chunks, unranked_chunks = retrieve_ranked_documents( - query, user_id, filters, QdrantIndex(collection) + query, user_id, filters, get_default_document_index(collection=collection) ) if not ranked_chunks: return QAResponse( diff --git a/backend/danswer/listeners/slack_listener.py b/backend/danswer/listeners/slack_listener.py index a0fbffc7c47..6854ce64270 100644 --- a/backend/danswer/listeners/slack_listener.py +++ b/backend/danswer/listeners/slack_listener.py @@ -13,7 +13,7 @@ from slack_sdk.socket_mode.response import SocketModeResponse from danswer.configs.app_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.app_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY from danswer.configs.app_configs import DANSWER_BOT_NUM_RETRIES -from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.constants import DocumentSource from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import UserIdReplacer @@ -206,7 +206,7 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non answer = _get_answer( QuestionRequest( query=req.payload.get("event", {}).get("text"), - collection=QDRANT_DEFAULT_COLLECTION, + collection=DOCUMENT_INDEX_NAME, use_keyword=None, filters=None, offset=None, diff --git a/backend/danswer/main.py b/backend/danswer/main.py index b17f6cb329a..72962ef26b2 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -21,17 +21,13 @@ from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import OAUTH_TYPE from danswer.configs.app_configs import OPENID_CONFIG_URL -from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.app_configs import SECRET -from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.model_configs import API_BASE_OPENAI from danswer.configs.model_configs import API_TYPE_OPENAI from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import INTERNAL_MODEL_VERSION -from danswer.datastores.qdrant.indexing import list_qdrant_collections -from danswer.datastores.typesense.store import check_typesense_collection_exist -from danswer.datastores.typesense.store import create_typesense_collection +from danswer.datastores.document_index import get_default_document_index from danswer.db.credentials import create_initial_public_credential from danswer.direct_qa.llm_utils import get_default_llm from danswer.server.event_loading import router as event_processing_router @@ -149,7 +145,6 @@ def get_application() -> FastAPI: from danswer.search.search_utils import ( warm_up_models, ) - from danswer.datastores.qdrant.indexing import create_qdrant_collection if DISABLE_GENERATIVE_AI: logger.info("Generative AI Q&A disabled") @@ -190,19 +185,8 @@ def get_application() -> FastAPI: logger.info("Verifying public credential exists.") create_initial_public_credential() - logger.info("Verifying Document Indexes are available.") - if QDRANT_DEFAULT_COLLECTION not in { - collection.name for collection in list_qdrant_collections().collections - }: - logger.info( - f"Creating Qdrant collection with name: {QDRANT_DEFAULT_COLLECTION}" - ) - create_qdrant_collection(collection_name=QDRANT_DEFAULT_COLLECTION) - if not check_typesense_collection_exist(TYPESENSE_DEFAULT_COLLECTION): - logger.info( - f"Creating Typesense collection with name: {TYPESENSE_DEFAULT_COLLECTION}" - ) - create_typesense_collection(collection_name=TYPESENSE_DEFAULT_COLLECTION) + logger.info("Verifying Document Index(s) is/are available.") + get_default_document_index().ensure_indices_exist() return application diff --git a/backend/danswer/search/keyword_search.py b/backend/danswer/search/keyword_search.py index 232166af943..6a09ca12417 100644 --- a/backend/danswer/search/keyword_search.py +++ b/backend/danswer/search/keyword_search.py @@ -7,8 +7,8 @@ from nltk.tokenize import word_tokenize # type:ignore from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import NUM_RETURNED_HITS +from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import IndexFilter -from danswer.datastores.interfaces import KeywordIndex from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -38,11 +38,11 @@ def retrieve_keyword_documents( query: str, user_id: UUID | None, filters: list[IndexFilter] | None, - datastore: KeywordIndex, + datastore: DocumentIndex, num_hits: int = NUM_RETURNED_HITS, ) -> list[InferenceChunk] | None: edited_query = query_processing(query) - top_chunks = datastore.keyword_search(edited_query, user_id, filters, num_hits) + top_chunks = datastore.keyword_retrieval(edited_query, user_id, filters, num_hits) if not top_chunks: filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "") logger.warning( diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index c0ee44a86cc..afc2d540ac0 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -1,6 +1,6 @@ from enum import Enum -from danswer.chunking.models import EmbeddedIndexChunk +from danswer.chunking.models import DocAwareChunk from danswer.chunking.models import IndexChunk @@ -15,5 +15,5 @@ class QueryFlow(str, Enum): class Embedder: - def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]: + def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]: raise NotImplementedError diff --git a/backend/danswer/search/semantic_search.py b/backend/danswer/search/semantic_search.py index 31eb94f7913..6dedaa87cbf 100644 --- a/backend/danswer/search/semantic_search.py +++ b/backend/danswer/search/semantic_search.py @@ -4,7 +4,8 @@ from uuid import UUID import numpy from sentence_transformers import SentenceTransformer # type: ignore -from danswer.chunking.models import EmbeddedIndexChunk +from danswer.chunking.models import ChunkEmbedding +from danswer.chunking.models import DocAwareChunk from danswer.chunking.models import IndexChunk from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import ENABLE_MINI_CHUNK @@ -12,8 +13,8 @@ from danswer.configs.app_configs import MINI_CHUNK_SIZE from danswer.configs.app_configs import NUM_RERANKED_RESULTS from danswer.configs.app_configs import NUM_RETURNED_HITS from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS +from danswer.datastores.interfaces import DocumentIndex from danswer.datastores.interfaces import IndexFilter -from danswer.datastores.interfaces import VectorIndex from danswer.search.models import Embedder from danswer.search.search_utils import get_default_embedding_model from danswer.search.search_utils import get_default_reranking_model_ensemble @@ -69,7 +70,7 @@ def retrieve_ranked_documents( query: str, user_id: UUID | None, filters: list[IndexFilter] | None, - datastore: VectorIndex, + datastore: DocumentIndex, num_hits: int = NUM_RETURNED_HITS, num_rerank: int = NUM_RERANKED_RESULTS, ) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]: @@ -127,12 +128,12 @@ def split_chunk_text_into_mini_chunks( @log_function_time() def encode_chunks( - chunks: list[IndexChunk], + chunks: list[DocAwareChunk], embedding_model: SentenceTransformer | None = None, batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, enable_mini_chunk: bool = ENABLE_MINI_CHUNK, -) -> list[EmbeddedIndexChunk]: - embedded_chunks: list[EmbeddedIndexChunk] = [] +) -> list[IndexChunk]: + embedded_chunks: list[IndexChunk] = [] if embedding_model is None: embedding_model = get_default_embedding_model() @@ -163,9 +164,12 @@ def encode_chunks( chunk_embeddings = embeddings[ embedding_ind_start : embedding_ind_start + num_embeddings ] - new_embedded_chunk = EmbeddedIndexChunk( + new_embedded_chunk = IndexChunk( **{k: getattr(chunk, k) for k in chunk.__dataclass_fields__}, - embeddings=chunk_embeddings, + embeddings=ChunkEmbedding( + full_embedding=chunk_embeddings[0], + mini_chunk_embeddings=chunk_embeddings[1:], + ), ) embedded_chunks.append(new_embedded_chunk) embedding_ind_start += num_embeddings @@ -174,5 +178,5 @@ def encode_chunks( class DefaultEmbedder(Embedder): - def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]: + def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]: return encode_chunks(chunks) diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index f1cad9d208d..084fc255e51 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -10,8 +10,7 @@ from danswer.auth.users import current_user from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS -from danswer.datastores.qdrant.store import QdrantIndex -from danswer.datastores.typesense.store import TypesenseIndex +from danswer.datastores.document_index import get_default_document_index from danswer.db.models import User from danswer.direct_qa.answer_question import answer_question from danswer.direct_qa.exceptions import OpenAIKeyMissing @@ -60,7 +59,7 @@ def semantic_search( user_id = None if user is None else user.id ranked_chunks, unranked_chunks = retrieve_ranked_documents( - query, user_id, filters, QdrantIndex(collection) + query, user_id, filters, get_default_document_index(collection=collection) ) if not ranked_chunks: return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None) @@ -82,7 +81,7 @@ def keyword_search( user_id = None if user is None else user.id ranked_chunks = retrieve_keyword_documents( - query, user_id, filters, TypesenseIndex(collection) + query, user_id, filters, get_default_document_index(collection=collection) ) if not ranked_chunks: return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None) @@ -110,6 +109,8 @@ def stream_direct_qa( logger.debug(f"Received QA query: {question.query}") logger.debug(f"Query filters: {question.filters}") + if question.use_keyword: + logger.debug(f"User selected Keyword Search") @log_generator_function_time() def stream_qa_portions( @@ -128,12 +129,18 @@ def stream_direct_qa( user_id = None if user is None else user.id if use_keyword: ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents( - query, user_id, filters, TypesenseIndex(collection) + query, + user_id, + filters, + get_default_document_index(collection=collection), ) unranked_chunks: list[InferenceChunk] | None = [] else: ranked_chunks, unranked_chunks = retrieve_ranked_documents( - query, user_id, filters, QdrantIndex(collection) + query, + user_id, + filters, + get_default_document_index(collection=collection), ) if not ranked_chunks: logger.debug("No Documents Found") diff --git a/backend/danswer/connectors/utils.py b/backend/danswer/utils/batching.py similarity index 100% rename from backend/danswer/connectors/utils.py rename to backend/danswer/utils/batching.py diff --git a/backend/scripts/list_typesense_docs.py b/backend/scripts/list_typesense_docs.py index 429030e4dbb..9cb6823d403 100644 --- a/backend/scripts/list_typesense_docs.py +++ b/backend/scripts/list_typesense_docs.py @@ -1,4 +1,4 @@ -from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.utils.clients import get_typesense_client @@ -14,9 +14,7 @@ if __name__ == "__main__": "page": page_number, "per_page": per_page, } - response = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.search( - params - ) + response = ts_client.collections[DOCUMENT_INDEX_NAME].documents.search(params) documents = response.get("hits") if not documents: break # if there are no more documents, break out of the loop diff --git a/backend/scripts/save_load_state.py b/backend/scripts/save_load_state.py index 9848dd98f86..18ab12b26eb 100644 --- a/backend/scripts/save_load_state.py +++ b/backend/scripts/save_load_state.py @@ -11,17 +11,16 @@ from typesense.exceptions import ObjectNotFound # type: ignore from alembic import command from alembic.config import Config +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.app_configs import POSTGRES_DB from danswer.configs.app_configs import POSTGRES_HOST from danswer.configs.app_configs import POSTGRES_PASSWORD from danswer.configs.app_configs import POSTGRES_PORT from danswer.configs.app_configs import POSTGRES_USER -from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.app_configs import QDRANT_HOST from danswer.configs.app_configs import QDRANT_PORT -from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION -from danswer.datastores.qdrant.indexing import create_qdrant_collection -from danswer.datastores.qdrant.indexing import list_qdrant_collections +from danswer.datastores.qdrant.utils import create_qdrant_collection +from danswer.datastores.qdrant.utils import list_qdrant_collections from danswer.datastores.typesense.store import create_typesense_collection from danswer.utils.clients import get_qdrant_client from danswer.utils.clients import get_typesense_client @@ -60,13 +59,13 @@ def snapshot_time_compare(snap: SnapshotDescription) -> datetime: def save_qdrant(filename: str) -> None: logger.info("Attempting to take Qdrant snapshot") qdrant_client = get_qdrant_client() - qdrant_client.create_snapshot(collection_name=QDRANT_DEFAULT_COLLECTION) - snapshots = qdrant_client.list_snapshots(collection_name=QDRANT_DEFAULT_COLLECTION) + qdrant_client.create_snapshot(collection_name=DOCUMENT_INDEX_NAME) + snapshots = qdrant_client.list_snapshots(collection_name=DOCUMENT_INDEX_NAME) valid_snapshots = [snap for snap in snapshots if snap.creation_time is not None] sorted_snapshots = sorted(valid_snapshots, key=snapshot_time_compare) last_snapshot_name = sorted_snapshots[-1].name - url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/{last_snapshot_name}" + url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{DOCUMENT_INDEX_NAME}/snapshots/{last_snapshot_name}" response = requests.get(url, stream=True) @@ -80,11 +79,11 @@ def save_qdrant(filename: str) -> None: def load_qdrant(filename: str) -> None: logger.info("Attempting to load Qdrant snapshot") - if QDRANT_DEFAULT_COLLECTION not in { + if DOCUMENT_INDEX_NAME not in { collection.name for collection in list_qdrant_collections().collections }: - create_qdrant_collection(QDRANT_DEFAULT_COLLECTION) - snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/" + create_qdrant_collection(DOCUMENT_INDEX_NAME) + snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{DOCUMENT_INDEX_NAME}/snapshots/" with open(filename, "rb") as f: files = {"snapshot": (os.path.basename(filename), f)} @@ -104,7 +103,7 @@ def load_qdrant(filename: str) -> None: def save_typesense(filename: str) -> None: logger.info("Attempting to take Typesense snapshot") ts_client = get_typesense_client() - all_docs = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.export() + all_docs = ts_client.collections[DOCUMENT_INDEX_NAME].documents.export() with open(filename, "w") as f: f.write(all_docs) @@ -113,14 +112,14 @@ def load_typesense(filename: str) -> None: logger.info("Attempting to load Typesense snapshot") ts_client = get_typesense_client() try: - ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].delete() + ts_client.collections[DOCUMENT_INDEX_NAME].delete() except ObjectNotFound: pass - create_typesense_collection(TYPESENSE_DEFAULT_COLLECTION) + create_typesense_collection(DOCUMENT_INDEX_NAME) with open(filename) as jsonl_file: - ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.import_( + ts_client.collections[DOCUMENT_INDEX_NAME].documents.import_( jsonl_file.read().encode("utf-8"), {"action": "create"} ) diff --git a/backend/scripts/simulate_frontend.py b/backend/scripts/simulate_frontend.py index 8ccca2ba3ff..08c314b451d 100644 --- a/backend/scripts/simulate_frontend.py +++ b/backend/scripts/simulate_frontend.py @@ -7,7 +7,7 @@ from pprint import pprint import requests from danswer.configs.app_configs import APP_PORT -from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION +from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.constants import SOURCE_TYPE @@ -95,7 +95,7 @@ if __name__ == "__main__": query_json = { "query": query, - "collection": QDRANT_DEFAULT_COLLECTION, + "collection": DOCUMENT_INDEX_NAME, "use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints "filters": [{SOURCE_TYPE: source_types}], }