mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-24 23:11:01 +02:00
Always Use Model Server (#1306)
This commit is contained in:
parent
795243283d
commit
2db906b7a2
2
.github/workflows/pr-python-checks.yml
vendored
2
.github/workflows/pr-python-checks.yml
vendored
@ -20,10 +20,12 @@ jobs:
|
|||||||
cache-dependency-path: |
|
cache-dependency-path: |
|
||||||
backend/requirements/default.txt
|
backend/requirements/default.txt
|
||||||
backend/requirements/dev.txt
|
backend/requirements/dev.txt
|
||||||
|
backend/requirements/model_server.txt
|
||||||
- run: |
|
- run: |
|
||||||
python -m pip install --upgrade pip
|
python -m pip install --upgrade pip
|
||||||
pip install -r backend/requirements/default.txt
|
pip install -r backend/requirements/default.txt
|
||||||
pip install -r backend/requirements/dev.txt
|
pip install -r backend/requirements/dev.txt
|
||||||
|
pip install -r backend/requirements/model_server.txt
|
||||||
|
|
||||||
- name: Run MyPy
|
- name: Run MyPy
|
||||||
run: |
|
run: |
|
||||||
|
@ -85,6 +85,7 @@ Install the required python dependencies:
|
|||||||
```bash
|
```bash
|
||||||
pip install -r danswer/backend/requirements/default.txt
|
pip install -r danswer/backend/requirements/default.txt
|
||||||
pip install -r danswer/backend/requirements/dev.txt
|
pip install -r danswer/backend/requirements/dev.txt
|
||||||
|
pip install -r danswer/backend/requirements/model_server.txt
|
||||||
```
|
```
|
||||||
|
|
||||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||||
@ -117,7 +118,19 @@ To start the frontend, navigate to `danswer/web` and run:
|
|||||||
npm run dev
|
npm run dev
|
||||||
```
|
```
|
||||||
|
|
||||||
The first time running Danswer, you will also need to run the DB migrations for Postgres.
|
Next, start the model server which runs the local NLP models.
|
||||||
|
Navigate to `danswer/backend` and run:
|
||||||
|
```bash
|
||||||
|
uvicorn model_server.main:app --reload --port 9000
|
||||||
|
```
|
||||||
|
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||||
|
```bash
|
||||||
|
powershell -Command "
|
||||||
|
uvicorn model_server.main:app --reload --port 9000
|
||||||
|
"
|
||||||
|
```
|
||||||
|
|
||||||
|
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||||
After the first time, this is no longer required unless the DB models change.
|
After the first time, this is no longer required unless the DB models change.
|
||||||
|
|
||||||
Navigate to `danswer/backend` and with the venv active, run:
|
Navigate to `danswer/backend` and with the venv active, run:
|
||||||
|
@ -40,7 +40,7 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma
|
|||||||
# Set up application files
|
# Set up application files
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY ./danswer /app/danswer
|
COPY ./danswer /app/danswer
|
||||||
COPY ./shared_models /app/shared_models
|
COPY ./shared_configs /app/shared_configs
|
||||||
COPY ./alembic /app/alembic
|
COPY ./alembic /app/alembic
|
||||||
COPY ./alembic.ini /app/alembic.ini
|
COPY ./alembic.ini /app/alembic.ini
|
||||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||||
|
@ -25,11 +25,8 @@ COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
|
|||||||
# Place to fetch version information
|
# Place to fetch version information
|
||||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||||
|
|
||||||
# Shared implementations for running NLP models locally
|
|
||||||
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
|
|
||||||
|
|
||||||
# Request/Response models
|
# Request/Response models
|
||||||
COPY ./shared_models /app/shared_models
|
COPY ./shared_configs /app/shared_configs
|
||||||
|
|
||||||
# Model Server main code
|
# Model Server main code
|
||||||
COPY ./model_server /app/model_server
|
COPY ./model_server /app/model_server
|
||||||
|
@ -6,18 +6,15 @@ NOTE: cannot use Celery directly due to
|
|||||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from multiprocessing import Process
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from torch.multiprocessing import Process
|
|
||||||
|
|
||||||
JobStatusType = (
|
JobStatusType = (
|
||||||
Literal["error"]
|
Literal["error"]
|
||||||
| Literal["finished"]
|
| Literal["finished"]
|
||||||
@ -89,8 +86,6 @@ class SimpleJobClient:
|
|||||||
|
|
||||||
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
||||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||||
from torch.multiprocessing import Process
|
|
||||||
|
|
||||||
self._cleanup_completed_jobs()
|
self._cleanup_completed_jobs()
|
||||||
if len(self.jobs) >= self.n_workers:
|
if len(self.jobs) >= self.n_workers:
|
||||||
logger.debug("No available workers to run job")
|
logger.debug("No available workers to run job")
|
||||||
|
@ -330,20 +330,15 @@ def _run_indexing(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
def run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||||
"""Entrypoint for indexing run when using dask distributed.
|
"""Entrypoint for indexing run when using dask distributed.
|
||||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||||
and mark the attempt as failed."""
|
and mark the attempt as failed."""
|
||||||
import torch
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# set the indexing attempt ID so that all log messages from this process
|
# set the indexing attempt ID so that all log messages from this process
|
||||||
# will have it added as a prefix
|
# will have it added as a prefix
|
||||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||||
|
|
||||||
logger.info(f"Setting task to use {num_threads} threads")
|
|
||||||
torch.set_num_threads(num_threads)
|
|
||||||
|
|
||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
attempt = get_index_attempt(
|
attempt = get_index_attempt(
|
||||||
db_session=db_session, index_attempt_id=index_attempt_id
|
db_session=db_session, index_attempt_id=index_attempt_id
|
||||||
|
@ -15,9 +15,10 @@ from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
|||||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||||
|
from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
|
||||||
from danswer.configs.app_configs import LOG_LEVEL
|
from danswer.configs.app_configs import LOG_LEVEL
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
|
||||||
from danswer.db.connector import fetch_connectors
|
from danswer.db.connector import fetch_connectors
|
||||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||||
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
||||||
@ -43,6 +44,7 @@ from danswer.db.models import EmbeddingModel
|
|||||||
from danswer.db.models import IndexAttempt
|
from danswer.db.models import IndexAttempt
|
||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
from danswer.db.models import IndexModelStatus
|
from danswer.db.models import IndexModelStatus
|
||||||
|
from danswer.search.search_nlp_models import warm_up_encoders
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -56,18 +58,6 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
"""Util funcs"""
|
|
||||||
|
|
||||||
|
|
||||||
def _get_num_threads() -> int:
|
|
||||||
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
|
||||||
the torch implementation, which returns the # of physical cores on the machine.
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
|
|
||||||
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
|
||||||
|
|
||||||
|
|
||||||
def _should_create_new_indexing(
|
def _should_create_new_indexing(
|
||||||
connector: Connector,
|
connector: Connector,
|
||||||
last_index: IndexAttempt | None,
|
last_index: IndexAttempt | None,
|
||||||
@ -346,12 +336,10 @@ def kickoff_indexing_jobs(
|
|||||||
|
|
||||||
if use_secondary_index:
|
if use_secondary_index:
|
||||||
run = secondary_client.submit(
|
run = secondary_client.submit(
|
||||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
run_indexing_entrypoint, attempt.id, pure=False
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
run = client.submit(
|
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
|
||||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
|
||||||
)
|
|
||||||
|
|
||||||
if run:
|
if run:
|
||||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||||
@ -409,6 +397,20 @@ def check_index_swap(db_session: Session) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||||
|
engine = get_sqlalchemy_engine()
|
||||||
|
with Session(engine) as db_session:
|
||||||
|
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||||
|
|
||||||
|
# So that the first time users aren't surprised by really slow speed of first
|
||||||
|
# batch of documents indexed
|
||||||
|
logger.info("Running a first inference to warm up embedding model")
|
||||||
|
warm_up_encoders(
|
||||||
|
model_name=db_embedding_model.model_name,
|
||||||
|
normalize=db_embedding_model.normalize,
|
||||||
|
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||||
|
model_server_port=MODEL_SERVER_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
client_primary: Client | SimpleJobClient
|
client_primary: Client | SimpleJobClient
|
||||||
client_secondary: Client | SimpleJobClient
|
client_secondary: Client | SimpleJobClient
|
||||||
if DASK_JOB_CLIENT_ENABLED:
|
if DASK_JOB_CLIENT_ENABLED:
|
||||||
@ -435,7 +437,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
|||||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||||
|
|
||||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||||
engine = get_sqlalchemy_engine()
|
|
||||||
|
|
||||||
with Session(engine) as db_session:
|
with Session(engine) as db_session:
|
||||||
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
|
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
|
||||||
@ -472,14 +473,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
|||||||
|
|
||||||
|
|
||||||
def update__main() -> None:
|
def update__main() -> None:
|
||||||
# needed for CUDA to work with multiprocessing
|
|
||||||
# NOTE: needs to be done on application startup
|
|
||||||
# before any other torch code has been run
|
|
||||||
import torch
|
|
||||||
|
|
||||||
if not DASK_JOB_CLIENT_ENABLED:
|
|
||||||
torch.multiprocessing.set_start_method("spawn")
|
|
||||||
|
|
||||||
logger.info("Starting Indexing Loop")
|
logger.info("Starting Indexing Loop")
|
||||||
update_loop()
|
update_loop()
|
||||||
|
|
||||||
|
@ -207,15 +207,11 @@ DISABLE_DOCUMENT_CLEANUP = (
|
|||||||
#####
|
#####
|
||||||
# Model Server Configs
|
# Model Server Configs
|
||||||
#####
|
#####
|
||||||
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
|
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
|
||||||
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
|
|
||||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
|
|
||||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||||
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||||
|
# Model server for indexing should use a separate one to not allow indexing to introduce delay
|
||||||
# specify this env variable directly to have a different model server for the background
|
# for inference
|
||||||
# indexing job vs the api server so that background indexing does not effect query-time
|
|
||||||
# performance
|
|
||||||
INDEXING_MODEL_SERVER_HOST = (
|
INDEXING_MODEL_SERVER_HOST = (
|
||||||
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||||
)
|
)
|
||||||
|
@ -37,33 +37,13 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
|
|||||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||||
# Purely an optimization, memory limitation consideration
|
# Purely an optimization, memory limitation consideration
|
||||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||||
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
# For score display purposes, only way is to know the expected ranges
|
||||||
# model. If torch finds more threads on its own, this value is not used.
|
|
||||||
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
|
||||||
|
|
||||||
# Cross Encoder Settings
|
|
||||||
ENABLE_RERANKING_ASYNC_FLOW = (
|
|
||||||
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
|
||||||
)
|
|
||||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
|
||||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
|
||||||
)
|
|
||||||
# Only using one for now
|
|
||||||
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
|
||||||
# For score normalizing purposes, only way is to know the expected ranges
|
|
||||||
CROSS_ENCODER_RANGE_MAX = 12
|
CROSS_ENCODER_RANGE_MAX = 12
|
||||||
CROSS_ENCODER_RANGE_MIN = -12
|
CROSS_ENCODER_RANGE_MIN = -12
|
||||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
|
||||||
|
|
||||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||||
SEARCH_DISTANCE_CUTOFF = 0
|
SEARCH_DISTANCE_CUTOFF = 0
|
||||||
|
|
||||||
# Intent model max context size
|
|
||||||
QUERY_MAX_CONTEXT_SIZE = 256
|
|
||||||
|
|
||||||
# Danswer custom Deep Learning Models
|
|
||||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Generative AI Model Configs
|
# Generative AI Model Configs
|
||||||
|
@ -22,7 +22,6 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
|||||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||||
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
|
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
|
||||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
|
||||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||||
@ -52,6 +51,7 @@ from danswer.search.models import BaseFilters
|
|||||||
from danswer.search.models import OptionalSearchSetting
|
from danswer.search.models import OptionalSearchSetting
|
||||||
from danswer.search.models import RetrievalDetails
|
from danswer.search.models import RetrievalDetails
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||||
|
|
||||||
logger_base = setup_logger()
|
logger_base = setup_logger()
|
||||||
|
|
||||||
|
@ -10,10 +10,11 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
|||||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
|
||||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||||
@ -43,7 +44,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
|||||||
from danswer.db.engine import get_sqlalchemy_engine
|
from danswer.db.engine import get_sqlalchemy_engine
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
from danswer.one_shot_answer.models import ThreadMessage
|
from danswer.one_shot_answer.models import ThreadMessage
|
||||||
from danswer.search.search_nlp_models import warm_up_models
|
from danswer.search.search_nlp_models import warm_up_encoders
|
||||||
from danswer.server.manage.models import SlackBotTokens
|
from danswer.server.manage.models import SlackBotTokens
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
@ -390,10 +391,11 @@ if __name__ == "__main__":
|
|||||||
with Session(get_sqlalchemy_engine()) as db_session:
|
with Session(get_sqlalchemy_engine()) as db_session:
|
||||||
embedding_model = get_current_db_embedding_model(db_session)
|
embedding_model = get_current_db_embedding_model(db_session)
|
||||||
|
|
||||||
warm_up_models(
|
warm_up_encoders(
|
||||||
model_name=embedding_model.model_name,
|
model_name=embedding_model.model_name,
|
||||||
normalize=embedding_model.normalize,
|
normalize=embedding_model.normalize,
|
||||||
skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW,
|
model_server_host=MODEL_SERVER_HOST,
|
||||||
|
model_server_port=MODEL_SERVER_PORT,
|
||||||
)
|
)
|
||||||
|
|
||||||
slack_bot_tokens = latest_slack_bot_tokens
|
slack_bot_tokens = latest_slack_bot_tokens
|
||||||
|
@ -16,8 +16,9 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
|||||||
from danswer.indexing.models import ChunkEmbedding
|
from danswer.indexing.models import ChunkEmbedding
|
||||||
from danswer.indexing.models import DocAwareChunk
|
from danswer.indexing.models import DocAwareChunk
|
||||||
from danswer.indexing.models import IndexChunk
|
from danswer.indexing.models import IndexChunk
|
||||||
|
from danswer.search.enums import EmbedTextType
|
||||||
from danswer.search.search_nlp_models import EmbeddingModel
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
from danswer.search.search_nlp_models import EmbedTextType
|
from danswer.utils.batching import batch_list
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
@ -73,6 +74,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
title_embed_dict: dict[str, list[float]] = {}
|
title_embed_dict: dict[str, list[float]] = {}
|
||||||
embedded_chunks: list[IndexChunk] = []
|
embedded_chunks: list[IndexChunk] = []
|
||||||
|
|
||||||
|
# Create Mini Chunks for more precise matching of details
|
||||||
|
# Off by default with unedited settings
|
||||||
chunk_texts = []
|
chunk_texts = []
|
||||||
chunk_mini_chunks_count = {}
|
chunk_mini_chunks_count = {}
|
||||||
for chunk_ind, chunk in enumerate(chunks):
|
for chunk_ind, chunk in enumerate(chunks):
|
||||||
@ -85,23 +88,41 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
chunk_texts.extend(mini_chunk_texts)
|
chunk_texts.extend(mini_chunk_texts)
|
||||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||||
|
|
||||||
text_batches = [
|
# Batching for embedding
|
||||||
chunk_texts[i : i + batch_size]
|
text_batches = batch_list(chunk_texts, batch_size)
|
||||||
for i in range(0, len(chunk_texts), batch_size)
|
|
||||||
]
|
|
||||||
|
|
||||||
embeddings: list[list[float]] = []
|
embeddings: list[list[float]] = []
|
||||||
len_text_batches = len(text_batches)
|
len_text_batches = len(text_batches)
|
||||||
for idx, text_batch in enumerate(text_batches, start=1):
|
for idx, text_batch in enumerate(text_batches, start=1):
|
||||||
logger.debug(f"Embedding text batch {idx} of {len_text_batches}")
|
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
|
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||||
|
# value for the set loss
|
||||||
embeddings.extend(
|
embeddings.extend(
|
||||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace line above with the line below for easy debugging of indexing flow, skipping the actual model
|
# Replace line above with the line below for easy debugging of indexing flow
|
||||||
|
# skipping the actual model
|
||||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||||
|
|
||||||
|
chunk_titles = {
|
||||||
|
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||||
|
}
|
||||||
|
chunk_titles.discard(None)
|
||||||
|
|
||||||
|
# Embed Titles in batches
|
||||||
|
title_batches = batch_list(list(chunk_titles), batch_size)
|
||||||
|
len_title_batches = len(title_batches)
|
||||||
|
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||||
|
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||||
|
title_embeddings = self.embedding_model.encode(
|
||||||
|
title_batch, text_type=EmbedTextType.PASSAGE
|
||||||
|
)
|
||||||
|
title_embed_dict.update(
|
||||||
|
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mapping embeddings to chunks
|
||||||
embedding_ind_start = 0
|
embedding_ind_start = 0
|
||||||
for chunk_ind, chunk in enumerate(chunks):
|
for chunk_ind, chunk in enumerate(chunks):
|
||||||
num_embeddings = chunk_mini_chunks_count[chunk_ind]
|
num_embeddings = chunk_mini_chunks_count[chunk_ind]
|
||||||
@ -114,9 +135,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
|||||||
title_embedding = None
|
title_embedding = None
|
||||||
if title:
|
if title:
|
||||||
if title in title_embed_dict:
|
if title in title_embed_dict:
|
||||||
# Using cached value for speedup
|
# Using cached value to avoid recalculating for every chunk
|
||||||
title_embedding = title_embed_dict[title]
|
title_embedding = title_embed_dict[title]
|
||||||
else:
|
else:
|
||||||
|
logger.error(
|
||||||
|
"Title had to be embedded separately, this should not happen!"
|
||||||
|
)
|
||||||
title_embedding = self.embedding_model.encode(
|
title_embedding = self.embedding_model.encode(
|
||||||
[title], text_type=EmbedTextType.PASSAGE
|
[title], text_type=EmbedTextType.PASSAGE
|
||||||
)[0]
|
)[0]
|
||||||
|
@ -1,10 +1,10 @@
|
|||||||
|
import time
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
import nltk # type:ignore
|
import nltk # type:ignore
|
||||||
import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
|
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
@ -36,7 +36,6 @@ from danswer.configs.app_configs import SECRET
|
|||||||
from danswer.configs.app_configs import WEB_DOMAIN
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||||
from danswer.configs.constants import AuthType
|
from danswer.configs.constants import AuthType
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
|
||||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
from danswer.db.chat import delete_old_default_personas
|
from danswer.db.chat import delete_old_default_personas
|
||||||
@ -54,7 +53,7 @@ from danswer.document_index.factory import get_default_document_index
|
|||||||
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
|
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
|
||||||
from danswer.llm.factory import get_default_llm
|
from danswer.llm.factory import get_default_llm
|
||||||
from danswer.llm.utils import get_default_llm_version
|
from danswer.llm.utils import get_default_llm_version
|
||||||
from danswer.search.search_nlp_models import warm_up_models
|
from danswer.search.search_nlp_models import warm_up_encoders
|
||||||
from danswer.server.danswer_api.ingestion import get_danswer_api_key
|
from danswer.server.danswer_api.ingestion import get_danswer_api_key
|
||||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||||
@ -82,6 +81,7 @@ from danswer.utils.logger import setup_logger
|
|||||||
from danswer.utils.telemetry import optional_telemetry
|
from danswer.utils.telemetry import optional_telemetry
|
||||||
from danswer.utils.telemetry import RecordType
|
from danswer.utils.telemetry import RecordType
|
||||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||||
|
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
@ -204,24 +204,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
||||||
logger.info("Reranking step of search flow is enabled.")
|
logger.info("Reranking step of search flow is enabled.")
|
||||||
|
|
||||||
if MODEL_SERVER_HOST:
|
|
||||||
logger.info(
|
|
||||||
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Warming up local NLP models.")
|
|
||||||
warm_up_models(
|
|
||||||
model_name=db_embedding_model.model_name,
|
|
||||||
normalize=db_embedding_model.normalize,
|
|
||||||
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
logger.info("GPU is available")
|
|
||||||
else:
|
|
||||||
logger.info("GPU is not available")
|
|
||||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
nltk.download("stopwords", quiet=True)
|
nltk.download("stopwords", quiet=True)
|
||||||
nltk.download("wordnet", quiet=True)
|
nltk.download("wordnet", quiet=True)
|
||||||
@ -237,19 +219,34 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
|||||||
load_chat_yamls()
|
load_chat_yamls()
|
||||||
|
|
||||||
logger.info("Verifying Document Index(s) is/are available.")
|
logger.info("Verifying Document Index(s) is/are available.")
|
||||||
|
|
||||||
document_index = get_default_document_index(
|
document_index = get_default_document_index(
|
||||||
primary_index_name=db_embedding_model.index_name,
|
primary_index_name=db_embedding_model.index_name,
|
||||||
secondary_index_name=secondary_db_embedding_model.index_name
|
secondary_index_name=secondary_db_embedding_model.index_name
|
||||||
if secondary_db_embedding_model
|
if secondary_db_embedding_model
|
||||||
else None,
|
else None,
|
||||||
)
|
)
|
||||||
document_index.ensure_indices_exist(
|
# Vespa startup is a bit slow, so give it a few seconds
|
||||||
index_embedding_dim=db_embedding_model.model_dim,
|
wait_time = 5
|
||||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
for attempt in range(5):
|
||||||
if secondary_db_embedding_model
|
try:
|
||||||
else None,
|
document_index.ensure_indices_exist(
|
||||||
)
|
index_embedding_dim=db_embedding_model.model_dim,
|
||||||
|
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||||
|
if secondary_db_embedding_model
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
except Exception:
|
||||||
|
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
|
||||||
|
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||||
|
warm_up_encoders(
|
||||||
|
model_name=db_embedding_model.model_name,
|
||||||
|
normalize=db_embedding_model.normalize,
|
||||||
|
model_server_host=MODEL_SERVER_HOST,
|
||||||
|
model_server_port=MODEL_SERVER_PORT,
|
||||||
|
)
|
||||||
|
|
||||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||||
|
|
||||||
|
@ -28,3 +28,8 @@ class SearchType(str, Enum):
|
|||||||
class QueryFlow(str, Enum):
|
class QueryFlow(str, Enum):
|
||||||
SEARCH = "search"
|
SEARCH = "search"
|
||||||
QUESTION_ANSWER = "question-answer"
|
QUESTION_ANSWER = "question-answer"
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedTextType(str, Enum):
|
||||||
|
QUERY = "query"
|
||||||
|
PASSAGE = "passage"
|
||||||
|
@ -8,10 +8,10 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
|||||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
||||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
|
||||||
from danswer.db.models import Persona
|
from danswer.db.models import Persona
|
||||||
from danswer.search.enums import OptionalSearchSetting
|
from danswer.search.enums import OptionalSearchSetting
|
||||||
from danswer.search.enums import SearchType
|
from danswer.search.enums import SearchType
|
||||||
|
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||||
|
|
||||||
|
|
||||||
MAX_METRICS_CONTENT = (
|
MAX_METRICS_CONTENT = (
|
||||||
|
@ -5,7 +5,6 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
|||||||
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||||
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.search.enums import QueryFlow
|
from danswer.search.enums import QueryFlow
|
||||||
from danswer.search.enums import RecencyBiasSetting
|
from danswer.search.enums import RecencyBiasSetting
|
||||||
@ -22,6 +21,7 @@ from danswer.utils.logger import setup_logger
|
|||||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
|
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
@ -14,6 +14,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
|||||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||||
from danswer.document_index.interfaces import DocumentIndex
|
from danswer.document_index.interfaces import DocumentIndex
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
|
from danswer.search.enums import EmbedTextType
|
||||||
from danswer.search.models import ChunkMetric
|
from danswer.search.models import ChunkMetric
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import MAX_METRICS_CONTENT
|
from danswer.search.models import MAX_METRICS_CONTENT
|
||||||
@ -21,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
|
|||||||
from danswer.search.models import SearchQuery
|
from danswer.search.models import SearchQuery
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.search_nlp_models import EmbeddingModel
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
from danswer.search.search_nlp_models import EmbedTextType
|
|
||||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||||
|
@ -1,56 +1,38 @@
|
|||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
import time
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import requests
|
import requests
|
||||||
from transformers import logging as transformer_logging # type:ignore
|
from transformers import logging as transformer_logging # type:ignore
|
||||||
|
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
from danswer.search.enums import EmbedTextType
|
||||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
|
||||||
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
|
||||||
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_models.model_server_models import EmbedRequest
|
from shared_configs.model_server_models import EmbedRequest
|
||||||
from shared_models.model_server_models import EmbedResponse
|
from shared_configs.model_server_models import EmbedResponse
|
||||||
from shared_models.model_server_models import IntentRequest
|
from shared_configs.model_server_models import IntentRequest
|
||||||
from shared_models.model_server_models import IntentResponse
|
from shared_configs.model_server_models import IntentResponse
|
||||||
from shared_models.model_server_models import RerankRequest
|
from shared_configs.model_server_models import RerankRequest
|
||||||
from shared_models.model_server_models import RerankResponse
|
from shared_configs.model_server_models import RerankResponse
|
||||||
|
|
||||||
|
transformer_logging.set_verbosity_error()
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
transformer_logging.set_verbosity_error()
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
|
||||||
from transformers import AutoTokenizer # type: ignore
|
from transformers import AutoTokenizer # type: ignore
|
||||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||||
_EMBED_MODEL: tuple[Optional["SentenceTransformer"], str | None] = (None, None)
|
|
||||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
|
||||||
_INTENT_TOKENIZER: Optional["AutoTokenizer"] = None
|
|
||||||
_INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None
|
|
||||||
|
|
||||||
|
|
||||||
class EmbedTextType(str, Enum):
|
|
||||||
QUERY = "query"
|
|
||||||
PASSAGE = "passage"
|
|
||||||
|
|
||||||
|
|
||||||
def clean_model_name(model_str: str) -> str:
|
def clean_model_name(model_str: str) -> str:
|
||||||
@ -84,89 +66,10 @@ def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
|
|||||||
return _TOKENIZER[0]
|
return _TOKENIZER[0]
|
||||||
|
|
||||||
|
|
||||||
def get_local_embedding_model(
|
|
||||||
model_name: str,
|
|
||||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
|
||||||
) -> "SentenceTransformer":
|
|
||||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
|
||||||
# processes importing this file despite not using any of this
|
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
|
||||||
|
|
||||||
global _EMBED_MODEL
|
|
||||||
if (
|
|
||||||
_EMBED_MODEL[0] is None
|
|
||||||
or max_context_length != _EMBED_MODEL[0].max_seq_length
|
|
||||||
or model_name != _EMBED_MODEL[1]
|
|
||||||
):
|
|
||||||
if _EMBED_MODEL[0] is not None:
|
|
||||||
del _EMBED_MODEL
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
logger.info(f"Loading {model_name}")
|
|
||||||
_EMBED_MODEL = (SentenceTransformer(model_name), model_name)
|
|
||||||
_EMBED_MODEL[0].max_seq_length = max_context_length
|
|
||||||
return _EMBED_MODEL[0]
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_reranking_model_ensemble(
|
|
||||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
|
||||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
|
||||||
) -> list["CrossEncoder"]:
|
|
||||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
|
||||||
# processes importing this file despite not using any of this
|
|
||||||
from sentence_transformers import CrossEncoder
|
|
||||||
|
|
||||||
global _RERANK_MODELS
|
|
||||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
|
||||||
del _RERANK_MODELS
|
|
||||||
gc.collect()
|
|
||||||
|
|
||||||
_RERANK_MODELS = []
|
|
||||||
for model_name in model_names:
|
|
||||||
logger.info(f"Loading {model_name}")
|
|
||||||
model = CrossEncoder(model_name)
|
|
||||||
model.max_length = max_context_length
|
|
||||||
_RERANK_MODELS.append(model)
|
|
||||||
return _RERANK_MODELS
|
|
||||||
|
|
||||||
|
|
||||||
def get_intent_model_tokenizer(
|
|
||||||
model_name: str = INTENT_MODEL_VERSION,
|
|
||||||
) -> "AutoTokenizer":
|
|
||||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
|
||||||
# processes importing this file despite not using any of this
|
|
||||||
from transformers import AutoTokenizer # type: ignore
|
|
||||||
|
|
||||||
global _INTENT_TOKENIZER
|
|
||||||
if _INTENT_TOKENIZER is None:
|
|
||||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
|
||||||
return _INTENT_TOKENIZER
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_intent_model(
|
|
||||||
model_name: str = INTENT_MODEL_VERSION,
|
|
||||||
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
|
||||||
) -> "TFDistilBertForSequenceClassification":
|
|
||||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
|
||||||
# processes importing this file despite not using any of this
|
|
||||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
|
||||||
|
|
||||||
global _INTENT_MODEL
|
|
||||||
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
|
||||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
|
||||||
model_name
|
|
||||||
)
|
|
||||||
_INTENT_MODEL.max_seq_length = max_context_length
|
|
||||||
return _INTENT_MODEL
|
|
||||||
|
|
||||||
|
|
||||||
def build_model_server_url(
|
def build_model_server_url(
|
||||||
model_server_host: str | None,
|
model_server_host: str,
|
||||||
model_server_port: int | None,
|
model_server_port: int,
|
||||||
) -> str | None:
|
) -> str:
|
||||||
if not model_server_host or model_server_port is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
model_server_url = f"{model_server_host}:{model_server_port}"
|
model_server_url = f"{model_server_host}:{model_server_port}"
|
||||||
|
|
||||||
# use protocol if provided
|
# use protocol if provided
|
||||||
@ -184,8 +87,8 @@ class EmbeddingModel:
|
|||||||
query_prefix: str | None,
|
query_prefix: str | None,
|
||||||
passage_prefix: str | None,
|
passage_prefix: str | None,
|
||||||
normalize: bool,
|
normalize: bool,
|
||||||
server_host: str | None, # Changes depending on indexing or inference
|
server_host: str, # Changes depending on indexing or inference
|
||||||
server_port: int | None,
|
server_port: int,
|
||||||
# The following are globals are currently not configurable
|
# The following are globals are currently not configurable
|
||||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -196,17 +99,7 @@ class EmbeddingModel:
|
|||||||
self.normalize = normalize
|
self.normalize = normalize
|
||||||
|
|
||||||
model_server_url = build_model_server_url(server_host, server_port)
|
model_server_url = build_model_server_url(server_host, server_port)
|
||||||
self.embed_server_endpoint = (
|
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||||
f"{model_server_url}/encoder/bi-encoder-embed" if model_server_url else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_model(self) -> Optional["SentenceTransformer"]:
|
|
||||||
if self.embed_server_endpoint:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return get_local_embedding_model(
|
|
||||||
model_name=self.model_name, max_context_length=self.max_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||||
@ -216,166 +109,67 @@ class EmbeddingModel:
|
|||||||
else:
|
else:
|
||||||
prefixed_texts = texts
|
prefixed_texts = texts
|
||||||
|
|
||||||
if self.embed_server_endpoint:
|
embed_request = EmbedRequest(
|
||||||
embed_request = EmbedRequest(
|
texts=prefixed_texts,
|
||||||
texts=prefixed_texts,
|
model_name=self.model_name,
|
||||||
model_name=self.model_name,
|
max_context_length=self.max_seq_length,
|
||||||
normalize_embeddings=self.normalize,
|
normalize_embeddings=self.normalize,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||||
response = requests.post(
|
response.raise_for_status()
|
||||||
self.embed_server_endpoint, json=embed_request.dict()
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return EmbedResponse(**response.json()).embeddings
|
return EmbedResponse(**response.json()).embeddings
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.exception(f"Failed to get Embedding: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
local_model = self.load_model()
|
|
||||||
|
|
||||||
if local_model is None:
|
|
||||||
raise RuntimeError("Failed to load local Embedding Model")
|
|
||||||
|
|
||||||
return local_model.encode(
|
|
||||||
prefixed_texts, normalize_embeddings=self.normalize
|
|
||||||
).tolist()
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEncoderEnsembleModel:
|
class CrossEncoderEnsembleModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
model_server_host: str = MODEL_SERVER_HOST,
|
||||||
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
|
||||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
|
||||||
model_server_port: int = MODEL_SERVER_PORT,
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_names = model_names
|
|
||||||
self.max_seq_length = max_seq_length
|
|
||||||
|
|
||||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||||
self.rerank_server_endpoint = (
|
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||||
model_server_url + "/encoder/cross-encoder-scores"
|
|
||||||
if model_server_url
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_model(self) -> list["CrossEncoder"] | None:
|
|
||||||
if (
|
|
||||||
ENABLE_RERANKING_REAL_TIME_FLOW is False
|
|
||||||
and ENABLE_RERANKING_ASYNC_FLOW is False
|
|
||||||
):
|
|
||||||
logger.warning(
|
|
||||||
"Running rerankers but they are globally disabled."
|
|
||||||
"Was this specified explicitly via an API?"
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.rerank_server_endpoint:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return get_local_reranking_model_ensemble(
|
|
||||||
model_names=self.model_names, max_context_length=self.max_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||||
if self.rerank_server_endpoint:
|
rerank_request = RerankRequest(query=query, documents=passages)
|
||||||
rerank_request = RerankRequest(query=query, documents=passages)
|
|
||||||
|
|
||||||
try:
|
response = requests.post(
|
||||||
response = requests.post(
|
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||||
self.rerank_server_endpoint, json=rerank_request.dict()
|
)
|
||||||
)
|
response.raise_for_status()
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return RerankResponse(**response.json()).scores
|
return RerankResponse(**response.json()).scores
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.exception(f"Failed to get Reranking Scores: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
local_models = self.load_model()
|
|
||||||
|
|
||||||
if local_models is None:
|
|
||||||
raise RuntimeError("Failed to load local Reranking Model Ensemble")
|
|
||||||
|
|
||||||
scores = [
|
|
||||||
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
|
|
||||||
for cross_encoder in local_models
|
|
||||||
]
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
class IntentModel:
|
class IntentModel:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_name: str = INTENT_MODEL_VERSION,
|
model_server_host: str = MODEL_SERVER_HOST,
|
||||||
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
|
|
||||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
|
||||||
model_server_port: int = MODEL_SERVER_PORT,
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.model_name = model_name
|
|
||||||
self.max_seq_length = max_seq_length
|
|
||||||
|
|
||||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||||
self.intent_server_endpoint = (
|
self.intent_server_endpoint = model_server_url + "/custom/intent-model"
|
||||||
model_server_url + "/custom/intent-model" if model_server_url else None
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_model(self) -> Optional["SentenceTransformer"]:
|
|
||||||
if self.intent_server_endpoint:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return get_local_intent_model(
|
|
||||||
model_name=self.model_name, max_context_length=self.max_seq_length
|
|
||||||
)
|
|
||||||
|
|
||||||
def predict(
|
def predict(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
intent_request = IntentRequest(query=query)
|
||||||
# processes importing this file despite not using any of this
|
|
||||||
import tensorflow as tf # type: ignore
|
|
||||||
|
|
||||||
if self.intent_server_endpoint:
|
response = requests.post(
|
||||||
intent_request = IntentRequest(query=query)
|
self.intent_server_endpoint, json=intent_request.dict()
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
self.intent_server_endpoint, json=intent_request.dict()
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
return IntentResponse(**response.json()).class_probs
|
|
||||||
except requests.RequestException as e:
|
|
||||||
logger.exception(f"Failed to get Embedding: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
tokenizer = get_intent_model_tokenizer()
|
|
||||||
local_model = self.load_model()
|
|
||||||
|
|
||||||
if local_model is None:
|
|
||||||
raise RuntimeError("Failed to load local Intent Model")
|
|
||||||
|
|
||||||
intent_model = get_local_intent_model()
|
|
||||||
model_input = tokenizer(
|
|
||||||
query, return_tensors="tf", truncation=True, padding=True
|
|
||||||
)
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
predictions = intent_model(model_input)[0]
|
return IntentResponse(**response.json()).class_probs
|
||||||
probabilities = tf.nn.softmax(predictions, axis=-1)
|
|
||||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
|
||||||
|
|
||||||
return list(class_percentages.tolist()[0])
|
|
||||||
|
|
||||||
|
|
||||||
def warm_up_models(
|
def warm_up_encoders(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
normalize: bool,
|
normalize: bool,
|
||||||
skip_cross_encoders: bool = True,
|
model_server_host: str = MODEL_SERVER_HOST,
|
||||||
indexer_only: bool = False,
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
) -> None:
|
) -> None:
|
||||||
warm_up_str = (
|
warm_up_str = (
|
||||||
"Danswer is amazing! Check out our easy deployment guide at "
|
"Danswer is amazing! Check out our easy deployment guide at "
|
||||||
@ -387,23 +181,23 @@ def warm_up_models(
|
|||||||
embed_model = EmbeddingModel(
|
embed_model = EmbeddingModel(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
normalize=normalize,
|
normalize=normalize,
|
||||||
# These don't matter, if it's a remote model, this function shouldn't be called
|
# Not a big deal if prefix is incorrect
|
||||||
query_prefix=None,
|
query_prefix=None,
|
||||||
passage_prefix=None,
|
passage_prefix=None,
|
||||||
server_host=None,
|
server_host=model_server_host,
|
||||||
server_port=None,
|
server_port=model_server_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
# First time downloading the models it may take even longer, but just in case,
|
||||||
|
# retry the whole server
|
||||||
if indexer_only:
|
wait_time = 5
|
||||||
return
|
for attempt in range(20):
|
||||||
|
try:
|
||||||
if not skip_cross_encoders:
|
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||||
CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str])
|
return
|
||||||
|
except Exception:
|
||||||
intent_tokenizer = get_intent_model_tokenizer()
|
logger.info(
|
||||||
inputs = intent_tokenizer(
|
f"Failed to run test embedding, retrying in {wait_time} seconds..."
|
||||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
)
|
||||||
)
|
time.sleep(wait_time)
|
||||||
get_local_intent_model()(inputs)
|
raise Exception("Failed to run test embedding.")
|
||||||
|
@ -21,3 +21,10 @@ def batch_generator(
|
|||||||
if pre_batch_yield:
|
if pre_batch_yield:
|
||||||
pre_batch_yield(batch)
|
pre_batch_yield(batch)
|
||||||
yield batch
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def batch_list(
|
||||||
|
lst: list[T],
|
||||||
|
batch_size: int,
|
||||||
|
) -> list[list[T]]:
|
||||||
|
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
|
||||||
|
1
backend/model_server/constants.py
Normal file
1
backend/model_server/constants.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
MODEL_WARM_UP_STRING = "hi " * 512
|
@ -1,19 +1,58 @@
|
|||||||
import numpy as np
|
from typing import Optional
|
||||||
from fastapi import APIRouter
|
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf # type: ignore
|
||||||
|
from fastapi import APIRouter
|
||||||
|
from transformers import AutoTokenizer # type: ignore
|
||||||
|
from transformers import TFDistilBertForSequenceClassification
|
||||||
|
|
||||||
|
from model_server.constants import MODEL_WARM_UP_STRING
|
||||||
|
from model_server.utils import simple_log_function_time
|
||||||
|
from shared_configs.model_server_models import IntentRequest
|
||||||
|
from shared_configs.model_server_models import IntentResponse
|
||||||
|
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||||
|
from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE
|
||||||
|
from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION
|
||||||
|
|
||||||
from danswer.search.search_nlp_models import get_intent_model_tokenizer
|
|
||||||
from danswer.search.search_nlp_models import get_local_intent_model
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
from shared_models.model_server_models import IntentRequest
|
|
||||||
from shared_models.model_server_models import IntentResponse
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/custom")
|
router = APIRouter(prefix="/custom")
|
||||||
|
|
||||||
|
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
|
||||||
|
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
|
||||||
|
|
||||||
@log_function_time(print_only=True)
|
|
||||||
|
def get_intent_model_tokenizer(
|
||||||
|
model_name: str = INTENT_MODEL_VERSION,
|
||||||
|
) -> "AutoTokenizer":
|
||||||
|
global _INTENT_TOKENIZER
|
||||||
|
if _INTENT_TOKENIZER is None:
|
||||||
|
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
||||||
|
return _INTENT_TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_intent_model(
|
||||||
|
model_name: str = INTENT_MODEL_VERSION,
|
||||||
|
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
|
||||||
|
) -> TFDistilBertForSequenceClassification:
|
||||||
|
global _INTENT_MODEL
|
||||||
|
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
||||||
|
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||||
|
model_name
|
||||||
|
)
|
||||||
|
_INTENT_MODEL.max_seq_length = max_context_length
|
||||||
|
return _INTENT_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_intent_model() -> None:
|
||||||
|
intent_tokenizer = get_intent_model_tokenizer()
|
||||||
|
inputs = intent_tokenizer(
|
||||||
|
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
|
||||||
|
)
|
||||||
|
get_local_intent_model()(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
@simple_log_function_time()
|
||||||
def classify_intent(query: str) -> list[float]:
|
def classify_intent(query: str) -> list[float]:
|
||||||
import tensorflow as tf # type:ignore
|
|
||||||
|
|
||||||
tokenizer = get_intent_model_tokenizer()
|
tokenizer = get_intent_model_tokenizer()
|
||||||
intent_model = get_local_intent_model()
|
intent_model = get_local_intent_model()
|
||||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||||
@ -26,16 +65,11 @@ def classify_intent(query: str) -> list[float]:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/intent-model")
|
@router.post("/intent-model")
|
||||||
def process_intent_request(
|
async def process_intent_request(
|
||||||
intent_request: IntentRequest,
|
intent_request: IntentRequest,
|
||||||
) -> IntentResponse:
|
) -> IntentResponse:
|
||||||
|
if INDEXING_ONLY:
|
||||||
|
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||||
|
|
||||||
class_percentages = classify_intent(intent_request.query)
|
class_percentages = classify_intent(intent_request.query)
|
||||||
return IntentResponse(class_probs=class_percentages)
|
return IntentResponse(class_probs=class_percentages)
|
||||||
|
|
||||||
|
|
||||||
def warm_up_intent_model() -> None:
|
|
||||||
intent_tokenizer = get_intent_model_tokenizer()
|
|
||||||
inputs = intent_tokenizer(
|
|
||||||
"danswer", return_tensors="tf", truncation=True, padding=True
|
|
||||||
)
|
|
||||||
get_local_intent_model()(inputs)
|
|
||||||
|
@ -1,34 +1,33 @@
|
|||||||
from typing import TYPE_CHECKING
|
import gc
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
|
||||||
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_function_time
|
from model_server.constants import MODEL_WARM_UP_STRING
|
||||||
from shared_models.model_server_models import EmbedRequest
|
from model_server.utils import simple_log_function_time
|
||||||
from shared_models.model_server_models import EmbedResponse
|
from shared_configs.model_server_models import EmbedRequest
|
||||||
from shared_models.model_server_models import RerankRequest
|
from shared_configs.model_server_models import EmbedResponse
|
||||||
from shared_models.model_server_models import RerankResponse
|
from shared_configs.model_server_models import RerankRequest
|
||||||
|
from shared_configs.model_server_models import RerankResponse
|
||||||
if TYPE_CHECKING:
|
from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||||
|
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
WARM_UP_STRING = "Danswer is amazing"
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/encoder")
|
router = APIRouter(prefix="/encoder")
|
||||||
|
|
||||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||||
|
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(
|
def get_embedding_model(
|
||||||
model_name: str,
|
model_name: str,
|
||||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
max_context_length: int,
|
||||||
) -> "SentenceTransformer":
|
) -> "SentenceTransformer":
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
@ -48,11 +47,44 @@ def get_embedding_model(
|
|||||||
return _GLOBAL_MODELS_DICT[model_name]
|
return _GLOBAL_MODELS_DICT[model_name]
|
||||||
|
|
||||||
|
|
||||||
@log_function_time(print_only=True)
|
def get_local_reranking_model_ensemble(
|
||||||
|
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||||
|
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||||
|
) -> list[CrossEncoder]:
|
||||||
|
global _RERANK_MODELS
|
||||||
|
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||||
|
del _RERANK_MODELS
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
_RERANK_MODELS = []
|
||||||
|
for model_name in model_names:
|
||||||
|
logger.info(f"Loading {model_name}")
|
||||||
|
model = CrossEncoder(model_name)
|
||||||
|
model.max_length = max_context_length
|
||||||
|
_RERANK_MODELS.append(model)
|
||||||
|
return _RERANK_MODELS
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_cross_encoders() -> None:
|
||||||
|
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||||
|
|
||||||
|
cross_encoders = get_local_reranking_model_ensemble()
|
||||||
|
[
|
||||||
|
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||||
|
for cross_encoder in cross_encoders
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@simple_log_function_time()
|
||||||
def embed_text(
|
def embed_text(
|
||||||
texts: list[str], model_name: str, normalize_embeddings: bool
|
texts: list[str],
|
||||||
|
model_name: str,
|
||||||
|
max_context_length: int,
|
||||||
|
normalize_embeddings: bool,
|
||||||
) -> list[list[float]]:
|
) -> list[list[float]]:
|
||||||
model = get_embedding_model(model_name=model_name)
|
model = get_embedding_model(
|
||||||
|
model_name=model_name, max_context_length=max_context_length
|
||||||
|
)
|
||||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||||
|
|
||||||
if not isinstance(embeddings, list):
|
if not isinstance(embeddings, list):
|
||||||
@ -61,7 +93,7 @@ def embed_text(
|
|||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
@log_function_time(print_only=True)
|
@simple_log_function_time()
|
||||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||||
cross_encoders = get_local_reranking_model_ensemble()
|
cross_encoders = get_local_reranking_model_ensemble()
|
||||||
sim_scores = [
|
sim_scores = [
|
||||||
@ -72,13 +104,14 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/bi-encoder-embed")
|
@router.post("/bi-encoder-embed")
|
||||||
def process_embed_request(
|
async def process_embed_request(
|
||||||
embed_request: EmbedRequest,
|
embed_request: EmbedRequest,
|
||||||
) -> EmbedResponse:
|
) -> EmbedResponse:
|
||||||
try:
|
try:
|
||||||
embeddings = embed_text(
|
embeddings = embed_text(
|
||||||
texts=embed_request.texts,
|
texts=embed_request.texts,
|
||||||
model_name=embed_request.model_name,
|
model_name=embed_request.model_name,
|
||||||
|
max_context_length=embed_request.max_context_length,
|
||||||
normalize_embeddings=embed_request.normalize_embeddings,
|
normalize_embeddings=embed_request.normalize_embeddings,
|
||||||
)
|
)
|
||||||
return EmbedResponse(embeddings=embeddings)
|
return EmbedResponse(embeddings=embeddings)
|
||||||
@ -87,7 +120,11 @@ def process_embed_request(
|
|||||||
|
|
||||||
|
|
||||||
@router.post("/cross-encoder-scores")
|
@router.post("/cross-encoder-scores")
|
||||||
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||||
|
"""Cross encoders can be purely black box from the app perspective"""
|
||||||
|
if INDEXING_ONLY:
|
||||||
|
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
sim_scores = calc_sim_scores(
|
sim_scores = calc_sim_scores(
|
||||||
query=embed_request.query, docs=embed_request.documents
|
query=embed_request.query, docs=embed_request.documents
|
||||||
@ -95,13 +132,3 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
|||||||
return RerankResponse(scores=sim_scores)
|
return RerankResponse(scores=sim_scores)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=str(e))
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
def warm_up_cross_encoders() -> None:
|
|
||||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
|
||||||
|
|
||||||
cross_encoders = get_local_reranking_model_ensemble()
|
|
||||||
[
|
|
||||||
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
|
|
||||||
for cross_encoder in cross_encoders
|
|
||||||
]
|
|
||||||
|
@ -1,40 +1,61 @@
|
|||||||
|
import os
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from transformers import logging as transformer_logging # type:ignore
|
||||||
|
|
||||||
from danswer import __version__
|
from danswer import __version__
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
|
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from model_server.custom_models import router as custom_models_router
|
from model_server.custom_models import router as custom_models_router
|
||||||
from model_server.custom_models import warm_up_intent_model
|
from model_server.custom_models import warm_up_intent_model
|
||||||
from model_server.encoders import router as encoders_router
|
from model_server.encoders import router as encoders_router
|
||||||
from model_server.encoders import warm_up_cross_encoders
|
from model_server.encoders import warm_up_cross_encoders
|
||||||
|
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||||
|
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||||
|
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||||
|
from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS
|
||||||
|
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||||
|
|
||||||
|
transformer_logging.set_verbosity_error()
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.info("GPU is available")
|
||||||
|
else:
|
||||||
|
logger.info("GPU is not available")
|
||||||
|
|
||||||
|
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||||
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||||
|
|
||||||
|
if not INDEXING_ONLY:
|
||||||
|
warm_up_intent_model()
|
||||||
|
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||||
|
warm_up_cross_encoders()
|
||||||
|
else:
|
||||||
|
logger.info("This model server should only run document indexing.")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def get_model_app() -> FastAPI:
|
def get_model_app() -> FastAPI:
|
||||||
application = FastAPI(title="Danswer Model Server", version=__version__)
|
application = FastAPI(
|
||||||
|
title="Danswer Model Server", version=__version__, lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
application.include_router(encoders_router)
|
application.include_router(encoders_router)
|
||||||
application.include_router(custom_models_router)
|
application.include_router(custom_models_router)
|
||||||
|
|
||||||
@application.on_event("startup")
|
|
||||||
def startup_event() -> None:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
logger.info("GPU is available")
|
|
||||||
else:
|
|
||||||
logger.info("GPU is not available")
|
|
||||||
|
|
||||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
|
||||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
|
||||||
|
|
||||||
warm_up_cross_encoders()
|
|
||||||
warm_up_intent_model()
|
|
||||||
|
|
||||||
return application
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
41
backend/model_server/utils.py
Normal file
41
backend/model_server/utils.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
import time
|
||||||
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Generator
|
||||||
|
from collections.abc import Iterator
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable)
|
||||||
|
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
||||||
|
|
||||||
|
|
||||||
|
def simple_log_function_time(
|
||||||
|
func_name: str | None = None,
|
||||||
|
debug_only: bool = False,
|
||||||
|
include_args: bool = False,
|
||||||
|
) -> Callable[[F], F]:
|
||||||
|
def decorator(func: F) -> F:
|
||||||
|
@wraps(func)
|
||||||
|
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
start_time = time.time()
|
||||||
|
result = func(*args, **kwargs)
|
||||||
|
elapsed_time_str = str(time.time() - start_time)
|
||||||
|
log_name = func_name or func.__name__
|
||||||
|
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||||
|
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||||
|
if debug_only:
|
||||||
|
logger.debug(final_log)
|
||||||
|
else:
|
||||||
|
logger.info(final_log)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
return cast(F, wrapped_func)
|
||||||
|
|
||||||
|
return decorator
|
@ -54,19 +54,12 @@ requests-oauthlib==1.3.1
|
|||||||
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
|
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
|
||||||
rfc3986==1.5.0
|
rfc3986==1.5.0
|
||||||
rt==3.1.2
|
rt==3.1.2
|
||||||
# need to pin `safetensors` version, since the latest versions requires
|
|
||||||
# building from source using Rust
|
|
||||||
safetensors==0.4.2
|
|
||||||
sentence-transformers==2.6.1
|
|
||||||
slack-sdk==3.20.2
|
slack-sdk==3.20.2
|
||||||
SQLAlchemy[mypy]==2.0.15
|
SQLAlchemy[mypy]==2.0.15
|
||||||
starlette==0.36.3
|
starlette==0.36.3
|
||||||
supervisor==4.2.5
|
supervisor==4.2.5
|
||||||
tensorflow==2.15.0
|
|
||||||
tiktoken==0.4.0
|
tiktoken==0.4.0
|
||||||
timeago==1.0.16
|
timeago==1.0.16
|
||||||
torch==2.0.1
|
|
||||||
torchvision==0.15.2
|
|
||||||
transformers==4.39.2
|
transformers==4.39.2
|
||||||
uvicorn==0.21.1
|
uvicorn==0.21.1
|
||||||
zulip==0.8.2
|
zulip==0.8.2
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
fastapi==0.109.1
|
fastapi==0.109.2
|
||||||
pydantic==1.10.7
|
pydantic==1.10.7
|
||||||
safetensors==0.3.1
|
safetensors==0.4.2
|
||||||
sentence-transformers==2.2.2
|
sentence-transformers==2.6.1
|
||||||
tensorflow==2.15.0
|
tensorflow==2.15.0
|
||||||
torch==2.0.1
|
torch==2.0.1
|
||||||
transformers==4.36.2
|
transformers==4.39.2
|
||||||
uvicorn==0.21.1
|
uvicorn==0.21.1
|
||||||
|
@ -2,8 +2,10 @@ from pydantic import BaseModel
|
|||||||
|
|
||||||
|
|
||||||
class EmbedRequest(BaseModel):
|
class EmbedRequest(BaseModel):
|
||||||
|
# This already includes any prefixes, the text is just passed directly to the model
|
||||||
texts: list[str]
|
texts: list[str]
|
||||||
model_name: str
|
model_name: str
|
||||||
|
max_context_length: int
|
||||||
normalize_embeddings: bool
|
normalize_embeddings: bool
|
||||||
|
|
||||||
|
|
26
backend/shared_configs/nlp_model_configs.py
Normal file
26
backend/shared_configs/nlp_model_configs.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
# Danswer custom Deep Learning Models
|
||||||
|
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||||
|
INTENT_MODEL_CONTEXT_SIZE = 256
|
||||||
|
|
||||||
|
# Bi-Encoder, other details
|
||||||
|
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||||
|
|
||||||
|
# Cross Encoder Settings
|
||||||
|
ENABLE_RERANKING_ASYNC_FLOW = (
|
||||||
|
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
||||||
|
)
|
||||||
|
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||||
|
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||||
|
)
|
||||||
|
# Only using one cross-encoder for now
|
||||||
|
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
||||||
|
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||||
|
|
||||||
|
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
||||||
|
# model. If torch finds more threads on its own, this value is not used.
|
||||||
|
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
||||||
|
|
||||||
|
INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"
|
@ -67,7 +67,7 @@ services:
|
|||||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||||
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
|
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
|
||||||
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
|
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
|
||||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||||
# Leave this on pretty please? Nothing sensitive is collected!
|
# Leave this on pretty please? Nothing sensitive is collected!
|
||||||
# https://docs.danswer.dev/more/telemetry
|
# https://docs.danswer.dev/more/telemetry
|
||||||
@ -80,9 +80,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
- file_connector_tmp_storage:/home/file_connector_storage
|
- file_connector_tmp_storage:/home/file_connector_storage
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_nltk:/root/nltk_data/
|
- model_cache_nltk:/root/nltk_data/
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@ -90,6 +88,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
background:
|
background:
|
||||||
image: danswer/danswer-backend:latest
|
image: danswer/danswer-backend:latest
|
||||||
build:
|
build:
|
||||||
@ -137,10 +137,9 @@ services:
|
|||||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
|
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
|
||||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-}
|
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
|
||||||
# Indexing Configs
|
# Indexing Configs
|
||||||
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
||||||
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
|
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
|
||||||
@ -174,9 +173,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
- file_connector_tmp_storage:/home/file_connector_storage
|
- file_connector_tmp_storage:/home/file_connector_storage
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_nltk:/root/nltk_data/
|
- model_cache_nltk:/root/nltk_data/
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@ -184,6 +181,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
web_server:
|
web_server:
|
||||||
image: danswer/danswer-web-server:latest
|
image: danswer/danswer-web-server:latest
|
||||||
build:
|
build:
|
||||||
@ -198,6 +197,63 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
- INTERNAL_URL=http://api_server:8080
|
- INTERNAL_URL=http://api_server:8080
|
||||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||||
|
|
||||||
|
|
||||||
|
inference_model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
command: >
|
||||||
|
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||||
|
echo 'Skipping service...';
|
||||||
|
exit 0;
|
||||||
|
else
|
||||||
|
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||||
|
fi"
|
||||||
|
restart: on-failure
|
||||||
|
environment:
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "50m"
|
||||||
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
|
indexing_model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
command: >
|
||||||
|
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||||
|
echo 'Skipping service...';
|
||||||
|
exit 0;
|
||||||
|
else
|
||||||
|
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||||
|
fi"
|
||||||
|
restart: on-failure
|
||||||
|
environment:
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
- INDEXING_ONLY=True
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "50m"
|
||||||
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
relational_db:
|
relational_db:
|
||||||
image: postgres:15.2-alpine
|
image: postgres:15.2-alpine
|
||||||
restart: always
|
restart: always
|
||||||
@ -208,6 +264,8 @@ services:
|
|||||||
- "5432:5432"
|
- "5432:5432"
|
||||||
volumes:
|
volumes:
|
||||||
- db_volume:/var/lib/postgresql/data
|
- db_volume:/var/lib/postgresql/data
|
||||||
|
|
||||||
|
|
||||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||||
index:
|
index:
|
||||||
image: vespaengine/vespa:8.277.17
|
image: vespaengine/vespa:8.277.17
|
||||||
@ -222,6 +280,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
nginx:
|
nginx:
|
||||||
image: nginx:1.23.4-alpine
|
image: nginx:1.23.4-alpine
|
||||||
restart: always
|
restart: always
|
||||||
@ -250,32 +310,8 @@ services:
|
|||||||
command: >
|
command: >
|
||||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
|
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
|
||||||
# Run with --profile model-server to bring up the danswer-model-server container
|
|
||||||
# Be sure to change MODEL_SERVER_HOST (see above) as well
|
|
||||||
# ie. MODEL_SERVER_HOST="model_server" docker compose -f docker-compose.dev.yml -p danswer-stack --profile model-server up -d --build
|
|
||||||
model_server:
|
|
||||||
image: danswer/danswer-model-server:latest
|
|
||||||
build:
|
|
||||||
context: ../../backend
|
|
||||||
dockerfile: Dockerfile.model_server
|
|
||||||
profiles:
|
|
||||||
- "model-server"
|
|
||||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
|
||||||
restart: always
|
|
||||||
environment:
|
|
||||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
|
||||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
|
||||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
|
||||||
# Set to debug to get more fine-grained logs
|
|
||||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
|
||||||
volumes:
|
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
logging:
|
|
||||||
driver: json-file
|
|
||||||
options:
|
|
||||||
max-size: "50m"
|
|
||||||
max-file: "6"
|
|
||||||
volumes:
|
volumes:
|
||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||||
|
@ -22,9 +22,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
- file_connector_tmp_storage:/home/file_connector_storage
|
- file_connector_tmp_storage:/home/file_connector_storage
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_nltk:/root/nltk_data/
|
- model_cache_nltk:/root/nltk_data/
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@ -32,6 +30,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
background:
|
background:
|
||||||
image: danswer/danswer-backend:latest
|
image: danswer/danswer-backend:latest
|
||||||
build:
|
build:
|
||||||
@ -51,9 +51,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
- file_connector_tmp_storage:/home/file_connector_storage
|
- file_connector_tmp_storage:/home/file_connector_storage
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_nltk:/root/nltk_data/
|
- model_cache_nltk:/root/nltk_data/
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@ -61,6 +59,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
web_server:
|
web_server:
|
||||||
image: danswer/danswer-web-server:latest
|
image: danswer/danswer-web-server:latest
|
||||||
build:
|
build:
|
||||||
@ -81,6 +81,63 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
|
inference_model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
command: >
|
||||||
|
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||||
|
echo 'Skipping service...';
|
||||||
|
exit 0;
|
||||||
|
else
|
||||||
|
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||||
|
fi"
|
||||||
|
restart: on-failure
|
||||||
|
environment:
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "50m"
|
||||||
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
|
indexing_model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
command: >
|
||||||
|
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||||
|
echo 'Skipping service...';
|
||||||
|
exit 0;
|
||||||
|
else
|
||||||
|
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||||
|
fi"
|
||||||
|
restart: on-failure
|
||||||
|
environment:
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
- INDEXING_ONLY=True
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "50m"
|
||||||
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
relational_db:
|
relational_db:
|
||||||
image: postgres:15.2-alpine
|
image: postgres:15.2-alpine
|
||||||
restart: always
|
restart: always
|
||||||
@ -94,6 +151,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||||
index:
|
index:
|
||||||
image: vespaengine/vespa:8.277.17
|
image: vespaengine/vespa:8.277.17
|
||||||
@ -108,6 +167,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
nginx:
|
nginx:
|
||||||
image: nginx:1.23.4-alpine
|
image: nginx:1.23.4-alpine
|
||||||
restart: always
|
restart: always
|
||||||
@ -137,30 +198,8 @@ services:
|
|||||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.no-letsencrypt"
|
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.no-letsencrypt"
|
||||||
env_file:
|
env_file:
|
||||||
- .env.nginx
|
- .env.nginx
|
||||||
# Run with --profile model-server to bring up the danswer-model-server container
|
|
||||||
model_server:
|
|
||||||
image: danswer/danswer-model-server:latest
|
|
||||||
build:
|
|
||||||
context: ../../backend
|
|
||||||
dockerfile: Dockerfile.model_server
|
|
||||||
profiles:
|
|
||||||
- "model-server"
|
|
||||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
|
||||||
restart: always
|
|
||||||
environment:
|
|
||||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
|
||||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
|
||||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
|
||||||
# Set to debug to get more fine-grained logs
|
|
||||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
|
||||||
volumes:
|
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
logging:
|
|
||||||
driver: json-file
|
|
||||||
options:
|
|
||||||
max-size: "50m"
|
|
||||||
max-file: "6"
|
|
||||||
volumes:
|
volumes:
|
||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||||
|
@ -22,9 +22,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
- file_connector_tmp_storage:/home/file_connector_storage
|
- file_connector_tmp_storage:/home/file_connector_storage
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_nltk:/root/nltk_data/
|
- model_cache_nltk:/root/nltk_data/
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@ -32,6 +30,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
background:
|
background:
|
||||||
image: danswer/danswer-backend:latest
|
image: danswer/danswer-backend:latest
|
||||||
build:
|
build:
|
||||||
@ -51,9 +51,7 @@ services:
|
|||||||
volumes:
|
volumes:
|
||||||
- local_dynamic_storage:/home/storage
|
- local_dynamic_storage:/home/storage
|
||||||
- file_connector_tmp_storage:/home/file_connector_storage
|
- file_connector_tmp_storage:/home/file_connector_storage
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_nltk:/root/nltk_data/
|
- model_cache_nltk:/root/nltk_data/
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
extra_hosts:
|
extra_hosts:
|
||||||
- "host.docker.internal:host-gateway"
|
- "host.docker.internal:host-gateway"
|
||||||
logging:
|
logging:
|
||||||
@ -61,6 +59,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
web_server:
|
web_server:
|
||||||
image: danswer/danswer-web-server:latest
|
image: danswer/danswer-web-server:latest
|
||||||
build:
|
build:
|
||||||
@ -94,6 +94,63 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
|
inference_model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
command: >
|
||||||
|
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||||
|
echo 'Skipping service...';
|
||||||
|
exit 0;
|
||||||
|
else
|
||||||
|
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||||
|
fi"
|
||||||
|
restart: on-failure
|
||||||
|
environment:
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "50m"
|
||||||
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
|
indexing_model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
command: >
|
||||||
|
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||||
|
echo 'Skipping service...';
|
||||||
|
exit 0;
|
||||||
|
else
|
||||||
|
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||||
|
fi"
|
||||||
|
restart: on-failure
|
||||||
|
environment:
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
- INDEXING_ONLY=True
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
|
logging:
|
||||||
|
driver: json-file
|
||||||
|
options:
|
||||||
|
max-size: "50m"
|
||||||
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||||
index:
|
index:
|
||||||
image: vespaengine/vespa:8.277.17
|
image: vespaengine/vespa:8.277.17
|
||||||
@ -108,6 +165,8 @@ services:
|
|||||||
options:
|
options:
|
||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
|
|
||||||
|
|
||||||
nginx:
|
nginx:
|
||||||
image: nginx:1.23.4-alpine
|
image: nginx:1.23.4-alpine
|
||||||
restart: always
|
restart: always
|
||||||
@ -141,6 +200,8 @@ services:
|
|||||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||||
env_file:
|
env_file:
|
||||||
- .env.nginx
|
- .env.nginx
|
||||||
|
|
||||||
|
|
||||||
# follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71
|
# follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71
|
||||||
certbot:
|
certbot:
|
||||||
image: certbot/certbot
|
image: certbot/certbot
|
||||||
@ -154,30 +215,8 @@ services:
|
|||||||
max-size: "50m"
|
max-size: "50m"
|
||||||
max-file: "6"
|
max-file: "6"
|
||||||
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
|
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
|
||||||
# Run with --profile model-server to bring up the danswer-model-server container
|
|
||||||
model_server:
|
|
||||||
image: danswer/danswer-model-server:latest
|
|
||||||
build:
|
|
||||||
context: ../../backend
|
|
||||||
dockerfile: Dockerfile.model_server
|
|
||||||
profiles:
|
|
||||||
- "model-server"
|
|
||||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
|
||||||
restart: always
|
|
||||||
environment:
|
|
||||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
|
||||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
|
||||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
|
||||||
# Set to debug to get more fine-grained logs
|
|
||||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
|
||||||
volumes:
|
|
||||||
- model_cache_torch:/root/.cache/torch/
|
|
||||||
- model_cache_huggingface:/root/.cache/huggingface/
|
|
||||||
logging:
|
|
||||||
driver: json-file
|
|
||||||
options:
|
|
||||||
max-size: "50m"
|
|
||||||
max-file: "6"
|
|
||||||
volumes:
|
volumes:
|
||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||||
|
@ -43,9 +43,9 @@ data:
|
|||||||
ASYM_PASSAGE_PREFIX: ""
|
ASYM_PASSAGE_PREFIX: ""
|
||||||
ENABLE_RERANKING_REAL_TIME_FLOW: ""
|
ENABLE_RERANKING_REAL_TIME_FLOW: ""
|
||||||
ENABLE_RERANKING_ASYNC_FLOW: ""
|
ENABLE_RERANKING_ASYNC_FLOW: ""
|
||||||
MODEL_SERVER_HOST: ""
|
MODEL_SERVER_HOST: "inference-model-server-service"
|
||||||
MODEL_SERVER_PORT: ""
|
MODEL_SERVER_PORT: ""
|
||||||
INDEXING_MODEL_SERVER_HOST: ""
|
INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service"
|
||||||
MIN_THREADS_ML_MODELS: ""
|
MIN_THREADS_ML_MODELS: ""
|
||||||
# Indexing Configs
|
# Indexing Configs
|
||||||
NUM_INDEXING_WORKERS: ""
|
NUM_INDEXING_WORKERS: ""
|
||||||
|
@ -0,0 +1,59 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: indexing-model-server-service
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: indexing-model-server
|
||||||
|
ports:
|
||||||
|
- name: indexing-model-server-port
|
||||||
|
protocol: TCP
|
||||||
|
port: 9000
|
||||||
|
targetPort: 9000
|
||||||
|
type: ClusterIP
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: indexing-model-server-deployment
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: indexing-model-server
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: indexing-model-server
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: indexing-model-server
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
imagePullPolicy: IfNotPresent
|
||||||
|
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
|
||||||
|
ports:
|
||||||
|
- containerPort: 9000
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: env-configmap
|
||||||
|
env:
|
||||||
|
- name: INDEXING_ONLY
|
||||||
|
value: "True"
|
||||||
|
volumeMounts:
|
||||||
|
- name: indexing-model-storage
|
||||||
|
mountPath: /root/.cache
|
||||||
|
volumes:
|
||||||
|
- name: indexing-model-storage
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: indexing-model-pvc
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: indexing-model-pvc
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteOnce
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 3Gi
|
@ -0,0 +1,56 @@
|
|||||||
|
apiVersion: v1
|
||||||
|
kind: Service
|
||||||
|
metadata:
|
||||||
|
name: inference-model-server-service
|
||||||
|
spec:
|
||||||
|
selector:
|
||||||
|
app: inference-model-server
|
||||||
|
ports:
|
||||||
|
- name: inference-model-server-port
|
||||||
|
protocol: TCP
|
||||||
|
port: 9000
|
||||||
|
targetPort: 9000
|
||||||
|
type: ClusterIP
|
||||||
|
---
|
||||||
|
apiVersion: apps/v1
|
||||||
|
kind: Deployment
|
||||||
|
metadata:
|
||||||
|
name: inference-model-server-deployment
|
||||||
|
spec:
|
||||||
|
replicas: 1
|
||||||
|
selector:
|
||||||
|
matchLabels:
|
||||||
|
app: inference-model-server
|
||||||
|
template:
|
||||||
|
metadata:
|
||||||
|
labels:
|
||||||
|
app: inference-model-server
|
||||||
|
spec:
|
||||||
|
containers:
|
||||||
|
- name: inference-model-server
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
imagePullPolicy: IfNotPresent
|
||||||
|
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
|
||||||
|
ports:
|
||||||
|
- containerPort: 9000
|
||||||
|
envFrom:
|
||||||
|
- configMapRef:
|
||||||
|
name: env-configmap
|
||||||
|
volumeMounts:
|
||||||
|
- name: inference-model-storage
|
||||||
|
mountPath: /root/.cache
|
||||||
|
volumes:
|
||||||
|
- name: inference-model-storage
|
||||||
|
persistentVolumeClaim:
|
||||||
|
claimName: inference-model-pvc
|
||||||
|
---
|
||||||
|
apiVersion: v1
|
||||||
|
kind: PersistentVolumeClaim
|
||||||
|
metadata:
|
||||||
|
name: inference-model-pvc
|
||||||
|
spec:
|
||||||
|
accessModes:
|
||||||
|
- ReadWriteOnce
|
||||||
|
resources:
|
||||||
|
requests:
|
||||||
|
storage: 3Gi
|
Loading…
x
Reference in New Issue
Block a user