Always Use Model Server (#1306)

This commit is contained in:
Yuhong Sun 2024-04-07 21:25:06 -07:00 committed by GitHub
parent 795243283d
commit 2db906b7a2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 724 additions and 550 deletions

View File

@ -20,10 +20,12 @@ jobs:
cache-dependency-path: |
backend/requirements/default.txt
backend/requirements/dev.txt
backend/requirements/model_server.txt
- run: |
python -m pip install --upgrade pip
pip install -r backend/requirements/default.txt
pip install -r backend/requirements/dev.txt
pip install -r backend/requirements/model_server.txt
- name: Run MyPy
run: |

View File

@ -85,6 +85,7 @@ Install the required python dependencies:
```bash
pip install -r danswer/backend/requirements/default.txt
pip install -r danswer/backend/requirements/dev.txt
pip install -r danswer/backend/requirements/model_server.txt
```
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
@ -117,7 +118,19 @@ To start the frontend, navigate to `danswer/web` and run:
npm run dev
```
The first time running Danswer, you will also need to run the DB migrations for Postgres.
Next, start the model server which runs the local NLP models.
Navigate to `danswer/backend` and run:
```bash
uvicorn model_server.main:app --reload --port 9000
```
_For Windows (for compatibility with both PowerShell and Command Prompt):_
```bash
powershell -Command "
uvicorn model_server.main:app --reload --port 9000
"
```
The first time running Danswer, you will need to run the DB migrations for Postgres.
After the first time, this is no longer required unless the DB models change.
Navigate to `danswer/backend` and with the venv active, run:

View File

@ -40,7 +40,7 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma
# Set up application files
WORKDIR /app
COPY ./danswer /app/danswer
COPY ./shared_models /app/shared_models
COPY ./shared_configs /app/shared_configs
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /usr/etc/supervisord.conf

View File

@ -25,11 +25,8 @@ COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
# Place to fetch version information
COPY ./danswer/__init__.py /app/danswer/__init__.py
# Shared implementations for running NLP models locally
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
# Request/Response models
COPY ./shared_models /app/shared_models
COPY ./shared_configs /app/shared_configs
# Model Server main code
COPY ./model_server /app/model_server

View File

@ -6,18 +6,15 @@ NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing import Process
from typing import Any
from typing import Literal
from typing import Optional
from typing import TYPE_CHECKING
from danswer.utils.logger import setup_logger
logger = setup_logger()
if TYPE_CHECKING:
from torch.multiprocessing import Process
JobStatusType = (
Literal["error"]
| Literal["finished"]
@ -89,8 +86,6 @@ class SimpleJobClient:
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
from torch.multiprocessing import Process
self._cleanup_completed_jobs()
if len(self.jobs) >= self.n_workers:
logger.debug("No available workers to run job")

View File

@ -330,20 +330,15 @@ def _run_indexing(
)
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
def run_indexing_entrypoint(index_attempt_id: int) -> None:
"""Entrypoint for indexing run when using dask distributed.
Wraps the actual logic in a `try` block so that we can catch any exceptions
and mark the attempt as failed."""
import torch
try:
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
logger.info(f"Setting task to use {num_threads} threads")
torch.set_num_threads(num_threads)
with Session(get_sqlalchemy_engine()) as db_session:
attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id

View File

@ -15,9 +15,10 @@ from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
@ -43,6 +44,7 @@ from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.utils.logger import setup_logger
logger = setup_logger()
@ -56,18 +58,6 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
)
"""Util funcs"""
def _get_num_threads() -> int:
"""Get # of "threads" to use for ML models in an indexing job. By default uses
the torch implementation, which returns the # of physical cores on the machine.
"""
import torch
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
def _should_create_new_indexing(
connector: Connector,
last_index: IndexAttempt | None,
@ -346,12 +336,10 @@ def kickoff_indexing_jobs(
if use_secondary_index:
run = secondary_client.submit(
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
run_indexing_entrypoint, attempt.id, pure=False
)
else:
run = client.submit(
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
)
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
if run:
secondary_str = "(secondary index) " if use_secondary_index else ""
@ -409,6 +397,20 @@ def check_index_swap(db_session: Session) -> None:
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
logger.info("Running a first inference to warm up embedding model")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
@ -435,7 +437,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
client_secondary = SimpleJobClient(n_workers=num_workers)
existing_jobs: dict[int, Future | SimpleJob] = {}
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
@ -472,14 +473,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
def update__main() -> None:
# needed for CUDA to work with multiprocessing
# NOTE: needs to be done on application startup
# before any other torch code has been run
import torch
if not DASK_JOB_CLIENT_ENABLED:
torch.multiprocessing.set_start_method("spawn")
logger.info("Starting Indexing Loop")
update_loop()

View File

@ -207,15 +207,11 @@ DISABLE_DOCUMENT_CLEANUP = (
#####
# Model Server Configs
#####
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
# specify this env variable directly to have a different model server for the background
# indexing job vs the api server so that background indexing does not effect query-time
# performance
# Model server for indexing should use a separate one to not allow indexing to introduce delay
# for inference
INDEXING_MODEL_SERVER_HOST = (
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
)

View File

@ -37,33 +37,13 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
# Purely an optimization, memory limitation consideration
BATCH_SIZE_ENCODE_CHUNKS = 8
# This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. If torch finds more threads on its own, this value is not used.
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
# Cross Encoder Settings
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
)
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# Only using one for now
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
# For score normalizing purposes, only way is to know the expected ranges
# For score display purposes, only way is to know the expected ranges
CROSS_ENCODER_RANGE_MAX = 12
CROSS_ENCODER_RANGE_MIN = -12
CROSS_EMBED_CONTEXT_SIZE = 512
# Unused currently, can't be used with the current default encoder model due to its output range
SEARCH_DISTANCE_CUTOFF = 0
# Intent model max context size
QUERY_MAX_CONTEXT_SIZE = 256
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
#####
# Generative AI Model Configs

View File

@ -22,7 +22,6 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.danswerbot.slack.blocks import build_documents_blocks
from danswer.danswerbot.slack.blocks import build_follow_up_block
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
@ -52,6 +51,7 @@ from danswer.search.models import BaseFilters
from danswer.search.models import OptionalSearchSetting
from danswer.search.models import RetrievalDetails
from danswer.utils.logger import setup_logger
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
logger_base = setup_logger()

View File

@ -10,10 +10,11 @@ from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.constants import MessageType
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
@ -43,7 +44,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.search_nlp_models import warm_up_models
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.manage.models import SlackBotTokens
from danswer.utils.logger import setup_logger
@ -390,10 +391,11 @@ if __name__ == "__main__":
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
warm_up_models(
warm_up_encoders(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
slack_bot_tokens = latest_slack_bot_tokens

View File

@ -16,8 +16,9 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.enums import EmbedTextType
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.search_nlp_models import EmbedTextType
from danswer.utils.batching import batch_list
from danswer.utils.logger import setup_logger
@ -73,6 +74,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embed_dict: dict[str, list[float]] = {}
embedded_chunks: list[IndexChunk] = []
# Create Mini Chunks for more precise matching of details
# Off by default with unedited settings
chunk_texts = []
chunk_mini_chunks_count = {}
for chunk_ind, chunk in enumerate(chunks):
@ -85,23 +88,41 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
chunk_texts.extend(mini_chunk_texts)
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
text_batches = [
chunk_texts[i : i + batch_size]
for i in range(0, len(chunk_texts), batch_size)
]
# Batching for embedding
text_batches = batch_list(chunk_texts, batch_size)
embeddings: list[list[float]] = []
len_text_batches = len(text_batches)
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Embedding text batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
# Normalize embeddings is only configured via model_configs.py, be sure to use right
# value for the set loss
embeddings.extend(
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
)
# Replace line above with the line below for easy debugging of indexing flow, skipping the actual model
# Replace line above with the line below for easy debugging of indexing flow
# skipping the actual model
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
chunk_titles = {
chunk.source_document.get_title_for_document_index() for chunk in chunks
}
chunk_titles.discard(None)
# Embed Titles in batches
title_batches = batch_list(list(chunk_titles), batch_size)
len_title_batches = len(title_batches)
for ind_batch, title_batch in enumerate(title_batches, start=1):
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
title_embeddings = self.embedding_model.encode(
title_batch, text_type=EmbedTextType.PASSAGE
)
title_embed_dict.update(
{title: vector for title, vector in zip(title_batch, title_embeddings)}
)
# Mapping embeddings to chunks
embedding_ind_start = 0
for chunk_ind, chunk in enumerate(chunks):
num_embeddings = chunk_mini_chunks_count[chunk_ind]
@ -114,9 +135,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
title_embedding = None
if title:
if title in title_embed_dict:
# Using cached value for speedup
# Using cached value to avoid recalculating for every chunk
title_embedding = title_embed_dict[title]
else:
logger.error(
"Title had to be embedded separately, this should not happen!"
)
title_embedding = self.embedding_model.encode(
[title], text_type=EmbedTextType.PASSAGE
)[0]

View File

@ -1,10 +1,10 @@
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from typing import cast
import nltk # type:ignore
import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
import uvicorn
from fastapi import APIRouter
from fastapi import FastAPI
@ -36,7 +36,6 @@ from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.constants import AuthType
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.db.chat import delete_old_default_personas
@ -54,7 +53,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_default_llm_version
from danswer.search.search_nlp_models import warm_up_models
from danswer.search.search_nlp_models import warm_up_encoders
from danswer.server.danswer_api.ingestion import get_danswer_api_key
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@ -82,6 +81,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.telemetry import optional_telemetry
from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
logger = setup_logger()
@ -204,24 +204,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
@ -237,19 +219,34 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
load_chat_yamls()
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
# Vespa startup is a bit slow, so give it a few seconds
wait_time = 5
for attempt in range(5):
try:
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
break
except Exception:
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
time.sleep(wait_time)
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
warm_up_encoders(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})

View File

@ -28,3 +28,8 @@ class SearchType(str, Enum):
class QueryFlow(str, Enum):
SEARCH = "search"
QUESTION_ANSWER = "question-answer"
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"

View File

@ -8,10 +8,10 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.db.models import Persona
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
MAX_METRICS_CONTENT = (

View File

@ -5,7 +5,6 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.db.models import User
from danswer.search.enums import QueryFlow
from danswer.search.enums import RecencyBiasSetting
@ -22,6 +21,7 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
logger = setup_logger()

View File

@ -14,6 +14,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.document_index.interfaces import DocumentIndex
from danswer.indexing.models import InferenceChunk
from danswer.search.enums import EmbedTextType
from danswer.search.models import ChunkMetric
from danswer.search.models import IndexFilters
from danswer.search.models import MAX_METRICS_CONTENT
@ -21,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.search.search_nlp_models import EmbedTextType
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel

View File

@ -1,56 +1,38 @@
import gc
import os
from enum import Enum
import time
from typing import Optional
from typing import TYPE_CHECKING
import numpy as np
import requests
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.search.enums import EmbedTextType
from danswer.utils.logger import setup_logger
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
logger = setup_logger()
transformer_logging.set_verbosity_error()
if TYPE_CHECKING:
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
_EMBED_MODEL: tuple[Optional["SentenceTransformer"], str | None] = (None, None)
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
_INTENT_TOKENIZER: Optional["AutoTokenizer"] = None
_INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None
class EmbedTextType(str, Enum):
QUERY = "query"
PASSAGE = "passage"
def clean_model_name(model_str: str) -> str:
@ -84,89 +66,10 @@ def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
return _TOKENIZER[0]
def get_local_embedding_model(
model_name: str,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> "SentenceTransformer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from sentence_transformers import SentenceTransformer # type: ignore
global _EMBED_MODEL
if (
_EMBED_MODEL[0] is None
or max_context_length != _EMBED_MODEL[0].max_seq_length
or model_name != _EMBED_MODEL[1]
):
if _EMBED_MODEL[0] is not None:
del _EMBED_MODEL
gc.collect()
logger.info(f"Loading {model_name}")
_EMBED_MODEL = (SentenceTransformer(model_name), model_name)
_EMBED_MODEL[0].max_seq_length = max_context_length
return _EMBED_MODEL[0]
def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list["CrossEncoder"]:
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from sentence_transformers import CrossEncoder
global _RERANK_MODELS
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
del _RERANK_MODELS
gc.collect()
_RERANK_MODELS = []
for model_name in model_names:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
model.max_length = max_context_length
_RERANK_MODELS.append(model)
return _RERANK_MODELS
def get_intent_model_tokenizer(
model_name: str = INTENT_MODEL_VERSION,
) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
) -> "TFDistilBertForSequenceClassification":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import TFDistilBertForSequenceClassification # type: ignore
global _INTENT_MODEL
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
model_name
)
_INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL
def build_model_server_url(
model_server_host: str | None,
model_server_port: int | None,
) -> str | None:
if not model_server_host or model_server_port is None:
return None
model_server_host: str,
model_server_port: int,
) -> str:
model_server_url = f"{model_server_host}:{model_server_port}"
# use protocol if provided
@ -184,8 +87,8 @@ class EmbeddingModel:
query_prefix: str | None,
passage_prefix: str | None,
normalize: bool,
server_host: str | None, # Changes depending on indexing or inference
server_port: int | None,
server_host: str, # Changes depending on indexing or inference
server_port: int,
# The following are globals are currently not configurable
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> None:
@ -196,17 +99,7 @@ class EmbeddingModel:
self.normalize = normalize
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = (
f"{model_server_url}/encoder/bi-encoder-embed" if model_server_url else None
)
def load_model(self) -> Optional["SentenceTransformer"]:
if self.embed_server_endpoint:
return None
return get_local_embedding_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
if text_type == EmbedTextType.QUERY and self.query_prefix:
@ -216,166 +109,67 @@ class EmbeddingModel:
else:
prefixed_texts = texts
if self.embed_server_endpoint:
embed_request = EmbedRequest(
texts=prefixed_texts,
model_name=self.model_name,
normalize_embeddings=self.normalize,
)
embed_request = EmbedRequest(
texts=prefixed_texts,
model_name=self.model_name,
max_context_length=self.max_seq_length,
normalize_embeddings=self.normalize,
)
try:
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
response.raise_for_status()
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
response.raise_for_status()
return EmbedResponse(**response.json()).embeddings
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Embedding Model")
return local_model.encode(
prefixed_texts, normalize_embeddings=self.normalize
).tolist()
return EmbedResponse(**response.json()).embeddings
class CrossEncoderEnsembleModel:
def __init__(
self,
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_names = model_names
self.max_seq_length = max_seq_length
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = (
model_server_url + "/encoder/cross-encoder-scores"
if model_server_url
else None
)
def load_model(self) -> list["CrossEncoder"] | None:
if (
ENABLE_RERANKING_REAL_TIME_FLOW is False
and ENABLE_RERANKING_ASYNC_FLOW is False
):
logger.warning(
"Running rerankers but they are globally disabled."
"Was this specified explicitly via an API?"
)
if self.rerank_server_endpoint:
return None
return get_local_reranking_model_ensemble(
model_names=self.model_names, max_context_length=self.max_seq_length
)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
if self.rerank_server_endpoint:
rerank_request = RerankRequest(query=query, documents=passages)
rerank_request = RerankRequest(query=query, documents=passages)
try:
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
)
response.raise_for_status()
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
except requests.RequestException as e:
logger.exception(f"Failed to get Reranking Scores: {e}")
raise
local_models = self.load_model()
if local_models is None:
raise RuntimeError("Failed to load local Reranking Model Ensemble")
scores = [
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
for cross_encoder in local_models
]
return scores
return RerankResponse(**response.json()).scores
class IntentModel:
def __init__(
self,
model_name: str = INTENT_MODEL_VERSION,
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_name = model_name
self.max_seq_length = max_seq_length
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.intent_server_endpoint = (
model_server_url + "/custom/intent-model" if model_server_url else None
)
def load_model(self) -> Optional["SentenceTransformer"]:
if self.intent_server_endpoint:
return None
return get_local_intent_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
self.intent_server_endpoint = model_server_url + "/custom/intent-model"
def predict(
self,
query: str,
) -> list[float]:
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
import tensorflow as tf # type: ignore
intent_request = IntentRequest(query=query)
if self.intent_server_endpoint:
intent_request = IntentRequest(query=query)
try:
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
)
response.raise_for_status()
return IntentResponse(**response.json()).class_probs
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
tokenizer = get_intent_model_tokenizer()
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Intent Model")
intent_model = get_local_intent_model()
model_input = tokenizer(
query, return_tensors="tf", truncation=True, padding=True
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
)
response.raise_for_status()
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])
return IntentResponse(**response.json()).class_probs
def warm_up_models(
def warm_up_encoders(
model_name: str,
normalize: bool,
skip_cross_encoders: bool = True,
indexer_only: bool = False,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
warm_up_str = (
"Danswer is amazing! Check out our easy deployment guide at "
@ -387,23 +181,23 @@ def warm_up_models(
embed_model = EmbeddingModel(
model_name=model_name,
normalize=normalize,
# These don't matter, if it's a remote model, this function shouldn't be called
# Not a big deal if prefix is incorrect
query_prefix=None,
passage_prefix=None,
server_host=None,
server_port=None,
server_host=model_server_host,
server_port=model_server_port,
)
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
if indexer_only:
return
if not skip_cross_encoders:
CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str])
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
warm_up_str, return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)
# First time downloading the models it may take even longer, but just in case,
# retry the whole server
wait_time = 5
for attempt in range(20):
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
return
except Exception:
logger.info(
f"Failed to run test embedding, retrying in {wait_time} seconds..."
)
time.sleep(wait_time)
raise Exception("Failed to run test embedding.")

View File

@ -21,3 +21,10 @@ def batch_generator(
if pre_batch_yield:
pre_batch_yield(batch)
yield batch
def batch_list(
lst: list[T],
batch_size: int,
) -> list[list[T]]:
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]

View File

@ -0,0 +1 @@
MODEL_WARM_UP_STRING = "hi " * 512

View File

@ -1,19 +1,58 @@
import numpy as np
from fastapi import APIRouter
from typing import Optional
import numpy as np
import tensorflow as tf # type: ignore
from fastapi import APIRouter
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.nlp_model_configs import INDEXING_ONLY
from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE
from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION
from danswer.search.search_nlp_models import get_intent_model_tokenizer
from danswer.search.search_nlp_models import get_local_intent_model
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
router = APIRouter(prefix="/custom")
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
@log_function_time(print_only=True)
def get_intent_model_tokenizer(
model_name: str = INTENT_MODEL_VERSION,
) -> "AutoTokenizer":
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
model_name
)
_INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)
@simple_log_function_time()
def classify_intent(query: str) -> list[float]:
import tensorflow as tf # type:ignore
tokenizer = get_intent_model_tokenizer()
intent_model = get_local_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
@ -26,16 +65,11 @@ def classify_intent(query: str) -> list[float]:
@router.post("/intent-model")
def process_intent_request(
async def process_intent_request(
intent_request: IntentRequest,
) -> IntentResponse:
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
class_percentages = classify_intent(intent_request.query)
return IntentResponse(class_probs=class_percentages)
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
"danswer", return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)

View File

@ -1,34 +1,33 @@
from typing import TYPE_CHECKING
import gc
from typing import Optional
from fastapi import APIRouter
from fastapi import HTTPException
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
if TYPE_CHECKING:
from sentence_transformers import SentenceTransformer # type: ignore
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.utils import simple_log_function_time
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE
from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from shared_configs.nlp_model_configs import INDEXING_ONLY
logger = setup_logger()
WARM_UP_STRING = "Danswer is amazing"
router = APIRouter(prefix="/encoder")
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
def get_embedding_model(
model_name: str,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
max_context_length: int,
) -> "SentenceTransformer":
from sentence_transformers import SentenceTransformer # type: ignore
@ -48,11 +47,44 @@ def get_embedding_model(
return _GLOBAL_MODELS_DICT[model_name]
@log_function_time(print_only=True)
def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
del _RERANK_MODELS
gc.collect()
_RERANK_MODELS = []
for model_name in model_names:
logger.info(f"Loading {model_name}")
model = CrossEncoder(model_name)
model.max_length = max_context_length
_RERANK_MODELS.append(model)
return _RERANK_MODELS
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
for cross_encoder in cross_encoders
]
@simple_log_function_time()
def embed_text(
texts: list[str], model_name: str, normalize_embeddings: bool
texts: list[str],
model_name: str,
max_context_length: int,
normalize_embeddings: bool,
) -> list[list[float]]:
model = get_embedding_model(model_name=model_name)
model = get_embedding_model(
model_name=model_name, max_context_length=max_context_length
)
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if not isinstance(embeddings, list):
@ -61,7 +93,7 @@ def embed_text(
return embeddings
@log_function_time(print_only=True)
@simple_log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
@ -72,13 +104,14 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
@router.post("/bi-encoder-embed")
def process_embed_request(
async def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
try:
embeddings = embed_text(
texts=embed_request.texts,
model_name=embed_request.model_name,
max_context_length=embed_request.max_context_length,
normalize_embeddings=embed_request.normalize_embeddings,
)
return EmbedResponse(embeddings=embeddings)
@ -87,7 +120,11 @@ def process_embed_request(
@router.post("/cross-encoder-scores")
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
"""Cross encoders can be purely black box from the app perspective"""
if INDEXING_ONLY:
raise RuntimeError("Indexing model server should not call intent endpoint")
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents
@ -95,13 +132,3 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
return RerankResponse(scores=sim_scores)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
for cross_encoder in cross_encoders
]

View File

@ -1,40 +1,61 @@
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
import torch
import uvicorn
from fastapi import FastAPI
from transformers import logging as transformer_logging # type:ignore
from danswer import __version__
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_cross_encoders
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
from shared_configs.nlp_model_configs import INDEXING_ONLY
from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
transformer_logging.set_verbosity_error()
logger = setup_logger()
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
warm_up_intent_model()
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
warm_up_cross_encoders()
else:
logger.info("This model server should only run document indexing.")
yield
def get_model_app() -> FastAPI:
application = FastAPI(title="Danswer Model Server", version=__version__)
application = FastAPI(
title="Danswer Model Server", version=__version__, lifespan=lifespan
)
application.include_router(encoders_router)
application.include_router(custom_models_router)
@application.on_event("startup")
def startup_event() -> None:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_cross_encoders()
warm_up_intent_model()
return application

View File

@ -0,0 +1,41 @@
import time
from collections.abc import Callable
from collections.abc import Generator
from collections.abc import Iterator
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
from danswer.utils.logger import setup_logger
logger = setup_logger()
F = TypeVar("F", bound=Callable)
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
def simple_log_function_time(
func_name: str | None = None,
debug_only: bool = False,
include_args: bool = False,
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
start_time = time.time()
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
if debug_only:
logger.debug(final_log)
else:
logger.info(final_log)
return result
return cast(F, wrapped_func)
return decorator

View File

@ -54,19 +54,12 @@ requests-oauthlib==1.3.1
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
rfc3986==1.5.0
rt==3.1.2
# need to pin `safetensors` version, since the latest versions requires
# building from source using Rust
safetensors==0.4.2
sentence-transformers==2.6.1
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
starlette==0.36.3
supervisor==4.2.5
tensorflow==2.15.0
tiktoken==0.4.0
timeago==1.0.16
torch==2.0.1
torchvision==0.15.2
transformers==4.39.2
uvicorn==0.21.1
zulip==0.8.2

View File

@ -1,8 +1,8 @@
fastapi==0.109.1
fastapi==0.109.2
pydantic==1.10.7
safetensors==0.3.1
sentence-transformers==2.2.2
safetensors==0.4.2
sentence-transformers==2.6.1
tensorflow==2.15.0
torch==2.0.1
transformers==4.36.2
transformers==4.39.2
uvicorn==0.21.1

View File

@ -2,8 +2,10 @@ from pydantic import BaseModel
class EmbedRequest(BaseModel):
# This already includes any prefixes, the text is just passed directly to the model
texts: list[str]
model_name: str
max_context_length: int
normalize_embeddings: bool

View File

@ -0,0 +1,26 @@
import os
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"
INTENT_MODEL_CONTEXT_SIZE = 256
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512
# Cross Encoder Settings
ENABLE_RERANKING_ASYNC_FLOW = (
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
)
ENABLE_RERANKING_REAL_TIME_FLOW = (
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
)
# Only using one cross-encoder for now
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
CROSS_EMBED_CONTEXT_SIZE = 512
# This controls the minimum number of pytorch "threads" to allocate to the embedding
# model. If torch finds more threads on its own, this value is not used.
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"

View File

@ -67,7 +67,7 @@ services:
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Leave this on pretty please? Nothing sensitive is collected!
# https://docs.danswer.dev/more/telemetry
@ -80,9 +80,7 @@ services:
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@ -90,6 +88,8 @@ services:
options:
max-size: "50m"
max-file: "6"
background:
image: danswer/danswer-backend:latest
build:
@ -137,10 +137,9 @@ services:
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
# Indexing Configs
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
@ -174,9 +173,7 @@ services:
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@ -184,6 +181,8 @@ services:
options:
max-size: "50m"
max-file: "6"
web_server:
image: danswer/danswer-web-server:latest
build:
@ -198,6 +197,63 @@ services:
environment:
- INTERNAL_URL=http://api_server:8080
- WEB_DOMAIN=${WEB_DOMAIN:-}
inference_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
relational_db:
image: postgres:15.2-alpine
restart: always
@ -208,6 +264,8 @@ services:
- "5432:5432"
volumes:
- db_volume:/var/lib/postgresql/data
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
@ -222,6 +280,8 @@ services:
options:
max-size: "50m"
max-file: "6"
nginx:
image: nginx:1.23.4-alpine
restart: always
@ -250,32 +310,8 @@ services:
command: >
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
# Run with --profile model-server to bring up the danswer-model-server container
# Be sure to change MODEL_SERVER_HOST (see above) as well
# ie. MODEL_SERVER_HOST="model_server" docker compose -f docker-compose.dev.yml -p danswer-stack --profile model-server up -d --build
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
volumes:
local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@ -22,9 +22,7 @@ services:
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@ -32,6 +30,8 @@ services:
options:
max-size: "50m"
max-file: "6"
background:
image: danswer/danswer-backend:latest
build:
@ -51,9 +51,7 @@ services:
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@ -61,6 +59,8 @@ services:
options:
max-size: "50m"
max-file: "6"
web_server:
image: danswer/danswer-web-server:latest
build:
@ -81,6 +81,63 @@ services:
options:
max-size: "50m"
max-file: "6"
inference_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
relational_db:
image: postgres:15.2-alpine
restart: always
@ -94,6 +151,8 @@ services:
options:
max-size: "50m"
max-file: "6"
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
@ -108,6 +167,8 @@ services:
options:
max-size: "50m"
max-file: "6"
nginx:
image: nginx:1.23.4-alpine
restart: always
@ -137,30 +198,8 @@ services:
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.no-letsencrypt"
env_file:
- .env.nginx
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
volumes:
local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@ -22,9 +22,7 @@ services:
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@ -32,6 +30,8 @@ services:
options:
max-size: "50m"
max-file: "6"
background:
image: danswer/danswer-backend:latest
build:
@ -51,9 +51,7 @@ services:
volumes:
- local_dynamic_storage:/home/storage
- file_connector_tmp_storage:/home/file_connector_storage
- model_cache_torch:/root/.cache/torch/
- model_cache_nltk:/root/nltk_data/
- model_cache_huggingface:/root/.cache/huggingface/
extra_hosts:
- "host.docker.internal:host-gateway"
logging:
@ -61,6 +59,8 @@ services:
options:
max-size: "50m"
max-file: "6"
web_server:
image: danswer/danswer-web-server:latest
build:
@ -94,6 +94,63 @@ services:
options:
max-size: "50m"
max-file: "6"
inference_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
indexing_model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
command: >
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
echo 'Skipping service...';
exit 0;
else
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
fi"
restart: on-failure
environment:
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- INDEXING_ONLY=True
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
@ -108,6 +165,8 @@ services:
options:
max-size: "50m"
max-file: "6"
nginx:
image: nginx:1.23.4-alpine
restart: always
@ -141,6 +200,8 @@ services:
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
env_file:
- .env.nginx
# follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71
certbot:
image: certbot/certbot
@ -154,30 +215,8 @@ services:
max-size: "50m"
max-file: "6"
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
logging:
driver: json-file
options:
max-size: "50m"
max-file: "6"
volumes:
local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@ -43,9 +43,9 @@ data:
ASYM_PASSAGE_PREFIX: ""
ENABLE_RERANKING_REAL_TIME_FLOW: ""
ENABLE_RERANKING_ASYNC_FLOW: ""
MODEL_SERVER_HOST: ""
MODEL_SERVER_HOST: "inference-model-server-service"
MODEL_SERVER_PORT: ""
INDEXING_MODEL_SERVER_HOST: ""
INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service"
MIN_THREADS_ML_MODELS: ""
# Indexing Configs
NUM_INDEXING_WORKERS: ""

View File

@ -0,0 +1,59 @@
apiVersion: v1
kind: Service
metadata:
name: indexing-model-server-service
spec:
selector:
app: indexing-model-server
ports:
- name: indexing-model-server-port
protocol: TCP
port: 9000
targetPort: 9000
type: ClusterIP
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: indexing-model-server-deployment
spec:
replicas: 1
selector:
matchLabels:
app: indexing-model-server
template:
metadata:
labels:
app: indexing-model-server
spec:
containers:
- name: indexing-model-server
image: danswer/danswer-model-server:latest
imagePullPolicy: IfNotPresent
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
ports:
- containerPort: 9000
envFrom:
- configMapRef:
name: env-configmap
env:
- name: INDEXING_ONLY
value: "True"
volumeMounts:
- name: indexing-model-storage
mountPath: /root/.cache
volumes:
- name: indexing-model-storage
persistentVolumeClaim:
claimName: indexing-model-pvc
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: indexing-model-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 3Gi

View File

@ -0,0 +1,56 @@
apiVersion: v1
kind: Service
metadata:
name: inference-model-server-service
spec:
selector:
app: inference-model-server
ports:
- name: inference-model-server-port
protocol: TCP
port: 9000
targetPort: 9000
type: ClusterIP
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: inference-model-server-deployment
spec:
replicas: 1
selector:
matchLabels:
app: inference-model-server
template:
metadata:
labels:
app: inference-model-server
spec:
containers:
- name: inference-model-server
image: danswer/danswer-model-server:latest
imagePullPolicy: IfNotPresent
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
ports:
- containerPort: 9000
envFrom:
- configMapRef:
name: env-configmap
volumeMounts:
- name: inference-model-storage
mountPath: /root/.cache
volumes:
- name: inference-model-storage
persistentVolumeClaim:
claimName: inference-model-pvc
---
apiVersion: v1
kind: PersistentVolumeClaim
metadata:
name: inference-model-pvc
spec:
accessModes:
- ReadWriteOnce
resources:
requests:
storage: 3Gi