Feature/vespa jinja (#4558)

* tool to generate vespa schema variations for our cloud

* extraneous assign

* use a real templating system instead of search/replace

* fix float

* maybe this should be double

* remove redundant var

* template the other files

* try a spawned process

* move the wrapper

* fix args

* increase timeout

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
rkuo-danswer
2025-04-20 15:28:55 -07:00
committed by GitHub
parent 87478c5ca6
commit 2111eccf07
8 changed files with 180 additions and 147 deletions

View File

@ -1,11 +1,17 @@
schema DANSWER_CHUNK_NAME { schema {{ schema_name }} {
document DANSWER_CHUNK_NAME { document {{ schema_name }} {
TENANT_ID_REPLACEMENT {% if multi_tenant %}
field tenant_id type string {
indexing: summary | attribute
rank: filter
attribute: fast-search
}
{% endif %}
# Not to be confused with the UUID generated for this chunk which is called documentid by default # Not to be confused with the UUID generated for this chunk which is called documentid by default
field document_id type string { field document_id type string {
indexing: summary | attribute indexing: summary | attribute
attribute: fast-search
rank: filter rank: filter
attribute: fast-search
} }
field chunk_id type int { field chunk_id type int {
indexing: summary | attribute indexing: summary | attribute
@ -37,7 +43,7 @@ schema DANSWER_CHUNK_NAME {
summary: dynamic summary: dynamic
} }
# Title embedding (x1) # Title embedding (x1)
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) { field title_embedding type tensor<{{ embedding_precision }}>(x[{{ dim }}]) {
indexing: attribute | index indexing: attribute | index
attribute { attribute {
distance-metric: angular distance-metric: angular
@ -45,7 +51,7 @@ schema DANSWER_CHUNK_NAME {
} }
# Content embeddings (chunk + optional mini chunks embeddings) # Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords # "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) { field embeddings type tensor<{{ embedding_precision }}>(t{},x[{{ dim }}]) {
indexing: attribute | index indexing: attribute | index
attribute { attribute {
distance-metric: angular distance-metric: angular
@ -176,9 +182,9 @@ schema DANSWER_CHUNK_NAME {
match-features: recency_bias match-features: recency_bias
} }
rank-profile hybrid_search_semantic_base_VARIABLE_DIM inherits default, default_rank { rank-profile hybrid_search_semantic_base_{{ dim }} inherits default, default_rank {
inputs { inputs {
query(query_embedding) tensor<float>(x[VARIABLE_DIM]) query(query_embedding) tensor<float>(x[{{ dim }}])
} }
function title_vector_score() { function title_vector_score() {
@ -244,9 +250,9 @@ schema DANSWER_CHUNK_NAME {
} }
rank-profile hybrid_search_keyword_base_VARIABLE_DIM inherits default, default_rank { rank-profile hybrid_search_keyword_base_{{ dim }} inherits default, default_rank {
inputs { inputs {
query(query_embedding) tensor<float>(x[VARIABLE_DIM]) query(query_embedding) tensor<float>(x[{{ dim }}])
} }
function title_vector_score() { function title_vector_score() {

View File

@ -14,7 +14,7 @@
<redundancy>1</redundancy> <redundancy>1</redundancy>
<documents> <documents>
<!-- <document type="danswer_chunk" mode="index" /> --> <!-- <document type="danswer_chunk" mode="index" /> -->
DOCUMENT_REPLACEMENT {{ document_elements }}
</documents> </documents>
<nodes> <nodes>
<node hostalias="danswer-node" distribution-key="0" /> <node hostalias="danswer-node" distribution-key="0" />
@ -31,7 +31,7 @@
<tuning> <tuning>
<searchnode> <searchnode>
<requestthreads> <requestthreads>
<persearch>SEARCH_THREAD_NUMBER</persearch> <persearch>{{ num_search_threads }}</persearch>
</requestthreads> </requestthreads>
</searchnode> </searchnode>
</tuning> </tuning>

View File

@ -1,11 +1,11 @@
<validation-overrides> <validation-overrides>
<allow <allow
until="DATE_REPLACEMENT" until="{{ until_date }}"
comment="We need to be able to create/delete indices for swapping models">schema-removal</allow> comment="We need to be able to create/delete indices for swapping models">schema-removal</allow>
<allow <allow
until="DATE_REPLACEMENT" until="{{ until_date }}"
comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow> comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow>
<allow <allow
until='DATE_REPLACEMENT' until="{{ until_date }}"
comment="Prevents old alt indices from interfering with changes">field-type-change</allow> comment="Prevents old alt indices from interfering with changes">field-type-change</allow>
</validation-overrides> </validation-overrides>

View File

@ -16,6 +16,7 @@ from typing import List
from uuid import UUID from uuid import UUID
import httpx # type: ignore import httpx # type: ignore
import jinja2
import requests # type: ignore import requests # type: ignore
from retry import retry from retry import retry
@ -61,21 +62,13 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import BATCH_SIZE from onyx.document_index.vespa_constants import BATCH_SIZE
from onyx.document_index.vespa_constants import BOOST from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CONTENT_SUMMARY from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DANSWER_CHUNK_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import DATE_REPLACEMENT
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import DOCUMENT_SETS from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import HIDDEN from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import NUM_THREADS from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
from onyx.document_index.vespa_constants import TENANT_ID_PAT
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from onyx.document_index.vespa_constants import USER_FILE from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER from onyx.document_index.vespa_constants import USER_FOLDER
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.indexing.models import DocMetadataAwareIndexChunk
@ -118,28 +111,6 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
return "\n".join(doc_lines) return "\n".join(doc_lines)
def _replace_template_values_in_schema(
schema_template: str,
index_name: str,
embedding_dim: int,
embedding_precision: EmbeddingPrecision,
) -> str:
return (
schema_template.replace(
EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value
)
.replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name)
.replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
)
def _replace_tenant_template_value_in_schema(
schema_template: str,
tenant_field: str,
) -> str:
return schema_template.replace(TENANT_ID_PAT, tenant_field)
def add_ngrams_to_schema(schema_content: str) -> str: def add_ngrams_to_schema(schema_content: str) -> str:
# Add the match blocks containing gram and gram-size to title and content fields # Add the match blocks containing gram and gram-size to title and content fields
schema_content = re.sub( schema_content = re.sub(
@ -156,6 +127,9 @@ def add_ngrams_to_schema(schema_content: str) -> str:
class VespaIndex(DocumentIndex): class VespaIndex(DocumentIndex):
VESPA_SCHEMA_JINJA_FILENAME = "danswer_chunk.sd.jinja"
def __init__( def __init__(
self, self,
index_name: str, index_name: str,
@ -202,26 +176,32 @@ class VespaIndex(DocumentIndex):
) )
return None return None
jinja_env = jinja2.Environment()
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate" deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
logger.notice(f"Deploying Vespa application package to {deploy_url}") logger.notice(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join( vespa_schema_path = os.path.join(
os.getcwd(), "onyx", "document_index", "vespa", "app_config" os.getcwd(), "onyx", "document_index", "vespa", "app_config"
) )
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd") schema_jinja_file = os.path.join(
services_file = os.path.join(vespa_schema_path, "services.xml") vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml")
with open(services_file, "r") as services_f:
services_template = services_f.read()
schema_names = [self.index_name, self.secondary_index_name]
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
services = services.replace(
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
) )
services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja")
overrides_jinja_file = os.path.join(
vespa_schema_path, "validation-overrides.xml.jinja"
)
with open(services_jinja_file, "r") as services_f:
schema_names = [self.index_name, self.secondary_index_name]
doc_lines = _create_document_xml_lines(schema_names)
services_template_str = services_f.read()
services_template = jinja_env.from_string(services_template_str)
services = services_template.render(
document_elements=doc_lines,
num_search_threads=str(VESPA_SEARCHER_THREADS),
)
kv_store = get_shared_kv_store() kv_store = get_shared_kv_store()
@ -231,30 +211,33 @@ class VespaIndex(DocumentIndex):
except Exception: except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams") logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
# Vespa requires an override to erase data including the indices we're no longer using # Vespa requires an override to erase data including the indices we're no longer using
# It also has a 30 day cap from current so we set it to 7 dynamically # It also has a 30 day cap from current so we set it to 7 dynamically
now = datetime.now() with open(overrides_jinja_file, "r") as overrides_f:
date_in_7_days = now + timedelta(days=7) overrides_template_str = overrides_f.read()
formatted_date = date_in_7_days.strftime("%Y-%m-%d") overrides_template = jinja_env.from_string(overrides_template_str)
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date) now = datetime.now()
date_in_7_days = now + timedelta(days=7)
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
overrides = overrides_template.render(
until_date=formatted_date,
)
zip_dict = { zip_dict = {
"services.xml": services.encode("utf-8"), "services.xml": services.encode("utf-8"),
"validation-overrides.xml": overrides.encode("utf-8"), "validation-overrides.xml": overrides.encode("utf-8"),
} }
with open(schema_file, "r") as schema_f: with open(schema_jinja_file, "r") as schema_f:
schema_template = schema_f.read() template_str = schema_f.read()
schema = _replace_tenant_template_value_in_schema(schema_template, "")
schema = _replace_template_values_in_schema( template = jinja_env.from_string(template_str)
schema, schema = template.render(
self.index_name, multi_tenant=MULTI_TENANT,
primary_embedding_dim, schema_name=self.index_name,
primary_embedding_precision, dim=primary_embedding_dim,
embedding_precision=primary_embedding_precision.value,
) )
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
@ -266,12 +249,13 @@ class VespaIndex(DocumentIndex):
if secondary_index_embedding_precision is None: if secondary_index_embedding_precision is None:
raise ValueError("Secondary index embedding precision is required") raise ValueError("Secondary index embedding precision is required")
upcoming_schema = _replace_template_values_in_schema( upcoming_schema = template.render(
schema_template, multi_tenant=MULTI_TENANT,
self.secondary_index_name, schema_name=self.secondary_index_name,
secondary_index_embedding_dim, dim=secondary_index_embedding_dim,
secondary_index_embedding_precision, embedding_precision=secondary_index_embedding_precision.value,
) )
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8") zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict) zip_file = in_memory_zip_from_file_bytes(zip_dict)
@ -301,24 +285,27 @@ class VespaIndex(DocumentIndex):
vespa_schema_path = os.path.join( vespa_schema_path = os.path.join(
os.getcwd(), "onyx", "document_index", "vespa", "app_config" os.getcwd(), "onyx", "document_index", "vespa", "app_config"
) )
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd") schema_jinja_file = os.path.join(
services_file = os.path.join(vespa_schema_path, "services.xml") vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml") )
services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja")
overrides_jinja_file = os.path.join(
vespa_schema_path, "validation-overrides.xml.jinja"
)
with open(services_file, "r") as services_f: jinja_env = jinja2.Environment()
services_template = services_f.read()
# Generate schema names from index settings # Generate schema names from index settings
schema_names = [index_name for index_name in indices] with open(services_jinja_file, "r") as services_f:
schema_names = [index_name for index_name in indices]
doc_lines = _create_document_xml_lines(schema_names)
full_schemas = schema_names services_template_str = services_f.read()
services_template = jinja_env.from_string(services_template_str)
doc_lines = _create_document_xml_lines(full_schemas) services = services_template.render(
document_elements=doc_lines,
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines) num_search_threads=str(VESPA_SEARCHER_THREADS),
services = services.replace( )
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_shared_kv_store() kv_store = get_shared_kv_store()
@ -328,24 +315,28 @@ class VespaIndex(DocumentIndex):
except Exception: except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams") logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
# Vespa requires an override to erase data including the indices we're no longer using # Vespa requires an override to erase data including the indices we're no longer using
# It also has a 30 day cap from current so we set it to 7 dynamically # It also has a 30 day cap from current so we set it to 7 dynamically
now = datetime.now() with open(overrides_jinja_file, "r") as overrides_f:
date_in_7_days = now + timedelta(days=7) overrides_template_str = overrides_f.read()
formatted_date = date_in_7_days.strftime("%Y-%m-%d") overrides_template = jinja_env.from_string(overrides_template_str)
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date) now = datetime.now()
date_in_7_days = now + timedelta(days=7)
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
overrides = overrides_template.render(
until_date=formatted_date,
)
zip_dict = { zip_dict = {
"services.xml": services.encode("utf-8"), "services.xml": services.encode("utf-8"),
"validation-overrides.xml": overrides.encode("utf-8"), "validation-overrides.xml": overrides.encode("utf-8"),
} }
with open(schema_file, "r") as schema_f: with open(schema_jinja_file, "r") as schema_f:
schema_template = schema_f.read() schema_template_str = schema_f.read()
schema_template = jinja_env.from_string(schema_template_str)
for i, index_name in enumerate(indices): for i, index_name in enumerate(indices):
embedding_dim = embedding_dims[i] embedding_dim = embedding_dims[i]
@ -354,15 +345,11 @@ class VespaIndex(DocumentIndex):
f"Creating index: {index_name} with embedding dimension: {embedding_dim}" f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
) )
schema = _replace_template_values_in_schema( schema = schema_template.render(
schema_template, index_name, embedding_dim, embedding_precision multi_tenant=MULTI_TENANT,
) schema_name=index_name,
dim=embedding_dim,
tenant_id_replacement = "" embedding_precision=embedding_precision.value,
if MULTI_TENANT:
tenant_id_replacement = TENANT_ID_REPLACEMENT
schema = _replace_tenant_template_value_in_schema(
schema, tenant_id_replacement
) )
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema schema = add_ngrams_to_schema(schema) if needs_reindexing else schema

View File

@ -5,20 +5,6 @@ from onyx.configs.app_configs import VESPA_PORT
from onyx.configs.app_configs import VESPA_TENANT_PORT from onyx.configs.app_configs import VESPA_TENANT_PORT
from onyx.configs.constants import SOURCE_TYPE from onyx.configs.constants import SOURCE_TYPE
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION"
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
DATE_REPLACEMENT = "DATE_REPLACEMENT"
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
TENANT_ID_PAT = "TENANT_ID_REPLACEMENT"
TENANT_ID_REPLACEMENT = """field tenant_id type string {
indexing: summary | attribute
rank: filter
attribute: fast-search
}"""
# config server # config server
@ -31,7 +17,7 @@ VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}" VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}"
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd # danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd.jinja
DOCUMENT_ID_ENDPOINT = ( DOCUMENT_ID_ENDPOINT = (
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid" f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
) )

View File

@ -2,43 +2,47 @@
import argparse import argparse
import jinja2
from onyx.db.enums import EmbeddingPrecision from onyx.db.enums import EmbeddingPrecision
from onyx.document_index.vespa.index import _replace_template_values_in_schema
from onyx.document_index.vespa.index import _replace_tenant_template_value_in_schema
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_logger
from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS
logger = setup_logger() logger = setup_logger()
def write_schema(index_name: str, dim: int, template: str) -> None: def write_schema(index_name: str, dim: int, template: jinja2.Template) -> None:
index_filename = index_name + ".sd" index_filename = index_name + ".sd"
index_rendered_str = _replace_tenant_template_value_in_schema(
template, TENANT_ID_REPLACEMENT schema = template.render(
) multi_tenant=True,
index_rendered_str = _replace_template_values_in_schema( schema_name=index_name,
index_rendered_str, index_name, dim, EmbeddingPrecision.FLOAT dim=dim,
embedding_precision=EmbeddingPrecision.FLOAT.value,
) )
with open(index_filename, "w", encoding="utf-8") as f: with open(index_filename, "w", encoding="utf-8") as f:
f.write(index_rendered_str) f.write(schema)
logger.info(f"Wrote {index_filename}") logger.info(f"Wrote {index_filename}")
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser(description="Generate multi tenant Vespa schemas") parser = argparse.ArgumentParser(description="Generate multi tenant Vespa schemas")
parser.add_argument("--template", help="The schema template to use", required=True) parser.add_argument("--template", help="The Jinja template to use", required=True)
args = parser.parse_args() args = parser.parse_args()
jinja_env = jinja2.Environment()
with open(args.template, "r", encoding="utf-8") as f: with open(args.template, "r", encoding="utf-8") as f:
template_str = f.read() template_str = f.read()
template = jinja_env.from_string(template_str)
num_indexes = 0 num_indexes = 0
for model in SUPPORTED_EMBEDDING_MODELS: for model in SUPPORTED_EMBEDDING_MODELS:
write_schema(model.index_name, model.dim, template_str) write_schema(model.index_name, model.dim, template)
write_schema(model.index_name + "__danswer_alt_index", model.dim, template_str) write_schema(model.index_name + "__danswer_alt_index", model.dim, template)
num_indexes += 2 num_indexes += 2
logger.info(f"Wrote {num_indexes} indexes.") logger.info(f"Wrote {num_indexes} indexes.")

View File

@ -156,7 +156,7 @@ def reset_postgres(
"""Reset the Postgres database.""" """Reset the Postgres database."""
# this seems to hang due to locking issues, so run with a timeout with a few retries # this seems to hang due to locking issues, so run with a timeout with a few retries
NUM_TRIES = 10 NUM_TRIES = 10
TIMEOUT = 10 TIMEOUT = 40
success = False success = False
for _ in range(NUM_TRIES): for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})") logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")

View File

@ -1,20 +1,70 @@
# import multiprocessing
# from collections.abc import Callable
# from typing import Any
# from typing import TypeVar
# T = TypeVar("T")
# def run_with_timeout_multiproc(
# task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
# ) -> T:
# # Use multiprocessing to prevent a thread from blocking the main thread
# with multiprocessing.Pool(processes=1) as pool:
# async_result = pool.apply_async(task, kwds=kwargs)
# try:
# # Wait at most timeout seconds for the function to complete
# result = async_result.get(timeout=timeout)
# return result
# except multiprocessing.TimeoutError:
# raise TimeoutError(f"Function timed out after {timeout} seconds")
import multiprocessing import multiprocessing
import traceback
from collections.abc import Callable from collections.abc import Callable
from multiprocessing import Queue
from typing import Any from typing import Any
from typing import TypeVar from typing import TypeVar
T = TypeVar("T") T = TypeVar("T")
def _multiproc_wrapper(
task: Callable[..., T], kwargs: dict[str, Any], q: Queue
) -> None:
try:
result = task(**kwargs)
q.put(("success", result))
except Exception:
q.put(("error", traceback.format_exc()))
def run_with_timeout_multiproc( def run_with_timeout_multiproc(
task: Callable[..., T], timeout: int, kwargs: dict[str, Any] task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
) -> T: ) -> T:
# Use multiprocessing to prevent a thread from blocking the main thread ctx = multiprocessing.get_context("spawn")
with multiprocessing.Pool(processes=1) as pool: q: Queue = ctx.Queue()
async_result = pool.apply_async(task, kwds=kwargs) p = ctx.Process(
try: target=_multiproc_wrapper,
# Wait at most timeout seconds for the function to complete args=(
result = async_result.get(timeout=timeout) task,
kwargs,
q,
),
)
p.start()
p.join(timeout)
if p.is_alive():
p.terminate()
raise TimeoutError(f"{task.__name__} timed out after {timeout} seconds")
if not q.empty():
status, result = q.get()
if status == "success":
return result return result
except multiprocessing.TimeoutError: else:
raise TimeoutError(f"Function timed out after {timeout} seconds") raise RuntimeError(f"{task.__name__} failed:\n{result}")
else:
raise RuntimeError(f"{task.__name__} returned no result")