Always Use Model Server (#1306)

This commit is contained in:
Yuhong Sun 2024-04-07 21:25:06 -07:00 committed by GitHub
parent 795243283d
commit 2db906b7a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 724 additions and 550 deletions

View File

@ -20,10 +20,12 @@ jobs:
cache-dependency-path: | cache-dependency-path: |
backend/requirements/default.txt backend/requirements/default.txt
backend/requirements/dev.txt backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: | - run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
- name: Run MyPy - name: Run MyPy
run: | run: |

View File

@ -85,6 +85,7 @@ Install the required python dependencies:
```bash ```bash
pip install -r danswer/backend/requirements/default.txt pip install -r danswer/backend/requirements/default.txt
pip install -r danswer/backend/requirements/dev.txt pip install -r danswer/backend/requirements/dev.txt
pip install -r danswer/backend/requirements/model_server.txt
``` ```
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend. Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
@ -117,7 +118,19 @@ To start the frontend, navigate to `danswer/web` and run:
npm run dev npm run dev
``` ```
The first time running Danswer, you will also need to run the DB migrations for Postgres. Next, start the model server which runs the local NLP models.
Navigate to `danswer/backend` and run:
```bash
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
uvicorn model_server.main:app --reload --port 9000
"
```
The first time running Danswer, you will need to run the DB migrations for Postgres.
After the first time, this is no longer required unless the DB models change. After the first time, this is no longer required unless the DB models change.
Navigate to `danswer/backend` and with the venv active, run: Navigate to `danswer/backend` and with the venv active, run:

View File

@ -40,7 +40,7 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma
# Set up application files # Set up application files
WORKDIR /app WORKDIR /app
COPY ./danswer /app/danswer COPY ./danswer /app/danswer
COPY ./shared_models /app/shared_models COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf COPY supervisord.conf /usr/etc/supervisord.conf

View File

@ -25,11 +25,8 @@ COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
# Place to fetch version information # Place to fetch version information
COPY ./danswer/__init__.py /app/danswer/__init__.py COPY ./danswer/__init__.py /app/danswer/__init__.py
# Shared implementations for running NLP models locally
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
# Request/Response models # Request/Response models
COPY ./shared_models /app/shared_models COPY ./shared_configs /app/shared_configs
# Model Server main code # Model Server main code
COPY ./model_server /app/model_server COPY ./model_server /app/model_server

View File

@ -6,18 +6,15 @@ NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
from collections.abc import Callable from collections.abc import Callable
from dataclasses import dataclass from dataclasses import dataclass
from multiprocessing import Process
from typing import Any from typing import Any
from typing import Literal from typing import Literal
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
if TYPE_CHECKING:
from torch.multiprocessing import Process
JobStatusType = ( JobStatusType = (
Literal["error"] Literal["error"]
| Literal["finished"] | Literal["finished"]
@ -89,8 +86,6 @@ class SimpleJobClient:
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None: def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
from torch.multiprocessing import Process
self._cleanup_completed_jobs() self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers: if len(self.jobs) >= self.n_workers:
logger.debug("No available workers to run job") logger.debug("No available workers to run job")

View File

@ -330,20 +330,15 @@ def _run_indexing(
) )
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None: def run_indexing_entrypoint(index_attempt_id: int) -> None:
"""Entrypoint for indexing run when using dask distributed. """Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed.""" and mark the attempt as failed."""
import torch
try: try:
# set the indexing attempt ID so that all log messages from this process # set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix # will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt( attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id db_session=db_session, index_attempt_id=index_attempt_id

View File

@ -15,9 +15,10 @@ from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.db.connector import fetch_connectors from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
@ -43,6 +44,7 @@ from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus from danswer.db.models import IndexModelStatus
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
@ -56,18 +58,6 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
) )
"""Util funcs"""
def _get_num_threads() -> int:
"""Get # of "threads" to use for ML models in an indexing job. By default uses
the torch implementation, which returns the # of physical cores on the machine.
"""
import torch
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
def _should_create_new_indexing( def _should_create_new_indexing(
connector: Connector, connector: Connector,
last_index: IndexAttempt | None, last_index: IndexAttempt | None,
@ -346,12 +336,10 @@ def kickoff_indexing_jobs(
if use_secondary_index: if use_secondary_index:
run = secondary_client.submit( run = secondary_client.submit(
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False run_indexing_entrypoint, attempt.id, pure=False
) )
else: else:
run = client.submit( run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
)
if run: if run:
secondary_str = "(secondary index) " if use_secondary_index else "" secondary_str = "(secondary index) " if use_secondary_index else ""
@ -409,6 +397,20 @@ def check_index_swap(db_session: Session) -> None:
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED: if DASK_JOB_CLIENT_ENABLED:
@ -435,7 +437,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client_secondary = SimpleJobClient(n_workers=num_workers) client_secondary = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[int, Future | SimpleJob] = {} existing_jobs: dict[int, Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
with Session(engine) as db_session: with Session(engine) as db_session:
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable # Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
@ -472,14 +473,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
def update__main() -> None: def update__main() -> None:
# needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup
# before any other torch code has been run
import torch
if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn")
logger.info("Starting Indexing Loop") logger.info("Starting Indexing Loop")
update_loop() update_loop()

View File

@ -207,15 +207,11 @@ DISABLE_DOCUMENT_CLEANUP = (
##### #####
# Model Server Configs # Model Server Configs
##### #####
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0" MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000") MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
# Model server for indexing should use a separate one to not allow indexing to introduce delay
# specify this env variable directly to have a different model server for the background # for inference
# indexing job vs the api server so that background indexing does not effect query-time
# performance
INDEXING_MODEL_SERVER_HOST = ( INDEXING_MODEL_SERVER_HOST = (
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
) )

View File

@ -37,33 +37,13 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ") ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
# Purely an optimization, memory limitation consideration # Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8 BATCH_SIZE_ENCODE_CHUNKS = 8
# This controls the minimum number of pytorch "threads" to allocate to the embedding # For score display purposes, only way is to know the expected ranges
# model. If torch finds more threads on its own, this value is not used.
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# Cross Encoder Settings
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
)
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# Only using one for now
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
# For score normalizing purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 12 CROSS_ENCODER_RANGE_MAX = 12
CROSS_ENCODER_RANGE_MIN = -12 CROSS_ENCODER_RANGE_MIN = -12
CROSS_EMBED_CONTEXT_SIZE = 512
# Unused currently, can't be used with the current default encoder model due to its output range # Unused currently, can't be used with the current default encoder model due to its output range
SEARCH_DISTANCE_CUTOFF = 0 SEARCH_DISTANCE_CUTOFF = 0
# Intent model max context size
QUERY_MAX_CONTEXT_SIZE = 256
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
##### #####
# Generative AI Model Configs # Generative AI Model Configs

View File

@ -22,7 +22,6 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.danswerbot.slack.blocks import build_documents_blocks from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks from danswer.danswerbot.slack.blocks import build_qa_response_blocks
@ -52,6 +51,7 @@ from danswer.search.models import BaseFilters
from danswer.search.models import OptionalSearchSetting from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails from danswer.search.models import RetrievalDetails
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
logger_base = setup_logger() logger_base = setup_logger()

View File

@ -10,10 +10,11 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
@ -43,7 +44,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.one_shot_answer.models import ThreadMessage from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.search_nlp_models import warm_up_models from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -390,10 +391,11 @@ if __name__ == "__main__":
with Session(get_sqlalchemy_engine()) as db_session: with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session) embedding_model = get_current_db_embedding_model(db_session)
warm_up_models( warm_up_encoders(
model_name=embedding_model.model_name, model_name=embedding_model.model_name,
normalize=embedding_model.normalize, normalize=embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW, model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
) )
slack_bot_tokens = latest_slack_bot_tokens slack_bot_tokens = latest_slack_bot_tokens

View File

@ -16,8 +16,9 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk from danswer.indexing.models import IndexChunk
from danswer.search.enums import EmbedTextType
from danswer.search.search_nlp_models import EmbeddingModel from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.search_nlp_models import EmbedTextType from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@ -73,6 +74,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict: dict[str, list[float]] = {} title_embed_dict: dict[str, list[float]] = {}
embedded_chunks: list[IndexChunk] = [] embedded_chunks: list[IndexChunk] = []
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunk_texts = [] chunk_texts = []
chunk_mini_chunks_count = {} chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks): for chunk_ind, chunk in enumerate(chunks):
@ -85,23 +88,41 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
chunk_texts.extend(mini_chunk_texts) chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts) chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
text_batches = [ # Batching for embedding
chunk_texts[i : i + batch_size] text_batches = batch_list(chunk_texts, batch_size)
for i in range(0, len(chunk_texts), batch_size)
]
embeddings: list[list[float]] = [] embeddings: list[list[float]] = []
len_text_batches = len(text_batches) len_text_batches = len(text_batches)
for idx, text_batch in enumerate(text_batches, start=1): for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding text batch {idx} of {len_text_batches}") logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss # Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend( embeddings.extend(
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE) self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
) )
# Replace line above with the line below for easy debugging of indexing flow, skipping the actual model # Replace line above with the line below for easy debugging of indexing flow
# skipping the actual model
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))]) # embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
chunk_titles = {
chunk.source_document.get_title_for_document_index() for chunk in chunks
}
chunk_titles.discard(None)
# Embed Titles in batches
title_batches = batch_list(list(chunk_titles), batch_size)
len_title_batches = len(title_batches)
for ind_batch, title_batch in enumerate(title_batches, start=1):
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
title_embeddings = self.embedding_model.encode(
title_batch, text_type=EmbedTextType.PASSAGE
)
title_embed_dict.update(
{title: vector for title, vector in zip(title_batch, title_embeddings)}
)
# Mapping embeddings to chunks
embedding_ind_start = 0 embedding_ind_start = 0
for chunk_ind, chunk in enumerate(chunks): for chunk_ind, chunk in enumerate(chunks):
num_embeddings = chunk_mini_chunks_count[chunk_ind] num_embeddings = chunk_mini_chunks_count[chunk_ind]
@ -114,9 +135,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embedding = None title_embedding = None
if title: if title:
if title in title_embed_dict: if title in title_embed_dict:
# Using cached value for speedup # Using cached value to avoid recalculating for every chunk
title_embedding = title_embed_dict[title] title_embedding = title_embed_dict[title]
else: else:
logger.error(
"Title had to be embedded separately, this should not happen!"
)
title_embedding = self.embedding_model.encode( title_embedding = self.embedding_model.encode(
[title], text_type=EmbedTextType.PASSAGE [title], text_type=EmbedTextType.PASSAGE
)[0] )[0]

View File

@ -1,10 +1,10 @@
import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any from typing import Any
from typing import cast from typing import cast
import nltk # type:ignore import nltk # type:ignore
import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
import uvicorn import uvicorn
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import FastAPI from fastapi import FastAPI
@ -36,7 +36,6 @@ from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import AuthType from danswer.configs.constants import AuthType
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.db.chat import delete_old_default_personas from danswer.db.chat import delete_old_default_personas
@ -54,7 +53,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_default_llm_version
from danswer.search.search_nlp_models import warm_up_models from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.danswer_api.ingestion import get_danswer_api_key from danswer.server.danswer_api.ingestion import get_danswer_api_key
from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router from danswer.server.documents.cc_pair import router as cc_pair_router
@ -82,6 +81,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
logger = setup_logger() logger = setup_logger()
@ -204,24 +204,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if ENABLE_RERANKING_REAL_TIME_FLOW: if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.") logger.info("Reranking step of search flow is enabled.")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Verifying query preprocessing (NLTK) data is downloaded") logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True) nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True) nltk.download("wordnet", quiet=True)
@ -237,19 +219,34 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
load_chat_yamls() load_chat_yamls()
logger.info("Verifying Document Index(s) is/are available.") logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index( document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name, primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model if secondary_db_embedding_model
else None, else None,
) )
document_index.ensure_indices_exist( # Vespa startup is a bit slow, so give it a few seconds
index_embedding_dim=db_embedding_model.model_dim, wait_time = 5
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim for attempt in range(5):
if secondary_db_embedding_model try:
else None, document_index.ensure_indices_exist(
) index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
break
except Exception:
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})

View File

@ -28,3 +28,8 @@ class SearchType(str, Enum):
class QueryFlow(str, Enum): class QueryFlow(str, Enum):
SEARCH = "search" SEARCH = "search"
QUESTION_ANSWER = "question-answer" QUESTION_ANSWER = "question-answer"
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"

View File

@ -8,10 +8,10 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.db.models import Persona from danswer.db.models import Persona
from danswer.search.enums import OptionalSearchSetting from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType from danswer.search.enums import SearchType
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
MAX_METRICS_CONTENT = ( MAX_METRICS_CONTENT = (

View File

@ -5,7 +5,6 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.db.models import User from danswer.db.models import User
from danswer.search.enums import QueryFlow from danswer.search.enums import QueryFlow
from danswer.search.enums import RecencyBiasSetting from danswer.search.enums import RecencyBiasSetting
@ -22,6 +21,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
logger = setup_logger() logger = setup_logger()

View File

@ -14,6 +14,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import InferenceChunk from danswer.indexing.models import InferenceChunk
from danswer.search.enums import EmbedTextType
from danswer.search.models import ChunkMetric from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters from danswer.search.models import IndexFilters
from danswer.search.models import MAX_METRICS_CONTENT from danswer.search.models import MAX_METRICS_CONTENT
@ -21,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery from danswer.search.models import SearchQuery
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.search.search_nlp_models import EmbeddingModel from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.search_nlp_models import EmbedTextType
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@ -1,56 +1,38 @@
import gc import gc
import os import os
from enum import Enum import time
from typing import Optional from typing import Optional
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import numpy as np
import requests import requests
from transformers import logging as transformer_logging # type:ignore from transformers import logging as transformer_logging # type:ignore
from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW from danswer.search.enums import EmbedTextType
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from shared_models.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse from shared_configs.model_server_models import EmbedResponse
from shared_models.model_server_models import IntentRequest from shared_configs.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse from shared_configs.model_server_models import IntentResponse
from shared_models.model_server_models import RerankRequest from shared_configs.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse from shared_configs.model_server_models import RerankResponse
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
logger = setup_logger() logger = setup_logger()
transformer_logging.set_verbosity_error()
if TYPE_CHECKING: if TYPE_CHECKING:
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None) _TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
_EMBED_MODEL: tuple[Optional["SentenceTransformer"], str | None] = (None, None)
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
_INTENT_TOKENIZER: Optional["AutoTokenizer"] = None
_INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"
def clean_model_name(model_str: str) -> str: def clean_model_name(model_str: str) -> str:
@ -84,89 +66,10 @@ def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
return _TOKENIZER[0] return _TOKENIZER[0]
def get_local_embedding_model(
model_name: str,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> "SentenceTransformer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from sentence_transformers import SentenceTransformer # type: ignore
global _EMBED_MODEL
if (
_EMBED_MODEL[0] is None
or max_context_length != _EMBED_MODEL[0].max_seq_length
or model_name != _EMBED_MODEL[1]
):
if _EMBED_MODEL[0] is not None:
del _EMBED_MODEL
gc.collect()
logger.info(f"Loading {model_name}")
_EMBED_MODEL = (SentenceTransformer(model_name), model_name)
_EMBED_MODEL[0].max_seq_length = max_context_length
return _EMBED_MODEL[0]
def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list["CrossEncoder"]:
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from sentence_transformers import CrossEncoder
global _RERANK_MODELS
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
del _RERANK_MODELS
gc.collect()
_RERANK_MODELS = []
for model_name in model_names:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
model.max_length = max_context_length
_RERANK_MODELS.append(model)
return _RERANK_MODELS
def get_intent_model_tokenizer(
model_name: str = INTENT_MODEL_VERSION,
) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
) -> "TFDistilBertForSequenceClassification":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import TFDistilBertForSequenceClassification # type: ignore
global _INTENT_MODEL
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
model_name
)
_INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL
def build_model_server_url( def build_model_server_url(
model_server_host: str | None, model_server_host: str,
model_server_port: int | None, model_server_port: int,
) -> str | None: ) -> str:
if not model_server_host or model_server_port is None:
return None
model_server_url = f"{model_server_host}:{model_server_port}" model_server_url = f"{model_server_host}:{model_server_port}"
# use protocol if provided # use protocol if provided
@ -184,8 +87,8 @@ class EmbeddingModel:
query_prefix: str | None, query_prefix: str | None,
passage_prefix: str | None, passage_prefix: str | None,
normalize: bool, normalize: bool,
server_host: str | None, # Changes depending on indexing or inference server_host: str, # Changes depending on indexing or inference
server_port: int | None, server_port: int,
# The following are globals are currently not configurable # The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> None: ) -> None:
@ -196,17 +99,7 @@ class EmbeddingModel:
self.normalize = normalize self.normalize = normalize
model_server_url = build_model_server_url(server_host, server_port) model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = ( self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
f"{model_server_url}/encoder/bi-encoder-embed" if model_server_url else None
)
def load_model(self) -> Optional["SentenceTransformer"]:
if self.embed_server_endpoint:
return None
return get_local_embedding_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]: def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
if text_type == EmbedTextType.QUERY and self.query_prefix: if text_type == EmbedTextType.QUERY and self.query_prefix:
@ -216,166 +109,67 @@ class EmbeddingModel:
else: else:
prefixed_texts = texts prefixed_texts = texts
if self.embed_server_endpoint: embed_request = EmbedRequest(
embed_request = EmbedRequest( texts=prefixed_texts,
texts=prefixed_texts, model_name=self.model_name,
model_name=self.model_name, max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize, normalize_embeddings=self.normalize,
) )
try: response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
response = requests.post( response.raise_for_status()
self.embed_server_endpoint, json=embed_request.dict()
)
response.raise_for_status()
return EmbedResponse(**response.json()).embeddings return EmbedResponse(**response.json()).embeddings
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Embedding Model")
return local_model.encode(
prefixed_texts, normalize_embeddings=self.normalize
).tolist()
class CrossEncoderEnsembleModel: class CrossEncoderEnsembleModel:
def __init__( def __init__(
self, self,
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE, model_server_host: str = MODEL_SERVER_HOST,
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT, model_server_port: int = MODEL_SERVER_PORT,
) -> None: ) -> None:
self.model_names = model_names
self.max_seq_length = max_seq_length
model_server_url = build_model_server_url(model_server_host, model_server_port) model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = ( self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
model_server_url + "/encoder/cross-encoder-scores"
if model_server_url
else None
)
def load_model(self) -> list["CrossEncoder"] | None:
if (
ENABLE_RERANKING_REAL_TIME_FLOW is False
and ENABLE_RERANKING_ASYNC_FLOW is False
):
logger.warning(
"Running rerankers but they are globally disabled."
"Was this specified explicitly via an API?"
)
if self.rerank_server_endpoint:
return None
return get_local_reranking_model_ensemble(
model_names=self.model_names, max_context_length=self.max_seq_length
)
def predict(self, query: str, passages: list[str]) -> list[list[float]]: def predict(self, query: str, passages: list[str]) -> list[list[float]]:
if self.rerank_server_endpoint: rerank_request = RerankRequest(query=query, documents=passages)
rerank_request = RerankRequest(query=query, documents=passages)
try: response = requests.post(
response = requests.post( self.rerank_server_endpoint, json=rerank_request.dict()
self.rerank_server_endpoint, json=rerank_request.dict() )
) response.raise_for_status()
response.raise_for_status()
return RerankResponse(**response.json()).scores return RerankResponse(**response.json()).scores
except requests.RequestException as e:
logger.exception(f"Failed to get Reranking Scores: {e}")
raise
local_models = self.load_model()
if local_models is None:
raise RuntimeError("Failed to load local Reranking Model Ensemble")
scores = [
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
for cross_encoder in local_models
]
return scores
class IntentModel: class IntentModel:
def __init__( def __init__(
self, self,
model_name: str = INTENT_MODEL_VERSION, model_server_host: str = MODEL_SERVER_HOST,
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT, model_server_port: int = MODEL_SERVER_PORT,
) -> None: ) -> None:
self.model_name = model_name
self.max_seq_length = max_seq_length
model_server_url = build_model_server_url(model_server_host, model_server_port) model_server_url = build_model_server_url(model_server_host, model_server_port)
self.intent_server_endpoint = ( self.intent_server_endpoint = model_server_url + "/custom/intent-model"
model_server_url + "/custom/intent-model" if model_server_url else None
)
def load_model(self) -> Optional["SentenceTransformer"]:
if self.intent_server_endpoint:
return None
return get_local_intent_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
def predict( def predict(
self, self,
query: str, query: str,
) -> list[float]: ) -> list[float]:
# NOTE: doing a local import here to avoid reduce memory usage caused by intent_request = IntentRequest(query=query)
# processes importing this file despite not using any of this
import tensorflow as tf # type: ignore
if self.intent_server_endpoint: response = requests.post(
intent_request = IntentRequest(query=query) self.intent_server_endpoint, json=intent_request.dict()
try:
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
)
response.raise_for_status()
return IntentResponse(**response.json()).class_probs
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
tokenizer = get_intent_model_tokenizer()
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Intent Model")
intent_model = get_local_intent_model()
model_input = tokenizer(
query, return_tensors="tf", truncation=True, padding=True
) )
response.raise_for_status()
predictions = intent_model(model_input)[0] return IntentResponse(**response.json()).class_probs
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])
def warm_up_models( def warm_up_encoders(
model_name: str, model_name: str,
normalize: bool, normalize: bool,
skip_cross_encoders: bool = True, model_server_host: str = MODEL_SERVER_HOST,
indexer_only: bool = False, model_server_port: int = MODEL_SERVER_PORT,
) -> None: ) -> None:
warm_up_str = ( warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at " "Danswer is amazing! Check out our easy deployment guide at "
@ -387,23 +181,23 @@ def warm_up_models(
embed_model = EmbeddingModel( embed_model = EmbeddingModel(
model_name=model_name, model_name=model_name,
normalize=normalize, normalize=normalize,
# These don't matter, if it's a remote model, this function shouldn't be called # Not a big deal if prefix is incorrect
query_prefix=None, query_prefix=None,
passage_prefix=None, passage_prefix=None,
server_host=None, server_host=model_server_host,
server_port=None, server_port=model_server_port,
) )
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) # First time downloading the models it may take even longer, but just in case,
# retry the whole server
if indexer_only: wait_time = 5
return for attempt in range(20):
try:
if not skip_cross_encoders: embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str]) return
except Exception:
intent_tokenizer = get_intent_model_tokenizer() logger.info(
inputs = intent_tokenizer( f"Failed to run test embedding, retrying in {wait_time} seconds..."
warm_up_str, return_tensors="tf", truncation=True, padding=True )
) time.sleep(wait_time)
get_local_intent_model()(inputs) raise Exception("Failed to run test embedding.")

View File

@ -21,3 +21,10 @@ def batch_generator(
if pre_batch_yield: if pre_batch_yield:
pre_batch_yield(batch) pre_batch_yield(batch)
yield batch yield batch
def batch_list(
lst: list[T],
batch_size: int,
) -> list[list[T]]:
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]

View File

@ -0,0 +1 @@
MODEL_WARM_UP_STRING = "hi " * 512

View File

@ -1,19 +1,58 @@
import numpy as np from typing import Optional
from fastapi import APIRouter
import numpy as np
import tensorflow as tf # type: ignore
from fastapi import APIRouter
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.nlp_model_configs import INDEXING_ONLY
from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE
from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION
from danswer.search.search_nlp_models import get_intent_model_tokenizer
from danswer.search.search_nlp_models import get_local_intent_model
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
router = APIRouter(prefix="/custom") router = APIRouter(prefix="/custom")
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
@log_function_time(print_only=True)
def get_intent_model_tokenizer(
model_name: str = INTENT_MODEL_VERSION,
) -> "AutoTokenizer":
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
model_name
)
_INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)
@simple_log_function_time()
def classify_intent(query: str) -> list[float]: def classify_intent(query: str) -> list[float]:
import tensorflow as tf # type:ignore
tokenizer = get_intent_model_tokenizer() tokenizer = get_intent_model_tokenizer()
intent_model = get_local_intent_model() intent_model = get_local_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True) model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
@ -26,16 +65,11 @@ def classify_intent(query: str) -> list[float]:
@router.post("/intent-model") @router.post("/intent-model")
def process_intent_request( async def process_intent_request(
intent_request: IntentRequest, intent_request: IntentRequest,
) -> IntentResponse: ) -> IntentResponse:
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
class_percentages = classify_intent(intent_request.query) class_percentages = classify_intent(intent_request.query)
return IntentResponse(class_probs=class_percentages) return IntentResponse(class_probs=class_percentages)
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
"danswer", return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)

View File

@ -1,34 +1,33 @@
from typing import TYPE_CHECKING import gc
from typing import Optional
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import HTTPException from fastapi import HTTPException
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time from model_server.constants import MODEL_WARM_UP_STRING
from shared_models.model_server_models import EmbedRequest from model_server.utils import simple_log_function_time
from shared_models.model_server_models import EmbedResponse from shared_configs.model_server_models import EmbedRequest
from shared_models.model_server_models import RerankRequest from shared_configs.model_server_models import EmbedResponse
from shared_models.model_server_models import RerankResponse from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
if TYPE_CHECKING: from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE
from sentence_transformers import SentenceTransformer # type: ignore from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from shared_configs.nlp_model_configs import INDEXING_ONLY
logger = setup_logger() logger = setup_logger()
WARM_UP_STRING = "Danswer is amazing"
router = APIRouter(prefix="/encoder") router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {} _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
def get_embedding_model( def get_embedding_model(
model_name: str, model_name: str,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE, max_context_length: int,
) -> "SentenceTransformer": ) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
@ -48,11 +47,44 @@ def get_embedding_model(
return _GLOBAL_MODELS_DICT[model_name] return _GLOBAL_MODELS_DICT[model_name]
@log_function_time(print_only=True) def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
del _RERANK_MODELS
gc.collect()
_RERANK_MODELS = []
for model_name in model_names:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
model.max_length = max_context_length
_RERANK_MODELS.append(model)
return _RERANK_MODELS
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
for cross_encoder in cross_encoders
]
@simple_log_function_time()
def embed_text( def embed_text(
texts: list[str], model_name: str, normalize_embeddings: bool texts: list[str],
model_name: str,
max_context_length: int,
normalize_embeddings: bool,
) -> list[list[float]]: ) -> list[list[float]]:
model = get_embedding_model(model_name=model_name) model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings) embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if not isinstance(embeddings, list): if not isinstance(embeddings, list):
@ -61,7 +93,7 @@ def embed_text(
return embeddings return embeddings
@log_function_time(print_only=True) @simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]: def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
cross_encoders = get_local_reranking_model_ensemble() cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [ sim_scores = [
@ -72,13 +104,14 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
@router.post("/bi-encoder-embed") @router.post("/bi-encoder-embed")
def process_embed_request( async def process_embed_request(
embed_request: EmbedRequest, embed_request: EmbedRequest,
) -> EmbedResponse: ) -> EmbedResponse:
try: try:
embeddings = embed_text( embeddings = embed_text(
texts=embed_request.texts, texts=embed_request.texts,
model_name=embed_request.model_name, model_name=embed_request.model_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings, normalize_embeddings=embed_request.normalize_embeddings,
) )
return EmbedResponse(embeddings=embeddings) return EmbedResponse(embeddings=embeddings)
@ -87,7 +120,11 @@ def process_embed_request(
@router.post("/cross-encoder-scores") @router.post("/cross-encoder-scores")
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse: async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
try: try:
sim_scores = calc_sim_scores( sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents query=embed_request.query, docs=embed_request.documents
@ -95,13 +132,3 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
return RerankResponse(scores=sim_scores) return RerankResponse(scores=sim_scores)
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
for cross_encoder in cross_encoders
]

View File

@ -1,40 +1,61 @@
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import torch import torch
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from transformers import logging as transformer_logging # type:ignore
from danswer import __version__ from danswer import __version__
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_cross_encoders from model_server.encoders import warm_up_cross_encoders
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from shared_configs.nlp_model_configs import INDEXING_ONLY
from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
transformer_logging.set_verbosity_error()
logger = setup_logger() logger = setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
warm_up_intent_model()
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
warm_up_cross_encoders()
else:
logger.info("This model server should only run document indexing.")
yield
def get_model_app() -> FastAPI: def get_model_app() -> FastAPI:
application = FastAPI(title="Danswer Model Server", version=__version__) application = FastAPI(
title="Danswer Model Server", version=__version__, lifespan=lifespan
)
application.include_router(encoders_router) application.include_router(encoders_router)
application.include_router(custom_models_router) application.include_router(custom_models_router)
@application.on_event("startup")
def startup_event() -> None:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_cross_encoders()
warm_up_intent_model()
return application return application

View File

@ -0,0 +1,41 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable)
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
def simple_log_function_time(
func_name: str | None = None,
debug_only: bool = False,
include_args: bool = False,
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.info(final_log)
return result
return cast(F, wrapped_func)
return decorator

View File

@ -54,19 +54,12 @@ requests-oauthlib==1.3.1
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
rfc3986==1.5.0 rfc3986==1.5.0
rt==3.1.2 rt==3.1.2
# need to pin `safetensors` version, since the latest versions requires
# building from source using Rust
safetensors==0.4.2
sentence-transformers==2.6.1
slack-sdk==3.20.2 slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15 SQLAlchemy[mypy]==2.0.15
starlette==0.36.3 starlette==0.36.3
supervisor==4.2.5 supervisor==4.2.5
tensorflow==2.15.0
tiktoken==0.4.0 tiktoken==0.4.0
timeago==1.0.16 timeago==1.0.16
torch==2.0.1
torchvision==0.15.2
transformers==4.39.2 transformers==4.39.2
uvicorn==0.21.1 uvicorn==0.21.1
zulip==0.8.2 zulip==0.8.2

View File

@ -1,8 +1,8 @@
fastapi==0.109.1 fastapi==0.109.2
pydantic==1.10.7 pydantic==1.10.7
safetensors==0.3.1 safetensors==0.4.2
sentence-transformers==2.2.2 sentence-transformers==2.6.1
tensorflow==2.15.0 tensorflow==2.15.0
torch==2.0.1 torch==2.0.1
transformers==4.36.2 transformers==4.39.2
uvicorn==0.21.1 uvicorn==0.21.1

View File

@ -2,8 +2,10 @@ from pydantic import BaseModel
class EmbedRequest(BaseModel): class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str] texts: list[str]
model_name: str model_name: str
max_context_length: int
normalize_embeddings: bool normalize_embeddings: bool

View File

@ -0,0 +1,26 @@
import os
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
INTENT_MODEL_CONTEXT_SIZE = 256
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512
# Cross Encoder Settings
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
)
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# Only using one cross-encoder for now
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
CROSS_EMBED_CONTEXT_SIZE = 512
# This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. If torch finds more threads on its own, this value is not used.
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"

View File

@ -67,7 +67,7 @@ services:
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-} - ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-} - ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected! # Leave this on pretty please? Nothing sensitive is collected!
# https://docs.danswer.dev/more/telemetry # https://docs.danswer.dev/more/telemetry
@ -80,9 +80,7 @@ services:
volumes: volumes:
- local_dynamic_storage:/home/storage - local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage - file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/ - model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@ -90,6 +88,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
background: background:
image: danswer/danswer-backend:latest image: danswer/danswer-backend:latest
build: build:
@ -137,10 +137,9 @@ services:
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-} - NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot - ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-} - ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-} - MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-} - MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-} - INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Indexing Configs # Indexing Configs
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-} - NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-} - DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
@ -174,9 +173,7 @@ services:
volumes: volumes:
- local_dynamic_storage:/home/storage - local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage - file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/ - model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@ -184,6 +181,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
web_server: web_server:
image: danswer/danswer-web-server:latest image: danswer/danswer-web-server:latest
build: build:
@ -198,6 +197,63 @@ services:
environment: environment:
- INTERNAL_URL=http://api_server:8080 - INTERNAL_URL=http://api_server:8080
- WEB_DOMAIN=${WEB_DOMAIN:-} - WEB_DOMAIN=${WEB_DOMAIN:-}
inference_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
relational_db: relational_db:
image: postgres:15.2-alpine image: postgres:15.2-alpine
restart: always restart: always
@ -208,6 +264,8 @@ services:
- "5432:5432" - "5432:5432"
volumes: volumes:
- db_volume:/var/lib/postgresql/data - db_volume:/var/lib/postgresql/data
# This container name cannot have an underscore in it due to Vespa expectations of the URL # This container name cannot have an underscore in it due to Vespa expectations of the URL
index: index:
image: vespaengine/vespa:8.277.17 image: vespaengine/vespa:8.277.17
@ -222,6 +280,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
nginx: nginx:
image: nginx:1.23.4-alpine image: nginx:1.23.4-alpine
restart: always restart: always
@ -250,32 +310,8 @@ services:
command: > command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh /bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev" && /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
# Run with --profile model-server to bring up the danswer-model-server container
# Be sure to change MODEL_SERVER_HOST (see above) as well
# ie. MODEL_SERVER_HOST="model_server" docker compose -f docker-compose.dev.yml -p danswer-stack --profile model-server up -d --build
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
volumes: volumes:
local_dynamic_storage: local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@ -22,9 +22,7 @@ services:
volumes: volumes:
- local_dynamic_storage:/home/storage - local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage - file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/ - model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@ -32,6 +30,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
background: background:
image: danswer/danswer-backend:latest image: danswer/danswer-backend:latest
build: build:
@ -51,9 +51,7 @@ services:
volumes: volumes:
- local_dynamic_storage:/home/storage - local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage - file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/ - model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@ -61,6 +59,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
web_server: web_server:
image: danswer/danswer-web-server:latest image: danswer/danswer-web-server:latest
build: build:
@ -81,6 +81,63 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
inference_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
relational_db: relational_db:
image: postgres:15.2-alpine image: postgres:15.2-alpine
restart: always restart: always
@ -94,6 +151,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
# This container name cannot have an underscore in it due to Vespa expectations of the URL # This container name cannot have an underscore in it due to Vespa expectations of the URL
index: index:
image: vespaengine/vespa:8.277.17 image: vespaengine/vespa:8.277.17
@ -108,6 +167,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
nginx: nginx:
image: nginx:1.23.4-alpine image: nginx:1.23.4-alpine
restart: always restart: always
@ -137,30 +198,8 @@ services:
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.no-letsencrypt" && /etc/nginx/conf.d/run-nginx.sh app.conf.template.no-letsencrypt"
env_file: env_file:
- .env.nginx - .env.nginx
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
volumes: volumes:
local_dynamic_storage: local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@ -22,9 +22,7 @@ services:
volumes: volumes:
- local_dynamic_storage:/home/storage - local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage - file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/ - model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@ -32,6 +30,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
background: background:
image: danswer/danswer-backend:latest image: danswer/danswer-backend:latest
build: build:
@ -51,9 +51,7 @@ services:
volumes: volumes:
- local_dynamic_storage:/home/storage - local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage - file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/ - model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts: extra_hosts:
- "host.docker.internal:host-gateway" - "host.docker.internal:host-gateway"
logging: logging:
@ -61,6 +59,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
web_server: web_server:
image: danswer/danswer-web-server:latest image: danswer/danswer-web-server:latest
build: build:
@ -94,6 +94,63 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
inference_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# This container name cannot have an underscore in it due to Vespa expectations of the URL # This container name cannot have an underscore in it due to Vespa expectations of the URL
index: index:
image: vespaengine/vespa:8.277.17 image: vespaengine/vespa:8.277.17
@ -108,6 +165,8 @@ services:
options: options:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
nginx: nginx:
image: nginx:1.23.4-alpine image: nginx:1.23.4-alpine
restart: always restart: always
@ -141,6 +200,8 @@ services:
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template" && /etc/nginx/conf.d/run-nginx.sh app.conf.template"
env_file: env_file:
- .env.nginx - .env.nginx
# follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71 # follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71
certbot: certbot:
image: certbot/certbot image: certbot/certbot
@ -154,30 +215,8 @@ services:
max-size: "50m" max-size: "50m"
max-file: "6" max-file: "6"
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'" entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
volumes: volumes:
local_dynamic_storage: local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@ -43,9 +43,9 @@ data:
ASYM_PASSAGE_PREFIX: "" ASYM_PASSAGE_PREFIX: ""
ENABLE_RERANKING_REAL_TIME_FLOW: "" ENABLE_RERANKING_REAL_TIME_FLOW: ""
ENABLE_RERANKING_ASYNC_FLOW: "" ENABLE_RERANKING_ASYNC_FLOW: ""
MODEL_SERVER_HOST: "" MODEL_SERVER_HOST: "inference-model-server-service"
MODEL_SERVER_PORT: "" MODEL_SERVER_PORT: ""
INDEXING_MODEL_SERVER_HOST: "" INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service"
MIN_THREADS_ML_MODELS: "" MIN_THREADS_ML_MODELS: ""
# Indexing Configs # Indexing Configs
NUM_INDEXING_WORKERS: "" NUM_INDEXING_WORKERS: ""

View File

@ -0,0 +1,59 @@
apiVersion: v1
kind: Service
metadata:
name: indexing-model-server-service
spec:
selector:
app: indexing-model-server
ports:
- name: indexing-model-server-port
protocol: TCP
port: 9000
targetPort: 9000
type: ClusterIP
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: indexing-model-server-deployment
spec:
replicas: 1
selector:
matchLabels:
app: indexing-model-server
template:
metadata:
labels:
app: indexing-model-server
spec:
containers:
- name: indexing-model-server
image: danswer/danswer-model-server:latest
imagePullPolicy: IfNotPresent
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
ports:
- containerPort: 9000
envFrom:
- configMapRef:
name: env-configmap
env:
- name: INDEXING_ONLY
value: "True"
volumeMounts:
- name: indexing-model-storage
mountPath: /root/.cache
volumes:
- name: indexing-model-storage
persistentVolumeClaim:
claimName: indexing-model-pvc
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: indexing-model-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 3Gi

View File

@ -0,0 +1,56 @@
apiVersion: v1
kind: Service
metadata:
name: inference-model-server-service
spec:
selector:
app: inference-model-server
ports:
- name: inference-model-server-port
protocol: TCP
port: 9000
targetPort: 9000
type: ClusterIP
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: inference-model-server-deployment
spec:
replicas: 1
selector:
matchLabels:
app: inference-model-server
template:
metadata:
labels:
app: inference-model-server
spec:
containers:
- name: inference-model-server
image: danswer/danswer-model-server:latest
imagePullPolicy: IfNotPresent
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
ports:
- containerPort: 9000
envFrom:
- configMapRef:
name: env-configmap
volumeMounts:
- name: inference-model-storage
mountPath: /root/.cache
volumes:
- name: inference-model-storage
persistentVolumeClaim:
claimName: inference-model-pvc
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: inference-model-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 3Gi