From 683addc3908ad5dbf6308902964a7c8487f972fd Mon Sep 17 00:00:00 2001 From: Weves Date: Thu, 9 May 2024 14:23:43 -0700 Subject: [PATCH] Use Vespa Visit to handle long documents --- backend/danswer/document_index/interfaces.py | 2 +- backend/danswer/document_index/vespa/index.py | 154 ++++++++++-------- backend/danswer/server/documents/document.py | 9 +- 3 files changed, 91 insertions(+), 74 deletions(-) diff --git a/backend/danswer/document_index/interfaces.py b/backend/danswer/document_index/interfaces.py index e59874fa0..6adedd452 100644 --- a/backend/danswer/document_index/interfaces.py +++ b/backend/danswer/document_index/interfaces.py @@ -185,7 +185,7 @@ class IdRetrievalCapable(abc.ABC): document_id: str, min_chunk_ind: int | None, max_chunk_ind: int | None, - filters: IndexFilters, + user_access_control_list: list[str] | None = None, ) -> list[InferenceChunk]: """ Fetch chunk(s) based on document id diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 78c79f137..cad1fa7b1 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -5,7 +5,6 @@ import os import string import time import zipfile -from collections.abc import Callable from collections.abc import Mapping from dataclasses import dataclass from datetime import datetime @@ -69,7 +68,6 @@ from danswer.search.retrieval.search_runner import query_processing from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.utils.batching import batch_generator from danswer.utils.logger import setup_logger -from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel logger = setup_logger() @@ -142,35 +140,58 @@ def _vespa_get_updated_at_attribute(t: datetime | None) -> int | None: return int(t.timestamp()) -def _get_vespa_chunk_ids_by_document_id( +def _get_vespa_chunks_by_document_id( document_id: str, index_name: str, - hits_per_page: int = _BATCH_SIZE, - index_filters: IndexFilters | None = None, -) -> list[str]: - filters_str = ( - _build_vespa_filters(filters=index_filters, include_hidden=True) - if index_filters is not None - else "" - ) + user_access_control_list: list[str] | None = None, + min_chunk_ind: int | None = None, + max_chunk_ind: int | None = None, + field_names: list[str] | None = None, +) -> list[dict]: + # Constructing the URL for the Visit API + # NOTE: visit API uses the same URL as the document API, but with different params + url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) - offset = 0 - doc_chunk_ids = [] - params: dict[str, int | str] = { - "yql": f"select documentid from {index_name} where {filters_str}document_id contains '{document_id}'", - "timeout": "10s", - "offset": offset, - "hits": hits_per_page, + # build the list of fields to retrieve + field_set_list = ( + None + if not field_names + else [f"{index_name}:{field_name}" for field_name in field_names] + ) + acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}" + if ( + field_set_list + and user_access_control_list + and acl_fieldset_entry not in field_set_list + ): + field_set_list.append(acl_fieldset_entry) + field_set = ",".join(field_set_list) if field_set_list else None + + # build filters + selection = f"{index_name}.document_id=='{document_id}'" + if min_chunk_ind is not None: + selection += f" and {index_name}.chunk_id>={min_chunk_ind}" + if max_chunk_ind is not None: + selection += f" and {index_name}.chunk_id<={max_chunk_ind}" + + # Setting up the selection criteria in the query parameters + params = { + # NOTE: Document Selector Language doesn't allow `contains`, so we can't check + # for the ACL in the selection. Instead, we have to check as a postfilter + "selection": selection, + "continuation": None, + "wantedDocumentCount": 1_000, + "fieldSet": field_set, } + + document_chunks: list[dict] = [] while True: - res = requests.post(SEARCH_ENDPOINT, json=params) + response = requests.get(url, params=params) try: - res.raise_for_status() + response.raise_for_status() except requests.HTTPError as e: - request_info = f"Headers: {res.request.headers}\nPayload: {params}" - response_info = ( - f"Status Code: {res.status_code}\nResponse Content: {res.text}" - ) + request_info = f"Headers: {response.request.headers}\nPayload: {params}" + response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}" error_base = f"Error occurred getting chunk by Document ID {document_id}" logger.error( f"{error_base}:\n" @@ -180,17 +201,39 @@ def _get_vespa_chunk_ids_by_document_id( ) raise requests.HTTPError(error_base) from e - results = res.json() - hits = results["root"].get("children", []) + # Check if the response contains any documents + response_data = response.json() + if "documents" in response_data: + for document in response_data["documents"]: + if user_access_control_list: + document_acl = document["fields"].get(ACCESS_CONTROL_LIST) + if not document_acl or not any( + user_acl_entry in document_acl + for user_acl_entry in user_access_control_list + ): + continue + document_chunks.append(document) + document_chunks.extend(response_data["documents"]) - doc_chunk_ids.extend( - [hit["fields"]["documentid"].split("::", 1)[-1] for hit in hits] - ) - params["offset"] += hits_per_page # type: ignore + # Check for continuation token to handle pagination + if "continuation" in response_data and response_data["continuation"]: + params["continuation"] = response_data["continuation"] + else: + break # Exit loop if no continuation token - if len(hits) < hits_per_page: - break - return doc_chunk_ids + return document_chunks + + +def _get_vespa_chunk_ids_by_document_id( + document_id: str, index_name: str, user_access_control_list: list[str] | None = None +) -> list[str]: + document_chunks = _get_vespa_chunks_by_document_id( + document_id=document_id, + index_name=index_name, + user_access_control_list=user_access_control_list, + field_names=[DOCUMENT_ID], + ) + return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks] @retry(tries=3, delay=1, backoff=2) @@ -872,43 +915,22 @@ class VespaIndex(DocumentIndex): document_id: str, min_chunk_ind: int | None, max_chunk_ind: int | None, - filters: IndexFilters, + user_access_control_list: list[str] | None = None, ) -> list[InferenceChunk]: - if min_chunk_ind is None and max_chunk_ind is None: - vespa_chunk_ids = _get_vespa_chunk_ids_by_document_id( - document_id=document_id, - index_name=self.index_name, - index_filters=filters, - ) - - if not vespa_chunk_ids: - return [] - - functions_with_args: list[tuple[Callable, tuple]] = [ - (_inference_chunk_by_vespa_id, (vespa_chunk_id, self.index_name)) - for vespa_chunk_id in vespa_chunk_ids - ] - - inference_chunks = run_functions_tuples_in_parallel( - functions_with_args, allow_failures=True - ) - inference_chunks.sort(key=lambda chunk: chunk.chunk_id) - return inference_chunks - - filters_str = _build_vespa_filters(filters=filters, include_hidden=True) - yql = ( - VespaIndex.yql_base.format(index_name=self.index_name) - + filters_str - + f"({DOCUMENT_ID} contains '{document_id}'" + vespa_chunks = _get_vespa_chunks_by_document_id( + document_id=document_id, + index_name=self.index_name, + user_access_control_list=user_access_control_list, + min_chunk_ind=min_chunk_ind, + max_chunk_ind=max_chunk_ind, ) - if min_chunk_ind is not None: - yql += f" and {min_chunk_ind} <= {CHUNK_ID}" - if max_chunk_ind is not None: - yql += f" and {max_chunk_ind} >= {CHUNK_ID}" - yql = yql + ")" + if not vespa_chunks: + return [] - inference_chunks = _query_vespa({"yql": yql}) + inference_chunks = [ + _vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks + ] inference_chunks.sort(key=lambda chunk: chunk.chunk_id) return inference_chunks diff --git a/backend/danswer/server/documents/document.py b/backend/danswer/server/documents/document.py index 06dd712d1..3b0adea24 100644 --- a/backend/danswer/server/documents/document.py +++ b/backend/danswer/server/documents/document.py @@ -11,7 +11,6 @@ from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.llm.utils import get_default_llm_token_encode from danswer.prompts.prompt_utils import build_doc_context_str -from danswer.search.models import IndexFilters from danswer.search.preprocessing.access_filters import build_access_filters_for_user from danswer.server.documents.models import ChunkInfo from danswer.server.documents.models import DocumentInfo @@ -35,13 +34,11 @@ def get_document_info( ) user_acl_filters = build_access_filters_for_user(user, db_session) - filters = IndexFilters(access_control_list=user_acl_filters) - inference_chunks = document_index.id_based_retrieval( document_id=document_id, min_chunk_ind=None, max_chunk_ind=None, - filters=filters, + user_access_control_list=user_acl_filters, ) if not inference_chunks: @@ -83,13 +80,11 @@ def get_chunk_info( ) user_acl_filters = build_access_filters_for_user(user, db_session) - filters = IndexFilters(access_control_list=user_acl_filters) - inference_chunks = document_index.id_based_retrieval( document_id=document_id, min_chunk_ind=chunk_id, max_chunk_ind=chunk_id, - filters=filters, + user_access_control_list=user_acl_filters, ) if not inference_chunks: