Use Vespa Visit to handle long documents

This commit is contained in:
Weves 2024-05-09 14:23:43 -07:00 committed by Chris Weaver
parent 2952b1dd96
commit 683addc390
3 changed files with 91 additions and 74 deletions

View File

@ -185,7 +185,7 @@ class IdRetrievalCapable(abc.ABC):
document_id: str, document_id: str,
min_chunk_ind: int | None, min_chunk_ind: int | None,
max_chunk_ind: int | None, max_chunk_ind: int | None,
filters: IndexFilters, user_access_control_list: list[str] | None = None,
) -> list[InferenceChunk]: ) -> list[InferenceChunk]:
""" """
Fetch chunk(s) based on document id Fetch chunk(s) based on document id

View File

@ -5,7 +5,6 @@ import os
import string import string
import time import time
import zipfile import zipfile
from collections.abc import Callable
from collections.abc import Mapping from collections.abc import Mapping
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
@ -69,7 +68,6 @@ from danswer.search.retrieval.search_runner import query_processing
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.utils.batching import batch_generator from danswer.utils.batching import batch_generator
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
logger = setup_logger() logger = setup_logger()
@ -142,35 +140,58 @@ def _vespa_get_updated_at_attribute(t: datetime | None) -> int | None:
return int(t.timestamp()) return int(t.timestamp())
def _get_vespa_chunk_ids_by_document_id( def _get_vespa_chunks_by_document_id(
document_id: str, document_id: str,
index_name: str, index_name: str,
hits_per_page: int = _BATCH_SIZE, user_access_control_list: list[str] | None = None,
index_filters: IndexFilters | None = None, min_chunk_ind: int | None = None,
) -> list[str]: max_chunk_ind: int | None = None,
filters_str = ( field_names: list[str] | None = None,
_build_vespa_filters(filters=index_filters, include_hidden=True) ) -> list[dict]:
if index_filters is not None # Constructing the URL for the Visit API
else "" # NOTE: visit API uses the same URL as the document API, but with different params
) url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
offset = 0 # build the list of fields to retrieve
doc_chunk_ids = [] field_set_list = (
params: dict[str, int | str] = { None
"yql": f"select documentid from {index_name} where {filters_str}document_id contains '{document_id}'", if not field_names
"timeout": "10s", else [f"{index_name}:{field_name}" for field_name in field_names]
"offset": offset,
"hits": hits_per_page,
}
while True:
res = requests.post(SEARCH_ENDPOINT, json=params)
try:
res.raise_for_status()
except requests.HTTPError as e:
request_info = f"Headers: {res.request.headers}\nPayload: {params}"
response_info = (
f"Status Code: {res.status_code}\nResponse Content: {res.text}"
) )
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
if (
field_set_list
and user_access_control_list
and acl_fieldset_entry not in field_set_list
):
field_set_list.append(acl_fieldset_entry)
field_set = ",".join(field_set_list) if field_set_list else None
# build filters
selection = f"{index_name}.document_id=='{document_id}'"
if min_chunk_ind is not None:
selection += f" and {index_name}.chunk_id>={min_chunk_ind}"
if max_chunk_ind is not None:
selection += f" and {index_name}.chunk_id<={max_chunk_ind}"
# Setting up the selection criteria in the query parameters
params = {
# NOTE: Document Selector Language doesn't allow `contains`, so we can't check
# for the ACL in the selection. Instead, we have to check as a postfilter
"selection": selection,
"continuation": None,
"wantedDocumentCount": 1_000,
"fieldSet": field_set,
}
document_chunks: list[dict] = []
while True:
response = requests.get(url, params=params)
try:
response.raise_for_status()
except requests.HTTPError as e:
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}"
error_base = f"Error occurred getting chunk by Document ID {document_id}" error_base = f"Error occurred getting chunk by Document ID {document_id}"
logger.error( logger.error(
f"{error_base}:\n" f"{error_base}:\n"
@ -180,17 +201,39 @@ def _get_vespa_chunk_ids_by_document_id(
) )
raise requests.HTTPError(error_base) from e raise requests.HTTPError(error_base) from e
results = res.json() # Check if the response contains any documents
hits = results["root"].get("children", []) response_data = response.json()
if "documents" in response_data:
for document in response_data["documents"]:
if user_access_control_list:
document_acl = document["fields"].get(ACCESS_CONTROL_LIST)
if not document_acl or not any(
user_acl_entry in document_acl
for user_acl_entry in user_access_control_list
):
continue
document_chunks.append(document)
document_chunks.extend(response_data["documents"])
doc_chunk_ids.extend( # Check for continuation token to handle pagination
[hit["fields"]["documentid"].split("::", 1)[-1] for hit in hits] if "continuation" in response_data and response_data["continuation"]:
params["continuation"] = response_data["continuation"]
else:
break # Exit loop if no continuation token
return document_chunks
def _get_vespa_chunk_ids_by_document_id(
document_id: str, index_name: str, user_access_control_list: list[str] | None = None
) -> list[str]:
document_chunks = _get_vespa_chunks_by_document_id(
document_id=document_id,
index_name=index_name,
user_access_control_list=user_access_control_list,
field_names=[DOCUMENT_ID],
) )
params["offset"] += hits_per_page # type: ignore return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
if len(hits) < hits_per_page:
break
return doc_chunk_ids
@retry(tries=3, delay=1, backoff=2) @retry(tries=3, delay=1, backoff=2)
@ -872,43 +915,22 @@ class VespaIndex(DocumentIndex):
document_id: str, document_id: str,
min_chunk_ind: int | None, min_chunk_ind: int | None,
max_chunk_ind: int | None, max_chunk_ind: int | None,
filters: IndexFilters, user_access_control_list: list[str] | None = None,
) -> list[InferenceChunk]: ) -> list[InferenceChunk]:
if min_chunk_ind is None and max_chunk_ind is None: vespa_chunks = _get_vespa_chunks_by_document_id(
vespa_chunk_ids = _get_vespa_chunk_ids_by_document_id(
document_id=document_id, document_id=document_id,
index_name=self.index_name, index_name=self.index_name,
index_filters=filters, user_access_control_list=user_access_control_list,
min_chunk_ind=min_chunk_ind,
max_chunk_ind=max_chunk_ind,
) )
if not vespa_chunk_ids: if not vespa_chunks:
return [] return []
functions_with_args: list[tuple[Callable, tuple]] = [ inference_chunks = [
(_inference_chunk_by_vespa_id, (vespa_chunk_id, self.index_name)) _vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks
for vespa_chunk_id in vespa_chunk_ids
] ]
inference_chunks = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
return inference_chunks
filters_str = _build_vespa_filters(filters=filters, include_hidden=True)
yql = (
VespaIndex.yql_base.format(index_name=self.index_name)
+ filters_str
+ f"({DOCUMENT_ID} contains '{document_id}'"
)
if min_chunk_ind is not None:
yql += f" and {min_chunk_ind} <= {CHUNK_ID}"
if max_chunk_ind is not None:
yql += f" and {max_chunk_ind} >= {CHUNK_ID}"
yql = yql + ")"
inference_chunks = _query_vespa({"yql": yql})
inference_chunks.sort(key=lambda chunk: chunk.chunk_id) inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
return inference_chunks return inference_chunks

View File

@ -11,7 +11,6 @@ from danswer.db.models import User
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.llm.utils import get_default_llm_token_encode from danswer.llm.utils import get_default_llm_token_encode
from danswer.prompts.prompt_utils import build_doc_context_str from danswer.prompts.prompt_utils import build_doc_context_str
from danswer.search.models import IndexFilters
from danswer.search.preprocessing.access_filters import build_access_filters_for_user from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.server.documents.models import ChunkInfo from danswer.server.documents.models import ChunkInfo
from danswer.server.documents.models import DocumentInfo from danswer.server.documents.models import DocumentInfo
@ -35,13 +34,11 @@ def get_document_info(
) )
user_acl_filters = build_access_filters_for_user(user, db_session) user_acl_filters = build_access_filters_for_user(user, db_session)
filters = IndexFilters(access_control_list=user_acl_filters)
inference_chunks = document_index.id_based_retrieval( inference_chunks = document_index.id_based_retrieval(
document_id=document_id, document_id=document_id,
min_chunk_ind=None, min_chunk_ind=None,
max_chunk_ind=None, max_chunk_ind=None,
filters=filters, user_access_control_list=user_acl_filters,
) )
if not inference_chunks: if not inference_chunks:
@ -83,13 +80,11 @@ def get_chunk_info(
) )
user_acl_filters = build_access_filters_for_user(user, db_session) user_acl_filters = build_access_filters_for_user(user, db_session)
filters = IndexFilters(access_control_list=user_acl_filters)
inference_chunks = document_index.id_based_retrieval( inference_chunks = document_index.id_based_retrieval(
document_id=document_id, document_id=document_id,
min_chunk_ind=chunk_id, min_chunk_ind=chunk_id,
max_chunk_ind=chunk_id, max_chunk_ind=chunk_id,
filters=filters, user_access_control_list=user_acl_filters,
) )
if not inference_chunks: if not inference_chunks: