From 2111eccf071a3cb7c42c545d86acad51cd070248 Mon Sep 17 00:00:00 2001 From: rkuo-danswer Date: Sun, 20 Apr 2025 15:28:55 -0700 Subject: [PATCH] Feature/vespa jinja (#4558) * tool to generate vespa schema variations for our cloud * extraneous assign * use a real templating system instead of search/replace * fix float * maybe this should be double * remove redundant var * template the other files * try a spawned process * move the wrapper * fix args * increase timeout --------- Co-authored-by: Richard Kuo (Onyx) Co-authored-by: Richard Kuo --- ...answer_chunk.sd => danswer_chunk.sd.jinja} | 26 ++- .../{services.xml => services.xml.jinja} | 4 +- ...des.xml => validation-overrides.xml.jinja} | 6 +- backend/onyx/document_index/vespa/index.py | 177 ++++++++---------- .../onyx/document_index/vespa_constants.py | 16 +- .../scripts/debugging/onyx_vespa_schemas.py | 30 +-- .../tests/integration/common_utils/reset.py | 2 +- .../tests/integration/common_utils/timeout.py | 66 ++++++- 8 files changed, 180 insertions(+), 147 deletions(-) rename backend/onyx/document_index/vespa/app_config/schemas/{danswer_chunk.sd => danswer_chunk.sd.jinja} (94%) rename backend/onyx/document_index/vespa/app_config/{services.xml => services.xml.jinja} (92%) rename backend/onyx/document_index/vespa/app_config/{validation-overrides.xml => validation-overrides.xml.jinja} (80%) diff --git a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja similarity index 94% rename from backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd rename to backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja index c70c83aff9..c6fa2f075d 100644 --- a/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja @@ -1,11 +1,17 @@ -schema DANSWER_CHUNK_NAME { - document DANSWER_CHUNK_NAME { - TENANT_ID_REPLACEMENT +schema {{ schema_name }} { + document {{ schema_name }} { + {% if multi_tenant %} + field tenant_id type string { + indexing: summary | attribute + rank: filter + attribute: fast-search + } + {% endif %} # Not to be confused with the UUID generated for this chunk which is called documentid by default field document_id type string { indexing: summary | attribute - attribute: fast-search rank: filter + attribute: fast-search } field chunk_id type int { indexing: summary | attribute @@ -37,7 +43,7 @@ schema DANSWER_CHUNK_NAME { summary: dynamic } # Title embedding (x1) - field title_embedding type tensor(x[VARIABLE_DIM]) { + field title_embedding type tensor<{{ embedding_precision }}>(x[{{ dim }}]) { indexing: attribute | index attribute { distance-metric: angular @@ -45,7 +51,7 @@ schema DANSWER_CHUNK_NAME { } # Content embeddings (chunk + optional mini chunks embeddings) # "t" and "x" are arbitrary names, not special keywords - field embeddings type tensor(t{},x[VARIABLE_DIM]) { + field embeddings type tensor<{{ embedding_precision }}>(t{},x[{{ dim }}]) { indexing: attribute | index attribute { distance-metric: angular @@ -176,9 +182,9 @@ schema DANSWER_CHUNK_NAME { match-features: recency_bias } - rank-profile hybrid_search_semantic_base_VARIABLE_DIM inherits default, default_rank { + rank-profile hybrid_search_semantic_base_{{ dim }} inherits default, default_rank { inputs { - query(query_embedding) tensor(x[VARIABLE_DIM]) + query(query_embedding) tensor(x[{{ dim }}]) } function title_vector_score() { @@ -244,9 +250,9 @@ schema DANSWER_CHUNK_NAME { } - rank-profile hybrid_search_keyword_base_VARIABLE_DIM inherits default, default_rank { + rank-profile hybrid_search_keyword_base_{{ dim }} inherits default, default_rank { inputs { - query(query_embedding) tensor(x[VARIABLE_DIM]) + query(query_embedding) tensor(x[{{ dim }}]) } function title_vector_score() { diff --git a/backend/onyx/document_index/vespa/app_config/services.xml b/backend/onyx/document_index/vespa/app_config/services.xml.jinja similarity index 92% rename from backend/onyx/document_index/vespa/app_config/services.xml rename to backend/onyx/document_index/vespa/app_config/services.xml.jinja index 5fa386a9ad..67cebc3466 100644 --- a/backend/onyx/document_index/vespa/app_config/services.xml +++ b/backend/onyx/document_index/vespa/app_config/services.xml.jinja @@ -14,7 +14,7 @@ 1 - DOCUMENT_REPLACEMENT + {{ document_elements }} @@ -31,7 +31,7 @@ - SEARCH_THREAD_NUMBER + {{ num_search_threads }} diff --git a/backend/onyx/document_index/vespa/app_config/validation-overrides.xml b/backend/onyx/document_index/vespa/app_config/validation-overrides.xml.jinja similarity index 80% rename from backend/onyx/document_index/vespa/app_config/validation-overrides.xml rename to backend/onyx/document_index/vespa/app_config/validation-overrides.xml.jinja index 7b0709620a..a5f620c091 100644 --- a/backend/onyx/document_index/vespa/app_config/validation-overrides.xml +++ b/backend/onyx/document_index/vespa/app_config/validation-overrides.xml.jinja @@ -1,11 +1,11 @@ schema-removal indexing-change field-type-change diff --git a/backend/onyx/document_index/vespa/index.py b/backend/onyx/document_index/vespa/index.py index d5ba53cfe4..9805e2aff4 100644 --- a/backend/onyx/document_index/vespa/index.py +++ b/backend/onyx/document_index/vespa/index.py @@ -16,6 +16,7 @@ from typing import List from uuid import UUID import httpx # type: ignore +import jinja2 import requests # type: ignore from retry import retry @@ -61,21 +62,13 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import BATCH_SIZE from onyx.document_index.vespa_constants import BOOST from onyx.document_index.vespa_constants import CONTENT_SUMMARY -from onyx.document_index.vespa_constants import DANSWER_CHUNK_REPLACEMENT_PAT -from onyx.document_index.vespa_constants import DATE_REPLACEMENT from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT -from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT from onyx.document_index.vespa_constants import DOCUMENT_SETS -from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT from onyx.document_index.vespa_constants import HIDDEN from onyx.document_index.vespa_constants import NUM_THREADS -from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT -from onyx.document_index.vespa_constants import TENANT_ID_PAT -from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT from onyx.document_index.vespa_constants import USER_FILE from onyx.document_index.vespa_constants import USER_FOLDER from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT -from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT from onyx.document_index.vespa_constants import VESPA_TIMEOUT from onyx.document_index.vespa_constants import YQL_BASE from onyx.indexing.models import DocMetadataAwareIndexChunk @@ -118,28 +111,6 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str: return "\n".join(doc_lines) -def _replace_template_values_in_schema( - schema_template: str, - index_name: str, - embedding_dim: int, - embedding_precision: EmbeddingPrecision, -) -> str: - return ( - schema_template.replace( - EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value - ) - .replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name) - .replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim)) - ) - - -def _replace_tenant_template_value_in_schema( - schema_template: str, - tenant_field: str, -) -> str: - return schema_template.replace(TENANT_ID_PAT, tenant_field) - - def add_ngrams_to_schema(schema_content: str) -> str: # Add the match blocks containing gram and gram-size to title and content fields schema_content = re.sub( @@ -156,6 +127,9 @@ def add_ngrams_to_schema(schema_content: str) -> str: class VespaIndex(DocumentIndex): + + VESPA_SCHEMA_JINJA_FILENAME = "danswer_chunk.sd.jinja" + def __init__( self, index_name: str, @@ -202,26 +176,32 @@ class VespaIndex(DocumentIndex): ) return None + jinja_env = jinja2.Environment() + deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate" logger.notice(f"Deploying Vespa application package to {deploy_url}") vespa_schema_path = os.path.join( os.getcwd(), "onyx", "document_index", "vespa", "app_config" ) - schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd") - services_file = os.path.join(vespa_schema_path, "services.xml") - overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml") - - with open(services_file, "r") as services_f: - services_template = services_f.read() - - schema_names = [self.index_name, self.secondary_index_name] - - doc_lines = _create_document_xml_lines(schema_names) - services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines) - services = services.replace( - SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS) + schema_jinja_file = os.path.join( + vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME ) + services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja") + overrides_jinja_file = os.path.join( + vespa_schema_path, "validation-overrides.xml.jinja" + ) + + with open(services_jinja_file, "r") as services_f: + schema_names = [self.index_name, self.secondary_index_name] + doc_lines = _create_document_xml_lines(schema_names) + + services_template_str = services_f.read() + services_template = jinja_env.from_string(services_template_str) + services = services_template.render( + document_elements=doc_lines, + num_search_threads=str(VESPA_SEARCHER_THREADS), + ) kv_store = get_shared_kv_store() @@ -231,30 +211,33 @@ class VespaIndex(DocumentIndex): except Exception: logger.debug("Could not load the reindexing flag. Using ngrams") - with open(overrides_file, "r") as overrides_f: - overrides_template = overrides_f.read() - # Vespa requires an override to erase data including the indices we're no longer using # It also has a 30 day cap from current so we set it to 7 dynamically - now = datetime.now() - date_in_7_days = now + timedelta(days=7) - formatted_date = date_in_7_days.strftime("%Y-%m-%d") + with open(overrides_jinja_file, "r") as overrides_f: + overrides_template_str = overrides_f.read() + overrides_template = jinja_env.from_string(overrides_template_str) - overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date) + now = datetime.now() + date_in_7_days = now + timedelta(days=7) + formatted_date = date_in_7_days.strftime("%Y-%m-%d") + overrides = overrides_template.render( + until_date=formatted_date, + ) zip_dict = { "services.xml": services.encode("utf-8"), "validation-overrides.xml": overrides.encode("utf-8"), } - with open(schema_file, "r") as schema_f: - schema_template = schema_f.read() - schema = _replace_tenant_template_value_in_schema(schema_template, "") - schema = _replace_template_values_in_schema( - schema, - self.index_name, - primary_embedding_dim, - primary_embedding_precision, + with open(schema_jinja_file, "r") as schema_f: + template_str = schema_f.read() + + template = jinja_env.from_string(template_str) + schema = template.render( + multi_tenant=MULTI_TENANT, + schema_name=self.index_name, + dim=primary_embedding_dim, + embedding_precision=primary_embedding_precision.value, ) schema = add_ngrams_to_schema(schema) if needs_reindexing else schema @@ -266,12 +249,13 @@ class VespaIndex(DocumentIndex): if secondary_index_embedding_precision is None: raise ValueError("Secondary index embedding precision is required") - upcoming_schema = _replace_template_values_in_schema( - schema_template, - self.secondary_index_name, - secondary_index_embedding_dim, - secondary_index_embedding_precision, + upcoming_schema = template.render( + multi_tenant=MULTI_TENANT, + schema_name=self.secondary_index_name, + dim=secondary_index_embedding_dim, + embedding_precision=secondary_index_embedding_precision.value, ) + zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8") zip_file = in_memory_zip_from_file_bytes(zip_dict) @@ -301,24 +285,27 @@ class VespaIndex(DocumentIndex): vespa_schema_path = os.path.join( os.getcwd(), "onyx", "document_index", "vespa", "app_config" ) - schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd") - services_file = os.path.join(vespa_schema_path, "services.xml") - overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml") + schema_jinja_file = os.path.join( + vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME + ) + services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja") + overrides_jinja_file = os.path.join( + vespa_schema_path, "validation-overrides.xml.jinja" + ) - with open(services_file, "r") as services_f: - services_template = services_f.read() + jinja_env = jinja2.Environment() # Generate schema names from index settings - schema_names = [index_name for index_name in indices] + with open(services_jinja_file, "r") as services_f: + schema_names = [index_name for index_name in indices] + doc_lines = _create_document_xml_lines(schema_names) - full_schemas = schema_names - - doc_lines = _create_document_xml_lines(full_schemas) - - services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines) - services = services.replace( - SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS) - ) + services_template_str = services_f.read() + services_template = jinja_env.from_string(services_template_str) + services = services_template.render( + document_elements=doc_lines, + num_search_threads=str(VESPA_SEARCHER_THREADS), + ) kv_store = get_shared_kv_store() @@ -328,24 +315,28 @@ class VespaIndex(DocumentIndex): except Exception: logger.debug("Could not load the reindexing flag. Using ngrams") - with open(overrides_file, "r") as overrides_f: - overrides_template = overrides_f.read() - # Vespa requires an override to erase data including the indices we're no longer using # It also has a 30 day cap from current so we set it to 7 dynamically - now = datetime.now() - date_in_7_days = now + timedelta(days=7) - formatted_date = date_in_7_days.strftime("%Y-%m-%d") + with open(overrides_jinja_file, "r") as overrides_f: + overrides_template_str = overrides_f.read() + overrides_template = jinja_env.from_string(overrides_template_str) - overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date) + now = datetime.now() + date_in_7_days = now + timedelta(days=7) + formatted_date = date_in_7_days.strftime("%Y-%m-%d") + overrides = overrides_template.render( + until_date=formatted_date, + ) zip_dict = { "services.xml": services.encode("utf-8"), "validation-overrides.xml": overrides.encode("utf-8"), } - with open(schema_file, "r") as schema_f: - schema_template = schema_f.read() + with open(schema_jinja_file, "r") as schema_f: + schema_template_str = schema_f.read() + + schema_template = jinja_env.from_string(schema_template_str) for i, index_name in enumerate(indices): embedding_dim = embedding_dims[i] @@ -354,15 +345,11 @@ class VespaIndex(DocumentIndex): f"Creating index: {index_name} with embedding dimension: {embedding_dim}" ) - schema = _replace_template_values_in_schema( - schema_template, index_name, embedding_dim, embedding_precision - ) - - tenant_id_replacement = "" - if MULTI_TENANT: - tenant_id_replacement = TENANT_ID_REPLACEMENT - schema = _replace_tenant_template_value_in_schema( - schema, tenant_id_replacement + schema = schema_template.render( + multi_tenant=MULTI_TENANT, + schema_name=index_name, + dim=embedding_dim, + embedding_precision=embedding_precision.value, ) schema = add_ngrams_to_schema(schema) if needs_reindexing else schema diff --git a/backend/onyx/document_index/vespa_constants.py b/backend/onyx/document_index/vespa_constants.py index 2b8f72c357..da82ed9287 100644 --- a/backend/onyx/document_index/vespa_constants.py +++ b/backend/onyx/document_index/vespa_constants.py @@ -5,20 +5,6 @@ from onyx.configs.app_configs import VESPA_PORT from onyx.configs.app_configs import VESPA_TENANT_PORT from onyx.configs.constants import SOURCE_TYPE -VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM" -EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION" -DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME" -DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT" -SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER" -DATE_REPLACEMENT = "DATE_REPLACEMENT" -SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER" -TENANT_ID_PAT = "TENANT_ID_REPLACEMENT" - -TENANT_ID_REPLACEMENT = """field tenant_id type string { - indexing: summary | attribute - rank: filter - attribute: fast-search - }""" # config server @@ -31,7 +17,7 @@ VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2" VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}" -# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd +# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd.jinja DOCUMENT_ID_ENDPOINT = ( f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid" ) diff --git a/backend/scripts/debugging/onyx_vespa_schemas.py b/backend/scripts/debugging/onyx_vespa_schemas.py index ca668acf39..5acad609f9 100644 --- a/backend/scripts/debugging/onyx_vespa_schemas.py +++ b/backend/scripts/debugging/onyx_vespa_schemas.py @@ -2,43 +2,47 @@ import argparse +import jinja2 + from onyx.db.enums import EmbeddingPrecision -from onyx.document_index.vespa.index import _replace_template_values_in_schema -from onyx.document_index.vespa.index import _replace_tenant_template_value_in_schema -from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT from onyx.utils.logger import setup_logger from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS logger = setup_logger() -def write_schema(index_name: str, dim: int, template: str) -> None: +def write_schema(index_name: str, dim: int, template: jinja2.Template) -> None: index_filename = index_name + ".sd" - index_rendered_str = _replace_tenant_template_value_in_schema( - template, TENANT_ID_REPLACEMENT - ) - index_rendered_str = _replace_template_values_in_schema( - index_rendered_str, index_name, dim, EmbeddingPrecision.FLOAT + + schema = template.render( + multi_tenant=True, + schema_name=index_name, + dim=dim, + embedding_precision=EmbeddingPrecision.FLOAT.value, ) with open(index_filename, "w", encoding="utf-8") as f: - f.write(index_rendered_str) + f.write(schema) logger.info(f"Wrote {index_filename}") def main() -> None: parser = argparse.ArgumentParser(description="Generate multi tenant Vespa schemas") - parser.add_argument("--template", help="The schema template to use", required=True) + parser.add_argument("--template", help="The Jinja template to use", required=True) args = parser.parse_args() + jinja_env = jinja2.Environment() + with open(args.template, "r", encoding="utf-8") as f: template_str = f.read() + template = jinja_env.from_string(template_str) + num_indexes = 0 for model in SUPPORTED_EMBEDDING_MODELS: - write_schema(model.index_name, model.dim, template_str) - write_schema(model.index_name + "__danswer_alt_index", model.dim, template_str) + write_schema(model.index_name, model.dim, template) + write_schema(model.index_name + "__danswer_alt_index", model.dim, template) num_indexes += 2 logger.info(f"Wrote {num_indexes} indexes.") diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 366f50814b..1123d8bdfa 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -156,7 +156,7 @@ def reset_postgres( """Reset the Postgres database.""" # this seems to hang due to locking issues, so run with a timeout with a few retries NUM_TRIES = 10 - TIMEOUT = 10 + TIMEOUT = 40 success = False for _ in range(NUM_TRIES): logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})") diff --git a/backend/tests/integration/common_utils/timeout.py b/backend/tests/integration/common_utils/timeout.py index 52c5ac0a0b..2d06e56ffa 100644 --- a/backend/tests/integration/common_utils/timeout.py +++ b/backend/tests/integration/common_utils/timeout.py @@ -1,20 +1,70 @@ +# import multiprocessing +# from collections.abc import Callable +# from typing import Any +# from typing import TypeVar + +# T = TypeVar("T") + + +# def run_with_timeout_multiproc( +# task: Callable[..., T], timeout: int, kwargs: dict[str, Any] +# ) -> T: +# # Use multiprocessing to prevent a thread from blocking the main thread +# with multiprocessing.Pool(processes=1) as pool: +# async_result = pool.apply_async(task, kwds=kwargs) +# try: +# # Wait at most timeout seconds for the function to complete +# result = async_result.get(timeout=timeout) +# return result +# except multiprocessing.TimeoutError: +# raise TimeoutError(f"Function timed out after {timeout} seconds") + + import multiprocessing +import traceback from collections.abc import Callable +from multiprocessing import Queue from typing import Any from typing import TypeVar T = TypeVar("T") +def _multiproc_wrapper( + task: Callable[..., T], kwargs: dict[str, Any], q: Queue +) -> None: + try: + result = task(**kwargs) + q.put(("success", result)) + except Exception: + q.put(("error", traceback.format_exc())) + + def run_with_timeout_multiproc( task: Callable[..., T], timeout: int, kwargs: dict[str, Any] ) -> T: - # Use multiprocessing to prevent a thread from blocking the main thread - with multiprocessing.Pool(processes=1) as pool: - async_result = pool.apply_async(task, kwds=kwargs) - try: - # Wait at most timeout seconds for the function to complete - result = async_result.get(timeout=timeout) + ctx = multiprocessing.get_context("spawn") + q: Queue = ctx.Queue() + p = ctx.Process( + target=_multiproc_wrapper, + args=( + task, + kwargs, + q, + ), + ) + p.start() + p.join(timeout) + + if p.is_alive(): + p.terminate() + raise TimeoutError(f"{task.__name__} timed out after {timeout} seconds") + + if not q.empty(): + status, result = q.get() + if status == "success": return result - except multiprocessing.TimeoutError: - raise TimeoutError(f"Function timed out after {timeout} seconds") + else: + raise RuntimeError(f"{task.__name__} failed:\n{result}") + else: + raise RuntimeError(f"{task.__name__} returned no result")