Add Vespa and rework Document Indices (#317)

This commit is contained in:
Yuhong Sun
2023-08-24 08:46:28 -07:00
committed by GitHub
parent a2d3a3f116
commit 8159fdcdce
33 changed files with 1059 additions and 433 deletions

1
backend/.gitignore vendored
View File

@@ -7,3 +7,4 @@ api_keys.py
qdrant-data/
typesense-data/
.env
vespa-app.zip

View File

@@ -0,0 +1,39 @@
"""Restructure Document Indices
Revision ID: 8aabb57f3b49
Revises: 5e84129c8be3
Create Date: 2023-08-18 21:15:57.629515
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "8aabb57f3b49"
down_revision = "5e84129c8be3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_table("chunk")
op.execute("DROP TYPE IF EXISTS documentstoretype")
def downgrade() -> None:
op.create_table(
"chunk",
sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.Column(
"document_store_type",
postgresql.ENUM("VECTOR", "KEYWORD", name="documentstoretype"),
autoincrement=False,
nullable=False,
),
sa.Column("document_id", sa.VARCHAR(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(
["document_id"], ["document.id"], name="chunk_document_id_fkey"
),
sa.PrimaryKeyConstraint("id", "document_store_type", name="chunk_pkey"),
)

View File

@@ -17,12 +17,9 @@ from datetime import datetime
from sqlalchemy.orm import Session
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.datastores.interfaces import KeywordIndex
from danswer.datastores.interfaces import StoreType
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.connector import fetch_connector_by_id
from danswer.db.connector_credential_pair import delete_connector_credential_pair
from danswer.db.connector_credential_pair import get_connector_credential_pair
@@ -31,13 +28,12 @@ from danswer.db.deletion_attempt import delete_deletion_attempts
from danswer.db.deletion_attempt import get_deletion_attempts
from danswer.db.document import delete_document_by_connector_credential_pair
from danswer.db.document import delete_documents_complete
from danswer.db.document import get_chunk_ids_for_document_ids
from danswer.db.document import (
get_chunks_with_single_connector_credential_pair,
)
from danswer.db.document import (
get_document_by_connector_credential_pairs_indexed_by_multiple,
)
from danswer.db.document import (
get_documents_with_single_connector_credential_pair,
)
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import delete_index_attempts
from danswer.db.models import Credential
@@ -50,8 +46,7 @@ logger = setup_logger()
def _delete_connector_credential_pair(
db_session: Session,
vector_index: VectorIndex,
keyword_index: KeywordIndex,
document_index: DocumentIndex,
deletion_attempt: DeletionAttempt,
) -> int:
connector_id = deletion_attempt.connector_id
@@ -59,33 +54,24 @@ def _delete_connector_credential_pair(
def _delete_singly_indexed_docs() -> int:
# if a document store entry is only indexed by this connector_credential_pair, delete it
num_docs_deleted = 0
chunks_to_delete = get_chunks_with_single_connector_credential_pair(
docs_to_delete = get_documents_with_single_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
credential_id=credential_id,
)
if chunks_to_delete:
document_ids: set[str] = set()
vector_chunk_ids_to_delete: list[str] = []
keyword_chunk_ids_to_delete: list[str] = []
for chunk in chunks_to_delete:
document_ids.add(chunk.document_id)
if chunk.document_store_type == StoreType.KEYWORD:
keyword_chunk_ids_to_delete.append(chunk.id)
else:
vector_chunk_ids_to_delete.append(chunk.id)
vector_index.delete(ids=vector_chunk_ids_to_delete)
keyword_index.delete(ids=keyword_chunk_ids_to_delete)
# removes all `Chunk`, `DocumentByConnectorCredentialPair`, and `Document`
if docs_to_delete:
document_ids = [doc.id for doc in docs_to_delete]
document_index.delete(doc_ids=document_ids)
# removes all `DocumentByConnectorCredentialPair`, and `Document`
# rows from the DB
delete_documents_complete(
db_session=db_session,
document_ids=list(document_ids),
)
num_docs_deleted += len(document_ids)
return num_docs_deleted
return len(docs_to_delete)
num_docs_deleted = _delete_singly_indexed_docs()
logger.info(f"Deleted {num_docs_deleted} documents from document stores")
@@ -144,18 +130,10 @@ def _delete_connector_credential_pair(
# actually perform the updates in the document store
update_requests = [
UpdateRequest(
ids=list(
get_chunk_ids_for_document_ids(
db_session=db_session, document_ids=document_ids
)
),
allowed_users=list(allowed_users),
)
UpdateRequest(document_ids=document_ids, allowed_users=list(allowed_users))
for allowed_users, document_ids in update_groups.items()
]
vector_index.update(update_requests=update_requests)
keyword_index.update(update_requests=update_requests)
document_index.update(update_requests=update_requests)
# delete the `document_by_connector_credential_pair` rows for the connector / credential pair
delete_document_by_connector_credential_pair(
@@ -246,8 +224,7 @@ def _run_deletion(db_session: Session) -> None:
try:
num_docs_deleted = _delete_connector_credential_pair(
db_session=db_session,
vector_index=QdrantIndex(),
keyword_index=TypesenseIndex(),
document_index=get_default_document_index(),
deletion_attempt=deletion_attempt,
)
except Exception as e:

View File

@@ -2,7 +2,7 @@ import abc
import re
from collections.abc import Callable
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import DocAwareChunk
from danswer.configs.app_configs import BLURB_LENGTH
from danswer.configs.app_configs import CHUNK_MAX_CHAR_OVERLAP
from danswer.configs.app_configs import CHUNK_SIZE
@@ -12,7 +12,7 @@ from danswer.connectors.models import Section
from danswer.utils.text_processing import shared_precompare_cleanup
SECTION_SEPARATOR = "\n\n"
ChunkFunc = Callable[[Document], list[IndexChunk]]
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
def extract_blurb(text: str, blurb_len: int) -> str:
@@ -51,7 +51,7 @@ def chunk_large_section(
word_overlap: int = CHUNK_WORD_OVERLAP,
blurb_len: int = BLURB_LENGTH,
chunk_overflow_max: int = CHUNK_MAX_CHAR_OVERLAP,
) -> list[IndexChunk]:
) -> list[DocAwareChunk]:
"""Split large sections into multiple chunks with the final chunk having as much previous overlap as possible.
Backtracks word_overlap words, delimited by whitespace, backtrack up to chunk_overflow_max characters max
When chunk is finished in forward direction, attempt to finish the word, but only up to chunk_overflow_max
@@ -129,7 +129,7 @@ def chunk_large_section(
chunks = []
for chunk_ind, chunk_str in enumerate(chunk_strs):
chunks.append(
IndexChunk(
DocAwareChunk(
source_document=document,
chunk_id=start_chunk_id + chunk_ind,
blurb=blurb,
@@ -146,8 +146,8 @@ def chunk_document(
chunk_size: int = CHUNK_SIZE,
subsection_overlap: int = CHUNK_WORD_OVERLAP,
blurb_len: int = BLURB_LENGTH,
) -> list[IndexChunk]:
chunks: list[IndexChunk] = []
) -> list[DocAwareChunk]:
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for section in document.sections:
@@ -160,7 +160,7 @@ def chunk_document(
if section_length > chunk_size:
if chunk_text:
chunks.append(
IndexChunk(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_len),
@@ -191,7 +191,7 @@ def chunk_document(
link_offsets[curr_offset_len] = section.link
else:
chunks.append(
IndexChunk(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_len),
@@ -206,7 +206,7 @@ def chunk_document(
# Once we hit the end, if we're still in the process of building a chunk, add what we have
if chunk_text:
chunks.append(
IndexChunk(
DocAwareChunk(
source_document=document,
chunk_id=len(chunks),
blurb=extract_blurb(chunk_text, blurb_len),
@@ -220,10 +220,10 @@ def chunk_document(
class Chunker:
@abc.abstractmethod
def chunk(self, document: Document) -> list[IndexChunk]:
def chunk(self, document: Document) -> list[DocAwareChunk]:
raise NotImplementedError
class DefaultChunker(Chunker):
def chunk(self, document: Document) -> list[IndexChunk]:
def chunk(self, document: Document) -> list[DocAwareChunk]:
return chunk_document(document)

View File

@@ -14,6 +14,15 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
Embedding = list[float]
@dataclass
class ChunkEmbedding:
full_embedding: Embedding
mini_chunk_embeddings: list[Embedding]
@dataclass
class BaseChunk:
chunk_id: int
@@ -26,7 +35,7 @@ class BaseChunk:
@dataclass
class IndexChunk(BaseChunk):
class DocAwareChunk(BaseChunk):
# During indexing flow, we have access to a complete "Document"
# During inference we only have access to the document id and do not reconstruct the Document
source_document: Document
@@ -39,8 +48,8 @@ class IndexChunk(BaseChunk):
@dataclass
class EmbeddedIndexChunk(IndexChunk):
embeddings: list[list[float]]
class IndexChunk(DocAwareChunk):
embeddings: ChunkEmbedding
@dataclass

View File

@@ -1,5 +1,7 @@
import os
from danswer.configs.constants import DocumentIndexType
#####
# App Configs
#####
@@ -62,20 +64,28 @@ MASK_CREDENTIAL_PREFIX = (
#####
# DB Configs
#####
DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices
# Vespa is now the default document index store for both keyword and vector
DOCUMENT_INDEX_TYPE = os.environ.get(
"DOCUMENT_INDEX_TYPE", DocumentIndexType.SPLIT.value
)
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
# The default below is for dockerized deployment
VESPA_DEPLOYMENT_ZIP = (
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
)
# Qdrant is Semantic Search Vector DB
# Url / Key are used to connect to a remote Qdrant instance
QDRANT_URL = os.environ.get("QDRANT_URL", "")
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
# Host / Port are used for connecting to local Qdrant instance
QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost")
QDRANT_HOST = os.environ.get("QDRANT_HOST") or "localhost"
QDRANT_PORT = 6333
QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_DEFAULT_COLLECTION", "danswer_index")
# Typesense is the Keyword Search Engine
TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST", "localhost")
TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST") or "localhost"
TYPESENSE_PORT = 8108
TYPESENSE_DEFAULT_COLLECTION = os.environ.get(
"TYPESENSE_DEFAULT_COLLECTION", "danswer_index"
)
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
INDEX_BATCH_SIZE = 16

View File

@@ -9,6 +9,7 @@ SOURCE_LINKS = "source_links"
SOURCE_LINK = "link"
SEMANTIC_IDENTIFIER = "semantic_identifier"
SECTION_CONTINUATION = "section_continuation"
EMBEDDINGS = "embeddings"
ALLOWED_USERS = "allowed_users"
ALLOWED_GROUPS = "allowed_groups"
METADATA = "metadata"
@@ -16,6 +17,7 @@ GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
HTML_SEPARATOR = "\n"
PUBLIC_DOC_PAT = "PUBLIC"
QUOTE = "quote"
BOOST = "boost"
class DocumentSource(str, Enum):
@@ -35,6 +37,11 @@ class DocumentSource(str, Enum):
LINEAR = "linear"
class DocumentIndexType(str, Enum):
COMBINED = "combined" # Vespa
SPLIT = "split" # Typesense + Qdrant
class DanswerGenAIModel(str, Enum):
"""This represents the internal Danswer GenAI model which determines the class that is used
to generate responses to the user query. Different models/services require different internal

View File

@@ -24,7 +24,7 @@ from danswer.connectors.interfaces import PollConnector
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
from danswer.connectors.models import Document
from danswer.connectors.models import Section
from danswer.connectors.utils import batch_generator
from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@@ -5,7 +5,6 @@ from typing import TypeVar
from pydantic import BaseModel
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.constants import PUBLIC_DOC_PAT
@@ -16,7 +15,7 @@ DEFAULT_BATCH_SIZE = 30
def get_uuid_from_chunk(
chunk: IndexChunk | EmbeddedIndexChunk | InferenceChunk, mini_chunk_ind: int = 0
chunk: IndexChunk | InferenceChunk, mini_chunk_ind: int = 0
) -> uuid.UUID:
doc_str = (
chunk.document_id
@@ -58,7 +57,7 @@ def _add_if_not_exists(l: list[T], item: T) -> list[T]:
def update_cross_connector_document_metadata_map(
chunk: IndexChunk | EmbeddedIndexChunk,
chunk: IndexChunk,
cross_connector_document_metadata_map: dict[str, CrossConnectorDocumentMetadata],
doc_store_cross_connector_document_metadata_fetch_fn: CrossConnectorDocumentMetadataFetchCallable,
index_attempt_metadata: IndexAttemptMetadata,

View File

@@ -0,0 +1,111 @@
from typing import Type
from uuid import UUID
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import DOCUMENT_INDEX_TYPE
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentIndexType
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import KeywordIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.datastores.vespa.store import VespaIndex
from danswer.utils.logger import setup_logger
logger = setup_logger()
class SplitDocumentIndex(DocumentIndex):
"""A Document index that uses 2 separate storages
one for keyword index and one for vector index"""
def __init__(
self,
index_name: str | None = DOCUMENT_INDEX_NAME,
keyword_index_cls: Type[KeywordIndex] = TypesenseIndex,
vector_index_cls: Type[VectorIndex] = QdrantIndex,
) -> None:
index_name = index_name or DOCUMENT_INDEX_NAME
self.keyword_index = keyword_index_cls(index_name)
self.vector_index = vector_index_cls(index_name)
def ensure_indices_exist(self) -> None:
self.keyword_index.ensure_indices_exist()
self.vector_index.ensure_indices_exist()
def index(
self,
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
) -> set[DocumentInsertionRecord]:
keyword_index_result = self.keyword_index.index(chunks, index_attempt_metadata)
vector_index_result = self.vector_index.index(chunks, index_attempt_metadata)
if keyword_index_result != vector_index_result:
logger.error(
f"Inconsistent document indexing:\n"
f"Keyword: {keyword_index_result}\n"
f"Vector: {vector_index_result}"
)
return keyword_index_result.union(vector_index_result)
def update(self, update_requests: list[UpdateRequest]) -> None:
self.keyword_index.update(update_requests)
self.vector_index.update(update_requests)
def delete(self, doc_ids: list[str]) -> None:
self.keyword_index.delete(doc_ids)
self.vector_index.delete(doc_ids)
def keyword_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk]:
return self.keyword_index.keyword_retrieval(
query, user_id, filters, num_to_retrieve
)
def semantic_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int = NUM_RETURNED_HITS,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
) -> list[InferenceChunk]:
return self.vector_index.semantic_retrieval(
query, user_id, filters, num_to_retrieve, distance_cutoff
)
def hybrid_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
) -> list[InferenceChunk]:
"""Currently results from vector and keyword indices are not combined post retrieval.
This may change in the future but for now, the default behavior is to use semantic search
which should be a more flexible/powerful search option"""
return self.semantic_retrieval(query, user_id, filters, num_to_retrieve)
def get_default_document_index(
collection: str | None = DOCUMENT_INDEX_NAME, index_type: str = DOCUMENT_INDEX_TYPE
) -> DocumentIndex:
if index_type == DocumentIndexType.COMBINED.value:
return VespaIndex() # Can't specify collection without modifying the deployment
elif index_type == DocumentIndexType.SPLIT.value:
return SplitDocumentIndex(index_name=collection)
else:
raise ValueError("Invalid document index type selected")

View File

@@ -6,15 +6,13 @@ from sqlalchemy.orm import Session
from danswer.chunking.chunk import Chunker
from danswer.chunking.chunk import DefaultChunker
from danswer.chunking.models import DocAwareChunk
from danswer.connectors.models import Document
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.interfaces import ChunkInsertionRecord
from danswer.datastores.interfaces import ChunkMetadata
from danswer.datastores.interfaces import KeywordIndex
from danswer.datastores.interfaces import StoreType
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.document import upsert_documents_complete
from danswer.db.engine import get_sqlalchemy_engine
from danswer.search.models import Embedder
@@ -32,20 +30,17 @@ class IndexingPipelineProtocol(Protocol):
def _upsert_insertion_records(
insertion_records: list[ChunkInsertionRecord],
insertion_records: set[DocumentInsertionRecord],
index_attempt_metadata: IndexAttemptMetadata,
document_store_type: StoreType,
) -> None:
with Session(get_sqlalchemy_engine()) as session:
upsert_documents_complete(
db_session=session,
document_metadata_batch=[
ChunkMetadata(
DocumentMetadata(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_id=insertion_record.document_id,
store_id=insertion_record.store_id,
document_store_type=document_store_type,
)
for insertion_record in insertion_records
],
@@ -53,7 +48,7 @@ def _upsert_insertion_records(
def _get_net_new_documents(
insertion_records: list[ChunkInsertionRecord],
insertion_records: list[DocumentInsertionRecord],
) -> int:
net_new_documents = 0
seen_documents: set[str] = set()
@@ -71,106 +66,61 @@ def _indexing_pipeline(
*,
chunker: Chunker,
embedder: Embedder,
vector_index: VectorIndex,
keyword_index: KeywordIndex,
document_index: DocumentIndex,
documents: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
) -> tuple[int, int]:
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
# Chunk the documents into reasonably-sized chunks so they can fit into the
# context-sizes of our embedding models
chunks = list(chain(*[chunker.chunk(document=document) for document in documents]))
chunks: list[DocAwareChunk] = list(
chain(*[chunker.chunk(document=document) for document in documents])
)
logger.debug(
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
)
# Insert the chunks into our Keyword document store + store records of these
# documents / chunks into our database
# TODO keyword indexing can occur at same time as embedding
keyword_store_insertion_records = keyword_index.index(
chunks=chunks, index_attempt_metadata=index_attempt_metadata
)
logger.debug(f"Keyword store insertion records: {keyword_store_insertion_records}")
# TODO (chris): remove this try/except after issue with null document_id is resolved
try:
_upsert_insertion_records(
insertion_records=keyword_store_insertion_records,
index_attempt_metadata=index_attempt_metadata,
document_store_type=StoreType.KEYWORD,
)
except Exception as e:
logger.error(
f"Failed to upsert insertion records from keyword index for documents: "
f"{[document.to_short_descriptor() for document in documents]}, "
f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks]},"
f"for insertion records: {keyword_store_insertion_records}"
)
raise e
net_doc_count_keyword = _get_net_new_documents(
insertion_records=keyword_store_insertion_records
)
# Embed the chunks and then insert them into our Vector document store
# + store records of these documents / chunks into our database
chunks_with_embeddings = embedder.embed(chunks=chunks)
vector_store_insertion_records = vector_index.index(
# A document will not be spread across different batches, so all the documents with chunks in this set, are fully
# represented by the chunks in this set
insertion_records = document_index.index(
chunks=chunks_with_embeddings, index_attempt_metadata=index_attempt_metadata
)
logger.debug(f"Vector store insertion records: {keyword_store_insertion_records}")
# TODO (chris): remove this try/except after issue with null document_id is resolved
try:
_upsert_insertion_records(
insertion_records=vector_store_insertion_records,
insertion_records=insertion_records,
index_attempt_metadata=index_attempt_metadata,
document_store_type=StoreType.VECTOR,
)
except Exception as e:
logger.error(
f"Failed to upsert insertion records from vector index for documents: "
f"{[document.to_short_descriptor() for document in documents]}, "
f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks_with_embeddings]}"
f"for insertion records: {vector_store_insertion_records}"
f"for insertion records: {insertion_records}"
)
raise e
net_doc_count_vector = _get_net_new_documents(
insertion_records=vector_store_insertion_records
)
if net_doc_count_vector != net_doc_count_keyword:
logger.warning("Document count change from keyword/vector indices don't align")
net_new_docs = max(net_doc_count_keyword, net_doc_count_vector)
logger.info(f"Indexed {net_new_docs} new documents")
return net_new_docs, len(chunks)
return len(insertion_records), len(chunks)
def build_indexing_pipeline(
*,
chunker: Chunker | None = None,
embedder: Embedder | None = None,
vector_index: VectorIndex | None = None,
keyword_index: KeywordIndex | None = None,
document_index: DocumentIndex | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them.
"""Builds a pipline which takes in a list (batch) of docs and indexes them."""
chunker = chunker or DefaultChunker()
Default uses _ chunker, _ embedder, and qdrant for the datastore"""
if chunker is None:
chunker = DefaultChunker()
embedder = embedder or DefaultEmbedder()
if embedder is None:
embedder = DefaultEmbedder()
if vector_index is None:
vector_index = QdrantIndex()
if keyword_index is None:
keyword_index = TypesenseIndex()
document_index = document_index or get_default_document_index()
return partial(
_indexing_pipeline,
chunker=chunker,
embedder=embedder,
vector_index=vector_index,
keyword_index=keyword_index,
document_index=document_index,
)

View File

@@ -1,62 +1,63 @@
import abc
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import Generic
from typing import TypeVar
from uuid import UUID
from danswer.chunking.models import BaseChunk
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
T = TypeVar("T", bound=BaseChunk)
IndexFilter = dict[str, str | list[str] | None]
class StoreType(str, Enum):
VECTOR = "vector"
KEYWORD = "keyword"
@dataclass
class ChunkInsertionRecord:
@dataclass(frozen=True)
class DocumentInsertionRecord:
document_id: str
store_id: str
already_existed: bool
@dataclass
class ChunkMetadata:
class DocumentMetadata:
connector_id: int
credential_id: int
document_id: str
store_id: str
document_store_type: StoreType
@dataclass
class UpdateRequest:
ids: list[str]
"""For all document_ids, update the allowed_users and the boost to the new value
ignore if None"""
document_ids: list[str]
# all other fields will be left alone
allowed_users: list[str]
allowed_users: list[str] | None = None
boost: int | None = None
class Indexable(Generic[T], abc.ABC):
class Verifiable(abc.ABC):
@abc.abstractmethod
def __init__(self, index_name: str, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.index_name = index_name
@abc.abstractmethod
def ensure_indices_exist(self) -> None:
raise NotImplementedError
class Indexable(abc.ABC):
@abc.abstractmethod
def index(
self, chunks: list[T], index_attempt_metadata: IndexAttemptMetadata
) -> list[ChunkInsertionRecord]:
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
) -> set[DocumentInsertionRecord]:
"""Indexes document chunks into the Document Index and return the IDs of all the documents indexed"""
raise NotImplementedError
class Deletable(abc.ABC):
@abc.abstractmethod
def delete(self, ids: list[str]) -> None:
def delete(self, doc_ids: list[str]) -> None:
"""Removes the specified documents from the Index"""
raise NotImplementedError
@@ -64,11 +65,23 @@ class Deletable(abc.ABC):
class Updatable(abc.ABC):
@abc.abstractmethod
def update(self, update_requests: list[UpdateRequest]) -> None:
"""Updates metadata for the specified documents in the Index"""
"""Updates metadata for the specified documents sets in the Index"""
raise NotImplementedError
class VectorIndex(Indexable[EmbeddedIndexChunk], Deletable, Updatable, abc.ABC):
class KeywordCapable(abc.ABC):
@abc.abstractmethod
def keyword_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError
class VectorCapable(abc.ABC):
@abc.abstractmethod
def semantic_retrieval(
self,
@@ -76,13 +89,14 @@ class VectorIndex(Indexable[EmbeddedIndexChunk], Deletable, Updatable, abc.ABC):
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
) -> list[InferenceChunk]:
raise NotImplementedError
class KeywordIndex(Indexable[IndexChunk], Deletable, Updatable, abc.ABC):
class HybridCapable(abc.ABC):
@abc.abstractmethod
def keyword_search(
def hybrid_retrieval(
self,
query: str,
user_id: UUID | None,
@@ -90,3 +104,27 @@ class KeywordIndex(Indexable[IndexChunk], Deletable, Updatable, abc.ABC):
num_to_retrieve: int,
) -> list[InferenceChunk]:
raise NotImplementedError
class BaseIndex(Verifiable, Indexable, Updatable, Deletable, abc.ABC):
"""All basic functionalities excluding a specific retrieval approach
Indices need to be able to
- Check that the index exists with a schema definition
- Can index documents
- Can delete documents
- Can update document metadata (such as access permissions and document specific boost)
"""
pass
class KeywordIndex(KeywordCapable, BaseIndex, abc.ABC):
pass
class VectorIndex(VectorCapable, BaseIndex, abc.ABC):
pass
class DocumentIndex(KeywordCapable, VectorCapable, HybridCapable, BaseIndex, abc.ABC):
pass

View File

@@ -1,18 +1,14 @@
import json
from functools import partial
from typing import cast
from uuid import UUID
from qdrant_client import QdrantClient
from qdrant_client.http import models
from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.models.models import UpdateResult
from qdrant_client.models import CollectionsResponse
from qdrant_client.models import Distance
from qdrant_client.models import PointStruct
from qdrant_client.models import VectorParams
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import BLURB
@@ -24,7 +20,6 @@ from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
@@ -32,7 +27,7 @@ from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import (
update_cross_connector_document_metadata_map,
)
from danswer.datastores.interfaces import ChunkInsertionRecord
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.qdrant.utils import get_payload_from_record
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logger import setup_logger
@@ -41,22 +36,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def list_qdrant_collections() -> CollectionsResponse:
return get_qdrant_client().get_collections()
def create_qdrant_collection(
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
) -> None:
logger.info(f"Attempting to create collection {collection_name}")
result = get_qdrant_client().create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
if not result:
raise RuntimeError("Could not create Qdrant collection")
def get_qdrant_document_cross_connector_metadata(
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
) -> CrossConnectorDocumentMetadata | None:
@@ -113,18 +92,18 @@ def delete_qdrant_doc_chunks(
def index_qdrant_chunks(
chunks: list[EmbeddedIndexChunk],
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
collection: str,
client: QdrantClient | None = None,
batch_upsert: bool = True,
) -> list[ChunkInsertionRecord]:
) -> set[DocumentInsertionRecord]:
# Public documents will have the PUBLIC string in ALLOWED_USERS
# If credential that kicked this off has no user associated, either Auth is off or the doc is public
q_client: QdrantClient = client if client else get_qdrant_client()
point_structs: list[PointStruct] = []
insertion_records: list[ChunkInsertionRecord] = []
insertion_records: set[DocumentInsertionRecord] = set()
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
cross_connector_document_metadata_map: dict[
str, CrossConnectorDocumentMetadata
@@ -152,12 +131,15 @@ def index_qdrant_chunks(
delete_qdrant_doc_chunks(document.id, collection, q_client)
already_existing_documents.add(document.id)
for minichunk_ind, embedding in enumerate(chunk.embeddings):
embeddings = chunk.embeddings
vector_list = [embeddings.full_embedding]
vector_list.extend(embeddings.mini_chunk_embeddings)
for minichunk_ind, embedding in enumerate(vector_list):
qdrant_id = str(get_uuid_from_chunk(chunk, minichunk_ind))
insertion_records.append(
ChunkInsertionRecord(
insertion_records.add(
DocumentInsertionRecord(
document_id=document.id,
store_id=qdrant_id,
already_existed=document.id in already_existing_documents,
)
)

View File

@@ -1,6 +1,6 @@
from typing import Any
from uuid import UUID
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import ResponseHandlingException
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import FieldCondition
@@ -8,22 +8,25 @@ from qdrant_client.http.models import Filter
from qdrant_client.http.models import MatchAny
from qdrant_client.http.models import MatchValue
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.utils import batch_generator
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.interfaces import ChunkInsertionRecord
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import UpdateRequest
from danswer.datastores.interfaces import VectorIndex
from danswer.datastores.qdrant.indexing import index_qdrant_chunks
from danswer.datastores.qdrant.utils import create_qdrant_collection
from danswer.datastores.qdrant.utils import list_qdrant_collections
from danswer.search.search_utils import get_default_embedding_model
from danswer.utils.batching import batch_generator
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@@ -81,16 +84,51 @@ def _build_qdrant_filters(
return filter_conditions
def _get_points_from_document_ids(
document_ids: list[str],
collection: str,
client: QdrantClient,
) -> list[int | str]:
offset: int | str | None = 0
chunk_ids = []
while offset is not None:
matches, offset = client.scroll(
collection_name=collection,
scroll_filter=Filter(
must=[FieldCondition(key=DOCUMENT_ID, match=MatchAny(any=document_ids))]
),
limit=_BATCH_SIZE,
with_payload=False,
with_vectors=False,
offset=offset,
)
for match in matches:
chunk_ids.append(match.id)
return chunk_ids
class QdrantIndex(VectorIndex):
def __init__(self, collection: str = QDRANT_DEFAULT_COLLECTION) -> None:
self.collection = collection
def __init__(self, index_name: str = DOCUMENT_INDEX_NAME) -> None:
# In Qdrant, the vector index is referred to as a collection
self.collection = index_name
self.client = get_qdrant_client()
def ensure_indices_exist(self) -> None:
if self.collection not in {
collection.name
for collection in list_qdrant_collections(self.client).collections
}:
logger.info(f"Creating Qdrant collection with name: {self.collection}")
create_qdrant_collection(
collection_name=self.collection, q_client=self.client
)
def index(
self,
chunks: list[EmbeddedIndexChunk],
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
) -> list[ChunkInsertionRecord]:
) -> set[DocumentInsertionRecord]:
return index_qdrant_chunks(
chunks=chunks,
index_attempt_metadata=index_attempt_metadata,
@@ -98,6 +136,35 @@ class QdrantIndex(VectorIndex):
client=self.client,
)
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(
f"Updating {len(update_requests)} documents' allowed_users in Qdrant"
)
for update_request in update_requests:
for doc_id_batch in batch_generator(
items=update_request.document_ids,
batch_size=_BATCH_SIZE,
):
chunk_ids = _get_points_from_document_ids(
doc_id_batch, self.collection, self.client
)
self.client.set_payload(
collection_name=self.collection,
payload={ALLOWED_USERS: update_request.allowed_users},
points=chunk_ids,
)
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Qdrant")
for doc_id_batch in batch_generator(items=doc_ids, batch_size=_BATCH_SIZE):
chunk_ids = _get_points_from_document_ids(
doc_id_batch, self.collection, self.client
)
self.client.delete(
collection_name=self.collection,
points_selector=chunk_ids,
)
@log_function_time()
def semantic_retrieval(
self,
@@ -105,8 +172,8 @@ class QdrantIndex(VectorIndex):
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int = NUM_RETURNED_HITS,
page_size: int = NUM_RETURNED_HITS,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
page_size: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk]:
query_embedding = get_default_embedding_model().encode(
query
@@ -156,45 +223,3 @@ class QdrantIndex(VectorIndex):
found_chunk_uuids.add(inf_chunk_id)
return found_inference_chunks
def delete(self, ids: list[str]) -> None:
logger.info(f"Deleting {len(ids)} documents from Qdrant")
for id_batch in batch_generator(items=ids, batch_size=_BATCH_SIZE):
self.client.delete(
collection_name=self.collection,
points_selector=id_batch,
)
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(
f"Updating {len(update_requests)} documents' allowed_users in Qdrant"
)
for update_request in update_requests:
for id_batch in batch_generator(
items=update_request.ids,
batch_size=_BATCH_SIZE,
):
self.client.set_payload(
collection_name=self.collection,
payload={ALLOWED_USERS: update_request.allowed_users},
points=id_batch,
)
def get_from_id(self, object_id: str) -> InferenceChunk | None:
matches, _ = self.client.scroll(
collection_name=self.collection,
scroll_filter=Filter(
must=[FieldCondition(key="id", match=MatchValue(value=object_id))]
),
)
if not matches:
return None
if len(matches) > 1:
logger.error(f"Found multiple matches for {logger}: {matches}")
match = matches[0]
if not match.payload:
return None
return InferenceChunk.from_dict(match.payload)

View File

@@ -1,6 +1,39 @@
from typing import Any
from qdrant_client import QdrantClient
from qdrant_client.http.models import Record
from qdrant_client.models import CollectionsResponse
from qdrant_client.models import Distance
from qdrant_client.models import VectorParams
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.utils.clients import get_qdrant_client
from danswer.utils.logger import setup_logger
logger = setup_logger()
def list_qdrant_collections(
q_client: QdrantClient | None = None,
) -> CollectionsResponse:
q_client = q_client or get_qdrant_client()
return q_client.get_collections()
def create_qdrant_collection(
collection_name: str,
embedding_dim: int = DOC_EMBEDDING_DIM,
q_client: QdrantClient | None = None,
) -> None:
q_client = q_client or get_qdrant_client()
logger.info(f"Attempting to create collection {collection_name}")
result = q_client.create_collection(
collection_name=collection_name,
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
)
if not result:
raise RuntimeError("Could not create Qdrant collection")
def get_payload_from_record(record: Record, is_required: bool = True) -> dict[str, Any]:

View File

@@ -7,10 +7,10 @@ from uuid import UUID
import typesense # type: ignore
from typesense.exceptions import ObjectNotFound # type: ignore
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import BLURB
@@ -24,17 +24,17 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.connectors.models import IndexAttemptMetadata
from danswer.connectors.utils import batch_generator
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import (
update_cross_connector_document_metadata_map,
)
from danswer.datastores.interfaces import ChunkInsertionRecord
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import KeywordIndex
from danswer.datastores.interfaces import UpdateRequest
from danswer.utils.batching import batch_generator
from danswer.utils.clients import get_typesense_client
from danswer.utils.logger import setup_logger
@@ -45,19 +45,8 @@ logger = setup_logger()
_BATCH_SIZE = 200
def check_typesense_collection_exist(
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
) -> bool:
client = get_typesense_client()
try:
client.collections[collection_name].retrieve()
except ObjectNotFound:
return False
return True
def create_typesense_collection(
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
collection_name: str = DOCUMENT_INDEX_NAME,
) -> None:
ts_client = get_typesense_client()
collection_schema = {
@@ -81,7 +70,18 @@ def create_typesense_collection(
ts_client.collections.create(collection_schema)
def get_typesense_document_cross_connector_metadata(
def _check_typesense_collection_exist(
collection_name: str = DOCUMENT_INDEX_NAME,
) -> bool:
client = get_typesense_client()
try:
client.collections[collection_name].retrieve()
except ObjectNotFound:
return False
return True
def _get_typesense_document_cross_connector_metadata(
doc_chunk_id: str, collection_name: str, ts_client: typesense.Client
) -> CrossConnectorDocumentMetadata | None:
"""Returns whether the document already exists and the users/group whitelists"""
@@ -115,7 +115,7 @@ def get_typesense_document_cross_connector_metadata(
)
def delete_typesense_doc_chunks(
def _delete_typesense_doc_chunks(
document_id: str, collection_name: str, ts_client: typesense.Client
) -> bool:
doc_id_filter = {"filter_by": f"{DOCUMENT_ID}:'{document_id}'"}
@@ -126,16 +126,16 @@ def delete_typesense_doc_chunks(
return del_result["num_deleted"] != 0
def index_typesense_chunks(
chunks: list[IndexChunk | EmbeddedIndexChunk],
def _index_typesense_chunks(
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
collection: str,
client: typesense.Client | None = None,
batch_upsert: bool = True,
) -> list[ChunkInsertionRecord]:
) -> set[DocumentInsertionRecord]:
ts_client: typesense.Client = client if client else get_typesense_client()
insertion_records: list[ChunkInsertionRecord] = []
insertion_records: set[DocumentInsertionRecord] = set()
new_documents: list[dict[str, Any]] = []
cross_connector_document_metadata_map: dict[
str, CrossConnectorDocumentMetadata
@@ -151,7 +151,7 @@ def index_typesense_chunks(
chunk=chunk,
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
doc_store_cross_connector_document_metadata_fetch_fn=partial(
get_typesense_document_cross_connector_metadata,
_get_typesense_document_cross_connector_metadata,
collection_name=collection,
ts_client=ts_client,
),
@@ -160,14 +160,13 @@ def index_typesense_chunks(
if should_delete_doc:
# Processing the first chunk of the doc and the doc exists
delete_typesense_doc_chunks(document.id, collection, ts_client)
_delete_typesense_doc_chunks(document.id, collection, ts_client)
already_existing_documents.add(document.id)
typesense_id = str(get_uuid_from_chunk(chunk))
insertion_records.append(
ChunkInsertionRecord(
insertion_records.add(
DocumentInsertionRecord(
document_id=document.id,
store_id=typesense_id,
already_existed=document.id in already_existing_documents,
)
)
@@ -250,49 +249,55 @@ def _build_typesense_filters(
class TypesenseIndex(KeywordIndex):
def __init__(self, collection: str = TYPESENSE_DEFAULT_COLLECTION) -> None:
self.collection = collection
def __init__(self, index_name: str = DOCUMENT_INDEX_NAME) -> None:
# In Typesense, the document index is referred to as a collection
self.collection = index_name
self.ts_client = get_typesense_client()
def ensure_indices_exist(self) -> None:
if not _check_typesense_collection_exist(self.collection):
logger.info(f"Creating Typesense collection with name: {self.collection}")
create_typesense_collection(collection_name=self.collection)
def index(
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
) -> list[ChunkInsertionRecord]:
return index_typesense_chunks(
) -> set[DocumentInsertionRecord]:
return _index_typesense_chunks(
chunks=chunks,
index_attempt_metadata=index_attempt_metadata,
collection=self.collection,
client=self.ts_client,
)
def delete(self, ids: list[str]) -> None:
logger.info(f"Deleting {len(ids)} documents from Typesense")
for id_batch in batch_generator(items=ids, batch_size=_BATCH_SIZE):
self.ts_client.collections[self.collection].documents.delete(
{"filter_by": f'id:[{",".join(id_batch)}]'}
)
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(
f"Updating {len(update_requests)} documents' allowed_users in Typesense"
)
for update_request in update_requests:
for id_batch in batch_generator(
items=update_request.ids, batch_size=_BATCH_SIZE
items=update_request.document_ids, batch_size=_BATCH_SIZE
):
typesense_updates = [
{"id": doc_id, ALLOWED_USERS: update_request.allowed_users}
{DOCUMENT_ID: doc_id, ALLOWED_USERS: update_request.allowed_users}
for doc_id in id_batch
]
self.ts_client.collections[self.collection].documents.import_(
typesense_updates, {"action": "update"}
)
def keyword_search(
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Typesense")
for id_batch in batch_generator(items=doc_ids, batch_size=_BATCH_SIZE):
self.ts_client.collections[self.collection].documents.delete(
{"filter_by": f'{DOCUMENT_ID}:[{",".join(id_batch)}]'}
)
def keyword_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
num_to_retrieve: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk]:
filters_str = _build_typesense_filters(user_id, filters)

View File

@@ -0,0 +1,89 @@
schema danswer_chunk {
document danswer_chunk {
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute
}
field chunk_id type int {
indexing: summary | attribute
}
field blurb type string {
indexing: summary | attribute
}
# Can separate out title in the future and give heavier bm-25 weighting
# Need to consider that not every doc has a separable title (ie. slack message)
# Set summary options to enable bolding
field content type string {
indexing: summary | attribute | index
index: enable-bm25
}
# https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it
field source_type type string {
indexing: summary | attribute
}
# Can also index links https://docs.vespa.ai/en/reference/schema-reference.html#attribute
# URL type matching
field source_links type string {
indexing: summary | attribute
}
field semantic_identifier type string {
indexing: summary | attribute
}
field section_continuation type bool {
indexing: summary | attribute
}
field boost type float {
indexing: summary | attribute
}
field metadata type string {
indexing: summary | attribute
}
field embeddings type tensor<float>(t{},x[768]) {
indexing: attribute
attribute {
distance-metric: angular
}
}
field allowed_users type array<string> {
indexing: summary | attribute
attribute: fast-search
}
field allowed_groups type array<string> {
indexing: summary | attribute
attribute: fast-search
}
}
fieldset default {
fields: content
}
rank-profile keyword_search inherits default {
first-phase {
expression: bm25(content) * attribute(boost)
}
}
rank-profile semantic_search inherits default {
inputs {
query(query_embedding) tensor<float>(x[768])
}
first-phase {
expression: closeness(field, embeddings) * attribute(boost)
}
match-features: closest(embeddings)
}
rank-profile hybrid_search inherits default {
inputs {
query(query_embedding) tensor<float>(x[768])
}
first-phase {
expression: bm25(content)
}
second-phase {
expression: closeness(field, embeddings) * attribute(boost)
}
match-features: closest(embeddings)
}
}

View File

@@ -0,0 +1,19 @@
<?xml version="1.0" encoding="utf-8" ?>
<services version="1.0" xmlns:deploy="vespa" xmlns:preprocess="properties">
<container id="default" version="1.0">
<document-api/>
<search/>
<nodes>
<node hostalias="node1" />
</nodes>
</container>
<content id="danswer_index" version="1.0">
<redundancy>2</redundancy>
<documents>
<document type="danswer_chunk" mode="index" />
</documents>
<nodes>
<node hostalias="node1" distribution-key="0" />
</nodes>
</content>
</services>

View File

@@ -0,0 +1,402 @@
import json
from collections.abc import Mapping
from functools import partial
from typing import cast
from uuid import UUID
import requests
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
from danswer.configs.app_configs import VESPA_HOST
from danswer.configs.app_configs import VESPA_PORT
from danswer.configs.app_configs import VESPA_TENANT_PORT
from danswer.configs.constants import ALLOWED_GROUPS
from danswer.configs.constants import ALLOWED_USERS
from danswer.configs.constants import BLURB
from danswer.configs.constants import BOOST
from danswer.configs.constants import CHUNK_ID
from danswer.configs.constants import CONTENT
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import EMBEDDINGS
from danswer.configs.constants import METADATA
from danswer.configs.constants import PUBLIC_DOC_PAT
from danswer.configs.constants import SECTION_CONTINUATION
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINKS
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
from danswer.connectors.models import IndexAttemptMetadata
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
from danswer.datastores.datastore_utils import get_uuid_from_chunk
from danswer.datastores.datastore_utils import (
update_cross_connector_document_metadata_map,
)
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import DocumentInsertionRecord
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import UpdateRequest
from danswer.search.search_utils import get_default_embedding_model
from danswer.utils.logger import setup_logger
logger = setup_logger()
VESPA_CONFIG_SERVER_URL = f"http://{VESPA_HOST}:{VESPA_TENANT_PORT}"
VESPA_APP_CONTAINER_URL = f"http://{VESPA_HOST}:{VESPA_PORT}"
VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd
DOCUMENT_ID_ENDPOINT = (
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/danswer_chunk/docid"
)
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
_BATCH_SIZE = 100 # Specific to Vespa
def _get_vespa_document_cross_connector_metadata(
doc_chunk_id: str,
) -> CrossConnectorDocumentMetadata | None:
"""Returns whether the document already exists and the users/group whitelists"""
doc_fetch_response = requests.get(f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}")
if doc_fetch_response.status_code == 404:
return None
if doc_fetch_response.status_code != 200:
raise RuntimeError(
f"Unexpected fetch document by ID value from Vespa "
f"with error {doc_fetch_response.status_code}"
)
doc_fields = doc_fetch_response.json()["fields"]
allowed_users = doc_fields.get(ALLOWED_USERS)
allowed_groups = doc_fields.get(ALLOWED_GROUPS)
# Add group permission later, empty list gets saved/loaded as a null
if allowed_users is None:
raise RuntimeError(
"Vespa Index is corrupted, Document found with no access lists."
)
return CrossConnectorDocumentMetadata(
allowed_users=allowed_users or [],
allowed_user_groups=allowed_groups or [],
already_in_index=True,
)
def _get_vespa_chunk_ids_by_document_id(
document_id: str, hits_per_page: int = _BATCH_SIZE
) -> list[str]:
offset = 0
doc_chunk_ids = []
params: dict[str, int | str] = {
"yql": f"select documentid from {DOCUMENT_INDEX_NAME} where document_id contains '{document_id}'",
"timeout": "10s",
"offset": offset,
"hits": hits_per_page,
}
while True:
results = requests.get(SEARCH_ENDPOINT, params=params).json()
hits = results["root"].get("children", [])
doc_chunk_ids.extend(
[hit.get("fields", {}).get("documentid").split("::")[1] for hit in hits]
)
params["offset"] += hits_per_page # type: ignore
if len(hits) < hits_per_page:
break
return doc_chunk_ids
def _delete_vespa_doc_chunks(document_id: str) -> bool:
doc_chunk_ids = _get_vespa_chunk_ids_by_document_id(document_id)
failures = [
requests.delete(f"{DOCUMENT_ID_ENDPOINT}/{doc}").status_code != 200
for doc in doc_chunk_ids
]
return not any(failures)
def _index_vespa_chunks(
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
) -> set[DocumentInsertionRecord]:
json_header = {"Content-Type": "application/json"}
insertion_records: set[DocumentInsertionRecord] = set()
cross_connector_document_metadata_map: dict[
str, CrossConnectorDocumentMetadata
] = {}
# document ids of documents that existed BEFORE this indexing
already_existing_documents: set[str] = set()
for chunk in chunks:
document = chunk.source_document
(
cross_connector_document_metadata_map,
should_delete_doc,
) = update_cross_connector_document_metadata_map(
chunk=chunk,
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
doc_store_cross_connector_document_metadata_fetch_fn=partial(
_get_vespa_document_cross_connector_metadata,
),
index_attempt_metadata=index_attempt_metadata,
)
if should_delete_doc:
# Processing the first chunk of the doc and the doc exists
deletion_success = _delete_vespa_doc_chunks(document.id)
if not deletion_success:
logger.error(
f"Failed to delete pre-existing chunks for with document with id: {document.id}"
)
already_existing_documents.add(document.id)
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
embeddings = chunk.embeddings
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
if embeddings.mini_chunk_embeddings:
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
vespa_document = {
"fields": {
DOCUMENT_ID: document.id,
CHUNK_ID: chunk.chunk_id,
BLURB: chunk.blurb,
CONTENT: chunk.content,
SOURCE_TYPE: str(document.source.value),
SOURCE_LINKS: json.dumps(chunk.source_links),
SEMANTIC_IDENTIFIER: document.semantic_identifier,
SECTION_CONTINUATION: chunk.section_continuation,
METADATA: json.dumps(document.metadata),
EMBEDDINGS: embeddings_name_vector_map,
BOOST: 1, # Boost value always starts at 1 for 0 impact on weight
ALLOWED_USERS: cross_connector_document_metadata_map[
document.id
].allowed_users,
ALLOWED_GROUPS: cross_connector_document_metadata_map[
document.id
].allowed_user_groups,
}
}
url = f"{DOCUMENT_ID_ENDPOINT}/{vespa_chunk_id}"
res = requests.post(url, headers=json_header, json=vespa_document)
res.raise_for_status()
insertion_records.add(
DocumentInsertionRecord(
document_id=document.id,
already_existed=document.id in already_existing_documents,
)
)
return insertion_records
def _build_vespa_filters(
user_id: UUID | None, filters: list[IndexFilter] | None
) -> str:
filter_str = ""
# Permissions filter
# TODO group permissioning
if user_id:
filter_str += (
f'({ALLOWED_USERS} contains "{user_id}" or '
f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}") and '
)
else:
filter_str += f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}" and '
# Provided query filters
if filters:
for filter_dict in filters:
valid_filters = {
key: value for key, value in filter_dict.items() if value is not None
}
for filter_key, filter_val in valid_filters.items():
if isinstance(filter_val, str):
filter_str += f'{filter_key} = "{filter_val}" and '
elif isinstance(filter_val, list):
quoted_elems = [f'"{elem}"' for elem in filter_val]
filters_or = ",".join(quoted_elems)
filter_str += f"{filter_key} in [{filters_or}] and "
else:
raise ValueError("Invalid filters provided")
return filter_str
def _query_vespa(query_params: Mapping[str, str | int]) -> list[InferenceChunk]:
if "query" in query_params and not cast(str, query_params["query"]).strip():
raise ValueError(
"Query only consisted of stopwords, should not use Keyword Search"
)
response = requests.get(SEARCH_ENDPOINT, params=query_params)
response.raise_for_status()
hits = response.json()["root"].get("children", [])
inference_chunks = [InferenceChunk.from_dict(hit["fields"]) for hit in hits]
return inference_chunks
class VespaIndex(DocumentIndex):
yql_base = (
f"select "
f"documentid, "
f"{DOCUMENT_ID}, "
f"{CHUNK_ID}, "
f"{BLURB}, "
f"{CONTENT}, "
f"{SOURCE_TYPE}, "
f"{SOURCE_LINKS}, "
f"{SEMANTIC_IDENTIFIER}, "
f"{SECTION_CONTINUATION}, "
f"{BOOST}, "
f"{METADATA} "
f"from {DOCUMENT_INDEX_NAME} where "
)
def __init__(self, deployment_zip: str = VESPA_DEPLOYMENT_ZIP) -> None:
# Vespa index name isn't configurable via code alone because of the config .sd file that needs
# to be updated + zipped + deployed, not supporting the option for simplicity
self.deployment_zip = deployment_zip
def ensure_indices_exist(self) -> None:
"""Verifying indices is more involved as there is no good way to
verify the deployed app against the zip locally. But deploying the latest app.zip will ensure that
the index is up-to-date with the expected schema and this does not erase the existing index.
If the changes cannot be applied without conflict with existing data, it will fail with a non 200
"""
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
headers = {"Content-Type": "application/zip"}
with open(self.deployment_zip, "rb") as f:
response = requests.post(deploy_url, headers=headers, data=f)
if response.status_code != 200:
raise RuntimeError(
f"Failed to prepare Vespa Danswer Index. Response: {response.text}"
)
def index(
self,
chunks: list[IndexChunk],
index_attempt_metadata: IndexAttemptMetadata,
) -> set[DocumentInsertionRecord]:
return _index_vespa_chunks(
chunks=chunks, index_attempt_metadata=index_attempt_metadata
)
def update(self, update_requests: list[UpdateRequest]) -> None:
logger.info(
f"Updating {len(update_requests)} documents' allowed_users in Vespa"
)
json_header = {"Content-Type": "application/json"}
for update_request in update_requests:
if update_request.boost is None and update_request.allowed_users is None:
logger.error("Update request received but nothing to update")
update_dict: dict[str, dict[str, list[str] | int]] = {"fields": {}}
if update_request.boost:
update_dict["fields"][BOOST] = update_request.boost
if update_request.allowed_users:
update_dict["fields"][ALLOWED_USERS] = update_request.allowed_users
for document_id in update_request.document_ids:
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(document_id):
url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}"
res = requests.put(url, headers=json_header, json=update_dict)
if res.status_code != 200:
logger.error(f"Failed to update document: {document_id}")
def delete(self, doc_ids: list[str]) -> None:
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
for doc_id in doc_ids:
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(doc_id):
url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}"
logger.debug("Deleting: " + url)
requests.delete(url)
def keyword_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk]:
vespa_where_clauses = _build_vespa_filters(user_id, filters)
yql = (
VespaIndex.yql_base
+ vespa_where_clauses
+ '({grammar: "weakAnd"}userInput(@query))'
)
params: dict[str, str | int] = {
"yql": yql,
"query": query,
"hits": num_to_retrieve,
"num_to_rerank": 10 * num_to_retrieve,
"ranking.profile": "keyword_search",
}
return _query_vespa(params)
def semantic_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
) -> list[InferenceChunk]:
vespa_where_clauses = _build_vespa_filters(user_id, filters)
yql = (
VespaIndex.yql_base
+ vespa_where_clauses
+ f"({{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding))"
)
query_embedding = get_default_embedding_model().encode(query)
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
params = {
"yql": yql,
"input.query(query_embedding)": str(query_embedding),
"ranking.profile": "semantic_search",
}
return _query_vespa(params)
def hybrid_retrieval(
self,
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
num_to_retrieve: int,
) -> list[InferenceChunk]:
vespa_where_clauses = _build_vespa_filters(user_id, filters)
yql = (
VespaIndex.yql_base
+ vespa_where_clauses
+ f'{{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding) or {{grammar: "weakAnd"}}userInput(@query)'
)
query_embedding = get_default_embedding_model().encode(query)
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
params = {
"yql": yql,
"query": query,
"input.query(query_embedding)": str(query_embedding),
"ranking.profile": "hybrid_search",
}
return _query_vespa(params)

View File

@@ -1,5 +1,4 @@
from collections.abc import Sequence
from dataclasses import dataclass
from sqlalchemy import and_
from sqlalchemy import delete
@@ -8,8 +7,7 @@ from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from danswer.datastores.interfaces import ChunkMetadata
from danswer.db.models import Chunk
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.models import Document
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.utils import model_to_dict
@@ -18,11 +16,11 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_chunks_with_single_connector_credential_pair(
def get_documents_with_single_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
) -> Sequence[Chunk]:
) -> Sequence[Document]:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
@@ -30,6 +28,8 @@ def get_chunks_with_single_connector_credential_pair(
)
)
# Filter it down to the documents with only a single connector/credential pair
# Meaning if this connector/credential pair is removed, this doc should be gone
trimmed_doc_ids_stmt = (
select(Document.id)
.join(
@@ -41,7 +41,7 @@ def get_chunks_with_single_connector_credential_pair(
.having(func.count(DocumentByConnectorCredentialPair.id) == 1)
)
stmt = select(Chunk).where(Chunk.document_id.in_(trimmed_doc_ids_stmt))
stmt = select(Document).where(Document.id.in_(trimmed_doc_ids_stmt))
return db_session.scalars(stmt).all()
@@ -57,6 +57,8 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple(
)
)
# Filter it down to the documents with more than 1 connector/credential pair
# Meaning if this connector/credential pair is removed, this doc is still accessible
trimmed_doc_ids_stmt = (
select(Document.id)
.join(
@@ -75,15 +77,8 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple(
return db_session.execute(stmt).scalars().all()
def get_chunk_ids_for_document_ids(
db_session: Session, document_ids: list[str]
) -> Sequence[str]:
stmt = select(Chunk.id).where(Chunk.document_id.in_(document_ids))
return db_session.execute(stmt).scalars().all()
def upsert_documents(
db_session: Session, document_metadata_batch: list[ChunkMetadata]
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
seen_document_ids: set[str] = set()
@@ -102,7 +97,7 @@ def upsert_documents(
def upsert_document_by_connector_credential_pair(
db_session: Session, document_metadata_batch: list[ChunkMetadata]
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
@@ -124,45 +119,16 @@ def upsert_document_by_connector_credential_pair(
db_session.commit()
def upsert_chunks(
db_session: Session, document_metadata_batch: list[ChunkMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
insert_stmt = insert(Chunk).values(
[
model_to_dict(
Chunk(
id=document_metadata.store_id,
document_id=document_metadata.document_id,
document_store_type=document_metadata.document_store_type,
)
)
for document_metadata in document_metadata_batch
]
)
on_conflict_stmt = insert_stmt.on_conflict_do_update(
index_elements=["id", "document_store_type"],
set_=dict(document_id=insert_stmt.excluded.document_id),
)
db_session.execute(on_conflict_stmt)
db_session.commit()
def upsert_documents_complete(
db_session: Session, document_metadata_batch: list[ChunkMetadata]
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
upsert_chunks(db_session, document_metadata_batch)
logger.info(
f"Upserted {len(document_metadata_batch)} document store entries into DB"
)
def delete_document_store_entries(db_session: Session, document_ids: list[str]) -> None:
db_session.execute(delete(Chunk).where(Chunk.document_id.in_(document_ids)))
def delete_document_by_connector_credential_pair(
db_session: Session, document_ids: list[str]
) -> None:
@@ -179,7 +145,6 @@ def delete_documents(db_session: Session, document_ids: list[str]) -> None:
def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None:
logger.info(f"Deleting {len(document_ids)} documents from the DB")
delete_document_store_entries(db_session, document_ids)
delete_document_by_connector_credential_pair(db_session, document_ids)
delete_documents(db_session, document_ids)
db_session.commit()

View File

@@ -25,7 +25,6 @@ from sqlalchemy.orm import relationship
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DocumentSource
from danswer.connectors.models import InputType
from danswer.datastores.interfaces import StoreType
class IndexingStatus(str, PyEnum):
@@ -185,7 +184,7 @@ class IndexAttempt(Base):
nullable=True,
)
status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
num_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
num_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0)
error_msg: Mapped[str | None] = mapped_column(
String(), default=None
) # only filled if status = "failed"
@@ -195,7 +194,7 @@ class IndexAttempt(Base):
)
# when the actual indexing run began
# NOTE: will use the api_server clock rather than DB server clock
time_started: Mapped[datetime.datetime] = mapped_column(
time_started: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), default=None
)
time_updated: Mapped[datetime.datetime] = mapped_column(
@@ -261,10 +260,7 @@ class DeletionAttempt(Base):
class Document(Base):
"""Represents a single documents from a source. This is used to store
document level metadata so we don't need to duplicate it in a bunch of
DocumentByConnectorCredentialPair's/Chunk's for documents
that are split into many chunks and/or indexed by many connector / credential
pairs."""
document level metadata, but currently nothing is stored"""
__tablename__ = "document"
@@ -272,10 +268,6 @@ class Document(Base):
# in Danswer)
id: Mapped[str] = mapped_column(String, primary_key=True)
document_store_entries: Mapped["Chunk"] = relationship(
"Chunk", back_populates="document"
)
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential
@@ -297,21 +289,3 @@ class DocumentByConnectorCredentialPair(Base):
credential: Mapped[Credential] = relationship(
"Credential", back_populates="documents_by_credential"
)
class Chunk(Base):
"""A row represents a single entry in a document store (e.g. a single chunk
in Qdrant/Typesense)"""
__tablename__ = "chunk"
# this should correspond to the ID in the document store
id: Mapped[str] = mapped_column(String, primary_key=True)
document_store_type: Mapped[StoreType] = mapped_column(
Enum(StoreType), primary_key=True
)
document_id: Mapped[str] = mapped_column(ForeignKey("document.id"))
document: Mapped[Document] = relationship(
"Document", back_populates="document_store_entries"
)

View File

@@ -2,8 +2,7 @@ from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.datastores.document_index import get_default_document_index
from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
@@ -43,12 +42,12 @@ def answer_question(
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
query, user_id, filters, get_default_document_index(collection=collection)
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
query, user_id, filters, get_default_document_index(collection=collection)
)
if not ranked_chunks:
return QAResponse(

View File

@@ -13,7 +13,7 @@ from slack_sdk.socket_mode.response import SocketModeResponse
from danswer.configs.app_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.app_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
from danswer.configs.app_configs import DANSWER_BOT_NUM_RETRIES
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.constants import DocumentSource
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import UserIdReplacer
@@ -206,7 +206,7 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
answer = _get_answer(
QuestionRequest(
query=req.payload.get("event", {}).get("text"),
collection=QDRANT_DEFAULT_COLLECTION,
collection=DOCUMENT_INDEX_NAME,
use_keyword=None,
filters=None,
offset=None,

View File

@@ -21,17 +21,13 @@ from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import OAUTH_TYPE
from danswer.configs.app_configs import OPENID_CONFIG_URL
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.model_configs import API_BASE_OPENAI
from danswer.configs.model_configs import API_TYPE_OPENAI
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.datastores.qdrant.indexing import list_qdrant_collections
from danswer.datastores.typesense.store import check_typesense_collection_exist
from danswer.datastores.typesense.store import create_typesense_collection
from danswer.datastores.document_index import get_default_document_index
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.llm_utils import get_default_llm
from danswer.server.event_loading import router as event_processing_router
@@ -149,7 +145,6 @@ def get_application() -> FastAPI:
from danswer.search.search_utils import (
warm_up_models,
)
from danswer.datastores.qdrant.indexing import create_qdrant_collection
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
@@ -190,19 +185,8 @@ def get_application() -> FastAPI:
logger.info("Verifying public credential exists.")
create_initial_public_credential()
logger.info("Verifying Document Indexes are available.")
if QDRANT_DEFAULT_COLLECTION not in {
collection.name for collection in list_qdrant_collections().collections
}:
logger.info(
f"Creating Qdrant collection with name: {QDRANT_DEFAULT_COLLECTION}"
)
create_qdrant_collection(collection_name=QDRANT_DEFAULT_COLLECTION)
if not check_typesense_collection_exist(TYPESENSE_DEFAULT_COLLECTION):
logger.info(
f"Creating Typesense collection with name: {TYPESENSE_DEFAULT_COLLECTION}"
)
create_typesense_collection(collection_name=TYPESENSE_DEFAULT_COLLECTION)
logger.info("Verifying Document Index(s) is/are available.")
get_default_document_index().ensure_indices_exist()
return application

View File

@@ -7,8 +7,8 @@ from nltk.tokenize import word_tokenize # type:ignore
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import KeywordIndex
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@@ -38,11 +38,11 @@ def retrieve_keyword_documents(
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
datastore: KeywordIndex,
datastore: DocumentIndex,
num_hits: int = NUM_RETURNED_HITS,
) -> list[InferenceChunk] | None:
edited_query = query_processing(query)
top_chunks = datastore.keyword_search(edited_query, user_id, filters, num_hits)
top_chunks = datastore.keyword_retrieval(edited_query, user_id, filters, num_hits)
if not top_chunks:
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
logger.warning(

View File

@@ -1,6 +1,6 @@
from enum import Enum
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import DocAwareChunk
from danswer.chunking.models import IndexChunk
@@ -15,5 +15,5 @@ class QueryFlow(str, Enum):
class Embedder:
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
raise NotImplementedError

View File

@@ -4,7 +4,8 @@ from uuid import UUID
import numpy
from sentence_transformers import SentenceTransformer # type: ignore
from danswer.chunking.models import EmbeddedIndexChunk
from danswer.chunking.models import ChunkEmbedding
from danswer.chunking.models import DocAwareChunk
from danswer.chunking.models import IndexChunk
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
@@ -12,8 +13,8 @@ from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
from danswer.configs.app_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.datastores.interfaces import DocumentIndex
from danswer.datastores.interfaces import IndexFilter
from danswer.datastores.interfaces import VectorIndex
from danswer.search.models import Embedder
from danswer.search.search_utils import get_default_embedding_model
from danswer.search.search_utils import get_default_reranking_model_ensemble
@@ -69,7 +70,7 @@ def retrieve_ranked_documents(
query: str,
user_id: UUID | None,
filters: list[IndexFilter] | None,
datastore: VectorIndex,
datastore: DocumentIndex,
num_hits: int = NUM_RETURNED_HITS,
num_rerank: int = NUM_RERANKED_RESULTS,
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
@@ -127,12 +128,12 @@ def split_chunk_text_into_mini_chunks(
@log_function_time()
def encode_chunks(
chunks: list[IndexChunk],
chunks: list[DocAwareChunk],
embedding_model: SentenceTransformer | None = None,
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
) -> list[EmbeddedIndexChunk]:
embedded_chunks: list[EmbeddedIndexChunk] = []
) -> list[IndexChunk]:
embedded_chunks: list[IndexChunk] = []
if embedding_model is None:
embedding_model = get_default_embedding_model()
@@ -163,9 +164,12 @@ def encode_chunks(
chunk_embeddings = embeddings[
embedding_ind_start : embedding_ind_start + num_embeddings
]
new_embedded_chunk = EmbeddedIndexChunk(
new_embedded_chunk = IndexChunk(
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
embeddings=chunk_embeddings,
embeddings=ChunkEmbedding(
full_embedding=chunk_embeddings[0],
mini_chunk_embeddings=chunk_embeddings[1:],
),
)
embedded_chunks.append(new_embedded_chunk)
embedding_ind_start += num_embeddings
@@ -174,5 +178,5 @@ def encode_chunks(
class DefaultEmbedder(Embedder):
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
return encode_chunks(chunks)

View File

@@ -10,8 +10,7 @@ from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.datastores.document_index import get_default_document_index
from danswer.db.models import User
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.exceptions import OpenAIKeyMissing
@@ -60,7 +59,7 @@ def semantic_search(
user_id = None if user is None else user.id
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
query, user_id, filters, get_default_document_index(collection=collection)
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
@@ -82,7 +81,7 @@ def keyword_search(
user_id = None if user is None else user.id
ranked_chunks = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
query, user_id, filters, get_default_document_index(collection=collection)
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
@@ -110,6 +109,8 @@ def stream_direct_qa(
logger.debug(f"Received QA query: {question.query}")
logger.debug(f"Query filters: {question.filters}")
if question.use_keyword:
logger.debug(f"User selected Keyword Search")
@log_generator_function_time()
def stream_qa_portions(
@@ -128,12 +129,18 @@ def stream_direct_qa(
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, TypesenseIndex(collection)
query,
user_id,
filters,
get_default_document_index(collection=collection),
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, QdrantIndex(collection)
query,
user_id,
filters,
get_default_document_index(collection=collection),
)
if not ranked_chunks:
logger.debug("No Documents Found")

View File

@@ -1,4 +1,4 @@
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.utils.clients import get_typesense_client
@@ -14,9 +14,7 @@ if __name__ == "__main__":
"page": page_number,
"per_page": per_page,
}
response = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.search(
params
)
response = ts_client.collections[DOCUMENT_INDEX_NAME].documents.search(params)
documents = response.get("hits")
if not documents:
break # if there are no more documents, break out of the loop

View File

@@ -11,17 +11,16 @@ from typesense.exceptions import ObjectNotFound # type: ignore
from alembic import command
from alembic.config import Config
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.app_configs import POSTGRES_DB
from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.app_configs import QDRANT_HOST
from danswer.configs.app_configs import QDRANT_PORT
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
from danswer.datastores.qdrant.indexing import create_qdrant_collection
from danswer.datastores.qdrant.indexing import list_qdrant_collections
from danswer.datastores.qdrant.utils import create_qdrant_collection
from danswer.datastores.qdrant.utils import list_qdrant_collections
from danswer.datastores.typesense.store import create_typesense_collection
from danswer.utils.clients import get_qdrant_client
from danswer.utils.clients import get_typesense_client
@@ -60,13 +59,13 @@ def snapshot_time_compare(snap: SnapshotDescription) -> datetime:
def save_qdrant(filename: str) -> None:
logger.info("Attempting to take Qdrant snapshot")
qdrant_client = get_qdrant_client()
qdrant_client.create_snapshot(collection_name=QDRANT_DEFAULT_COLLECTION)
snapshots = qdrant_client.list_snapshots(collection_name=QDRANT_DEFAULT_COLLECTION)
qdrant_client.create_snapshot(collection_name=DOCUMENT_INDEX_NAME)
snapshots = qdrant_client.list_snapshots(collection_name=DOCUMENT_INDEX_NAME)
valid_snapshots = [snap for snap in snapshots if snap.creation_time is not None]
sorted_snapshots = sorted(valid_snapshots, key=snapshot_time_compare)
last_snapshot_name = sorted_snapshots[-1].name
url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/{last_snapshot_name}"
url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{DOCUMENT_INDEX_NAME}/snapshots/{last_snapshot_name}"
response = requests.get(url, stream=True)
@@ -80,11 +79,11 @@ def save_qdrant(filename: str) -> None:
def load_qdrant(filename: str) -> None:
logger.info("Attempting to load Qdrant snapshot")
if QDRANT_DEFAULT_COLLECTION not in {
if DOCUMENT_INDEX_NAME not in {
collection.name for collection in list_qdrant_collections().collections
}:
create_qdrant_collection(QDRANT_DEFAULT_COLLECTION)
snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/"
create_qdrant_collection(DOCUMENT_INDEX_NAME)
snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{DOCUMENT_INDEX_NAME}/snapshots/"
with open(filename, "rb") as f:
files = {"snapshot": (os.path.basename(filename), f)}
@@ -104,7 +103,7 @@ def load_qdrant(filename: str) -> None:
def save_typesense(filename: str) -> None:
logger.info("Attempting to take Typesense snapshot")
ts_client = get_typesense_client()
all_docs = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.export()
all_docs = ts_client.collections[DOCUMENT_INDEX_NAME].documents.export()
with open(filename, "w") as f:
f.write(all_docs)
@@ -113,14 +112,14 @@ def load_typesense(filename: str) -> None:
logger.info("Attempting to load Typesense snapshot")
ts_client = get_typesense_client()
try:
ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].delete()
ts_client.collections[DOCUMENT_INDEX_NAME].delete()
except ObjectNotFound:
pass
create_typesense_collection(TYPESENSE_DEFAULT_COLLECTION)
create_typesense_collection(DOCUMENT_INDEX_NAME)
with open(filename) as jsonl_file:
ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.import_(
ts_client.collections[DOCUMENT_INDEX_NAME].documents.import_(
jsonl_file.read().encode("utf-8"), {"action": "create"}
)

View File

@@ -7,7 +7,7 @@ from pprint import pprint
import requests
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.constants import SOURCE_TYPE
@@ -95,7 +95,7 @@ if __name__ == "__main__":
query_json = {
"query": query,
"collection": QDRANT_DEFAULT_COLLECTION,
"collection": DOCUMENT_INDEX_NAME,
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
"filters": [{SOURCE_TYPE: source_types}],
}