mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-31 02:01:16 +02:00
Allow different model servers for different models / indexing jobs
This commit is contained in:
parent
26c6651a03
commit
3cec854c5c
@ -232,6 +232,24 @@ MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
|
||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||
|
||||
EMBEDDING_MODEL_SERVER_HOST = (
|
||||
os.environ.get("EMBEDDING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
CROSS_ENCODER_MODEL_SERVER_HOST = (
|
||||
os.environ.get("CROSS_ENCODER_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
INTENT_MODEL_SERVER_HOST = (
|
||||
os.environ.get("INTENT_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
|
||||
# specify this env variable directly to have a different model server for the background
|
||||
# indexing job vs the api server so that background indexing does not effect query-time
|
||||
# performance
|
||||
BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST = (
|
||||
os.environ.get("BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST")
|
||||
or EMBEDDING_MODEL_SERVER_HOST
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
@ -242,6 +260,11 @@ DYNAMIC_CONFIG_STORE = os.environ.get(
|
||||
)
|
||||
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
|
||||
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
|
||||
# used to allow the background indexing jobs to use a different embedding
|
||||
# model server than the API server
|
||||
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
|
||||
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
|
||||
)
|
||||
# Logs every model prompt and output, mostly used for development or exploration purposes
|
||||
LOG_ALL_MODEL_INTERACTIONS = (
|
||||
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"
|
||||
|
@ -6,7 +6,11 @@ from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import CROSS_ENCODER_MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import CURRENT_PROCESS_IS_AN_INDEXING_JOB
|
||||
from danswer.configs.app_configs import EMBEDDING_MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import INTENT_MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
@ -88,18 +92,46 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def build_model_server_url(
|
||||
model_server_host: str,
|
||||
model_server_port: int | None,
|
||||
) -> str:
|
||||
model_server_url = model_server_host + (
|
||||
f":{model_server_port}" if model_server_port else ""
|
||||
)
|
||||
|
||||
# use protocol if provided
|
||||
if "http" in model_server_url:
|
||||
return model_server_url
|
||||
|
||||
# otherwise default to http
|
||||
return f"http://{model_server_url}"
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DOCUMENT_ENCODER_MODEL,
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
# `model_server_host` one has to default to `None` since it's
|
||||
# default value is conditional
|
||||
model_server_host: str | None = None,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
if model_server_host is None:
|
||||
model_server_host = (
|
||||
BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST
|
||||
if CURRENT_PROCESS_IS_AN_INDEXING_JOB
|
||||
else EMBEDDING_MODEL_SERVER_HOST
|
||||
)
|
||||
|
||||
self.model_name = model_name
|
||||
self.max_seq_length = max_seq_length
|
||||
self.embed_server_endpoint = (
|
||||
f"http://{model_server_host}:{model_server_port}/encoder/bi-encoder-embed"
|
||||
(
|
||||
build_model_server_url(model_server_host, model_server_port)
|
||||
+ "/encoder/bi-encoder-embed"
|
||||
)
|
||||
if model_server_host
|
||||
else None
|
||||
)
|
||||
@ -144,13 +176,16 @@ class CrossEncoderEnsembleModel:
|
||||
self,
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
model_server_host: str | None = CROSS_ENCODER_MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
self.model_names = model_names
|
||||
self.max_seq_length = max_seq_length
|
||||
self.rerank_server_endpoint = (
|
||||
f"http://{model_server_host}:{model_server_port}/encoder/cross-encoder-scores"
|
||||
(
|
||||
build_model_server_url(model_server_host, model_server_port)
|
||||
+ "/encoder/cross-encoder-scores"
|
||||
)
|
||||
if model_server_host
|
||||
else None
|
||||
)
|
||||
@ -196,13 +231,16 @@ class IntentModel:
|
||||
self,
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
model_server_host: str | None = INTENT_MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.max_seq_length = max_seq_length
|
||||
self.intent_server_endpoint = (
|
||||
f"http://{model_server_host}:{model_server_port}/custom/intent-model"
|
||||
(
|
||||
build_model_server_url(model_server_host, model_server_port)
|
||||
+ "/custom/intent-model"
|
||||
)
|
||||
if model_server_host
|
||||
else None
|
||||
)
|
||||
|
@ -6,6 +6,7 @@ logfile=/var/log/supervisord.log
|
||||
# Cannot place this in Celery for now because Celery must run as a single process (see note below)
|
||||
# Indexing uses multi-processing to speed things up
|
||||
[program:document_indexing]
|
||||
environment=CURRENT_PROCESS_IS_AN_INDEXING_JOB=true
|
||||
command=python danswer/background/update.py
|
||||
stdout_logfile=/var/log/update.log
|
||||
stdout_logfile_maxbytes=52428800
|
||||
|
Loading…
x
Reference in New Issue
Block a user