mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-13 13:20:15 +02:00
431 lines
16 KiB
Python
431 lines
16 KiB
Python
import json
|
|
import string
|
|
from collections.abc import Callable
|
|
from collections.abc import Mapping
|
|
from datetime import datetime
|
|
from datetime import timezone
|
|
from typing import Any
|
|
from typing import cast
|
|
|
|
import httpx
|
|
from retry import retry
|
|
|
|
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
|
from onyx.context.search.models import IndexFilters
|
|
from onyx.context.search.models import InferenceChunkUncleaned
|
|
from onyx.document_index.interfaces import VespaChunkRequest
|
|
from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client
|
|
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
|
build_vespa_filters,
|
|
)
|
|
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
|
build_vespa_id_based_retrieval_yql,
|
|
)
|
|
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
|
from onyx.document_index.vespa_constants import BLURB
|
|
from onyx.document_index.vespa_constants import BOOST
|
|
from onyx.document_index.vespa_constants import CHUNK_ID
|
|
from onyx.document_index.vespa_constants import CONTENT
|
|
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
|
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
|
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
|
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
|
from onyx.document_index.vespa_constants import HIDDEN
|
|
from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS
|
|
from onyx.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE
|
|
from onyx.document_index.vespa_constants import MAX_OR_CONDITIONS
|
|
from onyx.document_index.vespa_constants import METADATA
|
|
from onyx.document_index.vespa_constants import METADATA_SUFFIX
|
|
from onyx.document_index.vespa_constants import PRIMARY_OWNERS
|
|
from onyx.document_index.vespa_constants import RECENCY_BIAS
|
|
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
|
|
from onyx.document_index.vespa_constants import SECONDARY_OWNERS
|
|
from onyx.document_index.vespa_constants import SECTION_CONTINUATION
|
|
from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER
|
|
from onyx.document_index.vespa_constants import SOURCE_LINKS
|
|
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
|
from onyx.document_index.vespa_constants import TITLE
|
|
from onyx.document_index.vespa_constants import YQL_BASE
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def _process_dynamic_summary(
|
|
dynamic_summary: str, max_summary_length: int = 400
|
|
) -> list[str]:
|
|
if not dynamic_summary:
|
|
return []
|
|
|
|
current_length = 0
|
|
processed_summary: list[str] = []
|
|
for summary_section in dynamic_summary.split("<sep />"):
|
|
# if we're past the desired max length, break at the last word
|
|
if current_length + len(summary_section) >= max_summary_length:
|
|
summary_section = summary_section[: max_summary_length - current_length]
|
|
summary_section = summary_section.lstrip() # remove any leading whitespace
|
|
|
|
# handle the case where the truncated section is either just a
|
|
# single (partial) word or if it's empty
|
|
first_space = summary_section.find(" ")
|
|
if first_space == -1:
|
|
# add ``...`` to previous section
|
|
if processed_summary:
|
|
processed_summary[-1] += "..."
|
|
break
|
|
|
|
# handle the valid truncated section case
|
|
summary_section = summary_section.rsplit(" ", 1)[0]
|
|
if summary_section[-1] in string.punctuation:
|
|
summary_section = summary_section[:-1]
|
|
summary_section += "..."
|
|
processed_summary.append(summary_section)
|
|
break
|
|
|
|
processed_summary.append(summary_section)
|
|
current_length += len(summary_section)
|
|
|
|
return processed_summary
|
|
|
|
|
|
def _vespa_hit_to_inference_chunk(
|
|
hit: dict[str, Any], null_score: bool = False
|
|
) -> InferenceChunkUncleaned:
|
|
fields = cast(dict[str, Any], hit["fields"])
|
|
|
|
# parse fields that are stored as strings, but are really json / datetime
|
|
metadata = json.loads(fields[METADATA]) if METADATA in fields else {}
|
|
updated_at = (
|
|
datetime.fromtimestamp(fields[DOC_UPDATED_AT], tz=timezone.utc)
|
|
if DOC_UPDATED_AT in fields
|
|
else None
|
|
)
|
|
|
|
match_highlights = _process_dynamic_summary(
|
|
# fallback to regular `content` if the `content_summary` field
|
|
# isn't present
|
|
dynamic_summary=hit["fields"].get(CONTENT_SUMMARY, hit["fields"][CONTENT]),
|
|
)
|
|
semantic_identifier = fields.get(SEMANTIC_IDENTIFIER, "")
|
|
if not semantic_identifier:
|
|
logger.error(
|
|
f"Chunk with blurb: {fields.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier"
|
|
)
|
|
|
|
source_links = fields.get(SOURCE_LINKS, {})
|
|
source_links_dict_unprocessed = (
|
|
json.loads(source_links) if isinstance(source_links, str) else source_links
|
|
)
|
|
source_links_dict = {
|
|
int(k): v
|
|
for k, v in cast(dict[str, str], source_links_dict_unprocessed).items()
|
|
}
|
|
|
|
return InferenceChunkUncleaned(
|
|
chunk_id=fields[CHUNK_ID],
|
|
blurb=fields.get(BLURB, ""), # Unused
|
|
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
|
|
source_links=source_links_dict or {0: ""},
|
|
section_continuation=fields[SECTION_CONTINUATION],
|
|
document_id=fields[DOCUMENT_ID],
|
|
source_type=fields[SOURCE_TYPE],
|
|
title=fields.get(TITLE),
|
|
semantic_identifier=fields[SEMANTIC_IDENTIFIER],
|
|
boost=fields.get(BOOST, 1),
|
|
recency_bias=fields.get("matchfeatures", {}).get(RECENCY_BIAS, 1.0),
|
|
score=None if null_score else hit.get("relevance", 0),
|
|
hidden=fields.get(HIDDEN, False),
|
|
primary_owners=fields.get(PRIMARY_OWNERS),
|
|
secondary_owners=fields.get(SECONDARY_OWNERS),
|
|
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
|
|
metadata=metadata,
|
|
metadata_suffix=fields.get(METADATA_SUFFIX),
|
|
match_highlights=match_highlights,
|
|
updated_at=updated_at,
|
|
)
|
|
|
|
|
|
def _get_chunks_via_visit_api(
|
|
chunk_request: VespaChunkRequest,
|
|
index_name: str,
|
|
filters: IndexFilters,
|
|
field_names: list[str] | None = None,
|
|
get_large_chunks: bool = False,
|
|
) -> list[dict]:
|
|
# Constructing the URL for the Visit API
|
|
# NOTE: visit API uses the same URL as the document API, but with different params
|
|
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
|
|
|
# build the list of fields to retrieve
|
|
field_set_list = (
|
|
None
|
|
if not field_names
|
|
else [f"{index_name}:{field_name}" for field_name in field_names]
|
|
)
|
|
acl_fieldset_entry = f"{index_name}:{ACCESS_CONTROL_LIST}"
|
|
if (
|
|
field_set_list
|
|
and filters.access_control_list
|
|
and acl_fieldset_entry not in field_set_list
|
|
):
|
|
field_set_list.append(acl_fieldset_entry)
|
|
field_set = ",".join(field_set_list) if field_set_list else None
|
|
|
|
# build filters
|
|
selection = f"{index_name}.document_id=='{chunk_request.document_id}'"
|
|
|
|
if chunk_request.is_capped:
|
|
selection += f" and {index_name}.chunk_id>={chunk_request.min_chunk_ind or 0}"
|
|
selection += f" and {index_name}.chunk_id<={chunk_request.max_chunk_ind}"
|
|
if not get_large_chunks:
|
|
selection += f" and {index_name}.large_chunk_reference_ids == null"
|
|
|
|
# Setting up the selection criteria in the query parameters
|
|
params = {
|
|
# NOTE: Document Selector Language doesn't allow `contains`, so we can't check
|
|
# for the ACL in the selection. Instead, we have to check as a postfilter
|
|
"selection": selection,
|
|
"continuation": None,
|
|
"wantedDocumentCount": 1_000,
|
|
"fieldSet": field_set,
|
|
}
|
|
|
|
document_chunks: list[dict] = []
|
|
while True:
|
|
try:
|
|
filtered_params = {k: v for k, v in params.items() if v is not None}
|
|
with get_vespa_http_client() as http_client:
|
|
response = http_client.get(url, params=filtered_params)
|
|
response.raise_for_status()
|
|
except httpx.HTTPError as e:
|
|
error_base = "Failed to query Vespa"
|
|
logger.error(
|
|
f"{error_base}:\n"
|
|
f"Request URL: {e.request.url}\n"
|
|
f"Request Headers: {e.request.headers}\n"
|
|
f"Request Payload: {params}\n"
|
|
f"Exception: {str(e)}"
|
|
)
|
|
raise httpx.HTTPError(error_base) from e
|
|
|
|
# Check if the response contains any documents
|
|
response_data = response.json()
|
|
if "documents" in response_data:
|
|
for document in response_data["documents"]:
|
|
if filters.access_control_list:
|
|
document_acl = document["fields"].get(ACCESS_CONTROL_LIST)
|
|
if not document_acl or not any(
|
|
user_acl_entry in document_acl
|
|
for user_acl_entry in filters.access_control_list
|
|
):
|
|
continue
|
|
document_chunks.append(document)
|
|
|
|
# Check for continuation token to handle pagination
|
|
if "continuation" in response_data and response_data["continuation"]:
|
|
params["continuation"] = response_data["continuation"]
|
|
else:
|
|
break # Exit loop if no continuation token
|
|
|
|
return document_chunks
|
|
|
|
|
|
@retry(tries=10, delay=1, backoff=2)
|
|
def get_all_vespa_ids_for_document_id(
|
|
document_id: str,
|
|
index_name: str,
|
|
filters: IndexFilters | None = None,
|
|
get_large_chunks: bool = False,
|
|
) -> list[str]:
|
|
document_chunks = _get_chunks_via_visit_api(
|
|
chunk_request=VespaChunkRequest(document_id=document_id),
|
|
index_name=index_name,
|
|
filters=filters or IndexFilters(access_control_list=None),
|
|
field_names=[DOCUMENT_ID],
|
|
get_large_chunks=get_large_chunks,
|
|
)
|
|
return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks]
|
|
|
|
|
|
def parallel_visit_api_retrieval(
|
|
index_name: str,
|
|
chunk_requests: list[VespaChunkRequest],
|
|
filters: IndexFilters,
|
|
get_large_chunks: bool = False,
|
|
) -> list[InferenceChunkUncleaned]:
|
|
functions_with_args: list[tuple[Callable, tuple]] = [
|
|
(
|
|
_get_chunks_via_visit_api,
|
|
(chunk_request, index_name, filters, get_large_chunks),
|
|
)
|
|
for chunk_request in chunk_requests
|
|
]
|
|
|
|
parallel_results = run_functions_tuples_in_parallel(
|
|
functions_with_args, allow_failures=True
|
|
)
|
|
|
|
# Any failures to retrieve would give a None, drop the Nones and empty lists
|
|
vespa_chunk_sets = [res for res in parallel_results if res]
|
|
|
|
flattened_vespa_chunks = []
|
|
for chunk_set in vespa_chunk_sets:
|
|
flattened_vespa_chunks.extend(chunk_set)
|
|
|
|
inference_chunks = [
|
|
_vespa_hit_to_inference_chunk(chunk, null_score=True)
|
|
for chunk in flattened_vespa_chunks
|
|
]
|
|
|
|
return inference_chunks
|
|
|
|
|
|
@retry(tries=3, delay=1, backoff=2)
|
|
def query_vespa(
|
|
query_params: Mapping[str, str | int | float]
|
|
) -> list[InferenceChunkUncleaned]:
|
|
if "query" in query_params and not cast(str, query_params["query"]).strip():
|
|
raise ValueError("No/empty query received")
|
|
|
|
params = dict(
|
|
**query_params,
|
|
**{
|
|
"presentation.timing": True,
|
|
}
|
|
if LOG_VESPA_TIMING_INFORMATION
|
|
else {},
|
|
)
|
|
|
|
try:
|
|
with get_vespa_http_client() as http_client:
|
|
response = http_client.post(SEARCH_ENDPOINT, json=params)
|
|
response.raise_for_status()
|
|
except httpx.HTTPError as e:
|
|
error_base = "Failed to query Vespa"
|
|
logger.error(
|
|
f"{error_base}:\n"
|
|
f"Request URL: {e.request.url}\n"
|
|
f"Request Headers: {e.request.headers}\n"
|
|
f"Request Payload: {params}\n"
|
|
f"Exception: {str(e)}"
|
|
)
|
|
raise httpx.HTTPError(error_base) from e
|
|
|
|
response_json: dict[str, Any] = response.json()
|
|
|
|
if LOG_VESPA_TIMING_INFORMATION:
|
|
logger.debug("Vespa timing info: %s", response_json.get("timing"))
|
|
hits = response_json["root"].get("children", [])
|
|
|
|
if not hits:
|
|
logger.warning(
|
|
f"No hits found for YQL Query: {query_params.get('yql', 'No YQL Query')}"
|
|
)
|
|
logger.debug(f"Vespa Response: {response.text}")
|
|
|
|
for hit in hits:
|
|
if hit["fields"].get(CONTENT) is None:
|
|
identifier = hit["fields"].get("documentid") or hit["id"]
|
|
logger.error(
|
|
f"Vespa Index with Vespa ID {identifier} has no contents. "
|
|
f"This is invalid because the vector is not meaningful and keywordsearch cannot "
|
|
f"fetch this document"
|
|
)
|
|
|
|
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
|
|
|
|
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
|
|
# Good Debugging Spot
|
|
return inference_chunks
|
|
|
|
|
|
def _get_chunks_via_batch_search(
|
|
index_name: str,
|
|
chunk_requests: list[VespaChunkRequest],
|
|
filters: IndexFilters,
|
|
get_large_chunks: bool = False,
|
|
) -> list[InferenceChunkUncleaned]:
|
|
if not chunk_requests:
|
|
return []
|
|
|
|
filters_str = build_vespa_filters(filters=filters, include_hidden=True)
|
|
|
|
yql = (
|
|
YQL_BASE.format(index_name=index_name)
|
|
+ filters_str
|
|
+ build_vespa_id_based_retrieval_yql(chunk_requests[0])
|
|
)
|
|
chunk_requests.pop(0)
|
|
|
|
for request in chunk_requests:
|
|
yql += " or " + build_vespa_id_based_retrieval_yql(request)
|
|
params: dict[str, str | int | float] = {
|
|
"yql": yql,
|
|
"hits": MAX_ID_SEARCH_QUERY_SIZE,
|
|
}
|
|
|
|
inference_chunks = query_vespa(params)
|
|
if not get_large_chunks:
|
|
inference_chunks = [
|
|
chunk for chunk in inference_chunks if not chunk.large_chunk_reference_ids
|
|
]
|
|
inference_chunks.sort(key=lambda chunk: chunk.chunk_id)
|
|
return inference_chunks
|
|
|
|
|
|
def batch_search_api_retrieval(
|
|
index_name: str,
|
|
chunk_requests: list[VespaChunkRequest],
|
|
filters: IndexFilters,
|
|
get_large_chunks: bool = False,
|
|
) -> list[InferenceChunkUncleaned]:
|
|
retrieved_chunks: list[InferenceChunkUncleaned] = []
|
|
capped_requests: list[VespaChunkRequest] = []
|
|
uncapped_requests: list[VespaChunkRequest] = []
|
|
chunk_count = 0
|
|
for req_ind, request in enumerate(chunk_requests, start=1):
|
|
# All requests without a chunk range are uncapped
|
|
# Uncapped requests are retrieved using the Visit API
|
|
range = request.range
|
|
if range is None:
|
|
uncapped_requests.append(request)
|
|
continue
|
|
|
|
if (
|
|
chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE
|
|
or req_ind % MAX_OR_CONDITIONS == 0
|
|
):
|
|
retrieved_chunks.extend(
|
|
_get_chunks_via_batch_search(
|
|
index_name=index_name,
|
|
chunk_requests=capped_requests,
|
|
filters=filters,
|
|
get_large_chunks=get_large_chunks,
|
|
)
|
|
)
|
|
capped_requests = []
|
|
chunk_count = 0
|
|
capped_requests.append(request)
|
|
chunk_count += range
|
|
|
|
if capped_requests:
|
|
retrieved_chunks.extend(
|
|
_get_chunks_via_batch_search(
|
|
index_name=index_name,
|
|
chunk_requests=capped_requests,
|
|
filters=filters,
|
|
get_large_chunks=get_large_chunks,
|
|
)
|
|
)
|
|
|
|
if uncapped_requests:
|
|
logger.debug(f"Retrieving {len(uncapped_requests)} uncapped requests")
|
|
retrieved_chunks.extend(
|
|
parallel_visit_api_retrieval(
|
|
index_name, uncapped_requests, filters, get_large_chunks
|
|
)
|
|
)
|
|
|
|
return retrieved_chunks
|