mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-31 02:01:16 +02:00
Use Vespa Visit to handle long documents
This commit is contained in:
parent
2952b1dd96
commit
683addc390
@ -185,7 +185,7 @@ class IdRetrievalCapable(abc.ABC):
|
||||
document_id: str,
|
||||
min_chunk_ind: int | None,
|
||||
max_chunk_ind: int | None,
|
||||
filters: IndexFilters,
|
||||
user_access_control_list: list[str] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Fetch chunk(s) based on document id
|
||||
|
@ -5,7 +5,6 @@ import os
|
||||
import string
|
||||
import time
|
||||
import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
@ -69,7 +68,6 @@ from danswer.search.retrieval.search_runner import query_processing
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.utils.batching import batch_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -142,35 +140,58 @@ def _vespa_get_updated_at_attribute(t: datetime | None) -> int | None:
|
||||
return int(t.timestamp())
|
||||
|
||||
|
||||
def _get_vespa_chunk_ids_by_document_id(
|
||||
def _get_vespa_chunks_by_document_id(
|
||||
document_id: str,
|
||||
index_name: str,
|
||||
hits_per_page: int = _BATCH_SIZE,
|
||||
index_filters: IndexFilters | None = None,
|
||||
) -> list[str]:
|
||||
filters_str = (
|
||||
_build_vespa_filters(filters=index_filters, include_hidden=True)
|
||||
if index_filters is not None
|
||||
else ""
|
||||
)
|
||||
user_access_control_list: list[str] | None = None,
|
||||
min_chunk_ind: int | None = None,
|
||||
max_chunk_ind: int | None = None,
|
||||
field_names: list[str] | None = None,
|
||||
) -> list[dict]:
|
||||
# Constructing the URL for the Visit API
|
||||
# NOTE: visit API uses the same URL as the document API, but with different params
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
offset = 0
|
||||
doc_chunk_ids = []
|
||||
params: dict[str, int | str] = {
|
||||
"yql": f"select documentid from {index_name} where {filters_str}document_id contains '{document_id}'",
|
||||
"timeout": "10s",
|
||||
"offset": offset,
|
||||
"hits": hits_per_page,
|
||||
# build the list of fields to retrieve
|
||||
field_set_list = (
|
||||
None
|
||||
if not field_names
|
||||
else [f"{index_name}:{field_name}" for field_name in field_names]
|
||||
)
|
||||
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
|
||||
if (
|
||||
field_set_list
|
||||
and user_access_control_list
|
||||
and acl_fieldset_entry not in field_set_list
|
||||
):
|
||||
field_set_list.append(acl_fieldset_entry)
|
||||
field_set = ",".join(field_set_list) if field_set_list else None
|
||||
|
||||
# build filters
|
||||
selection = f"{index_name}.document_id=='{document_id}'"
|
||||
if min_chunk_ind is not None:
|
||||
selection += f" and {index_name}.chunk_id>={min_chunk_ind}"
|
||||
if max_chunk_ind is not None:
|
||||
selection += f" and {index_name}.chunk_id<={max_chunk_ind}"
|
||||
|
||||
# Setting up the selection criteria in the query parameters
|
||||
params = {
|
||||
# NOTE: Document Selector Language doesn't allow `contains`, so we can't check
|
||||
# for the ACL in the selection. Instead, we have to check as a postfilter
|
||||
"selection": selection,
|
||||
"continuation": None,
|
||||
"wantedDocumentCount": 1_000,
|
||||
"fieldSet": field_set,
|
||||
}
|
||||
|
||||
document_chunks: list[dict] = []
|
||||
while True:
|
||||
res = requests.post(SEARCH_ENDPOINT, json=params)
|
||||
response = requests.get(url, params=params)
|
||||
try:
|
||||
res.raise_for_status()
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
request_info = f"Headers: {res.request.headers}\nPayload: {params}"
|
||||
response_info = (
|
||||
f"Status Code: {res.status_code}\nResponse Content: {res.text}"
|
||||
)
|
||||
request_info = f"Headers: {response.request.headers}\nPayload: {params}"
|
||||
response_info = f"Status Code: {response.status_code}\nResponse Content: {response.text}"
|
||||
error_base = f"Error occurred getting chunk by Document ID {document_id}"
|
||||
logger.error(
|
||||
f"{error_base}:\n"
|
||||
@ -180,17 +201,39 @@ def _get_vespa_chunk_ids_by_document_id(
|
||||
)
|
||||
raise requests.HTTPError(error_base) from e
|
||||
|
||||
results = res.json()
|
||||
hits = results["root"].get("children", [])
|
||||
# Check if the response contains any documents
|
||||
response_data = response.json()
|
||||
if "documents" in response_data:
|
||||
for document in response_data["documents"]:
|
||||
if user_access_control_list:
|
||||
document_acl = document["fields"].get(ACCESS_CONTROL_LIST)
|
||||
if not document_acl or not any(
|
||||
user_acl_entry in document_acl
|
||||
for user_acl_entry in user_access_control_list
|
||||
):
|
||||
continue
|
||||
document_chunks.append(document)
|
||||
document_chunks.extend(response_data["documents"])
|
||||
|
||||
doc_chunk_ids.extend(
|
||||
[hit["fields"]["documentid"].split("::", 1)[-1] for hit in hits]
|
||||
)
|
||||
params["offset"] += hits_per_page # type: ignore
|
||||
# Check for continuation token to handle pagination
|
||||
if "continuation" in response_data and response_data["continuation"]:
|
||||
params["continuation"] = response_data["continuation"]
|
||||
else:
|
||||
break # Exit loop if no continuation token
|
||||
|
||||
if len(hits) < hits_per_page:
|
||||
break
|
||||
return doc_chunk_ids
|
||||
return document_chunks
|
||||
|
||||
|
||||
def _get_vespa_chunk_ids_by_document_id(
|
||||
document_id: str, index_name: str, user_access_control_list: list[str] | None = None
|
||||
) -> list[str]:
|
||||
document_chunks = _get_vespa_chunks_by_document_id(
|
||||
document_id=document_id,
|
||||
index_name=index_name,
|
||||
user_access_control_list=user_access_control_list,
|
||||
field_names=[DOCUMENT_ID],
|
||||
)
|
||||
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
||||
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
@ -872,43 +915,22 @@ class VespaIndex(DocumentIndex):
|
||||
document_id: str,
|
||||
min_chunk_ind: int | None,
|
||||
max_chunk_ind: int | None,
|
||||
filters: IndexFilters,
|
||||
user_access_control_list: list[str] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
if min_chunk_ind is None and max_chunk_ind is None:
|
||||
vespa_chunk_ids = _get_vespa_chunk_ids_by_document_id(
|
||||
document_id=document_id,
|
||||
index_name=self.index_name,
|
||||
index_filters=filters,
|
||||
)
|
||||
|
||||
if not vespa_chunk_ids:
|
||||
return []
|
||||
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(_inference_chunk_by_vespa_id, (vespa_chunk_id, self.index_name))
|
||||
for vespa_chunk_id in vespa_chunk_ids
|
||||
]
|
||||
|
||||
inference_chunks = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
|
||||
return inference_chunks
|
||||
|
||||
filters_str = _build_vespa_filters(filters=filters, include_hidden=True)
|
||||
yql = (
|
||||
VespaIndex.yql_base.format(index_name=self.index_name)
|
||||
+ filters_str
|
||||
+ f"({DOCUMENT_ID} contains '{document_id}'"
|
||||
vespa_chunks = _get_vespa_chunks_by_document_id(
|
||||
document_id=document_id,
|
||||
index_name=self.index_name,
|
||||
user_access_control_list=user_access_control_list,
|
||||
min_chunk_ind=min_chunk_ind,
|
||||
max_chunk_ind=max_chunk_ind,
|
||||
)
|
||||
|
||||
if min_chunk_ind is not None:
|
||||
yql += f" and {min_chunk_ind} <= {CHUNK_ID}"
|
||||
if max_chunk_ind is not None:
|
||||
yql += f" and {max_chunk_ind} >= {CHUNK_ID}"
|
||||
yql = yql + ")"
|
||||
if not vespa_chunks:
|
||||
return []
|
||||
|
||||
inference_chunks = _query_vespa({"yql": yql})
|
||||
inference_chunks = [
|
||||
_vespa_hit_to_inference_chunk(chunk) for chunk in vespa_chunks
|
||||
]
|
||||
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
|
||||
return inference_chunks
|
||||
|
||||
|
@ -11,7 +11,6 @@ from danswer.db.models import User
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.prompts.prompt_utils import build_doc_context_str
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.server.documents.models import ChunkInfo
|
||||
from danswer.server.documents.models import DocumentInfo
|
||||
@ -35,13 +34,11 @@ def get_document_info(
|
||||
)
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
filters = IndexFilters(access_control_list=user_acl_filters)
|
||||
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
document_id=document_id,
|
||||
min_chunk_ind=None,
|
||||
max_chunk_ind=None,
|
||||
filters=filters,
|
||||
user_access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
if not inference_chunks:
|
||||
@ -83,13 +80,11 @@ def get_chunk_info(
|
||||
)
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
filters = IndexFilters(access_control_list=user_acl_filters)
|
||||
|
||||
inference_chunks = document_index.id_based_retrieval(
|
||||
document_id=document_id,
|
||||
min_chunk_ind=chunk_id,
|
||||
max_chunk_ind=chunk_id,
|
||||
filters=filters,
|
||||
user_access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
if not inference_chunks:
|
||||
|
Loading…
x
Reference in New Issue
Block a user