Feature/vespa jinja (#4558)

* tool to generate vespa schema variations for our cloud

* extraneous assign

* use a real templating system instead of search/replace

* fix float

* maybe this should be double

* remove redundant var

* template the other files

* try a spawned process

* move the wrapper

* fix args

* increase timeout

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer
2025-04-20 15:28:55 -07:00
committed by GitHub
parent 87478c5ca6
commit 2111eccf07
8 changed files with 180 additions and 147 deletions

View File

@ -1,11 +1,17 @@
schema DANSWER_CHUNK_NAME {
document DANSWER_CHUNK_NAME {
TENANT_ID_REPLACEMENT
schema {{ schema_name }} {
document {{ schema_name }} {
{% if multi_tenant %}
field tenant_id type string {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
{% endif %}
# Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string {
indexing: summary | attribute
attribute: fast-search
rank: filter
attribute: fast-search
}
field chunk_id type int {
indexing: summary | attribute
@ -37,7 +43,7 @@ schema DANSWER_CHUNK_NAME {
summary: dynamic
}
# Title embedding (x1)
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
field title_embedding type tensor<{{ embedding_precision }}>(x[{{ dim }}]) {
indexing: attribute | index
attribute {
distance-metric: angular
@ -45,7 +51,7 @@ schema DANSWER_CHUNK_NAME {
}
# Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
field embeddings type tensor<{{ embedding_precision }}>(t{},x[{{ dim }}]) {
indexing: attribute | index
attribute {
distance-metric: angular
@ -176,9 +182,9 @@ schema DANSWER_CHUNK_NAME {
match-features: recency_bias
}
rank-profile hybrid_search_semantic_base_VARIABLE_DIM inherits default, default_rank {
rank-profile hybrid_search_semantic_base_{{ dim }} inherits default, default_rank {
inputs {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
query(query_embedding) tensor<float>(x[{{ dim }}])
}
function title_vector_score() {
@ -244,9 +250,9 @@ schema DANSWER_CHUNK_NAME {
}
rank-profile hybrid_search_keyword_base_VARIABLE_DIM inherits default, default_rank {
rank-profile hybrid_search_keyword_base_{{ dim }} inherits default, default_rank {
inputs {
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
query(query_embedding) tensor<float>(x[{{ dim }}])
}
function title_vector_score() {

View File

@ -14,7 +14,7 @@
<redundancy>1</redundancy>
<documents>
<!-- <document type="danswer_chunk" mode="index" /> -->
DOCUMENT_REPLACEMENT
{{ document_elements }}
</documents>
<nodes>
<node hostalias="danswer-node" distribution-key="0" />
@ -31,7 +31,7 @@
<tuning>
<searchnode>
<requestthreads>
<persearch>SEARCH_THREAD_NUMBER</persearch>
<persearch>{{ num_search_threads }}</persearch>
</requestthreads>
</searchnode>
</tuning>

View File

@ -1,11 +1,11 @@
<validation-overrides>
<allow
until="DATE_REPLACEMENT"
until="{{ until_date }}"
comment="We need to be able to create/delete indices for swapping models">schema-removal</allow>
<allow
until="DATE_REPLACEMENT"
until="{{ until_date }}"
comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow>
<allow
until='DATE_REPLACEMENT'
until="{{ until_date }}"
comment="Prevents old alt indices from interfering with changes">field-type-change</allow>
</validation-overrides>

View File

@ -16,6 +16,7 @@ from typing import List
from uuid import UUID
import httpx # type: ignore
import jinja2
import requests # type: ignore
from retry import retry
@ -61,21 +62,13 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import BATCH_SIZE
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DANSWER_CHUNK_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import DATE_REPLACEMENT
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
from onyx.document_index.vespa_constants import TENANT_ID_PAT
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
@ -118,28 +111,6 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
return "\n".join(doc_lines)
def _replace_template_values_in_schema(
schema_template: str,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
) -> str:
return (
schema_template.replace(
EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value
)
.replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name)
.replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
)
def _replace_tenant_template_value_in_schema(
schema_template: str,
tenant_field: str,
) -> str:
return schema_template.replace(TENANT_ID_PAT, tenant_field)
def add_ngrams_to_schema(schema_content: str) -> str:
# Add the match blocks containing gram and gram-size to title and content fields
schema_content = re.sub(
@ -156,6 +127,9 @@ def add_ngrams_to_schema(schema_content: str) -> str:
class VespaIndex(DocumentIndex):
VESPA_SCHEMA_JINJA_FILENAME = "danswer_chunk.sd.jinja"
def __init__(
self,
index_name: str,
@ -202,25 +176,31 @@ class VespaIndex(DocumentIndex):
)
return None
jinja_env = jinja2.Environment()
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.notice(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
os.getcwd(), "onyx", "document_index", "vespa", "app_config"
)
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd")
services_file = os.path.join(vespa_schema_path, "services.xml")
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml")
with open(services_file, "r") as services_f:
services_template = services_f.read()
schema_jinja_file = os.path.join(
vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME
)
services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja")
overrides_jinja_file = os.path.join(
vespa_schema_path, "validation-overrides.xml.jinja"
)
with open(services_jinja_file, "r") as services_f:
schema_names = [self.index_name, self.secondary_index_name]
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
services = services.replace(
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
services_template_str = services_f.read()
services_template = jinja_env.from_string(services_template_str)
services = services_template.render(
document_elements=doc_lines,
num_search_threads=str(VESPA_SEARCHER_THREADS),
)
kv_store = get_shared_kv_store()
@ -231,30 +211,33 @@ class VespaIndex(DocumentIndex):
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
# Vespa requires an override to erase data including the indices we're no longer using
# It also has a 30 day cap from current so we set it to 7 dynamically
with open(overrides_jinja_file, "r") as overrides_f:
overrides_template_str = overrides_f.read()
overrides_template = jinja_env.from_string(overrides_template_str)
now = datetime.now()
date_in_7_days = now + timedelta(days=7)
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date)
overrides = overrides_template.render(
until_date=formatted_date,
)
zip_dict = {
"services.xml": services.encode("utf-8"),
"validation-overrides.xml": overrides.encode("utf-8"),
}
with open(schema_file, "r") as schema_f:
schema_template = schema_f.read()
schema = _replace_tenant_template_value_in_schema(schema_template, "")
schema = _replace_template_values_in_schema(
schema,
self.index_name,
primary_embedding_dim,
primary_embedding_precision,
with open(schema_jinja_file, "r") as schema_f:
template_str = schema_f.read()
template = jinja_env.from_string(template_str)
schema = template.render(
multi_tenant=MULTI_TENANT,
schema_name=self.index_name,
dim=primary_embedding_dim,
embedding_precision=primary_embedding_precision.value,
)
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
@ -266,12 +249,13 @@ class VespaIndex(DocumentIndex):
if secondary_index_embedding_precision is None:
raise ValueError("Secondary index embedding precision is required")
upcoming_schema = _replace_template_values_in_schema(
schema_template,
self.secondary_index_name,
secondary_index_embedding_dim,
secondary_index_embedding_precision,
upcoming_schema = template.render(
multi_tenant=MULTI_TENANT,
schema_name=self.secondary_index_name,
dim=secondary_index_embedding_dim,
embedding_precision=secondary_index_embedding_precision.value,
)
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
@ -301,23 +285,26 @@ class VespaIndex(DocumentIndex):
vespa_schema_path = os.path.join(
os.getcwd(), "onyx", "document_index", "vespa", "app_config"
)
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd")
services_file = os.path.join(vespa_schema_path, "services.xml")
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml")
schema_jinja_file = os.path.join(
vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME
)
services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja")
overrides_jinja_file = os.path.join(
vespa_schema_path, "validation-overrides.xml.jinja"
)
with open(services_file, "r") as services_f:
services_template = services_f.read()
jinja_env = jinja2.Environment()
# Generate schema names from index settings
with open(services_jinja_file, "r") as services_f:
schema_names = [index_name for index_name in indices]
doc_lines = _create_document_xml_lines(schema_names)
full_schemas = schema_names
doc_lines = _create_document_xml_lines(full_schemas)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
services = services.replace(
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
services_template_str = services_f.read()
services_template = jinja_env.from_string(services_template_str)
services = services_template.render(
document_elements=doc_lines,
num_search_threads=str(VESPA_SEARCHER_THREADS),
)
kv_store = get_shared_kv_store()
@ -328,24 +315,28 @@ class VespaIndex(DocumentIndex):
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
# Vespa requires an override to erase data including the indices we're no longer using
# It also has a 30 day cap from current so we set it to 7 dynamically
with open(overrides_jinja_file, "r") as overrides_f:
overrides_template_str = overrides_f.read()
overrides_template = jinja_env.from_string(overrides_template_str)
now = datetime.now()
date_in_7_days = now + timedelta(days=7)
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date)
overrides = overrides_template.render(
until_date=formatted_date,
)
zip_dict = {
"services.xml": services.encode("utf-8"),
"validation-overrides.xml": overrides.encode("utf-8"),
}
with open(schema_file, "r") as schema_f:
schema_template = schema_f.read()
with open(schema_jinja_file, "r") as schema_f:
schema_template_str = schema_f.read()
schema_template = jinja_env.from_string(schema_template_str)
for i, index_name in enumerate(indices):
embedding_dim = embedding_dims[i]
@ -354,15 +345,11 @@ class VespaIndex(DocumentIndex):
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
)
schema = _replace_template_values_in_schema(
schema_template, index_name, embedding_dim, embedding_precision
)
tenant_id_replacement = ""
if MULTI_TENANT:
tenant_id_replacement = TENANT_ID_REPLACEMENT
schema = _replace_tenant_template_value_in_schema(
schema, tenant_id_replacement
schema = schema_template.render(
multi_tenant=MULTI_TENANT,
schema_name=index_name,
dim=embedding_dim,
embedding_precision=embedding_precision.value,
)
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema

View File

@ -5,20 +5,6 @@ from onyx.configs.app_configs import VESPA_PORT
from onyx.configs.app_configs import VESPA_TENANT_PORT
from onyx.configs.constants import SOURCE_TYPE
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION"
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
DATE_REPLACEMENT = "DATE_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
TENANT_ID_PAT = "TENANT_ID_REPLACEMENT"
TENANT_ID_REPLACEMENT = """field tenant_id type string {
indexing: summary | attribute
rank: filter
attribute: fast-search
}"""
# config server
@ -31,7 +17,7 @@ VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}"
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd.jinja
DOCUMENT_ID_ENDPOINT = (
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
)

View File

@ -2,43 +2,47 @@
import argparse
import jinja2
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.vespa.index import _replace_template_values_in_schema
from onyx.document_index.vespa.index import _replace_tenant_template_value_in_schema
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from onyx.utils.logger import setup_logger
from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS
logger = setup_logger()
def write_schema(index_name: str, dim: int, template: str) -> None:
def write_schema(index_name: str, dim: int, template: jinja2.Template) -> None:
index_filename = index_name + ".sd"
index_rendered_str = _replace_tenant_template_value_in_schema(
template, TENANT_ID_REPLACEMENT
)
index_rendered_str = _replace_template_values_in_schema(
index_rendered_str, index_name, dim, EmbeddingPrecision.FLOAT
schema = template.render(
multi_tenant=True,
schema_name=index_name,
dim=dim,
embedding_precision=EmbeddingPrecision.FLOAT.value,
)
with open(index_filename, "w", encoding="utf-8") as f:
f.write(index_rendered_str)
f.write(schema)
logger.info(f"Wrote {index_filename}")
def main() -> None:
parser = argparse.ArgumentParser(description="Generate multi tenant Vespa schemas")
parser.add_argument("--template", help="The schema template to use", required=True)
parser.add_argument("--template", help="The Jinja template to use", required=True)
args = parser.parse_args()
jinja_env = jinja2.Environment()
with open(args.template, "r", encoding="utf-8") as f:
template_str = f.read()
template = jinja_env.from_string(template_str)
num_indexes = 0
for model in SUPPORTED_EMBEDDING_MODELS:
write_schema(model.index_name, model.dim, template_str)
write_schema(model.index_name + "__danswer_alt_index", model.dim, template_str)
write_schema(model.index_name, model.dim, template)
write_schema(model.index_name + "__danswer_alt_index", model.dim, template)
num_indexes += 2
logger.info(f"Wrote {num_indexes} indexes.")

View File

@ -156,7 +156,7 @@ def reset_postgres(
"""Reset the Postgres database."""
# this seems to hang due to locking issues, so run with a timeout with a few retries
NUM_TRIES = 10
TIMEOUT = 10
TIMEOUT = 40
success = False
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")

View File

@ -1,20 +1,70 @@
# import multiprocessing
# from collections.abc import Callable
# from typing import Any
# from typing import TypeVar
# T = TypeVar("T")
# def run_with_timeout_multiproc(
# task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
# ) -> T:
# # Use multiprocessing to prevent a thread from blocking the main thread
# with multiprocessing.Pool(processes=1) as pool:
# async_result = pool.apply_async(task, kwds=kwargs)
# try:
# # Wait at most timeout seconds for the function to complete
# result = async_result.get(timeout=timeout)
# return result
# except multiprocessing.TimeoutError:
# raise TimeoutError(f"Function timed out after {timeout} seconds")
import multiprocessing
import traceback
from collections.abc import Callable
from multiprocessing import Queue
from typing import Any
from typing import TypeVar
T = TypeVar("T")
def _multiproc_wrapper(
task: Callable[..., T], kwargs: dict[str, Any], q: Queue
) -> None:
try:
result = task(**kwargs)
q.put(("success", result))
except Exception:
q.put(("error", traceback.format_exc()))
def run_with_timeout_multiproc(
task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
) -> T:
# Use multiprocessing to prevent a thread from blocking the main thread
with multiprocessing.Pool(processes=1) as pool:
async_result = pool.apply_async(task, kwds=kwargs)
try:
# Wait at most timeout seconds for the function to complete
result = async_result.get(timeout=timeout)
ctx = multiprocessing.get_context("spawn")
q: Queue = ctx.Queue()
p = ctx.Process(
target=_multiproc_wrapper,
args=(
task,
kwargs,
q,
),
)
p.start()
p.join(timeout)
if p.is_alive():
p.terminate()
raise TimeoutError(f"{task.__name__} timed out after {timeout} seconds")
if not q.empty():
status, result = q.get()
if status == "success":
return result
except multiprocessing.TimeoutError:
raise TimeoutError(f"Function timed out after {timeout} seconds")
else:
raise RuntimeError(f"{task.__name__} failed:\n{result}")
else:
raise RuntimeError(f"{task.__name__} returned no result")