mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 03:58:30 +02:00
Add Vespa and rework Document Indices (#317)
This commit is contained in:
3
backend/.gitignore
vendored
3
backend/.gitignore
vendored
@@ -6,4 +6,5 @@ api_keys.py
|
||||
*ipynb
|
||||
qdrant-data/
|
||||
typesense-data/
|
||||
.env
|
||||
.env
|
||||
vespa-app.zip
|
||||
|
@@ -0,0 +1,39 @@
|
||||
"""Restructure Document Indices
|
||||
|
||||
Revision ID: 8aabb57f3b49
|
||||
Revises: 5e84129c8be3
|
||||
Create Date: 2023-08-18 21:15:57.629515
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8aabb57f3b49"
|
||||
down_revision = "5e84129c8be3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("chunk")
|
||||
op.execute("DROP TYPE IF EXISTS documentstoretype")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"chunk",
|
||||
sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"document_store_type",
|
||||
postgresql.ENUM("VECTOR", "KEYWORD", name="documentstoretype"),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("document_id", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"], ["document.id"], name="chunk_document_id_fkey"
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", "document_store_type", name="chunk_pkey"),
|
||||
)
|
@@ -17,12 +17,9 @@ from datetime import datetime
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.datastores.interfaces import KeywordIndex
|
||||
from danswer.datastores.interfaces import StoreType
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.datastores.interfaces import VectorIndex
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector_credential_pair import delete_connector_credential_pair
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -31,13 +28,12 @@ from danswer.db.deletion_attempt import delete_deletion_attempts
|
||||
from danswer.db.deletion_attempt import get_deletion_attempts
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair
|
||||
from danswer.db.document import delete_documents_complete
|
||||
from danswer.db.document import get_chunk_ids_for_document_ids
|
||||
from danswer.db.document import (
|
||||
get_chunks_with_single_connector_credential_pair,
|
||||
)
|
||||
from danswer.db.document import (
|
||||
get_document_by_connector_credential_pairs_indexed_by_multiple,
|
||||
)
|
||||
from danswer.db.document import (
|
||||
get_documents_with_single_connector_credential_pair,
|
||||
)
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import delete_index_attempts
|
||||
from danswer.db.models import Credential
|
||||
@@ -50,8 +46,7 @@ logger = setup_logger()
|
||||
|
||||
def _delete_connector_credential_pair(
|
||||
db_session: Session,
|
||||
vector_index: VectorIndex,
|
||||
keyword_index: KeywordIndex,
|
||||
document_index: DocumentIndex,
|
||||
deletion_attempt: DeletionAttempt,
|
||||
) -> int:
|
||||
connector_id = deletion_attempt.connector_id
|
||||
@@ -59,33 +54,24 @@ def _delete_connector_credential_pair(
|
||||
|
||||
def _delete_singly_indexed_docs() -> int:
|
||||
# if a document store entry is only indexed by this connector_credential_pair, delete it
|
||||
num_docs_deleted = 0
|
||||
chunks_to_delete = get_chunks_with_single_connector_credential_pair(
|
||||
docs_to_delete = get_documents_with_single_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
if chunks_to_delete:
|
||||
document_ids: set[str] = set()
|
||||
vector_chunk_ids_to_delete: list[str] = []
|
||||
keyword_chunk_ids_to_delete: list[str] = []
|
||||
for chunk in chunks_to_delete:
|
||||
document_ids.add(chunk.document_id)
|
||||
if chunk.document_store_type == StoreType.KEYWORD:
|
||||
keyword_chunk_ids_to_delete.append(chunk.id)
|
||||
else:
|
||||
vector_chunk_ids_to_delete.append(chunk.id)
|
||||
|
||||
vector_index.delete(ids=vector_chunk_ids_to_delete)
|
||||
keyword_index.delete(ids=keyword_chunk_ids_to_delete)
|
||||
# removes all `Chunk`, `DocumentByConnectorCredentialPair`, and `Document`
|
||||
if docs_to_delete:
|
||||
document_ids = [doc.id for doc in docs_to_delete]
|
||||
document_index.delete(doc_ids=document_ids)
|
||||
|
||||
# removes all `DocumentByConnectorCredentialPair`, and `Document`
|
||||
# rows from the DB
|
||||
delete_documents_complete(
|
||||
db_session=db_session,
|
||||
document_ids=list(document_ids),
|
||||
)
|
||||
num_docs_deleted += len(document_ids)
|
||||
return num_docs_deleted
|
||||
|
||||
return len(docs_to_delete)
|
||||
|
||||
num_docs_deleted = _delete_singly_indexed_docs()
|
||||
logger.info(f"Deleted {num_docs_deleted} documents from document stores")
|
||||
@@ -144,18 +130,10 @@ def _delete_connector_credential_pair(
|
||||
|
||||
# actually perform the updates in the document store
|
||||
update_requests = [
|
||||
UpdateRequest(
|
||||
ids=list(
|
||||
get_chunk_ids_for_document_ids(
|
||||
db_session=db_session, document_ids=document_ids
|
||||
)
|
||||
),
|
||||
allowed_users=list(allowed_users),
|
||||
)
|
||||
UpdateRequest(document_ids=document_ids, allowed_users=list(allowed_users))
|
||||
for allowed_users, document_ids in update_groups.items()
|
||||
]
|
||||
vector_index.update(update_requests=update_requests)
|
||||
keyword_index.update(update_requests=update_requests)
|
||||
document_index.update(update_requests=update_requests)
|
||||
|
||||
# delete the `document_by_connector_credential_pair` rows for the connector / credential pair
|
||||
delete_document_by_connector_credential_pair(
|
||||
@@ -246,8 +224,7 @@ def _run_deletion(db_session: Session) -> None:
|
||||
try:
|
||||
num_docs_deleted = _delete_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
vector_index=QdrantIndex(),
|
||||
keyword_index=TypesenseIndex(),
|
||||
document_index=get_default_document_index(),
|
||||
deletion_attempt=deletion_attempt,
|
||||
)
|
||||
except Exception as e:
|
||||
|
@@ -2,7 +2,7 @@ import abc
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import DocAwareChunk
|
||||
from danswer.configs.app_configs import BLURB_LENGTH
|
||||
from danswer.configs.app_configs import CHUNK_MAX_CHAR_OVERLAP
|
||||
from danswer.configs.app_configs import CHUNK_SIZE
|
||||
@@ -12,7 +12,7 @@ from danswer.connectors.models import Section
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
ChunkFunc = Callable[[Document], list[IndexChunk]]
|
||||
ChunkFunc = Callable[[Document], list[DocAwareChunk]]
|
||||
|
||||
|
||||
def extract_blurb(text: str, blurb_len: int) -> str:
|
||||
@@ -51,7 +51,7 @@ def chunk_large_section(
|
||||
word_overlap: int = CHUNK_WORD_OVERLAP,
|
||||
blurb_len: int = BLURB_LENGTH,
|
||||
chunk_overflow_max: int = CHUNK_MAX_CHAR_OVERLAP,
|
||||
) -> list[IndexChunk]:
|
||||
) -> list[DocAwareChunk]:
|
||||
"""Split large sections into multiple chunks with the final chunk having as much previous overlap as possible.
|
||||
Backtracks word_overlap words, delimited by whitespace, backtrack up to chunk_overflow_max characters max
|
||||
When chunk is finished in forward direction, attempt to finish the word, but only up to chunk_overflow_max
|
||||
@@ -129,7 +129,7 @@ def chunk_large_section(
|
||||
chunks = []
|
||||
for chunk_ind, chunk_str in enumerate(chunk_strs):
|
||||
chunks.append(
|
||||
IndexChunk(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=start_chunk_id + chunk_ind,
|
||||
blurb=blurb,
|
||||
@@ -146,8 +146,8 @@ def chunk_document(
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
subsection_overlap: int = CHUNK_WORD_OVERLAP,
|
||||
blurb_len: int = BLURB_LENGTH,
|
||||
) -> list[IndexChunk]:
|
||||
chunks: list[IndexChunk] = []
|
||||
) -> list[DocAwareChunk]:
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
for section in document.sections:
|
||||
@@ -160,7 +160,7 @@ def chunk_document(
|
||||
if section_length > chunk_size:
|
||||
if chunk_text:
|
||||
chunks.append(
|
||||
IndexChunk(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_len),
|
||||
@@ -191,7 +191,7 @@ def chunk_document(
|
||||
link_offsets[curr_offset_len] = section.link
|
||||
else:
|
||||
chunks.append(
|
||||
IndexChunk(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_len),
|
||||
@@ -206,7 +206,7 @@ def chunk_document(
|
||||
# Once we hit the end, if we're still in the process of building a chunk, add what we have
|
||||
if chunk_text:
|
||||
chunks.append(
|
||||
IndexChunk(
|
||||
DocAwareChunk(
|
||||
source_document=document,
|
||||
chunk_id=len(chunks),
|
||||
blurb=extract_blurb(chunk_text, blurb_len),
|
||||
@@ -220,10 +220,10 @@ def chunk_document(
|
||||
|
||||
class Chunker:
|
||||
@abc.abstractmethod
|
||||
def chunk(self, document: Document) -> list[IndexChunk]:
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DefaultChunker(Chunker):
|
||||
def chunk(self, document: Document) -> list[IndexChunk]:
|
||||
def chunk(self, document: Document) -> list[DocAwareChunk]:
|
||||
return chunk_document(document)
|
||||
|
@@ -14,6 +14,15 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkEmbedding:
|
||||
full_embedding: Embedding
|
||||
mini_chunk_embeddings: list[Embedding]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseChunk:
|
||||
chunk_id: int
|
||||
@@ -26,7 +35,7 @@ class BaseChunk:
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexChunk(BaseChunk):
|
||||
class DocAwareChunk(BaseChunk):
|
||||
# During indexing flow, we have access to a complete "Document"
|
||||
# During inference we only have access to the document id and do not reconstruct the Document
|
||||
source_document: Document
|
||||
@@ -39,8 +48,8 @@ class IndexChunk(BaseChunk):
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmbeddedIndexChunk(IndexChunk):
|
||||
embeddings: list[list[float]]
|
||||
class IndexChunk(DocAwareChunk):
|
||||
embeddings: ChunkEmbedding
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
#####
|
||||
@@ -62,20 +64,28 @@ MASK_CREDENTIAL_PREFIX = (
|
||||
#####
|
||||
# DB Configs
|
||||
#####
|
||||
DOCUMENT_INDEX_NAME = "danswer_index" # Shared by vector/keyword indices
|
||||
# Vespa is now the default document index store for both keyword and vector
|
||||
DOCUMENT_INDEX_TYPE = os.environ.get(
|
||||
"DOCUMENT_INDEX_TYPE", DocumentIndexType.SPLIT.value
|
||||
)
|
||||
VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
# The default below is for dockerized deployment
|
||||
VESPA_DEPLOYMENT_ZIP = (
|
||||
os.environ.get("VESPA_DEPLOYMENT_ZIP") or "/app/danswer/vespa-app.zip"
|
||||
)
|
||||
# Qdrant is Semantic Search Vector DB
|
||||
# Url / Key are used to connect to a remote Qdrant instance
|
||||
QDRANT_URL = os.environ.get("QDRANT_URL", "")
|
||||
QDRANT_API_KEY = os.environ.get("QDRANT_API_KEY", "")
|
||||
# Host / Port are used for connecting to local Qdrant instance
|
||||
QDRANT_HOST = os.environ.get("QDRANT_HOST", "localhost")
|
||||
QDRANT_HOST = os.environ.get("QDRANT_HOST") or "localhost"
|
||||
QDRANT_PORT = 6333
|
||||
QDRANT_DEFAULT_COLLECTION = os.environ.get("QDRANT_DEFAULT_COLLECTION", "danswer_index")
|
||||
# Typesense is the Keyword Search Engine
|
||||
TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST", "localhost")
|
||||
TYPESENSE_HOST = os.environ.get("TYPESENSE_HOST") or "localhost"
|
||||
TYPESENSE_PORT = 8108
|
||||
TYPESENSE_DEFAULT_COLLECTION = os.environ.get(
|
||||
"TYPESENSE_DEFAULT_COLLECTION", "danswer_index"
|
||||
)
|
||||
TYPESENSE_API_KEY = os.environ.get("TYPESENSE_API_KEY", "")
|
||||
# Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder)
|
||||
INDEX_BATCH_SIZE = 16
|
||||
|
@@ -9,6 +9,7 @@ SOURCE_LINKS = "source_links"
|
||||
SOURCE_LINK = "link"
|
||||
SEMANTIC_IDENTIFIER = "semantic_identifier"
|
||||
SECTION_CONTINUATION = "section_continuation"
|
||||
EMBEDDINGS = "embeddings"
|
||||
ALLOWED_USERS = "allowed_users"
|
||||
ALLOWED_GROUPS = "allowed_groups"
|
||||
METADATA = "metadata"
|
||||
@@ -16,6 +17,7 @@ GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
HTML_SEPARATOR = "\n"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
QUOTE = "quote"
|
||||
BOOST = "boost"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
@@ -35,6 +37,11 @@ class DocumentSource(str, Enum):
|
||||
LINEAR = "linear"
|
||||
|
||||
|
||||
class DocumentIndexType(str, Enum):
|
||||
COMBINED = "combined" # Vespa
|
||||
SPLIT = "split" # Typesense + Qdrant
|
||||
|
||||
|
||||
class DanswerGenAIModel(str, Enum):
|
||||
"""This represents the internal Danswer GenAI model which determines the class that is used
|
||||
to generate responses to the user query. Different models/services require different internal
|
||||
|
@@ -24,7 +24,7 @@ from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.utils import batch_generator
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
@@ -5,7 +5,6 @@ from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
@@ -16,7 +15,7 @@ DEFAULT_BATCH_SIZE = 30
|
||||
|
||||
|
||||
def get_uuid_from_chunk(
|
||||
chunk: IndexChunk | EmbeddedIndexChunk | InferenceChunk, mini_chunk_ind: int = 0
|
||||
chunk: IndexChunk | InferenceChunk, mini_chunk_ind: int = 0
|
||||
) -> uuid.UUID:
|
||||
doc_str = (
|
||||
chunk.document_id
|
||||
@@ -58,7 +57,7 @@ def _add_if_not_exists(l: list[T], item: T) -> list[T]:
|
||||
|
||||
|
||||
def update_cross_connector_document_metadata_map(
|
||||
chunk: IndexChunk | EmbeddedIndexChunk,
|
||||
chunk: IndexChunk,
|
||||
cross_connector_document_metadata_map: dict[str, CrossConnectorDocumentMetadata],
|
||||
doc_store_cross_connector_document_metadata_fetch_fn: CrossConnectorDocumentMetadataFetchCallable,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
|
111
backend/danswer/datastores/document_index.py
Normal file
111
backend/danswer/datastores/document_index.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Type
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_TYPE
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import KeywordIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.datastores.interfaces import VectorIndex
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.datastores.vespa.store import VespaIndex
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class SplitDocumentIndex(DocumentIndex):
|
||||
"""A Document index that uses 2 separate storages
|
||||
one for keyword index and one for vector index"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str | None = DOCUMENT_INDEX_NAME,
|
||||
keyword_index_cls: Type[KeywordIndex] = TypesenseIndex,
|
||||
vector_index_cls: Type[VectorIndex] = QdrantIndex,
|
||||
) -> None:
|
||||
index_name = index_name or DOCUMENT_INDEX_NAME
|
||||
self.keyword_index = keyword_index_cls(index_name)
|
||||
self.vector_index = vector_index_cls(index_name)
|
||||
|
||||
def ensure_indices_exist(self) -> None:
|
||||
self.keyword_index.ensure_indices_exist()
|
||||
self.vector_index.ensure_indices_exist()
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
keyword_index_result = self.keyword_index.index(chunks, index_attempt_metadata)
|
||||
vector_index_result = self.vector_index.index(chunks, index_attempt_metadata)
|
||||
if keyword_index_result != vector_index_result:
|
||||
logger.error(
|
||||
f"Inconsistent document indexing:\n"
|
||||
f"Keyword: {keyword_index_result}\n"
|
||||
f"Vector: {vector_index_result}"
|
||||
)
|
||||
return keyword_index_result.union(vector_index_result)
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
self.keyword_index.update(update_requests)
|
||||
self.vector_index.update(update_requests)
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
self.keyword_index.delete(doc_ids)
|
||||
self.vector_index.delete(doc_ids)
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
) -> list[InferenceChunk]:
|
||||
return self.keyword_index.keyword_retrieval(
|
||||
query, user_id, filters, num_to_retrieve
|
||||
)
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
) -> list[InferenceChunk]:
|
||||
return self.vector_index.semantic_retrieval(
|
||||
query, user_id, filters, num_to_retrieve, distance_cutoff
|
||||
)
|
||||
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Currently results from vector and keyword indices are not combined post retrieval.
|
||||
This may change in the future but for now, the default behavior is to use semantic search
|
||||
which should be a more flexible/powerful search option"""
|
||||
return self.semantic_retrieval(query, user_id, filters, num_to_retrieve)
|
||||
|
||||
|
||||
def get_default_document_index(
|
||||
collection: str | None = DOCUMENT_INDEX_NAME, index_type: str = DOCUMENT_INDEX_TYPE
|
||||
) -> DocumentIndex:
|
||||
if index_type == DocumentIndexType.COMBINED.value:
|
||||
return VespaIndex() # Can't specify collection without modifying the deployment
|
||||
elif index_type == DocumentIndexType.SPLIT.value:
|
||||
return SplitDocumentIndex(index_name=collection)
|
||||
else:
|
||||
raise ValueError("Invalid document index type selected")
|
@@ -6,15 +6,13 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chunking.chunk import Chunker
|
||||
from danswer.chunking.chunk import DefaultChunker
|
||||
from danswer.chunking.models import DocAwareChunk
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.interfaces import ChunkInsertionRecord
|
||||
from danswer.datastores.interfaces import ChunkMetadata
|
||||
from danswer.datastores.interfaces import KeywordIndex
|
||||
from danswer.datastores.interfaces import StoreType
|
||||
from danswer.datastores.interfaces import VectorIndex
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import DocumentMetadata
|
||||
from danswer.db.document import upsert_documents_complete
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.search.models import Embedder
|
||||
@@ -32,20 +30,17 @@ class IndexingPipelineProtocol(Protocol):
|
||||
|
||||
|
||||
def _upsert_insertion_records(
|
||||
insertion_records: list[ChunkInsertionRecord],
|
||||
insertion_records: set[DocumentInsertionRecord],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
document_store_type: StoreType,
|
||||
) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as session:
|
||||
upsert_documents_complete(
|
||||
db_session=session,
|
||||
document_metadata_batch=[
|
||||
ChunkMetadata(
|
||||
DocumentMetadata(
|
||||
connector_id=index_attempt_metadata.connector_id,
|
||||
credential_id=index_attempt_metadata.credential_id,
|
||||
document_id=insertion_record.document_id,
|
||||
store_id=insertion_record.store_id,
|
||||
document_store_type=document_store_type,
|
||||
)
|
||||
for insertion_record in insertion_records
|
||||
],
|
||||
@@ -53,7 +48,7 @@ def _upsert_insertion_records(
|
||||
|
||||
|
||||
def _get_net_new_documents(
|
||||
insertion_records: list[ChunkInsertionRecord],
|
||||
insertion_records: list[DocumentInsertionRecord],
|
||||
) -> int:
|
||||
net_new_documents = 0
|
||||
seen_documents: set[str] = set()
|
||||
@@ -71,106 +66,61 @@ def _indexing_pipeline(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
embedder: Embedder,
|
||||
vector_index: VectorIndex,
|
||||
keyword_index: KeywordIndex,
|
||||
document_index: DocumentIndex,
|
||||
documents: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> tuple[int, int]:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
# Chunk the documents into reasonably-sized chunks so they can fit into the
|
||||
# context-sizes of our embedding models
|
||||
chunks = list(chain(*[chunker.chunk(document=document) for document in documents]))
|
||||
chunks: list[DocAwareChunk] = list(
|
||||
chain(*[chunker.chunk(document=document) for document in documents])
|
||||
)
|
||||
logger.debug(
|
||||
f"Indexing the following chunks: {[chunk.to_short_descriptor() for chunk in chunks]}"
|
||||
)
|
||||
|
||||
# Insert the chunks into our Keyword document store + store records of these
|
||||
# documents / chunks into our database
|
||||
# TODO keyword indexing can occur at same time as embedding
|
||||
keyword_store_insertion_records = keyword_index.index(
|
||||
chunks=chunks, index_attempt_metadata=index_attempt_metadata
|
||||
)
|
||||
logger.debug(f"Keyword store insertion records: {keyword_store_insertion_records}")
|
||||
# TODO (chris): remove this try/except after issue with null document_id is resolved
|
||||
try:
|
||||
_upsert_insertion_records(
|
||||
insertion_records=keyword_store_insertion_records,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
document_store_type=StoreType.KEYWORD,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to upsert insertion records from keyword index for documents: "
|
||||
f"{[document.to_short_descriptor() for document in documents]}, "
|
||||
f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks]},"
|
||||
f"for insertion records: {keyword_store_insertion_records}"
|
||||
)
|
||||
raise e
|
||||
net_doc_count_keyword = _get_net_new_documents(
|
||||
insertion_records=keyword_store_insertion_records
|
||||
)
|
||||
|
||||
# Embed the chunks and then insert them into our Vector document store
|
||||
# + store records of these documents / chunks into our database
|
||||
chunks_with_embeddings = embedder.embed(chunks=chunks)
|
||||
vector_store_insertion_records = vector_index.index(
|
||||
|
||||
# A document will not be spread across different batches, so all the documents with chunks in this set, are fully
|
||||
# represented by the chunks in this set
|
||||
insertion_records = document_index.index(
|
||||
chunks=chunks_with_embeddings, index_attempt_metadata=index_attempt_metadata
|
||||
)
|
||||
logger.debug(f"Vector store insertion records: {keyword_store_insertion_records}")
|
||||
|
||||
# TODO (chris): remove this try/except after issue with null document_id is resolved
|
||||
try:
|
||||
_upsert_insertion_records(
|
||||
insertion_records=vector_store_insertion_records,
|
||||
insertion_records=insertion_records,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
document_store_type=StoreType.VECTOR,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to upsert insertion records from vector index for documents: "
|
||||
f"{[document.to_short_descriptor() for document in documents]}, "
|
||||
f"for chunks: {[chunk.to_short_descriptor() for chunk in chunks_with_embeddings]}"
|
||||
f"for insertion records: {vector_store_insertion_records}"
|
||||
f"for insertion records: {insertion_records}"
|
||||
)
|
||||
raise e
|
||||
net_doc_count_vector = _get_net_new_documents(
|
||||
insertion_records=vector_store_insertion_records
|
||||
)
|
||||
|
||||
if net_doc_count_vector != net_doc_count_keyword:
|
||||
logger.warning("Document count change from keyword/vector indices don't align")
|
||||
net_new_docs = max(net_doc_count_keyword, net_doc_count_vector)
|
||||
logger.info(f"Indexed {net_new_docs} new documents")
|
||||
return net_new_docs, len(chunks)
|
||||
return len(insertion_records), len(chunks)
|
||||
|
||||
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
chunker: Chunker | None = None,
|
||||
embedder: Embedder | None = None,
|
||||
vector_index: VectorIndex | None = None,
|
||||
keyword_index: KeywordIndex | None = None,
|
||||
document_index: DocumentIndex | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them.
|
||||
"""Builds a pipline which takes in a list (batch) of docs and indexes them."""
|
||||
chunker = chunker or DefaultChunker()
|
||||
|
||||
Default uses _ chunker, _ embedder, and qdrant for the datastore"""
|
||||
if chunker is None:
|
||||
chunker = DefaultChunker()
|
||||
embedder = embedder or DefaultEmbedder()
|
||||
|
||||
if embedder is None:
|
||||
embedder = DefaultEmbedder()
|
||||
|
||||
if vector_index is None:
|
||||
vector_index = QdrantIndex()
|
||||
|
||||
if keyword_index is None:
|
||||
keyword_index = TypesenseIndex()
|
||||
document_index = document_index or get_default_document_index()
|
||||
|
||||
return partial(
|
||||
_indexing_pipeline,
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
vector_index=vector_index,
|
||||
keyword_index=keyword_index,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
@@ -1,62 +1,63 @@
|
||||
import abc
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
from uuid import UUID
|
||||
|
||||
from danswer.chunking.models import BaseChunk
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
|
||||
|
||||
T = TypeVar("T", bound=BaseChunk)
|
||||
IndexFilter = dict[str, str | list[str] | None]
|
||||
|
||||
|
||||
class StoreType(str, Enum):
|
||||
VECTOR = "vector"
|
||||
KEYWORD = "keyword"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkInsertionRecord:
|
||||
@dataclass(frozen=True)
|
||||
class DocumentInsertionRecord:
|
||||
document_id: str
|
||||
store_id: str
|
||||
already_existed: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkMetadata:
|
||||
class DocumentMetadata:
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
document_id: str
|
||||
store_id: str
|
||||
document_store_type: StoreType
|
||||
|
||||
|
||||
@dataclass
|
||||
class UpdateRequest:
|
||||
ids: list[str]
|
||||
"""For all document_ids, update the allowed_users and the boost to the new value
|
||||
ignore if None"""
|
||||
|
||||
document_ids: list[str]
|
||||
# all other fields will be left alone
|
||||
allowed_users: list[str]
|
||||
allowed_users: list[str] | None = None
|
||||
boost: int | None = None
|
||||
|
||||
|
||||
class Indexable(Generic[T], abc.ABC):
|
||||
class Verifiable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def __init__(self, index_name: str, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.index_name = index_name
|
||||
|
||||
@abc.abstractmethod
|
||||
def ensure_indices_exist(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self, chunks: list[T], index_attempt_metadata: IndexAttemptMetadata
|
||||
) -> list[ChunkInsertionRecord]:
|
||||
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""Indexes document chunks into the Document Index and return the IDs of all the documents indexed"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Deletable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def delete(self, ids: list[str]) -> None:
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
"""Removes the specified documents from the Index"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -64,11 +65,23 @@ class Deletable(abc.ABC):
|
||||
class Updatable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
"""Updates metadata for the specified documents in the Index"""
|
||||
"""Updates metadata for the specified documents sets in the Index"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VectorIndex(Indexable[EmbeddedIndexChunk], Deletable, Updatable, abc.ABC):
|
||||
class KeywordCapable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class VectorCapable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
@@ -76,13 +89,14 @@ class VectorIndex(Indexable[EmbeddedIndexChunk], Deletable, Updatable, abc.ABC):
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KeywordIndex(Indexable[IndexChunk], Deletable, Updatable, abc.ABC):
|
||||
class HybridCapable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def keyword_search(
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
@@ -90,3 +104,27 @@ class KeywordIndex(Indexable[IndexChunk], Deletable, Updatable, abc.ABC):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseIndex(Verifiable, Indexable, Updatable, Deletable, abc.ABC):
|
||||
"""All basic functionalities excluding a specific retrieval approach
|
||||
Indices need to be able to
|
||||
- Check that the index exists with a schema definition
|
||||
- Can index documents
|
||||
- Can delete documents
|
||||
- Can update document metadata (such as access permissions and document specific boost)
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class KeywordIndex(KeywordCapable, BaseIndex, abc.ABC):
|
||||
pass
|
||||
|
||||
|
||||
class VectorIndex(VectorCapable, BaseIndex, abc.ABC):
|
||||
pass
|
||||
|
||||
|
||||
class DocumentIndex(KeywordCapable, VectorCapable, HybridCapable, BaseIndex, abc.ABC):
|
||||
pass
|
||||
|
@@ -1,18 +1,14 @@
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http import models
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from qdrant_client.http.models.models import UpdateResult
|
||||
from qdrant_client.models import CollectionsResponse
|
||||
from qdrant_client.models import Distance
|
||||
from qdrant_client.models import PointStruct
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.configs.constants import ALLOWED_GROUPS
|
||||
from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import BLURB
|
||||
@@ -24,7 +20,6 @@ from danswer.configs.constants import SECTION_CONTINUATION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
|
||||
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
|
||||
@@ -32,7 +27,7 @@ from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.datastore_utils import (
|
||||
update_cross_connector_document_metadata_map,
|
||||
)
|
||||
from danswer.datastores.interfaces import ChunkInsertionRecord
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.qdrant.utils import get_payload_from_record
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -41,22 +36,6 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def list_qdrant_collections() -> CollectionsResponse:
|
||||
return get_qdrant_client().get_collections()
|
||||
|
||||
|
||||
def create_qdrant_collection(
|
||||
collection_name: str, embedding_dim: int = DOC_EMBEDDING_DIM
|
||||
) -> None:
|
||||
logger.info(f"Attempting to create collection {collection_name}")
|
||||
result = get_qdrant_client().create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("Could not create Qdrant collection")
|
||||
|
||||
|
||||
def get_qdrant_document_cross_connector_metadata(
|
||||
doc_chunk_id: str, collection_name: str, q_client: QdrantClient
|
||||
) -> CrossConnectorDocumentMetadata | None:
|
||||
@@ -113,18 +92,18 @@ def delete_qdrant_doc_chunks(
|
||||
|
||||
|
||||
def index_qdrant_chunks(
|
||||
chunks: list[EmbeddedIndexChunk],
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
collection: str,
|
||||
client: QdrantClient | None = None,
|
||||
batch_upsert: bool = True,
|
||||
) -> list[ChunkInsertionRecord]:
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
# Public documents will have the PUBLIC string in ALLOWED_USERS
|
||||
# If credential that kicked this off has no user associated, either Auth is off or the doc is public
|
||||
q_client: QdrantClient = client if client else get_qdrant_client()
|
||||
|
||||
point_structs: list[PointStruct] = []
|
||||
insertion_records: list[ChunkInsertionRecord] = []
|
||||
insertion_records: set[DocumentInsertionRecord] = set()
|
||||
# Maps document id to dict of whitelists for users/groups each containing list of users/groups as strings
|
||||
cross_connector_document_metadata_map: dict[
|
||||
str, CrossConnectorDocumentMetadata
|
||||
@@ -152,12 +131,15 @@ def index_qdrant_chunks(
|
||||
delete_qdrant_doc_chunks(document.id, collection, q_client)
|
||||
already_existing_documents.add(document.id)
|
||||
|
||||
for minichunk_ind, embedding in enumerate(chunk.embeddings):
|
||||
embeddings = chunk.embeddings
|
||||
vector_list = [embeddings.full_embedding]
|
||||
vector_list.extend(embeddings.mini_chunk_embeddings)
|
||||
|
||||
for minichunk_ind, embedding in enumerate(vector_list):
|
||||
qdrant_id = str(get_uuid_from_chunk(chunk, minichunk_ind))
|
||||
insertion_records.append(
|
||||
ChunkInsertionRecord(
|
||||
insertion_records.add(
|
||||
DocumentInsertionRecord(
|
||||
document_id=document.id,
|
||||
store_id=qdrant_id,
|
||||
already_existed=document.id in already_existing_documents,
|
||||
)
|
||||
)
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.exceptions import ResponseHandlingException
|
||||
from qdrant_client.http.exceptions import UnexpectedResponse
|
||||
from qdrant_client.http.models import FieldCondition
|
||||
@@ -8,22 +8,25 @@ from qdrant_client.http.models import Filter
|
||||
from qdrant_client.http.models import MatchAny
|
||||
from qdrant_client.http.models import MatchValue
|
||||
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.utils import batch_generator
|
||||
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.interfaces import ChunkInsertionRecord
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.datastores.interfaces import VectorIndex
|
||||
from danswer.datastores.qdrant.indexing import index_qdrant_chunks
|
||||
from danswer.datastores.qdrant.utils import create_qdrant_collection
|
||||
from danswer.datastores.qdrant.utils import list_qdrant_collections
|
||||
from danswer.search.search_utils import get_default_embedding_model
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
@@ -81,16 +84,51 @@ def _build_qdrant_filters(
|
||||
return filter_conditions
|
||||
|
||||
|
||||
def _get_points_from_document_ids(
|
||||
document_ids: list[str],
|
||||
collection: str,
|
||||
client: QdrantClient,
|
||||
) -> list[int | str]:
|
||||
offset: int | str | None = 0
|
||||
chunk_ids = []
|
||||
while offset is not None:
|
||||
matches, offset = client.scroll(
|
||||
collection_name=collection,
|
||||
scroll_filter=Filter(
|
||||
must=[FieldCondition(key=DOCUMENT_ID, match=MatchAny(any=document_ids))]
|
||||
),
|
||||
limit=_BATCH_SIZE,
|
||||
with_payload=False,
|
||||
with_vectors=False,
|
||||
offset=offset,
|
||||
)
|
||||
for match in matches:
|
||||
chunk_ids.append(match.id)
|
||||
|
||||
return chunk_ids
|
||||
|
||||
|
||||
class QdrantIndex(VectorIndex):
|
||||
def __init__(self, collection: str = QDRANT_DEFAULT_COLLECTION) -> None:
|
||||
self.collection = collection
|
||||
def __init__(self, index_name: str = DOCUMENT_INDEX_NAME) -> None:
|
||||
# In Qdrant, the vector index is referred to as a collection
|
||||
self.collection = index_name
|
||||
self.client = get_qdrant_client()
|
||||
|
||||
def ensure_indices_exist(self) -> None:
|
||||
if self.collection not in {
|
||||
collection.name
|
||||
for collection in list_qdrant_collections(self.client).collections
|
||||
}:
|
||||
logger.info(f"Creating Qdrant collection with name: {self.collection}")
|
||||
create_qdrant_collection(
|
||||
collection_name=self.collection, q_client=self.client
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[EmbeddedIndexChunk],
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> list[ChunkInsertionRecord]:
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
return index_qdrant_chunks(
|
||||
chunks=chunks,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
@@ -98,6 +136,35 @@ class QdrantIndex(VectorIndex):
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(
|
||||
f"Updating {len(update_requests)} documents' allowed_users in Qdrant"
|
||||
)
|
||||
for update_request in update_requests:
|
||||
for doc_id_batch in batch_generator(
|
||||
items=update_request.document_ids,
|
||||
batch_size=_BATCH_SIZE,
|
||||
):
|
||||
chunk_ids = _get_points_from_document_ids(
|
||||
doc_id_batch, self.collection, self.client
|
||||
)
|
||||
self.client.set_payload(
|
||||
collection_name=self.collection,
|
||||
payload={ALLOWED_USERS: update_request.allowed_users},
|
||||
points=chunk_ids,
|
||||
)
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Qdrant")
|
||||
for doc_id_batch in batch_generator(items=doc_ids, batch_size=_BATCH_SIZE):
|
||||
chunk_ids = _get_points_from_document_ids(
|
||||
doc_id_batch, self.collection, self.client
|
||||
)
|
||||
self.client.delete(
|
||||
collection_name=self.collection,
|
||||
points_selector=chunk_ids,
|
||||
)
|
||||
|
||||
@log_function_time()
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
@@ -105,8 +172,8 @@ class QdrantIndex(VectorIndex):
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
page_size: int = NUM_RETURNED_HITS,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
page_size: int = NUM_RETURNED_HITS,
|
||||
) -> list[InferenceChunk]:
|
||||
query_embedding = get_default_embedding_model().encode(
|
||||
query
|
||||
@@ -156,45 +223,3 @@ class QdrantIndex(VectorIndex):
|
||||
found_chunk_uuids.add(inf_chunk_id)
|
||||
|
||||
return found_inference_chunks
|
||||
|
||||
def delete(self, ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(ids)} documents from Qdrant")
|
||||
for id_batch in batch_generator(items=ids, batch_size=_BATCH_SIZE):
|
||||
self.client.delete(
|
||||
collection_name=self.collection,
|
||||
points_selector=id_batch,
|
||||
)
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(
|
||||
f"Updating {len(update_requests)} documents' allowed_users in Qdrant"
|
||||
)
|
||||
for update_request in update_requests:
|
||||
for id_batch in batch_generator(
|
||||
items=update_request.ids,
|
||||
batch_size=_BATCH_SIZE,
|
||||
):
|
||||
self.client.set_payload(
|
||||
collection_name=self.collection,
|
||||
payload={ALLOWED_USERS: update_request.allowed_users},
|
||||
points=id_batch,
|
||||
)
|
||||
|
||||
def get_from_id(self, object_id: str) -> InferenceChunk | None:
|
||||
matches, _ = self.client.scroll(
|
||||
collection_name=self.collection,
|
||||
scroll_filter=Filter(
|
||||
must=[FieldCondition(key="id", match=MatchValue(value=object_id))]
|
||||
),
|
||||
)
|
||||
if not matches:
|
||||
return None
|
||||
|
||||
if len(matches) > 1:
|
||||
logger.error(f"Found multiple matches for {logger}: {matches}")
|
||||
|
||||
match = matches[0]
|
||||
if not match.payload:
|
||||
return None
|
||||
|
||||
return InferenceChunk.from_dict(match.payload)
|
||||
|
@@ -1,6 +1,39 @@
|
||||
from typing import Any
|
||||
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.http.models import Record
|
||||
from qdrant_client.models import CollectionsResponse
|
||||
from qdrant_client.models import Distance
|
||||
from qdrant_client.models import VectorParams
|
||||
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def list_qdrant_collections(
|
||||
q_client: QdrantClient | None = None,
|
||||
) -> CollectionsResponse:
|
||||
q_client = q_client or get_qdrant_client()
|
||||
return q_client.get_collections()
|
||||
|
||||
|
||||
def create_qdrant_collection(
|
||||
collection_name: str,
|
||||
embedding_dim: int = DOC_EMBEDDING_DIM,
|
||||
q_client: QdrantClient | None = None,
|
||||
) -> None:
|
||||
q_client = q_client or get_qdrant_client()
|
||||
logger.info(f"Attempting to create collection {collection_name}")
|
||||
result = q_client.create_collection(
|
||||
collection_name=collection_name,
|
||||
vectors_config=VectorParams(size=embedding_dim, distance=Distance.COSINE),
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("Could not create Qdrant collection")
|
||||
|
||||
|
||||
def get_payload_from_record(record: Record, is_required: bool = True) -> dict[str, Any]:
|
||||
|
@@ -7,10 +7,10 @@ from uuid import UUID
|
||||
import typesense # type: ignore
|
||||
from typesense.exceptions import ObjectNotFound # type: ignore
|
||||
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import ALLOWED_GROUPS
|
||||
from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import BLURB
|
||||
@@ -24,17 +24,17 @@ from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.connectors.utils import batch_generator
|
||||
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
|
||||
from danswer.datastores.datastore_utils import DEFAULT_BATCH_SIZE
|
||||
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.datastore_utils import (
|
||||
update_cross_connector_document_metadata_map,
|
||||
)
|
||||
from danswer.datastores.interfaces import ChunkInsertionRecord
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import KeywordIndex
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.clients import get_typesense_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -45,19 +45,8 @@ logger = setup_logger()
|
||||
_BATCH_SIZE = 200
|
||||
|
||||
|
||||
def check_typesense_collection_exist(
|
||||
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
|
||||
) -> bool:
|
||||
client = get_typesense_client()
|
||||
try:
|
||||
client.collections[collection_name].retrieve()
|
||||
except ObjectNotFound:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def create_typesense_collection(
|
||||
collection_name: str = TYPESENSE_DEFAULT_COLLECTION,
|
||||
collection_name: str = DOCUMENT_INDEX_NAME,
|
||||
) -> None:
|
||||
ts_client = get_typesense_client()
|
||||
collection_schema = {
|
||||
@@ -81,7 +70,18 @@ def create_typesense_collection(
|
||||
ts_client.collections.create(collection_schema)
|
||||
|
||||
|
||||
def get_typesense_document_cross_connector_metadata(
|
||||
def _check_typesense_collection_exist(
|
||||
collection_name: str = DOCUMENT_INDEX_NAME,
|
||||
) -> bool:
|
||||
client = get_typesense_client()
|
||||
try:
|
||||
client.collections[collection_name].retrieve()
|
||||
except ObjectNotFound:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_typesense_document_cross_connector_metadata(
|
||||
doc_chunk_id: str, collection_name: str, ts_client: typesense.Client
|
||||
) -> CrossConnectorDocumentMetadata | None:
|
||||
"""Returns whether the document already exists and the users/group whitelists"""
|
||||
@@ -115,7 +115,7 @@ def get_typesense_document_cross_connector_metadata(
|
||||
)
|
||||
|
||||
|
||||
def delete_typesense_doc_chunks(
|
||||
def _delete_typesense_doc_chunks(
|
||||
document_id: str, collection_name: str, ts_client: typesense.Client
|
||||
) -> bool:
|
||||
doc_id_filter = {"filter_by": f"{DOCUMENT_ID}:'{document_id}'"}
|
||||
@@ -126,16 +126,16 @@ def delete_typesense_doc_chunks(
|
||||
return del_result["num_deleted"] != 0
|
||||
|
||||
|
||||
def index_typesense_chunks(
|
||||
chunks: list[IndexChunk | EmbeddedIndexChunk],
|
||||
def _index_typesense_chunks(
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
collection: str,
|
||||
client: typesense.Client | None = None,
|
||||
batch_upsert: bool = True,
|
||||
) -> list[ChunkInsertionRecord]:
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
ts_client: typesense.Client = client if client else get_typesense_client()
|
||||
|
||||
insertion_records: list[ChunkInsertionRecord] = []
|
||||
insertion_records: set[DocumentInsertionRecord] = set()
|
||||
new_documents: list[dict[str, Any]] = []
|
||||
cross_connector_document_metadata_map: dict[
|
||||
str, CrossConnectorDocumentMetadata
|
||||
@@ -151,7 +151,7 @@ def index_typesense_chunks(
|
||||
chunk=chunk,
|
||||
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
|
||||
doc_store_cross_connector_document_metadata_fetch_fn=partial(
|
||||
get_typesense_document_cross_connector_metadata,
|
||||
_get_typesense_document_cross_connector_metadata,
|
||||
collection_name=collection,
|
||||
ts_client=ts_client,
|
||||
),
|
||||
@@ -160,14 +160,13 @@ def index_typesense_chunks(
|
||||
|
||||
if should_delete_doc:
|
||||
# Processing the first chunk of the doc and the doc exists
|
||||
delete_typesense_doc_chunks(document.id, collection, ts_client)
|
||||
_delete_typesense_doc_chunks(document.id, collection, ts_client)
|
||||
already_existing_documents.add(document.id)
|
||||
|
||||
typesense_id = str(get_uuid_from_chunk(chunk))
|
||||
insertion_records.append(
|
||||
ChunkInsertionRecord(
|
||||
insertion_records.add(
|
||||
DocumentInsertionRecord(
|
||||
document_id=document.id,
|
||||
store_id=typesense_id,
|
||||
already_existed=document.id in already_existing_documents,
|
||||
)
|
||||
)
|
||||
@@ -250,49 +249,55 @@ def _build_typesense_filters(
|
||||
|
||||
|
||||
class TypesenseIndex(KeywordIndex):
|
||||
def __init__(self, collection: str = TYPESENSE_DEFAULT_COLLECTION) -> None:
|
||||
self.collection = collection
|
||||
def __init__(self, index_name: str = DOCUMENT_INDEX_NAME) -> None:
|
||||
# In Typesense, the document index is referred to as a collection
|
||||
self.collection = index_name
|
||||
self.ts_client = get_typesense_client()
|
||||
|
||||
def ensure_indices_exist(self) -> None:
|
||||
if not _check_typesense_collection_exist(self.collection):
|
||||
logger.info(f"Creating Typesense collection with name: {self.collection}")
|
||||
create_typesense_collection(collection_name=self.collection)
|
||||
|
||||
def index(
|
||||
self, chunks: list[IndexChunk], index_attempt_metadata: IndexAttemptMetadata
|
||||
) -> list[ChunkInsertionRecord]:
|
||||
return index_typesense_chunks(
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
return _index_typesense_chunks(
|
||||
chunks=chunks,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
collection=self.collection,
|
||||
client=self.ts_client,
|
||||
)
|
||||
|
||||
def delete(self, ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(ids)} documents from Typesense")
|
||||
for id_batch in batch_generator(items=ids, batch_size=_BATCH_SIZE):
|
||||
self.ts_client.collections[self.collection].documents.delete(
|
||||
{"filter_by": f'id:[{",".join(id_batch)}]'}
|
||||
)
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(
|
||||
f"Updating {len(update_requests)} documents' allowed_users in Typesense"
|
||||
)
|
||||
for update_request in update_requests:
|
||||
for id_batch in batch_generator(
|
||||
items=update_request.ids, batch_size=_BATCH_SIZE
|
||||
items=update_request.document_ids, batch_size=_BATCH_SIZE
|
||||
):
|
||||
typesense_updates = [
|
||||
{"id": doc_id, ALLOWED_USERS: update_request.allowed_users}
|
||||
{DOCUMENT_ID: doc_id, ALLOWED_USERS: update_request.allowed_users}
|
||||
for doc_id in id_batch
|
||||
]
|
||||
self.ts_client.collections[self.collection].documents.import_(
|
||||
typesense_updates, {"action": "update"}
|
||||
)
|
||||
|
||||
def keyword_search(
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Typesense")
|
||||
for id_batch in batch_generator(items=doc_ids, batch_size=_BATCH_SIZE):
|
||||
self.ts_client.collections[self.collection].documents.delete(
|
||||
{"filter_by": f'{DOCUMENT_ID}:[{",".join(id_batch)}]'}
|
||||
)
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
) -> list[InferenceChunk]:
|
||||
filters_str = _build_typesense_filters(user_id, filters)
|
||||
|
||||
|
0
backend/danswer/datastores/vespa/__init__.py
Normal file
0
backend/danswer/datastores/vespa/__init__.py
Normal file
@@ -0,0 +1,89 @@
|
||||
schema danswer_chunk {
|
||||
document danswer_chunk {
|
||||
# Not to be confused with the UUID generated for this chunk which is called documentid by default
|
||||
field document_id type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field chunk_id type int {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field blurb type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
# Can separate out title in the future and give heavier bm-25 weighting
|
||||
# Need to consider that not every doc has a separable title (ie. slack message)
|
||||
# Set summary options to enable bolding
|
||||
field content type string {
|
||||
indexing: summary | attribute | index
|
||||
index: enable-bm25
|
||||
}
|
||||
# https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it
|
||||
field source_type type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
# Can also index links https://docs.vespa.ai/en/reference/schema-reference.html#attribute
|
||||
# URL type matching
|
||||
field source_links type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field semantic_identifier type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field section_continuation type bool {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field boost type float {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field metadata type string {
|
||||
indexing: summary | attribute
|
||||
}
|
||||
field embeddings type tensor<float>(t{},x[768]) {
|
||||
indexing: attribute
|
||||
attribute {
|
||||
distance-metric: angular
|
||||
}
|
||||
}
|
||||
field allowed_users type array<string> {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
field allowed_groups type array<string> {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
}
|
||||
}
|
||||
|
||||
fieldset default {
|
||||
fields: content
|
||||
}
|
||||
|
||||
rank-profile keyword_search inherits default {
|
||||
first-phase {
|
||||
expression: bm25(content) * attribute(boost)
|
||||
}
|
||||
}
|
||||
|
||||
rank-profile semantic_search inherits default {
|
||||
inputs {
|
||||
query(query_embedding) tensor<float>(x[768])
|
||||
}
|
||||
first-phase {
|
||||
expression: closeness(field, embeddings) * attribute(boost)
|
||||
}
|
||||
match-features: closest(embeddings)
|
||||
}
|
||||
|
||||
rank-profile hybrid_search inherits default {
|
||||
inputs {
|
||||
query(query_embedding) tensor<float>(x[768])
|
||||
}
|
||||
first-phase {
|
||||
expression: bm25(content)
|
||||
}
|
||||
second-phase {
|
||||
expression: closeness(field, embeddings) * attribute(boost)
|
||||
}
|
||||
match-features: closest(embeddings)
|
||||
}
|
||||
}
|
19
backend/danswer/datastores/vespa/app_config/services.xml
Normal file
19
backend/danswer/datastores/vespa/app_config/services.xml
Normal file
@@ -0,0 +1,19 @@
|
||||
<?xml version="1.0" encoding="utf-8" ?>
|
||||
<services version="1.0" xmlns:deploy="vespa" xmlns:preprocess="properties">
|
||||
<container id="default" version="1.0">
|
||||
<document-api/>
|
||||
<search/>
|
||||
<nodes>
|
||||
<node hostalias="node1" />
|
||||
</nodes>
|
||||
</container>
|
||||
<content id="danswer_index" version="1.0">
|
||||
<redundancy>2</redundancy>
|
||||
<documents>
|
||||
<document type="danswer_chunk" mode="index" />
|
||||
</documents>
|
||||
<nodes>
|
||||
<node hostalias="node1" distribution-key="0" />
|
||||
</nodes>
|
||||
</content>
|
||||
</services>
|
402
backend/danswer/datastores/vespa/store.py
Normal file
402
backend/danswer/datastores/vespa/store.py
Normal file
@@ -0,0 +1,402 @@
|
||||
import json
|
||||
from collections.abc import Mapping
|
||||
from functools import partial
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.app_configs import VESPA_DEPLOYMENT_ZIP
|
||||
from danswer.configs.app_configs import VESPA_HOST
|
||||
from danswer.configs.app_configs import VESPA_PORT
|
||||
from danswer.configs.app_configs import VESPA_TENANT_PORT
|
||||
from danswer.configs.constants import ALLOWED_GROUPS
|
||||
from danswer.configs.constants import ALLOWED_USERS
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import BOOST
|
||||
from danswer.configs.constants import CHUNK_ID
|
||||
from danswer.configs.constants import CONTENT
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import EMBEDDINGS
|
||||
from danswer.configs.constants import METADATA
|
||||
from danswer.configs.constants import PUBLIC_DOC_PAT
|
||||
from danswer.configs.constants import SECTION_CONTINUATION
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINKS
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import SEARCH_DISTANCE_CUTOFF
|
||||
from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.datastores.datastore_utils import CrossConnectorDocumentMetadata
|
||||
from danswer.datastores.datastore_utils import get_uuid_from_chunk
|
||||
from danswer.datastores.datastore_utils import (
|
||||
update_cross_connector_document_metadata_map,
|
||||
)
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import DocumentInsertionRecord
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.search.search_utils import get_default_embedding_model
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
VESPA_CONFIG_SERVER_URL = f"http://{VESPA_HOST}:{VESPA_TENANT_PORT}"
|
||||
VESPA_APP_CONTAINER_URL = f"http://{VESPA_HOST}:{VESPA_PORT}"
|
||||
VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
|
||||
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd
|
||||
DOCUMENT_ID_ENDPOINT = (
|
||||
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/danswer_chunk/docid"
|
||||
)
|
||||
SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/"
|
||||
_BATCH_SIZE = 100 # Specific to Vespa
|
||||
|
||||
|
||||
def _get_vespa_document_cross_connector_metadata(
|
||||
doc_chunk_id: str,
|
||||
) -> CrossConnectorDocumentMetadata | None:
|
||||
"""Returns whether the document already exists and the users/group whitelists"""
|
||||
doc_fetch_response = requests.get(f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}")
|
||||
if doc_fetch_response.status_code == 404:
|
||||
return None
|
||||
|
||||
if doc_fetch_response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Unexpected fetch document by ID value from Vespa "
|
||||
f"with error {doc_fetch_response.status_code}"
|
||||
)
|
||||
|
||||
doc_fields = doc_fetch_response.json()["fields"]
|
||||
allowed_users = doc_fields.get(ALLOWED_USERS)
|
||||
allowed_groups = doc_fields.get(ALLOWED_GROUPS)
|
||||
# Add group permission later, empty list gets saved/loaded as a null
|
||||
if allowed_users is None:
|
||||
raise RuntimeError(
|
||||
"Vespa Index is corrupted, Document found with no access lists."
|
||||
)
|
||||
return CrossConnectorDocumentMetadata(
|
||||
allowed_users=allowed_users or [],
|
||||
allowed_user_groups=allowed_groups or [],
|
||||
already_in_index=True,
|
||||
)
|
||||
|
||||
|
||||
def _get_vespa_chunk_ids_by_document_id(
|
||||
document_id: str, hits_per_page: int = _BATCH_SIZE
|
||||
) -> list[str]:
|
||||
offset = 0
|
||||
doc_chunk_ids = []
|
||||
params: dict[str, int | str] = {
|
||||
"yql": f"select documentid from {DOCUMENT_INDEX_NAME} where document_id contains '{document_id}'",
|
||||
"timeout": "10s",
|
||||
"offset": offset,
|
||||
"hits": hits_per_page,
|
||||
}
|
||||
while True:
|
||||
results = requests.get(SEARCH_ENDPOINT, params=params).json()
|
||||
hits = results["root"].get("children", [])
|
||||
doc_chunk_ids.extend(
|
||||
[hit.get("fields", {}).get("documentid").split("::")[1] for hit in hits]
|
||||
)
|
||||
params["offset"] += hits_per_page # type: ignore
|
||||
|
||||
if len(hits) < hits_per_page:
|
||||
break
|
||||
return doc_chunk_ids
|
||||
|
||||
|
||||
def _delete_vespa_doc_chunks(document_id: str) -> bool:
|
||||
doc_chunk_ids = _get_vespa_chunk_ids_by_document_id(document_id)
|
||||
|
||||
failures = [
|
||||
requests.delete(f"{DOCUMENT_ID_ENDPOINT}/{doc}").status_code != 200
|
||||
for doc in doc_chunk_ids
|
||||
]
|
||||
return not any(failures)
|
||||
|
||||
|
||||
def _index_vespa_chunks(
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
json_header = {"Content-Type": "application/json"}
|
||||
insertion_records: set[DocumentInsertionRecord] = set()
|
||||
cross_connector_document_metadata_map: dict[
|
||||
str, CrossConnectorDocumentMetadata
|
||||
] = {}
|
||||
# document ids of documents that existed BEFORE this indexing
|
||||
already_existing_documents: set[str] = set()
|
||||
for chunk in chunks:
|
||||
document = chunk.source_document
|
||||
(
|
||||
cross_connector_document_metadata_map,
|
||||
should_delete_doc,
|
||||
) = update_cross_connector_document_metadata_map(
|
||||
chunk=chunk,
|
||||
cross_connector_document_metadata_map=cross_connector_document_metadata_map,
|
||||
doc_store_cross_connector_document_metadata_fetch_fn=partial(
|
||||
_get_vespa_document_cross_connector_metadata,
|
||||
),
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
)
|
||||
|
||||
if should_delete_doc:
|
||||
# Processing the first chunk of the doc and the doc exists
|
||||
deletion_success = _delete_vespa_doc_chunks(document.id)
|
||||
if not deletion_success:
|
||||
logger.error(
|
||||
f"Failed to delete pre-existing chunks for with document with id: {document.id}"
|
||||
)
|
||||
already_existing_documents.add(document.id)
|
||||
|
||||
vespa_chunk_id = str(get_uuid_from_chunk(chunk))
|
||||
|
||||
embeddings = chunk.embeddings
|
||||
embeddings_name_vector_map = {"full_chunk": embeddings.full_embedding}
|
||||
if embeddings.mini_chunk_embeddings:
|
||||
for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings):
|
||||
embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed
|
||||
|
||||
vespa_document = {
|
||||
"fields": {
|
||||
DOCUMENT_ID: document.id,
|
||||
CHUNK_ID: chunk.chunk_id,
|
||||
BLURB: chunk.blurb,
|
||||
CONTENT: chunk.content,
|
||||
SOURCE_TYPE: str(document.source.value),
|
||||
SOURCE_LINKS: json.dumps(chunk.source_links),
|
||||
SEMANTIC_IDENTIFIER: document.semantic_identifier,
|
||||
SECTION_CONTINUATION: chunk.section_continuation,
|
||||
METADATA: json.dumps(document.metadata),
|
||||
EMBEDDINGS: embeddings_name_vector_map,
|
||||
BOOST: 1, # Boost value always starts at 1 for 0 impact on weight
|
||||
ALLOWED_USERS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_users,
|
||||
ALLOWED_GROUPS: cross_connector_document_metadata_map[
|
||||
document.id
|
||||
].allowed_user_groups,
|
||||
}
|
||||
}
|
||||
|
||||
url = f"{DOCUMENT_ID_ENDPOINT}/{vespa_chunk_id}"
|
||||
|
||||
res = requests.post(url, headers=json_header, json=vespa_document)
|
||||
res.raise_for_status()
|
||||
|
||||
insertion_records.add(
|
||||
DocumentInsertionRecord(
|
||||
document_id=document.id,
|
||||
already_existed=document.id in already_existing_documents,
|
||||
)
|
||||
)
|
||||
|
||||
return insertion_records
|
||||
|
||||
|
||||
def _build_vespa_filters(
|
||||
user_id: UUID | None, filters: list[IndexFilter] | None
|
||||
) -> str:
|
||||
filter_str = ""
|
||||
# Permissions filter
|
||||
# TODO group permissioning
|
||||
if user_id:
|
||||
filter_str += (
|
||||
f'({ALLOWED_USERS} contains "{user_id}" or '
|
||||
f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}") and '
|
||||
)
|
||||
else:
|
||||
filter_str += f'{ALLOWED_USERS} contains "{PUBLIC_DOC_PAT}" and '
|
||||
|
||||
# Provided query filters
|
||||
if filters:
|
||||
for filter_dict in filters:
|
||||
valid_filters = {
|
||||
key: value for key, value in filter_dict.items() if value is not None
|
||||
}
|
||||
for filter_key, filter_val in valid_filters.items():
|
||||
if isinstance(filter_val, str):
|
||||
filter_str += f'{filter_key} = "{filter_val}" and '
|
||||
elif isinstance(filter_val, list):
|
||||
quoted_elems = [f'"{elem}"' for elem in filter_val]
|
||||
filters_or = ",".join(quoted_elems)
|
||||
filter_str += f"{filter_key} in [{filters_or}] and "
|
||||
else:
|
||||
raise ValueError("Invalid filters provided")
|
||||
return filter_str
|
||||
|
||||
|
||||
def _query_vespa(query_params: Mapping[str, str | int]) -> list[InferenceChunk]:
|
||||
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
||||
raise ValueError(
|
||||
"Query only consisted of stopwords, should not use Keyword Search"
|
||||
)
|
||||
response = requests.get(SEARCH_ENDPOINT, params=query_params)
|
||||
response.raise_for_status()
|
||||
|
||||
hits = response.json()["root"].get("children", [])
|
||||
inference_chunks = [InferenceChunk.from_dict(hit["fields"]) for hit in hits]
|
||||
|
||||
return inference_chunks
|
||||
|
||||
|
||||
class VespaIndex(DocumentIndex):
|
||||
yql_base = (
|
||||
f"select "
|
||||
f"documentid, "
|
||||
f"{DOCUMENT_ID}, "
|
||||
f"{CHUNK_ID}, "
|
||||
f"{BLURB}, "
|
||||
f"{CONTENT}, "
|
||||
f"{SOURCE_TYPE}, "
|
||||
f"{SOURCE_LINKS}, "
|
||||
f"{SEMANTIC_IDENTIFIER}, "
|
||||
f"{SECTION_CONTINUATION}, "
|
||||
f"{BOOST}, "
|
||||
f"{METADATA} "
|
||||
f"from {DOCUMENT_INDEX_NAME} where "
|
||||
)
|
||||
|
||||
def __init__(self, deployment_zip: str = VESPA_DEPLOYMENT_ZIP) -> None:
|
||||
# Vespa index name isn't configurable via code alone because of the config .sd file that needs
|
||||
# to be updated + zipped + deployed, not supporting the option for simplicity
|
||||
self.deployment_zip = deployment_zip
|
||||
|
||||
def ensure_indices_exist(self) -> None:
|
||||
"""Verifying indices is more involved as there is no good way to
|
||||
verify the deployed app against the zip locally. But deploying the latest app.zip will ensure that
|
||||
the index is up-to-date with the expected schema and this does not erase the existing index.
|
||||
If the changes cannot be applied without conflict with existing data, it will fail with a non 200
|
||||
"""
|
||||
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
|
||||
headers = {"Content-Type": "application/zip"}
|
||||
with open(self.deployment_zip, "rb") as f:
|
||||
response = requests.post(deploy_url, headers=headers, data=f)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to prepare Vespa Danswer Index. Response: {response.text}"
|
||||
)
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[IndexChunk],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
return _index_vespa_chunks(
|
||||
chunks=chunks, index_attempt_metadata=index_attempt_metadata
|
||||
)
|
||||
|
||||
def update(self, update_requests: list[UpdateRequest]) -> None:
|
||||
logger.info(
|
||||
f"Updating {len(update_requests)} documents' allowed_users in Vespa"
|
||||
)
|
||||
|
||||
json_header = {"Content-Type": "application/json"}
|
||||
|
||||
for update_request in update_requests:
|
||||
if update_request.boost is None and update_request.allowed_users is None:
|
||||
logger.error("Update request received but nothing to update")
|
||||
|
||||
update_dict: dict[str, dict[str, list[str] | int]] = {"fields": {}}
|
||||
if update_request.boost:
|
||||
update_dict["fields"][BOOST] = update_request.boost
|
||||
if update_request.allowed_users:
|
||||
update_dict["fields"][ALLOWED_USERS] = update_request.allowed_users
|
||||
|
||||
for document_id in update_request.document_ids:
|
||||
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(document_id):
|
||||
url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}"
|
||||
res = requests.put(url, headers=json_header, json=update_dict)
|
||||
|
||||
if res.status_code != 200:
|
||||
logger.error(f"Failed to update document: {document_id}")
|
||||
|
||||
def delete(self, doc_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(doc_ids)} documents from Vespa")
|
||||
for doc_id in doc_ids:
|
||||
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(doc_id):
|
||||
url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}"
|
||||
logger.debug("Deleting: " + url)
|
||||
requests.delete(url)
|
||||
|
||||
def keyword_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = _build_vespa_filters(user_id, filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base
|
||||
+ vespa_where_clauses
|
||||
+ '({grammar: "weakAnd"}userInput(@query))'
|
||||
)
|
||||
|
||||
params: dict[str, str | int] = {
|
||||
"yql": yql,
|
||||
"query": query,
|
||||
"hits": num_to_retrieve,
|
||||
"num_to_rerank": 10 * num_to_retrieve,
|
||||
"ranking.profile": "keyword_search",
|
||||
}
|
||||
|
||||
return _query_vespa(params)
|
||||
|
||||
def semantic_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int,
|
||||
distance_cutoff: float | None = SEARCH_DISTANCE_CUTOFF,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = _build_vespa_filters(user_id, filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base
|
||||
+ vespa_where_clauses
|
||||
+ f"({{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding))"
|
||||
)
|
||||
|
||||
query_embedding = get_default_embedding_model().encode(query)
|
||||
if not isinstance(query_embedding, list):
|
||||
query_embedding = query_embedding.tolist()
|
||||
|
||||
params = {
|
||||
"yql": yql,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"ranking.profile": "semantic_search",
|
||||
}
|
||||
|
||||
return _query_vespa(params)
|
||||
|
||||
def hybrid_retrieval(
|
||||
self,
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = _build_vespa_filters(user_id, filters)
|
||||
yql = (
|
||||
VespaIndex.yql_base
|
||||
+ vespa_where_clauses
|
||||
+ f'{{targetHits: {10 * num_to_retrieve}}}nearestNeighbor(embeddings, query_embedding) or {{grammar: "weakAnd"}}userInput(@query)'
|
||||
)
|
||||
|
||||
query_embedding = get_default_embedding_model().encode(query)
|
||||
if not isinstance(query_embedding, list):
|
||||
query_embedding = query_embedding.tolist()
|
||||
|
||||
params = {
|
||||
"yql": yql,
|
||||
"query": query,
|
||||
"input.query(query_embedding)": str(query_embedding),
|
||||
"ranking.profile": "hybrid_search",
|
||||
}
|
||||
|
||||
return _query_vespa(params)
|
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
@@ -8,8 +7,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.datastores.interfaces import ChunkMetadata
|
||||
from danswer.db.models import Chunk
|
||||
from danswer.datastores.interfaces import DocumentMetadata
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.utils import model_to_dict
|
||||
@@ -18,11 +16,11 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_chunks_with_single_connector_credential_pair(
|
||||
def get_documents_with_single_connector_credential_pair(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> Sequence[Chunk]:
|
||||
) -> Sequence[Document]:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
@@ -30,6 +28,8 @@ def get_chunks_with_single_connector_credential_pair(
|
||||
)
|
||||
)
|
||||
|
||||
# Filter it down to the documents with only a single connector/credential pair
|
||||
# Meaning if this connector/credential pair is removed, this doc should be gone
|
||||
trimmed_doc_ids_stmt = (
|
||||
select(Document.id)
|
||||
.join(
|
||||
@@ -41,7 +41,7 @@ def get_chunks_with_single_connector_credential_pair(
|
||||
.having(func.count(DocumentByConnectorCredentialPair.id) == 1)
|
||||
)
|
||||
|
||||
stmt = select(Chunk).where(Chunk.document_id.in_(trimmed_doc_ids_stmt))
|
||||
stmt = select(Document).where(Document.id.in_(trimmed_doc_ids_stmt))
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
@@ -57,6 +57,8 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple(
|
||||
)
|
||||
)
|
||||
|
||||
# Filter it down to the documents with more than 1 connector/credential pair
|
||||
# Meaning if this connector/credential pair is removed, this doc is still accessible
|
||||
trimmed_doc_ids_stmt = (
|
||||
select(Document.id)
|
||||
.join(
|
||||
@@ -75,15 +77,8 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple(
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def get_chunk_ids_for_document_ids(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> Sequence[str]:
|
||||
stmt = select(Chunk.id).where(Chunk.document_id.in_(document_ids))
|
||||
return db_session.execute(stmt).scalars().all()
|
||||
|
||||
|
||||
def upsert_documents(
|
||||
db_session: Session, document_metadata_batch: list[ChunkMetadata]
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
seen_document_ids: set[str] = set()
|
||||
@@ -102,7 +97,7 @@ def upsert_documents(
|
||||
|
||||
|
||||
def upsert_document_by_connector_credential_pair(
|
||||
db_session: Session, document_metadata_batch: list[ChunkMetadata]
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
insert_stmt = insert(DocumentByConnectorCredentialPair).values(
|
||||
@@ -124,45 +119,16 @@ def upsert_document_by_connector_credential_pair(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_chunks(
|
||||
db_session: Session, document_metadata_batch: list[ChunkMetadata]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
insert_stmt = insert(Chunk).values(
|
||||
[
|
||||
model_to_dict(
|
||||
Chunk(
|
||||
id=document_metadata.store_id,
|
||||
document_id=document_metadata.document_id,
|
||||
document_store_type=document_metadata.document_store_type,
|
||||
)
|
||||
)
|
||||
for document_metadata in document_metadata_batch
|
||||
]
|
||||
)
|
||||
on_conflict_stmt = insert_stmt.on_conflict_do_update(
|
||||
index_elements=["id", "document_store_type"],
|
||||
set_=dict(document_id=insert_stmt.excluded.document_id),
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def upsert_documents_complete(
|
||||
db_session: Session, document_metadata_batch: list[ChunkMetadata]
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
) -> None:
|
||||
upsert_documents(db_session, document_metadata_batch)
|
||||
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
upsert_chunks(db_session, document_metadata_batch)
|
||||
logger.info(
|
||||
f"Upserted {len(document_metadata_batch)} document store entries into DB"
|
||||
)
|
||||
|
||||
|
||||
def delete_document_store_entries(db_session: Session, document_ids: list[str]) -> None:
|
||||
db_session.execute(delete(Chunk).where(Chunk.document_id.in_(document_ids)))
|
||||
|
||||
|
||||
def delete_document_by_connector_credential_pair(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
@@ -179,7 +145,6 @@ def delete_documents(db_session: Session, document_ids: list[str]) -> None:
|
||||
|
||||
def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None:
|
||||
logger.info(f"Deleting {len(document_ids)} documents from the DB")
|
||||
delete_document_store_entries(db_session, document_ids)
|
||||
delete_document_by_connector_credential_pair(db_session, document_ids)
|
||||
delete_documents(db_session, document_ids)
|
||||
db_session.commit()
|
||||
|
@@ -25,7 +25,6 @@ from sqlalchemy.orm import relationship
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.datastores.interfaces import StoreType
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -185,7 +184,7 @@ class IndexAttempt(Base):
|
||||
nullable=True,
|
||||
)
|
||||
status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus))
|
||||
num_docs_indexed: Mapped[int] = mapped_column(Integer, default=0)
|
||||
num_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0)
|
||||
error_msg: Mapped[str | None] = mapped_column(
|
||||
String(), default=None
|
||||
) # only filled if status = "failed"
|
||||
@@ -195,7 +194,7 @@ class IndexAttempt(Base):
|
||||
)
|
||||
# when the actual indexing run began
|
||||
# NOTE: will use the api_server clock rather than DB server clock
|
||||
time_started: Mapped[datetime.datetime] = mapped_column(
|
||||
time_started: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), default=None
|
||||
)
|
||||
time_updated: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -261,10 +260,7 @@ class DeletionAttempt(Base):
|
||||
|
||||
class Document(Base):
|
||||
"""Represents a single documents from a source. This is used to store
|
||||
document level metadata so we don't need to duplicate it in a bunch of
|
||||
DocumentByConnectorCredentialPair's/Chunk's for documents
|
||||
that are split into many chunks and/or indexed by many connector / credential
|
||||
pairs."""
|
||||
document level metadata, but currently nothing is stored"""
|
||||
|
||||
__tablename__ = "document"
|
||||
|
||||
@@ -272,10 +268,6 @@ class Document(Base):
|
||||
# in Danswer)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
document_store_entries: Mapped["Chunk"] = relationship(
|
||||
"Chunk", back_populates="document"
|
||||
)
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential
|
||||
@@ -297,21 +289,3 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
credential: Mapped[Credential] = relationship(
|
||||
"Credential", back_populates="documents_by_credential"
|
||||
)
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
"""A row represents a single entry in a document store (e.g. a single chunk
|
||||
in Qdrant/Typesense)"""
|
||||
|
||||
__tablename__ = "chunk"
|
||||
|
||||
# this should correspond to the ID in the document store
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
document_store_type: Mapped[StoreType] = mapped_column(
|
||||
Enum(StoreType), primary_key=True
|
||||
)
|
||||
document_id: Mapped[str] = mapped_column(ForeignKey("document.id"))
|
||||
|
||||
document: Mapped[Document] = relationship(
|
||||
"Document", back_populates="document_store_entries"
|
||||
)
|
||||
|
@@ -2,8 +2,7 @@ from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
@@ -43,12 +42,12 @@ def answer_question(
|
||||
user_id = None if user is None else user.id
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query, user_id, filters, TypesenseIndex(collection)
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
else:
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query, user_id, filters, QdrantIndex(collection)
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return QAResponse(
|
||||
|
@@ -13,7 +13,7 @@ from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from danswer.configs.app_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.app_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from danswer.configs.app_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import UserIdReplacer
|
||||
@@ -206,7 +206,7 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
|
||||
answer = _get_answer(
|
||||
QuestionRequest(
|
||||
query=req.payload.get("event", {}).get("text"),
|
||||
collection=QDRANT_DEFAULT_COLLECTION,
|
||||
collection=DOCUMENT_INDEX_NAME,
|
||||
use_keyword=None,
|
||||
filters=None,
|
||||
offset=None,
|
||||
|
@@ -21,17 +21,13 @@ from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import OAUTH_TYPE
|
||||
from danswer.configs.app_configs import OPENID_CONFIG_URL
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.model_configs import API_BASE_OPENAI
|
||||
from danswer.configs.model_configs import API_TYPE_OPENAI
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.datastores.qdrant.indexing import list_qdrant_collections
|
||||
from danswer.datastores.typesense.store import check_typesense_collection_exist
|
||||
from danswer.datastores.typesense.store import create_typesense_collection
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.server.event_loading import router as event_processing_router
|
||||
@@ -149,7 +145,6 @@ def get_application() -> FastAPI:
|
||||
from danswer.search.search_utils import (
|
||||
warm_up_models,
|
||||
)
|
||||
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.info("Generative AI Q&A disabled")
|
||||
@@ -190,19 +185,8 @@ def get_application() -> FastAPI:
|
||||
logger.info("Verifying public credential exists.")
|
||||
create_initial_public_credential()
|
||||
|
||||
logger.info("Verifying Document Indexes are available.")
|
||||
if QDRANT_DEFAULT_COLLECTION not in {
|
||||
collection.name for collection in list_qdrant_collections().collections
|
||||
}:
|
||||
logger.info(
|
||||
f"Creating Qdrant collection with name: {QDRANT_DEFAULT_COLLECTION}"
|
||||
)
|
||||
create_qdrant_collection(collection_name=QDRANT_DEFAULT_COLLECTION)
|
||||
if not check_typesense_collection_exist(TYPESENSE_DEFAULT_COLLECTION):
|
||||
logger.info(
|
||||
f"Creating Typesense collection with name: {TYPESENSE_DEFAULT_COLLECTION}"
|
||||
)
|
||||
create_typesense_collection(collection_name=TYPESENSE_DEFAULT_COLLECTION)
|
||||
logger.info("Verifying Document Index(s) is/are available.")
|
||||
get_default_document_index().ensure_indices_exist()
|
||||
|
||||
return application
|
||||
|
||||
|
@@ -7,8 +7,8 @@ from nltk.tokenize import word_tokenize # type:ignore
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import KeywordIndex
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
@@ -38,11 +38,11 @@ def retrieve_keyword_documents(
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
datastore: KeywordIndex,
|
||||
datastore: DocumentIndex,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
) -> list[InferenceChunk] | None:
|
||||
edited_query = query_processing(query)
|
||||
top_chunks = datastore.keyword_search(edited_query, user_id, filters, num_hits)
|
||||
top_chunks = datastore.keyword_retrieval(edited_query, user_id, filters, num_hits)
|
||||
if not top_chunks:
|
||||
filters_log_msg = json.dumps(filters, separators=(",", ":")).replace("\n", "")
|
||||
logger.warning(
|
||||
|
@@ -1,6 +1,6 @@
|
||||
from enum import Enum
|
||||
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import DocAwareChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
|
||||
|
||||
@@ -15,5 +15,5 @@ class QueryFlow(str, Enum):
|
||||
|
||||
|
||||
class Embedder:
|
||||
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
|
||||
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
||||
raise NotImplementedError
|
||||
|
@@ -4,7 +4,8 @@ from uuid import UUID
|
||||
import numpy
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from danswer.chunking.models import EmbeddedIndexChunk
|
||||
from danswer.chunking.models import ChunkEmbedding
|
||||
from danswer.chunking.models import DocAwareChunk
|
||||
from danswer.chunking.models import IndexChunk
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
@@ -12,8 +13,8 @@ from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.app_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.app_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.datastores.interfaces import DocumentIndex
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.datastores.interfaces import VectorIndex
|
||||
from danswer.search.models import Embedder
|
||||
from danswer.search.search_utils import get_default_embedding_model
|
||||
from danswer.search.search_utils import get_default_reranking_model_ensemble
|
||||
@@ -69,7 +70,7 @@ def retrieve_ranked_documents(
|
||||
query: str,
|
||||
user_id: UUID | None,
|
||||
filters: list[IndexFilter] | None,
|
||||
datastore: VectorIndex,
|
||||
datastore: DocumentIndex,
|
||||
num_hits: int = NUM_RETURNED_HITS,
|
||||
num_rerank: int = NUM_RERANKED_RESULTS,
|
||||
) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None]:
|
||||
@@ -127,12 +128,12 @@ def split_chunk_text_into_mini_chunks(
|
||||
|
||||
@log_function_time()
|
||||
def encode_chunks(
|
||||
chunks: list[IndexChunk],
|
||||
chunks: list[DocAwareChunk],
|
||||
embedding_model: SentenceTransformer | None = None,
|
||||
batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
||||
enable_mini_chunk: bool = ENABLE_MINI_CHUNK,
|
||||
) -> list[EmbeddedIndexChunk]:
|
||||
embedded_chunks: list[EmbeddedIndexChunk] = []
|
||||
) -> list[IndexChunk]:
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
if embedding_model is None:
|
||||
embedding_model = get_default_embedding_model()
|
||||
|
||||
@@ -163,9 +164,12 @@ def encode_chunks(
|
||||
chunk_embeddings = embeddings[
|
||||
embedding_ind_start : embedding_ind_start + num_embeddings
|
||||
]
|
||||
new_embedded_chunk = EmbeddedIndexChunk(
|
||||
new_embedded_chunk = IndexChunk(
|
||||
**{k: getattr(chunk, k) for k in chunk.__dataclass_fields__},
|
||||
embeddings=chunk_embeddings,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=chunk_embeddings[0],
|
||||
mini_chunk_embeddings=chunk_embeddings[1:],
|
||||
),
|
||||
)
|
||||
embedded_chunks.append(new_embedded_chunk)
|
||||
embedding_ind_start += num_embeddings
|
||||
@@ -174,5 +178,5 @@ def encode_chunks(
|
||||
|
||||
|
||||
class DefaultEmbedder(Embedder):
|
||||
def embed(self, chunks: list[IndexChunk]) -> list[EmbeddedIndexChunk]:
|
||||
def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]:
|
||||
return encode_chunks(chunks)
|
||||
|
@@ -10,8 +10,7 @@ from danswer.auth.users import current_user
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.answer_question import answer_question
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
@@ -60,7 +59,7 @@ def semantic_search(
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query, user_id, filters, QdrantIndex(collection)
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||
@@ -82,7 +81,7 @@ def keyword_search(
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
ranked_chunks = retrieve_keyword_documents(
|
||||
query, user_id, filters, TypesenseIndex(collection)
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||
@@ -110,6 +109,8 @@ def stream_direct_qa(
|
||||
|
||||
logger.debug(f"Received QA query: {question.query}")
|
||||
logger.debug(f"Query filters: {question.filters}")
|
||||
if question.use_keyword:
|
||||
logger.debug(f"User selected Keyword Search")
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_qa_portions(
|
||||
@@ -128,12 +129,18 @@ def stream_direct_qa(
|
||||
user_id = None if user is None else user.id
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query, user_id, filters, TypesenseIndex(collection)
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
get_default_document_index(collection=collection),
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
else:
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query, user_id, filters, QdrantIndex(collection)
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
get_default_document_index(collection=collection),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
logger.debug("No Documents Found")
|
||||
|
@@ -1,4 +1,4 @@
|
||||
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.utils.clients import get_typesense_client
|
||||
|
||||
|
||||
@@ -14,9 +14,7 @@ if __name__ == "__main__":
|
||||
"page": page_number,
|
||||
"per_page": per_page,
|
||||
}
|
||||
response = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.search(
|
||||
params
|
||||
)
|
||||
response = ts_client.collections[DOCUMENT_INDEX_NAME].documents.search(params)
|
||||
documents = response.get("hits")
|
||||
if not documents:
|
||||
break # if there are no more documents, break out of the loop
|
||||
|
@@ -11,17 +11,16 @@ from typesense.exceptions import ObjectNotFound # type: ignore
|
||||
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.app_configs import POSTGRES_DB
|
||||
from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import QDRANT_HOST
|
||||
from danswer.configs.app_configs import QDRANT_PORT
|
||||
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
||||
from danswer.datastores.qdrant.indexing import list_qdrant_collections
|
||||
from danswer.datastores.qdrant.utils import create_qdrant_collection
|
||||
from danswer.datastores.qdrant.utils import list_qdrant_collections
|
||||
from danswer.datastores.typesense.store import create_typesense_collection
|
||||
from danswer.utils.clients import get_qdrant_client
|
||||
from danswer.utils.clients import get_typesense_client
|
||||
@@ -60,13 +59,13 @@ def snapshot_time_compare(snap: SnapshotDescription) -> datetime:
|
||||
def save_qdrant(filename: str) -> None:
|
||||
logger.info("Attempting to take Qdrant snapshot")
|
||||
qdrant_client = get_qdrant_client()
|
||||
qdrant_client.create_snapshot(collection_name=QDRANT_DEFAULT_COLLECTION)
|
||||
snapshots = qdrant_client.list_snapshots(collection_name=QDRANT_DEFAULT_COLLECTION)
|
||||
qdrant_client.create_snapshot(collection_name=DOCUMENT_INDEX_NAME)
|
||||
snapshots = qdrant_client.list_snapshots(collection_name=DOCUMENT_INDEX_NAME)
|
||||
valid_snapshots = [snap for snap in snapshots if snap.creation_time is not None]
|
||||
|
||||
sorted_snapshots = sorted(valid_snapshots, key=snapshot_time_compare)
|
||||
last_snapshot_name = sorted_snapshots[-1].name
|
||||
url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/{last_snapshot_name}"
|
||||
url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{DOCUMENT_INDEX_NAME}/snapshots/{last_snapshot_name}"
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
|
||||
@@ -80,11 +79,11 @@ def save_qdrant(filename: str) -> None:
|
||||
|
||||
def load_qdrant(filename: str) -> None:
|
||||
logger.info("Attempting to load Qdrant snapshot")
|
||||
if QDRANT_DEFAULT_COLLECTION not in {
|
||||
if DOCUMENT_INDEX_NAME not in {
|
||||
collection.name for collection in list_qdrant_collections().collections
|
||||
}:
|
||||
create_qdrant_collection(QDRANT_DEFAULT_COLLECTION)
|
||||
snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{QDRANT_DEFAULT_COLLECTION}/snapshots/"
|
||||
create_qdrant_collection(DOCUMENT_INDEX_NAME)
|
||||
snapshot_url = f"http://{QDRANT_HOST}:{QDRANT_PORT}/collections/{DOCUMENT_INDEX_NAME}/snapshots/"
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
files = {"snapshot": (os.path.basename(filename), f)}
|
||||
@@ -104,7 +103,7 @@ def load_qdrant(filename: str) -> None:
|
||||
def save_typesense(filename: str) -> None:
|
||||
logger.info("Attempting to take Typesense snapshot")
|
||||
ts_client = get_typesense_client()
|
||||
all_docs = ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.export()
|
||||
all_docs = ts_client.collections[DOCUMENT_INDEX_NAME].documents.export()
|
||||
with open(filename, "w") as f:
|
||||
f.write(all_docs)
|
||||
|
||||
@@ -113,14 +112,14 @@ def load_typesense(filename: str) -> None:
|
||||
logger.info("Attempting to load Typesense snapshot")
|
||||
ts_client = get_typesense_client()
|
||||
try:
|
||||
ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].delete()
|
||||
ts_client.collections[DOCUMENT_INDEX_NAME].delete()
|
||||
except ObjectNotFound:
|
||||
pass
|
||||
|
||||
create_typesense_collection(TYPESENSE_DEFAULT_COLLECTION)
|
||||
create_typesense_collection(DOCUMENT_INDEX_NAME)
|
||||
|
||||
with open(filename) as jsonl_file:
|
||||
ts_client.collections[TYPESENSE_DEFAULT_COLLECTION].documents.import_(
|
||||
ts_client.collections[DOCUMENT_INDEX_NAME].documents.import_(
|
||||
jsonl_file.read().encode("utf-8"), {"action": "create"}
|
||||
)
|
||||
|
||||
|
@@ -7,7 +7,7 @@ from pprint import pprint
|
||||
import requests
|
||||
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ if __name__ == "__main__":
|
||||
|
||||
query_json = {
|
||||
"query": query,
|
||||
"collection": QDRANT_DEFAULT_COLLECTION,
|
||||
"collection": DOCUMENT_INDEX_NAME,
|
||||
"use_keyword": flow_type == "keyword", # Ignore if not QA Endpoints
|
||||
"filters": [{SOURCE_TYPE: source_types}],
|
||||
}
|
||||
|
Reference in New Issue
Block a user