mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-08 13:40:46 +02:00
Feature/vespa jinja (#4558)
* tool to generate vespa schema variations for our cloud * extraneous assign * use a real templating system instead of search/replace * fix float * maybe this should be double * remove redundant var * template the other files * try a spawned process * move the wrapper * fix args * increase timeout --------- Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app> Co-authored-by: Richard Kuo <rkuo@rkuo.com>
This commit is contained in:
@ -1,11 +1,17 @@
|
||||
schema DANSWER_CHUNK_NAME {
|
||||
document DANSWER_CHUNK_NAME {
|
||||
TENANT_ID_REPLACEMENT
|
||||
schema {{ schema_name }} {
|
||||
document {{ schema_name }} {
|
||||
{% if multi_tenant %}
|
||||
field tenant_id type string {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
{% endif %}
|
||||
# Not to be confused with the UUID generated for this chunk which is called documentid by default
|
||||
field document_id type string {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}
|
||||
field chunk_id type int {
|
||||
indexing: summary | attribute
|
||||
@ -37,7 +43,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
summary: dynamic
|
||||
}
|
||||
# Title embedding (x1)
|
||||
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
|
||||
field title_embedding type tensor<{{ embedding_precision }}>(x[{{ dim }}]) {
|
||||
indexing: attribute | index
|
||||
attribute {
|
||||
distance-metric: angular
|
||||
@ -45,7 +51,7 @@ schema DANSWER_CHUNK_NAME {
|
||||
}
|
||||
# Content embeddings (chunk + optional mini chunks embeddings)
|
||||
# "t" and "x" are arbitrary names, not special keywords
|
||||
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
|
||||
field embeddings type tensor<{{ embedding_precision }}>(t{},x[{{ dim }}]) {
|
||||
indexing: attribute | index
|
||||
attribute {
|
||||
distance-metric: angular
|
||||
@ -176,9 +182,9 @@ schema DANSWER_CHUNK_NAME {
|
||||
match-features: recency_bias
|
||||
}
|
||||
|
||||
rank-profile hybrid_search_semantic_base_VARIABLE_DIM inherits default, default_rank {
|
||||
rank-profile hybrid_search_semantic_base_{{ dim }} inherits default, default_rank {
|
||||
inputs {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
query(query_embedding) tensor<float>(x[{{ dim }}])
|
||||
}
|
||||
|
||||
function title_vector_score() {
|
||||
@ -244,9 +250,9 @@ schema DANSWER_CHUNK_NAME {
|
||||
}
|
||||
|
||||
|
||||
rank-profile hybrid_search_keyword_base_VARIABLE_DIM inherits default, default_rank {
|
||||
rank-profile hybrid_search_keyword_base_{{ dim }} inherits default, default_rank {
|
||||
inputs {
|
||||
query(query_embedding) tensor<float>(x[VARIABLE_DIM])
|
||||
query(query_embedding) tensor<float>(x[{{ dim }}])
|
||||
}
|
||||
|
||||
function title_vector_score() {
|
@ -14,7 +14,7 @@
|
||||
<redundancy>1</redundancy>
|
||||
<documents>
|
||||
<!-- <document type="danswer_chunk" mode="index" /> -->
|
||||
DOCUMENT_REPLACEMENT
|
||||
{{ document_elements }}
|
||||
</documents>
|
||||
<nodes>
|
||||
<node hostalias="danswer-node" distribution-key="0" />
|
||||
@ -31,7 +31,7 @@
|
||||
<tuning>
|
||||
<searchnode>
|
||||
<requestthreads>
|
||||
<persearch>SEARCH_THREAD_NUMBER</persearch>
|
||||
<persearch>{{ num_search_threads }}</persearch>
|
||||
</requestthreads>
|
||||
</searchnode>
|
||||
</tuning>
|
@ -1,11 +1,11 @@
|
||||
<validation-overrides>
|
||||
<allow
|
||||
until="DATE_REPLACEMENT"
|
||||
until="{{ until_date }}"
|
||||
comment="We need to be able to create/delete indices for swapping models">schema-removal</allow>
|
||||
<allow
|
||||
until="DATE_REPLACEMENT"
|
||||
until="{{ until_date }}"
|
||||
comment="We need to be able to update the schema for updates to the Onyx schema">indexing-change</allow>
|
||||
<allow
|
||||
until='DATE_REPLACEMENT'
|
||||
until="{{ until_date }}"
|
||||
comment="Prevents old alt indices from interfering with changes">field-type-change</allow>
|
||||
</validation-overrides>
|
@ -16,6 +16,7 @@ from typing import List
|
||||
from uuid import UUID
|
||||
|
||||
import httpx # type: ignore
|
||||
import jinja2
|
||||
import requests # type: ignore
|
||||
from retry import retry
|
||||
|
||||
@ -61,21 +62,13 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import BATCH_SIZE
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DANSWER_CHUNK_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import DATE_REPLACEMENT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import EMBEDDING_PRECISION_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import SEARCH_THREAD_NUMBER_PAT
|
||||
from onyx.document_index.vespa_constants import TENANT_ID_PAT
|
||||
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import VESPA_DIM_REPLACEMENT_PAT
|
||||
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
from onyx.document_index.vespa_constants import YQL_BASE
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
@ -118,28 +111,6 @@ def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str:
|
||||
return "\n".join(doc_lines)
|
||||
|
||||
|
||||
def _replace_template_values_in_schema(
|
||||
schema_template: str,
|
||||
index_name: str,
|
||||
embedding_dim: int,
|
||||
embedding_precision: EmbeddingPrecision,
|
||||
) -> str:
|
||||
return (
|
||||
schema_template.replace(
|
||||
EMBEDDING_PRECISION_REPLACEMENT_PAT, embedding_precision.value
|
||||
)
|
||||
.replace(DANSWER_CHUNK_REPLACEMENT_PAT, index_name)
|
||||
.replace(VESPA_DIM_REPLACEMENT_PAT, str(embedding_dim))
|
||||
)
|
||||
|
||||
|
||||
def _replace_tenant_template_value_in_schema(
|
||||
schema_template: str,
|
||||
tenant_field: str,
|
||||
) -> str:
|
||||
return schema_template.replace(TENANT_ID_PAT, tenant_field)
|
||||
|
||||
|
||||
def add_ngrams_to_schema(schema_content: str) -> str:
|
||||
# Add the match blocks containing gram and gram-size to title and content fields
|
||||
schema_content = re.sub(
|
||||
@ -156,6 +127,9 @@ def add_ngrams_to_schema(schema_content: str) -> str:
|
||||
|
||||
|
||||
class VespaIndex(DocumentIndex):
|
||||
|
||||
VESPA_SCHEMA_JINJA_FILENAME = "danswer_chunk.sd.jinja"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_name: str,
|
||||
@ -202,25 +176,31 @@ class VespaIndex(DocumentIndex):
|
||||
)
|
||||
return None
|
||||
|
||||
jinja_env = jinja2.Environment()
|
||||
|
||||
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
|
||||
logger.notice(f"Deploying Vespa application package to {deploy_url}")
|
||||
|
||||
vespa_schema_path = os.path.join(
|
||||
os.getcwd(), "onyx", "document_index", "vespa", "app_config"
|
||||
)
|
||||
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd")
|
||||
services_file = os.path.join(vespa_schema_path, "services.xml")
|
||||
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml")
|
||||
|
||||
with open(services_file, "r") as services_f:
|
||||
services_template = services_f.read()
|
||||
schema_jinja_file = os.path.join(
|
||||
vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME
|
||||
)
|
||||
services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja")
|
||||
overrides_jinja_file = os.path.join(
|
||||
vespa_schema_path, "validation-overrides.xml.jinja"
|
||||
)
|
||||
|
||||
with open(services_jinja_file, "r") as services_f:
|
||||
schema_names = [self.index_name, self.secondary_index_name]
|
||||
|
||||
doc_lines = _create_document_xml_lines(schema_names)
|
||||
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
|
||||
services = services.replace(
|
||||
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
|
||||
|
||||
services_template_str = services_f.read()
|
||||
services_template = jinja_env.from_string(services_template_str)
|
||||
services = services_template.render(
|
||||
document_elements=doc_lines,
|
||||
num_search_threads=str(VESPA_SEARCHER_THREADS),
|
||||
)
|
||||
|
||||
kv_store = get_shared_kv_store()
|
||||
@ -231,30 +211,33 @@ class VespaIndex(DocumentIndex):
|
||||
except Exception:
|
||||
logger.debug("Could not load the reindexing flag. Using ngrams")
|
||||
|
||||
with open(overrides_file, "r") as overrides_f:
|
||||
overrides_template = overrides_f.read()
|
||||
|
||||
# Vespa requires an override to erase data including the indices we're no longer using
|
||||
# It also has a 30 day cap from current so we set it to 7 dynamically
|
||||
with open(overrides_jinja_file, "r") as overrides_f:
|
||||
overrides_template_str = overrides_f.read()
|
||||
overrides_template = jinja_env.from_string(overrides_template_str)
|
||||
|
||||
now = datetime.now()
|
||||
date_in_7_days = now + timedelta(days=7)
|
||||
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
|
||||
|
||||
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date)
|
||||
overrides = overrides_template.render(
|
||||
until_date=formatted_date,
|
||||
)
|
||||
|
||||
zip_dict = {
|
||||
"services.xml": services.encode("utf-8"),
|
||||
"validation-overrides.xml": overrides.encode("utf-8"),
|
||||
}
|
||||
|
||||
with open(schema_file, "r") as schema_f:
|
||||
schema_template = schema_f.read()
|
||||
schema = _replace_tenant_template_value_in_schema(schema_template, "")
|
||||
schema = _replace_template_values_in_schema(
|
||||
schema,
|
||||
self.index_name,
|
||||
primary_embedding_dim,
|
||||
primary_embedding_precision,
|
||||
with open(schema_jinja_file, "r") as schema_f:
|
||||
template_str = schema_f.read()
|
||||
|
||||
template = jinja_env.from_string(template_str)
|
||||
schema = template.render(
|
||||
multi_tenant=MULTI_TENANT,
|
||||
schema_name=self.index_name,
|
||||
dim=primary_embedding_dim,
|
||||
embedding_precision=primary_embedding_precision.value,
|
||||
)
|
||||
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
@ -266,12 +249,13 @@ class VespaIndex(DocumentIndex):
|
||||
if secondary_index_embedding_precision is None:
|
||||
raise ValueError("Secondary index embedding precision is required")
|
||||
|
||||
upcoming_schema = _replace_template_values_in_schema(
|
||||
schema_template,
|
||||
self.secondary_index_name,
|
||||
secondary_index_embedding_dim,
|
||||
secondary_index_embedding_precision,
|
||||
upcoming_schema = template.render(
|
||||
multi_tenant=MULTI_TENANT,
|
||||
schema_name=self.secondary_index_name,
|
||||
dim=secondary_index_embedding_dim,
|
||||
embedding_precision=secondary_index_embedding_precision.value,
|
||||
)
|
||||
|
||||
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
|
||||
|
||||
zip_file = in_memory_zip_from_file_bytes(zip_dict)
|
||||
@ -301,23 +285,26 @@ class VespaIndex(DocumentIndex):
|
||||
vespa_schema_path = os.path.join(
|
||||
os.getcwd(), "onyx", "document_index", "vespa", "app_config"
|
||||
)
|
||||
schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd")
|
||||
services_file = os.path.join(vespa_schema_path, "services.xml")
|
||||
overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml")
|
||||
schema_jinja_file = os.path.join(
|
||||
vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME
|
||||
)
|
||||
services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja")
|
||||
overrides_jinja_file = os.path.join(
|
||||
vespa_schema_path, "validation-overrides.xml.jinja"
|
||||
)
|
||||
|
||||
with open(services_file, "r") as services_f:
|
||||
services_template = services_f.read()
|
||||
jinja_env = jinja2.Environment()
|
||||
|
||||
# Generate schema names from index settings
|
||||
with open(services_jinja_file, "r") as services_f:
|
||||
schema_names = [index_name for index_name in indices]
|
||||
doc_lines = _create_document_xml_lines(schema_names)
|
||||
|
||||
full_schemas = schema_names
|
||||
|
||||
doc_lines = _create_document_xml_lines(full_schemas)
|
||||
|
||||
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
|
||||
services = services.replace(
|
||||
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
|
||||
services_template_str = services_f.read()
|
||||
services_template = jinja_env.from_string(services_template_str)
|
||||
services = services_template.render(
|
||||
document_elements=doc_lines,
|
||||
num_search_threads=str(VESPA_SEARCHER_THREADS),
|
||||
)
|
||||
|
||||
kv_store = get_shared_kv_store()
|
||||
@ -328,24 +315,28 @@ class VespaIndex(DocumentIndex):
|
||||
except Exception:
|
||||
logger.debug("Could not load the reindexing flag. Using ngrams")
|
||||
|
||||
with open(overrides_file, "r") as overrides_f:
|
||||
overrides_template = overrides_f.read()
|
||||
|
||||
# Vespa requires an override to erase data including the indices we're no longer using
|
||||
# It also has a 30 day cap from current so we set it to 7 dynamically
|
||||
with open(overrides_jinja_file, "r") as overrides_f:
|
||||
overrides_template_str = overrides_f.read()
|
||||
overrides_template = jinja_env.from_string(overrides_template_str)
|
||||
|
||||
now = datetime.now()
|
||||
date_in_7_days = now + timedelta(days=7)
|
||||
formatted_date = date_in_7_days.strftime("%Y-%m-%d")
|
||||
|
||||
overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date)
|
||||
overrides = overrides_template.render(
|
||||
until_date=formatted_date,
|
||||
)
|
||||
|
||||
zip_dict = {
|
||||
"services.xml": services.encode("utf-8"),
|
||||
"validation-overrides.xml": overrides.encode("utf-8"),
|
||||
}
|
||||
|
||||
with open(schema_file, "r") as schema_f:
|
||||
schema_template = schema_f.read()
|
||||
with open(schema_jinja_file, "r") as schema_f:
|
||||
schema_template_str = schema_f.read()
|
||||
|
||||
schema_template = jinja_env.from_string(schema_template_str)
|
||||
|
||||
for i, index_name in enumerate(indices):
|
||||
embedding_dim = embedding_dims[i]
|
||||
@ -354,15 +345,11 @@ class VespaIndex(DocumentIndex):
|
||||
f"Creating index: {index_name} with embedding dimension: {embedding_dim}"
|
||||
)
|
||||
|
||||
schema = _replace_template_values_in_schema(
|
||||
schema_template, index_name, embedding_dim, embedding_precision
|
||||
)
|
||||
|
||||
tenant_id_replacement = ""
|
||||
if MULTI_TENANT:
|
||||
tenant_id_replacement = TENANT_ID_REPLACEMENT
|
||||
schema = _replace_tenant_template_value_in_schema(
|
||||
schema, tenant_id_replacement
|
||||
schema = schema_template.render(
|
||||
multi_tenant=MULTI_TENANT,
|
||||
schema_name=index_name,
|
||||
dim=embedding_dim,
|
||||
embedding_precision=embedding_precision.value,
|
||||
)
|
||||
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
|
@ -5,20 +5,6 @@ from onyx.configs.app_configs import VESPA_PORT
|
||||
from onyx.configs.app_configs import VESPA_TENANT_PORT
|
||||
from onyx.configs.constants import SOURCE_TYPE
|
||||
|
||||
VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM"
|
||||
EMBEDDING_PRECISION_REPLACEMENT_PAT = "EMBEDDING_PRECISION"
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME"
|
||||
DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT"
|
||||
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
|
||||
DATE_REPLACEMENT = "DATE_REPLACEMENT"
|
||||
SEARCH_THREAD_NUMBER_PAT = "SEARCH_THREAD_NUMBER"
|
||||
TENANT_ID_PAT = "TENANT_ID_REPLACEMENT"
|
||||
|
||||
TENANT_ID_REPLACEMENT = """field tenant_id type string {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
attribute: fast-search
|
||||
}"""
|
||||
# config server
|
||||
|
||||
|
||||
@ -31,7 +17,7 @@ VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2"
|
||||
VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}"
|
||||
|
||||
|
||||
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd
|
||||
# danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd.jinja
|
||||
DOCUMENT_ID_ENDPOINT = (
|
||||
f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid"
|
||||
)
|
||||
|
@ -2,43 +2,47 @@
|
||||
|
||||
import argparse
|
||||
|
||||
import jinja2
|
||||
|
||||
from onyx.db.enums import EmbeddingPrecision
|
||||
from onyx.document_index.vespa.index import _replace_template_values_in_schema
|
||||
from onyx.document_index.vespa.index import _replace_tenant_template_value_in_schema
|
||||
from onyx.document_index.vespa_constants import TENANT_ID_REPLACEMENT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import SUPPORTED_EMBEDDING_MODELS
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def write_schema(index_name: str, dim: int, template: str) -> None:
|
||||
def write_schema(index_name: str, dim: int, template: jinja2.Template) -> None:
|
||||
index_filename = index_name + ".sd"
|
||||
index_rendered_str = _replace_tenant_template_value_in_schema(
|
||||
template, TENANT_ID_REPLACEMENT
|
||||
)
|
||||
index_rendered_str = _replace_template_values_in_schema(
|
||||
index_rendered_str, index_name, dim, EmbeddingPrecision.FLOAT
|
||||
|
||||
schema = template.render(
|
||||
multi_tenant=True,
|
||||
schema_name=index_name,
|
||||
dim=dim,
|
||||
embedding_precision=EmbeddingPrecision.FLOAT.value,
|
||||
)
|
||||
|
||||
with open(index_filename, "w", encoding="utf-8") as f:
|
||||
f.write(index_rendered_str)
|
||||
f.write(schema)
|
||||
|
||||
logger.info(f"Wrote {index_filename}")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Generate multi tenant Vespa schemas")
|
||||
parser.add_argument("--template", help="The schema template to use", required=True)
|
||||
parser.add_argument("--template", help="The Jinja template to use", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
jinja_env = jinja2.Environment()
|
||||
|
||||
with open(args.template, "r", encoding="utf-8") as f:
|
||||
template_str = f.read()
|
||||
|
||||
template = jinja_env.from_string(template_str)
|
||||
|
||||
num_indexes = 0
|
||||
for model in SUPPORTED_EMBEDDING_MODELS:
|
||||
write_schema(model.index_name, model.dim, template_str)
|
||||
write_schema(model.index_name + "__danswer_alt_index", model.dim, template_str)
|
||||
write_schema(model.index_name, model.dim, template)
|
||||
write_schema(model.index_name + "__danswer_alt_index", model.dim, template)
|
||||
num_indexes += 2
|
||||
|
||||
logger.info(f"Wrote {num_indexes} indexes.")
|
||||
|
@ -156,7 +156,7 @@ def reset_postgres(
|
||||
"""Reset the Postgres database."""
|
||||
# this seems to hang due to locking issues, so run with a timeout with a few retries
|
||||
NUM_TRIES = 10
|
||||
TIMEOUT = 10
|
||||
TIMEOUT = 40
|
||||
success = False
|
||||
for _ in range(NUM_TRIES):
|
||||
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
|
||||
|
@ -1,20 +1,70 @@
|
||||
# import multiprocessing
|
||||
# from collections.abc import Callable
|
||||
# from typing import Any
|
||||
# from typing import TypeVar
|
||||
|
||||
# T = TypeVar("T")
|
||||
|
||||
|
||||
# def run_with_timeout_multiproc(
|
||||
# task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
|
||||
# ) -> T:
|
||||
# # Use multiprocessing to prevent a thread from blocking the main thread
|
||||
# with multiprocessing.Pool(processes=1) as pool:
|
||||
# async_result = pool.apply_async(task, kwds=kwargs)
|
||||
# try:
|
||||
# # Wait at most timeout seconds for the function to complete
|
||||
# result = async_result.get(timeout=timeout)
|
||||
# return result
|
||||
# except multiprocessing.TimeoutError:
|
||||
# raise TimeoutError(f"Function timed out after {timeout} seconds")
|
||||
|
||||
|
||||
import multiprocessing
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from multiprocessing import Queue
|
||||
from typing import Any
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def _multiproc_wrapper(
|
||||
task: Callable[..., T], kwargs: dict[str, Any], q: Queue
|
||||
) -> None:
|
||||
try:
|
||||
result = task(**kwargs)
|
||||
q.put(("success", result))
|
||||
except Exception:
|
||||
q.put(("error", traceback.format_exc()))
|
||||
|
||||
|
||||
def run_with_timeout_multiproc(
|
||||
task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
|
||||
) -> T:
|
||||
# Use multiprocessing to prevent a thread from blocking the main thread
|
||||
with multiprocessing.Pool(processes=1) as pool:
|
||||
async_result = pool.apply_async(task, kwds=kwargs)
|
||||
try:
|
||||
# Wait at most timeout seconds for the function to complete
|
||||
result = async_result.get(timeout=timeout)
|
||||
ctx = multiprocessing.get_context("spawn")
|
||||
q: Queue = ctx.Queue()
|
||||
p = ctx.Process(
|
||||
target=_multiproc_wrapper,
|
||||
args=(
|
||||
task,
|
||||
kwargs,
|
||||
q,
|
||||
),
|
||||
)
|
||||
p.start()
|
||||
p.join(timeout)
|
||||
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
raise TimeoutError(f"{task.__name__} timed out after {timeout} seconds")
|
||||
|
||||
if not q.empty():
|
||||
status, result = q.get()
|
||||
if status == "success":
|
||||
return result
|
||||
except multiprocessing.TimeoutError:
|
||||
raise TimeoutError(f"Function timed out after {timeout} seconds")
|
||||
else:
|
||||
raise RuntimeError(f"{task.__name__} failed:\n{result}")
|
||||
else:
|
||||
raise RuntimeError(f"{task.__name__} returned no result")
|
||||
|
Reference in New Issue
Block a user