From b59912884bed7a32a0858d2e634d5402647c5ea6 Mon Sep 17 00:00:00 2001
From: Yuhong Sun <yuhongsun96@gmail.com>
Date: Wed, 10 Apr 2024 23:13:22 -0700
Subject: [PATCH] Fix Model Server (#1320)

---
 backend/Dockerfile.model_server                  |  8 +-------
 backend/danswer/background/update.py             |  6 +++---
 backend/danswer/configs/app_configs.py           | 16 +---------------
 .../danswerbot/slack/handlers/handle_message.py  |  2 +-
 backend/danswer/danswerbot/slack/listener.py     |  4 ++--
 backend/danswer/indexing/embedder.py             |  4 ++--
 backend/danswer/llm/utils.py                     |  2 +-
 backend/danswer/main.py                          |  6 +++---
 backend/danswer/search/models.py                 |  2 +-
 .../search/preprocessing/preprocessing.py        |  2 +-
 .../danswer/search/retrieval/search_runner.py    |  4 ++--
 backend/danswer/search/search_nlp_models.py      |  4 ++--
 backend/danswer/utils/logger.py                  |  2 +-
 backend/model_server/custom_models.py            |  6 +++---
 backend/model_server/encoders.py                 |  6 +++---
 backend/model_server/main.py                     | 12 ++++++------
 backend/requirements/model_server.txt            |  1 +
 .../{nlp_model_configs.py => configs.py}         | 14 ++++++++++++++
 18 files changed, 48 insertions(+), 53 deletions(-)
 rename backend/shared_configs/{nlp_model_configs.py => configs.py} (57%)

diff --git a/backend/Dockerfile.model_server b/backend/Dockerfile.model_server
index 0eb455c51..cb7115c0b 100644
--- a/backend/Dockerfile.model_server
+++ b/backend/Dockerfile.model_server
@@ -13,19 +13,13 @@ RUN apt-get remove -y --allow-remove-essential perl-base && \
 
 WORKDIR /app
 
-# Needed for model configs and defaults
-COPY ./danswer/configs /app/danswer/configs
-COPY ./danswer/dynamic_configs /app/danswer/dynamic_configs
-
 # Utils used by model server
 COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
-COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
-COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
 
 # Place to fetch version information
 COPY ./danswer/__init__.py /app/danswer/__init__.py
 
-# Request/Response models
+# Shared between Danswer Backend and Model Server
 COPY ./shared_configs /app/shared_configs
 
 # Model Server main code
diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py
index 8d8de8da4..6042e02b1 100755
--- a/backend/danswer/background/update.py
+++ b/backend/danswer/background/update.py
@@ -15,9 +15,6 @@ from danswer.background.indexing.run_indexing import run_indexing_entrypoint
 from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
 from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
 from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
-from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
-from danswer.configs.app_configs import LOG_LEVEL
-from danswer.configs.app_configs import MODEL_SERVER_PORT
 from danswer.configs.app_configs import NUM_INDEXING_WORKERS
 from danswer.db.connector import fetch_connectors
 from danswer.db.connector_credential_pair import get_connector_credential_pairs
@@ -46,6 +43,9 @@ from danswer.db.models import IndexingStatus
 from danswer.db.models import IndexModelStatus
 from danswer.search.search_nlp_models import warm_up_encoders
 from danswer.utils.logger import setup_logger
+from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
+from shared_configs.configs import LOG_LEVEL
+from shared_configs.configs import MODEL_SERVER_PORT
 
 logger = setup_logger()
 
diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py
index b8bcc97f6..1e4809d07 100644
--- a/backend/danswer/configs/app_configs.py
+++ b/backend/danswer/configs/app_configs.py
@@ -209,19 +209,6 @@ DISABLE_DOCUMENT_CLEANUP = (
 )
 
 
-#####
-# Model Server Configs
-#####
-MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
-MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
-MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
-# Model server for indexing should use a separate one to not allow indexing to introduce delay
-# for inference
-INDEXING_MODEL_SERVER_HOST = (
-    os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
-)
-
-
 #####
 # Miscellaneous
 #####
@@ -246,8 +233,7 @@ LOG_VESPA_TIMING_INFORMATION = (
 )
 # Anonymous usage telemetry
 DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true"
-# notset, debug, info, warning, error, or critical
-LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
+
 TOKEN_BUDGET_GLOBALLY_ENABLED = (
     os.environ.get("TOKEN_BUDGET_GLOBALLY_ENABLED", "").lower() == "true"
 )
diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py
index 0886c0c17..fc1c038ae 100644
--- a/backend/danswer/danswerbot/slack/handlers/handle_message.py
+++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py
@@ -51,7 +51,7 @@ from danswer.search.models import BaseFilters
 from danswer.search.models import OptionalSearchSetting
 from danswer.search.models import RetrievalDetails
 from danswer.utils.logger import setup_logger
-from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
+from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
 
 logger_base = setup_logger()
 
diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py
index 08aa58411..7f935c1f0 100644
--- a/backend/danswer/danswerbot/slack/listener.py
+++ b/backend/danswer/danswerbot/slack/listener.py
@@ -10,8 +10,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
 from slack_sdk.socket_mode.response import SocketModeResponse
 from sqlalchemy.orm import Session
 
-from danswer.configs.app_configs import MODEL_SERVER_HOST
-from danswer.configs.app_configs import MODEL_SERVER_PORT
 from danswer.configs.constants import MessageType
 from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
 from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
@@ -47,6 +45,8 @@ from danswer.one_shot_answer.models import ThreadMessage
 from danswer.search.search_nlp_models import warm_up_encoders
 from danswer.server.manage.models import SlackBotTokens
 from danswer.utils.logger import setup_logger
+from shared_configs.configs import MODEL_SERVER_HOST
+from shared_configs.configs import MODEL_SERVER_PORT
 
 logger = setup_logger()
 
diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py
index 446122df1..20a8690e3 100644
--- a/backend/danswer/indexing/embedder.py
+++ b/backend/danswer/indexing/embedder.py
@@ -4,8 +4,6 @@ from abc import abstractmethod
 from sqlalchemy.orm import Session
 
 from danswer.configs.app_configs import ENABLE_MINI_CHUNK
-from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
-from danswer.configs.app_configs import MODEL_SERVER_PORT
 from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
 from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
 from danswer.db.embedding_model import get_current_db_embedding_model
@@ -20,6 +18,8 @@ from danswer.search.enums import EmbedTextType
 from danswer.search.search_nlp_models import EmbeddingModel
 from danswer.utils.batching import batch_list
 from danswer.utils.logger import setup_logger
+from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
+from shared_configs.configs import MODEL_SERVER_PORT
 
 
 logger = setup_logger()
diff --git a/backend/danswer/llm/utils.py b/backend/danswer/llm/utils.py
index b41c85b9e..05b36f6ff 100644
--- a/backend/danswer/llm/utils.py
+++ b/backend/danswer/llm/utils.py
@@ -20,7 +20,6 @@ from langchain.schema.messages import HumanMessage
 from langchain.schema.messages import SystemMessage
 from tiktoken.core import Encoding
 
-from danswer.configs.app_configs import LOG_LEVEL
 from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
 from danswer.configs.constants import GEN_AI_DETECTED_MODEL
 from danswer.configs.constants import MessageType
@@ -37,6 +36,7 @@ from danswer.dynamic_configs.interface import ConfigNotFoundError
 from danswer.indexing.models import InferenceChunk
 from danswer.llm.interfaces import LLM
 from danswer.utils.logger import setup_logger
+from shared_configs.configs import LOG_LEVEL
 
 if TYPE_CHECKING:
     from danswer.llm.answering.models import PreviousMessage
diff --git a/backend/danswer/main.py b/backend/danswer/main.py
index 9ce32fe01..3fb9a1175 100644
--- a/backend/danswer/main.py
+++ b/backend/danswer/main.py
@@ -28,8 +28,6 @@ from danswer.configs.app_configs import APP_PORT
 from danswer.configs.app_configs import AUTH_TYPE
 from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
 from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
-from danswer.configs.app_configs import MODEL_SERVER_HOST
-from danswer.configs.app_configs import MODEL_SERVER_PORT
 from danswer.configs.app_configs import OAUTH_CLIENT_ID
 from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
 from danswer.configs.app_configs import SECRET
@@ -81,7 +79,9 @@ from danswer.utils.logger import setup_logger
 from danswer.utils.telemetry import optional_telemetry
 from danswer.utils.telemetry import RecordType
 from danswer.utils.variable_functionality import fetch_versioned_implementation
-from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
+from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
+from shared_configs.configs import MODEL_SERVER_HOST
+from shared_configs.configs import MODEL_SERVER_PORT
 
 
 logger = setup_logger()
diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py
index 9d3eb39b0..7fc247fa4 100644
--- a/backend/danswer/search/models.py
+++ b/backend/danswer/search/models.py
@@ -11,7 +11,7 @@ from danswer.configs.constants import DocumentSource
 from danswer.db.models import Persona
 from danswer.search.enums import OptionalSearchSetting
 from danswer.search.enums import SearchType
-from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
+from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
 
 
 MAX_METRICS_CONTENT = (
diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py
index 7da6db4ce..ec9fc2dae 100644
--- a/backend/danswer/search/preprocessing/preprocessing.py
+++ b/backend/danswer/search/preprocessing/preprocessing.py
@@ -21,7 +21,7 @@ from danswer.utils.logger import setup_logger
 from danswer.utils.threadpool_concurrency import FunctionCall
 from danswer.utils.threadpool_concurrency import run_functions_in_parallel
 from danswer.utils.timing import log_function_time
-from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
+from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
 
 
 logger = setup_logger()
diff --git a/backend/danswer/search/retrieval/search_runner.py b/backend/danswer/search/retrieval/search_runner.py
index bb1725392..1189053db 100644
--- a/backend/danswer/search/retrieval/search_runner.py
+++ b/backend/danswer/search/retrieval/search_runner.py
@@ -7,8 +7,6 @@ from nltk.tokenize import word_tokenize  # type:ignore
 from sqlalchemy.orm import Session
 
 from danswer.chat.models import LlmDoc
-from danswer.configs.app_configs import MODEL_SERVER_HOST
-from danswer.configs.app_configs import MODEL_SERVER_PORT
 from danswer.configs.chat_configs import HYBRID_ALPHA
 from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
 from danswer.db.embedding_model import get_current_db_embedding_model
@@ -26,6 +24,8 @@ from danswer.secondary_llm_flows.query_expansion import multilingual_query_expan
 from danswer.utils.logger import setup_logger
 from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
 from danswer.utils.timing import log_function_time
+from shared_configs.configs import MODEL_SERVER_HOST
+from shared_configs.configs import MODEL_SERVER_PORT
 
 
 logger = setup_logger()
diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py
index 95bd4d0f2..39d762238 100644
--- a/backend/danswer/search/search_nlp_models.py
+++ b/backend/danswer/search/search_nlp_models.py
@@ -7,12 +7,12 @@ from typing import TYPE_CHECKING
 import requests
 from transformers import logging as transformer_logging  # type:ignore
 
-from danswer.configs.app_configs import MODEL_SERVER_HOST
-from danswer.configs.app_configs import MODEL_SERVER_PORT
 from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
 from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
 from danswer.search.enums import EmbedTextType
 from danswer.utils.logger import setup_logger
+from shared_configs.configs import MODEL_SERVER_HOST
+from shared_configs.configs import MODEL_SERVER_PORT
 from shared_configs.model_server_models import EmbedRequest
 from shared_configs.model_server_models import EmbedResponse
 from shared_configs.model_server_models import IntentRequest
diff --git a/backend/danswer/utils/logger.py b/backend/danswer/utils/logger.py
index c4dd59742..38e24a367 100644
--- a/backend/danswer/utils/logger.py
+++ b/backend/danswer/utils/logger.py
@@ -3,7 +3,7 @@ import os
 from collections.abc import MutableMapping
 from typing import Any
 
-from danswer.configs.app_configs import LOG_LEVEL
+from shared_configs.configs import LOG_LEVEL
 
 
 class IndexAttemptSingleton:
diff --git a/backend/model_server/custom_models.py b/backend/model_server/custom_models.py
index 9b8066e96..ee97ded78 100644
--- a/backend/model_server/custom_models.py
+++ b/backend/model_server/custom_models.py
@@ -8,11 +8,11 @@ from transformers import TFDistilBertForSequenceClassification
 
 from model_server.constants import MODEL_WARM_UP_STRING
 from model_server.utils import simple_log_function_time
+from shared_configs.configs import INDEXING_ONLY
+from shared_configs.configs import INTENT_MODEL_CONTEXT_SIZE
+from shared_configs.configs import INTENT_MODEL_VERSION
 from shared_configs.model_server_models import IntentRequest
 from shared_configs.model_server_models import IntentResponse
-from shared_configs.nlp_model_configs import INDEXING_ONLY
-from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE
-from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION
 
 
 router = APIRouter(prefix="/custom")
diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py
index f1f3fdf0c..705386a8c 100644
--- a/backend/model_server/encoders.py
+++ b/backend/model_server/encoders.py
@@ -9,13 +9,13 @@ from sentence_transformers import SentenceTransformer  # type: ignore
 from danswer.utils.logger import setup_logger
 from model_server.constants import MODEL_WARM_UP_STRING
 from model_server.utils import simple_log_function_time
+from shared_configs.configs import CROSS_EMBED_CONTEXT_SIZE
+from shared_configs.configs import CROSS_ENCODER_MODEL_ENSEMBLE
+from shared_configs.configs import INDEXING_ONLY
 from shared_configs.model_server_models import EmbedRequest
 from shared_configs.model_server_models import EmbedResponse
 from shared_configs.model_server_models import RerankRequest
 from shared_configs.model_server_models import RerankResponse
-from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE
-from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
-from shared_configs.nlp_model_configs import INDEXING_ONLY
 
 logger = setup_logger()
 
diff --git a/backend/model_server/main.py b/backend/model_server/main.py
index aaac1d0d1..c7b2a2f93 100644
--- a/backend/model_server/main.py
+++ b/backend/model_server/main.py
@@ -8,17 +8,17 @@ from fastapi import FastAPI
 from transformers import logging as transformer_logging  # type:ignore
 
 from danswer import __version__
-from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
-from danswer.configs.app_configs import MODEL_SERVER_PORT
 from danswer.utils.logger import setup_logger
 from model_server.custom_models import router as custom_models_router
 from model_server.custom_models import warm_up_intent_model
 from model_server.encoders import router as encoders_router
 from model_server.encoders import warm_up_cross_encoders
-from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
-from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
-from shared_configs.nlp_model_configs import INDEXING_ONLY
-from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS
+from shared_configs.configs import ENABLE_RERANKING_ASYNC_FLOW
+from shared_configs.configs import ENABLE_RERANKING_REAL_TIME_FLOW
+from shared_configs.configs import INDEXING_ONLY
+from shared_configs.configs import MIN_THREADS_ML_MODELS
+from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
+from shared_configs.configs import MODEL_SERVER_PORT
 
 os.environ["TOKENIZERS_PARALLELISM"] = "false"
 os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
diff --git a/backend/requirements/model_server.txt b/backend/requirements/model_server.txt
index 487e6338d..8f133657b 100644
--- a/backend/requirements/model_server.txt
+++ b/backend/requirements/model_server.txt
@@ -1,4 +1,5 @@
 fastapi==0.109.2
+h5py==3.9.0
 pydantic==1.10.7
 safetensors==0.4.2
 sentence-transformers==2.6.1
diff --git a/backend/shared_configs/nlp_model_configs.py b/backend/shared_configs/configs.py
similarity index 57%
rename from backend/shared_configs/nlp_model_configs.py
rename to backend/shared_configs/configs.py
index cc58a56b0..41b46723e 100644
--- a/backend/shared_configs/nlp_model_configs.py
+++ b/backend/shared_configs/configs.py
@@ -1,6 +1,15 @@
 import os
 
 
+MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
+MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
+MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
+# Model server for indexing should use a separate one to not allow indexing to introduce delay
+# for inference
+INDEXING_MODEL_SERVER_HOST = (
+    os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
+)
+
 # Danswer custom Deep Learning Models
 INTENT_MODEL_VERSION = "danswer/intent-model"
 INTENT_MODEL_CONTEXT_SIZE = 256
@@ -23,4 +32,9 @@ CROSS_EMBED_CONTEXT_SIZE = 512
 # model. If torch finds more threads on its own, this value is not used.
 MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
 
+# Model server that has indexing only set will throw exception if used for reranking
+# or intent classification
 INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"
+
+# notset, debug, info, warning, error, or critical
+LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")