Feature/vespa jinja (#4558)

* tool to generate vespa schema variations for our cloud

* extraneous assign

* use a real templating system instead of search/replace

* fix float

* maybe this should be double

* remove redundant var

* template the other files

* try a spawned process

* move the wrapper

* fix args

* increase timeout

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer
2025-04-20 15:28:55 -07:00
committed by GitHub
parent 87478c5ca6
commit 2111eccf07
8 changed files with 180 additions and 147 deletions

View File

@ -2,43 +2,47 @@
import argparse
import jinja2
from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.vespa.index import _replace_template_values_in_schema
from onyx.document_index.vespa.index import _replace_tenant_template_value_in_schema
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from onyx.utils.logger import setup_logger
from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS
logger = setup_logger()
def write_schema(index_name: str, dim: int, template: str) -> None:
def write_schema(index_name: str, dim: int, template: jinja2.Template) -> None:
index_filename = index_name + ".sd"
index_rendered_str = _replace_tenant_template_value_in_schema(
template, TENANT_ID_REPLACEMENT
)
index_rendered_str = _replace_template_values_in_schema(
index_rendered_str, index_name, dim, EmbeddingPrecision.FLOAT
schema = template.render(
multi_tenant=True,
schema_name=index_name,
dim=dim,
embedding_precision=EmbeddingPrecision.FLOAT.value,
)
with open(index_filename, "w", encoding="utf-8") as f:
f.write(index_rendered_str)
f.write(schema)
logger.info(f"Wrote {index_filename}")
def main() -> None:
parser = argparse.ArgumentParser(description="Generate multi tenant Vespa schemas")
parser.add_argument("--template", help="The schema template to use", required=True)
parser.add_argument("--template", help="The Jinja template to use", required=True)
args = parser.parse_args()
jinja_env = jinja2.Environment()
with open(args.template, "r", encoding="utf-8") as f:
template_str = f.read()
template = jinja_env.from_string(template_str)
num_indexes = 0
for model in SUPPORTED_EMBEDDING_MODELS:
write_schema(model.index_name, model.dim, template_str)
write_schema(model.index_name + "__danswer_alt_index", model.dim, template_str)
write_schema(model.index_name, model.dim, template)
write_schema(model.index_name + "__danswer_alt_index", model.dim, template)
num_indexes += 2
logger.info(f"Wrote {num_indexes} indexes.")