Allow different model servers for different models / indexing jobs

This commit is contained in:
Weves 2023-11-23 23:24:50 -08:00 committed by Chris Weaver
parent 26c6651a03
commit 3cec854c5c
3 changed files with 69 additions and 7 deletions

View File

@ -232,6 +232,24 @@ MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
EMBEDDING_MODEL_SERVER_HOST = (
os.environ.get("EMBEDDING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
)
CROSS_ENCODER_MODEL_SERVER_HOST = (
os.environ.get("CROSS_ENCODER_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
)
INTENT_MODEL_SERVER_HOST = (
os.environ.get("INTENT_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
)
# specify this env variable directly to have a different model server for the background
# indexing job vs the api server so that background indexing does not effect query-time
# performance
BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST = (
os.environ.get("BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST")
or EMBEDDING_MODEL_SERVER_HOST
)
#####
# Miscellaneous
@ -242,6 +260,11 @@ DYNAMIC_CONFIG_STORE = os.environ.get(
)
DYNAMIC_CONFIG_DIR_PATH = os.environ.get("DYNAMIC_CONFIG_DIR_PATH", "/home/storage")
JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default
# used to allow the background indexing jobs to use a different embedding
# model server than the API server
CURRENT_PROCESS_IS_AN_INDEXING_JOB = (
os.environ.get("CURRENT_PROCESS_IS_AN_INDEXING_JOB", "").lower() == "true"
)
# Logs every model prompt and output, mostly used for development or exploration purposes
LOG_ALL_MODEL_INTERACTIONS = (
os.environ.get("LOG_ALL_MODEL_INTERACTIONS", "").lower() == "true"

View File

@ -6,7 +6,11 @@ from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST
from danswer.configs.app_configs import CROSS_ENCODER_MODEL_SERVER_HOST
from danswer.configs.app_configs import CURRENT_PROCESS_IS_AN_INDEXING_JOB
from danswer.configs.app_configs import EMBEDDING_MODEL_SERVER_HOST
from danswer.configs.app_configs import INTENT_MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
@ -88,18 +92,46 @@ def get_local_intent_model(
return _INTENT_MODEL
def build_model_server_url(
model_server_host: str,
model_server_port: int | None,
) -> str:
model_server_url = model_server_host + (
f":{model_server_port}" if model_server_port else ""
)
# use protocol if provided
if "http" in model_server_url:
return model_server_url
# otherwise default to http
return f"http://{model_server_url}"
class EmbeddingModel:
def __init__(
self,
model_name: str = DOCUMENT_ENCODER_MODEL,
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
# `model_server_host` one has to default to `None` since it's
# default value is conditional
model_server_host: str | None = None,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
if model_server_host is None:
model_server_host = (
BACKGROUND_JOB_EMBEDDING_MODEL_SERVER_HOST
if CURRENT_PROCESS_IS_AN_INDEXING_JOB
else EMBEDDING_MODEL_SERVER_HOST
)
self.model_name = model_name
self.max_seq_length = max_seq_length
self.embed_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/encoder/bi-encoder-embed"
(
build_model_server_url(model_server_host, model_server_port)
+ "/encoder/bi-encoder-embed"
)
if model_server_host
else None
)
@ -144,13 +176,16 @@ class CrossEncoderEnsembleModel:
self,
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_host: str | None = CROSS_ENCODER_MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_names = model_names
self.max_seq_length = max_seq_length
self.rerank_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/encoder/cross-encoder-scores"
(
build_model_server_url(model_server_host, model_server_port)
+ "/encoder/cross-encoder-scores"
)
if model_server_host
else None
)
@ -196,13 +231,16 @@ class IntentModel:
self,
model_name: str = INTENT_MODEL_VERSION,
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_host: str | None = INTENT_MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_name = model_name
self.max_seq_length = max_seq_length
self.intent_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/custom/intent-model"
(
build_model_server_url(model_server_host, model_server_port)
+ "/custom/intent-model"
)
if model_server_host
else None
)

View File

@ -6,6 +6,7 @@ logfile=/var/log/supervisord.log
# Cannot place this in Celery for now because Celery must run as a single process (see note below)
# Indexing uses multi-processing to speed things up
[program:document_indexing]
environment=CURRENT_PROCESS_IS_AN_INDEXING_JOB=true
command=python danswer/background/update.py
stdout_logfile=/var/log/update.log
stdout_logfile_maxbytes=52428800