mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-22 14:00:57 +02:00
Always Use Model Server (#1306)
This commit is contained in:
parent
795243283d
commit
2db906b7a2
2
.github/workflows/pr-python-checks.yml
vendored
2
.github/workflows/pr-python-checks.yml
vendored
@ -20,10 +20,12 @@ jobs:
|
||||
cache-dependency-path: |
|
||||
backend/requirements/default.txt
|
||||
backend/requirements/dev.txt
|
||||
backend/requirements/model_server.txt
|
||||
- run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r backend/requirements/default.txt
|
||||
pip install -r backend/requirements/dev.txt
|
||||
pip install -r backend/requirements/model_server.txt
|
||||
|
||||
- name: Run MyPy
|
||||
run: |
|
||||
|
@ -85,6 +85,7 @@ Install the required python dependencies:
|
||||
```bash
|
||||
pip install -r danswer/backend/requirements/default.txt
|
||||
pip install -r danswer/backend/requirements/dev.txt
|
||||
pip install -r danswer/backend/requirements/model_server.txt
|
||||
```
|
||||
|
||||
Install [Node.js and npm](https://docs.npmjs.com/downloading-and-installing-node-js-and-npm) for the frontend.
|
||||
@ -117,7 +118,19 @@ To start the frontend, navigate to `danswer/web` and run:
|
||||
npm run dev
|
||||
```
|
||||
|
||||
The first time running Danswer, you will also need to run the DB migrations for Postgres.
|
||||
Next, start the model server which runs the local NLP models.
|
||||
Navigate to `danswer/backend` and run:
|
||||
```bash
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
```
|
||||
_For Windows (for compatibility with both PowerShell and Command Prompt):_
|
||||
```bash
|
||||
powershell -Command "
|
||||
uvicorn model_server.main:app --reload --port 9000
|
||||
"
|
||||
```
|
||||
|
||||
The first time running Danswer, you will need to run the DB migrations for Postgres.
|
||||
After the first time, this is no longer required unless the DB models change.
|
||||
|
||||
Navigate to `danswer/backend` and with the venv active, run:
|
||||
|
@ -40,7 +40,7 @@ RUN apt-get remove -y --allow-remove-essential perl-base xserver-common xvfb cma
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_models /app/shared_models
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
@ -25,11 +25,8 @@ COPY ./danswer/utils/telemetry.py /app/danswer/utils/telemetry.py
|
||||
# Place to fetch version information
|
||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||
|
||||
# Shared implementations for running NLP models locally
|
||||
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
|
||||
|
||||
# Request/Response models
|
||||
COPY ./shared_models /app/shared_models
|
||||
COPY ./shared_configs /app/shared_configs
|
||||
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
@ -6,18 +6,15 @@ NOTE: cannot use Celery directly due to
|
||||
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing import Process
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
JobStatusType = (
|
||||
Literal["error"]
|
||||
| Literal["finished"]
|
||||
@ -89,8 +86,6 @@ class SimpleJobClient:
|
||||
|
||||
def submit(self, func: Callable, *args: Any, pure: bool = True) -> SimpleJob | None:
|
||||
"""NOTE: `pure` arg is needed so this can be a drop in replacement for Dask"""
|
||||
from torch.multiprocessing import Process
|
||||
|
||||
self._cleanup_completed_jobs()
|
||||
if len(self.jobs) >= self.n_workers:
|
||||
logger.debug("No available workers to run job")
|
||||
|
@ -330,20 +330,15 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(index_attempt_id: int, num_threads: int) -> None:
|
||||
def run_indexing_entrypoint(index_attempt_id: int) -> None:
|
||||
"""Entrypoint for indexing run when using dask distributed.
|
||||
Wraps the actual logic in a `try` block so that we can catch any exceptions
|
||||
and mark the attempt as failed."""
|
||||
import torch
|
||||
|
||||
try:
|
||||
# set the indexing attempt ID so that all log messages from this process
|
||||
# will have it added as a prefix
|
||||
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
|
||||
|
||||
logger.info(f"Setting task to use {num_threads} threads")
|
||||
torch.set_num_threads(num_threads)
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
|
@ -15,9 +15,10 @@ from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
|
||||
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed
|
||||
@ -43,6 +44,7 @@ from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@ -56,18 +58,6 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
)
|
||||
|
||||
|
||||
"""Util funcs"""
|
||||
|
||||
|
||||
def _get_num_threads() -> int:
|
||||
"""Get # of "threads" to use for ML models in an indexing job. By default uses
|
||||
the torch implementation, which returns the # of physical cores on the machine.
|
||||
"""
|
||||
import torch
|
||||
|
||||
return max(MIN_THREADS_ML_MODELS, torch.get_num_threads())
|
||||
|
||||
|
||||
def _should_create_new_indexing(
|
||||
connector: Connector,
|
||||
last_index: IndexAttempt | None,
|
||||
@ -346,12 +336,10 @@ def kickoff_indexing_jobs(
|
||||
|
||||
if use_secondary_index:
|
||||
run = secondary_client.submit(
|
||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
run_indexing_entrypoint, attempt.id, pure=False
|
||||
)
|
||||
else:
|
||||
run = client.submit(
|
||||
run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False
|
||||
)
|
||||
run = client.submit(run_indexing_entrypoint, attempt.id, pure=False)
|
||||
|
||||
if run:
|
||||
secondary_str = "(secondary index) " if use_secondary_index else ""
|
||||
@ -409,6 +397,20 @@ def check_index_swap(db_session: Session) -> None:
|
||||
|
||||
|
||||
def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
logger.info("Running a first inference to warm up embedding model")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
client_secondary: Client | SimpleJobClient
|
||||
if DASK_JOB_CLIENT_ENABLED:
|
||||
@ -435,7 +437,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
client_secondary = SimpleJobClient(n_workers=num_workers)
|
||||
|
||||
existing_jobs: dict[int, Future | SimpleJob] = {}
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
with Session(engine) as db_session:
|
||||
# Previous version did not always clean up cc-pairs well leaving some connectors undeleteable
|
||||
@ -472,14 +473,6 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
|
||||
def update__main() -> None:
|
||||
# needed for CUDA to work with multiprocessing
|
||||
# NOTE: needs to be done on application startup
|
||||
# before any other torch code has been run
|
||||
import torch
|
||||
|
||||
if not DASK_JOB_CLIENT_ENABLED:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
||||
|
@ -207,15 +207,11 @@ DISABLE_DOCUMENT_CLEANUP = (
|
||||
#####
|
||||
# Model Server Configs
|
||||
#####
|
||||
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
|
||||
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
|
||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
|
||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
|
||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||
|
||||
# specify this env variable directly to have a different model server for the background
|
||||
# indexing job vs the api server so that background indexing does not effect query-time
|
||||
# performance
|
||||
# Model server for indexing should use a separate one to not allow indexing to introduce delay
|
||||
# for inference
|
||||
INDEXING_MODEL_SERVER_HOST = (
|
||||
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
|
@ -37,33 +37,13 @@ ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "query: ")
|
||||
ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "passage: ")
|
||||
# Purely an optimization, memory limitation consideration
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
||||
# model. If torch finds more threads on its own, this value is not used.
|
||||
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
||||
|
||||
# Cross Encoder Settings
|
||||
ENABLE_RERANKING_ASYNC_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
||||
)
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
# Only using one for now
|
||||
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
||||
# For score normalizing purposes, only way is to know the expected ranges
|
||||
# For score display purposes, only way is to know the expected ranges
|
||||
CROSS_ENCODER_RANGE_MAX = 12
|
||||
CROSS_ENCODER_RANGE_MIN = -12
|
||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
|
||||
# Unused currently, can't be used with the current default encoder model due to its output range
|
||||
SEARCH_DISTANCE_CUTOFF = 0
|
||||
|
||||
# Intent model max context size
|
||||
QUERY_MAX_CONTEXT_SIZE = 256
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
||||
|
||||
#####
|
||||
# Generative AI Model Configs
|
||||
|
@ -22,7 +22,6 @@ from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DISABLE_DANSWER_BOT_FILTER_DETECT
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.danswerbot.slack.blocks import build_documents_blocks
|
||||
from danswer.danswerbot.slack.blocks import build_follow_up_block
|
||||
from danswer.danswerbot.slack.blocks import build_qa_response_blocks
|
||||
@ -52,6 +51,7 @@ from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import OptionalSearchSetting
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
|
||||
logger_base = setup_logger()
|
||||
|
||||
|
@ -10,10 +10,11 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_RESPOND_EVERY_CHANNEL
|
||||
from danswer.configs.danswerbot_configs import NOTIFY_SLACKBOT_NO_ANSWER
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.danswerbot.slack.config import get_slack_bot_config_for_channel
|
||||
from danswer.danswerbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
from danswer.danswerbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID
|
||||
@ -43,7 +44,7 @@ from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -390,10 +391,11 @@ if __name__ == "__main__":
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
warm_up_models(
|
||||
warm_up_encoders(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
|
@ -16,8 +16,9 @@ from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.search_nlp_models import EmbedTextType
|
||||
from danswer.utils.batching import batch_list
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@ -73,6 +74,8 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embed_dict: dict[str, list[float]] = {}
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
|
||||
# Create Mini Chunks for more precise matching of details
|
||||
# Off by default with unedited settings
|
||||
chunk_texts = []
|
||||
chunk_mini_chunks_count = {}
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
@ -85,23 +88,41 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
chunk_texts.extend(mini_chunk_texts)
|
||||
chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts)
|
||||
|
||||
text_batches = [
|
||||
chunk_texts[i : i + batch_size]
|
||||
for i in range(0, len(chunk_texts), batch_size)
|
||||
]
|
||||
# Batching for embedding
|
||||
text_batches = batch_list(chunk_texts, batch_size)
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
len_text_batches = len(text_batches)
|
||||
for idx, text_batch in enumerate(text_batches, start=1):
|
||||
logger.debug(f"Embedding text batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
|
||||
logger.debug(f"Embedding Content Texts batch {idx} of {len_text_batches}")
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right
|
||||
# value for the set loss
|
||||
embeddings.extend(
|
||||
self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE)
|
||||
)
|
||||
|
||||
# Replace line above with the line below for easy debugging of indexing flow, skipping the actual model
|
||||
# Replace line above with the line below for easy debugging of indexing flow
|
||||
# skipping the actual model
|
||||
# embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))])
|
||||
|
||||
chunk_titles = {
|
||||
chunk.source_document.get_title_for_document_index() for chunk in chunks
|
||||
}
|
||||
chunk_titles.discard(None)
|
||||
|
||||
# Embed Titles in batches
|
||||
title_batches = batch_list(list(chunk_titles), batch_size)
|
||||
len_title_batches = len(title_batches)
|
||||
for ind_batch, title_batch in enumerate(title_batches, start=1):
|
||||
logger.debug(f"Embedding Titles batch {ind_batch} of {len_title_batches}")
|
||||
title_embeddings = self.embedding_model.encode(
|
||||
title_batch, text_type=EmbedTextType.PASSAGE
|
||||
)
|
||||
title_embed_dict.update(
|
||||
{title: vector for title, vector in zip(title_batch, title_embeddings)}
|
||||
)
|
||||
|
||||
# Mapping embeddings to chunks
|
||||
embedding_ind_start = 0
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
num_embeddings = chunk_mini_chunks_count[chunk_ind]
|
||||
@ -114,9 +135,12 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
title_embedding = None
|
||||
if title:
|
||||
if title in title_embed_dict:
|
||||
# Using cached value for speedup
|
||||
# Using cached value to avoid recalculating for every chunk
|
||||
title_embedding = title_embed_dict[title]
|
||||
else:
|
||||
logger.error(
|
||||
"Title had to be embedded separately, this should not happen!"
|
||||
)
|
||||
title_embedding = self.embedding_model.encode(
|
||||
[title], text_type=EmbedTextType.PASSAGE
|
||||
)[0]
|
||||
|
@ -1,10 +1,10 @@
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import nltk # type:ignore
|
||||
import torch # Import here is fine, API server needs torch anyway and nothing imports main.py
|
||||
import uvicorn
|
||||
from fastapi import APIRouter
|
||||
from fastapi import FastAPI
|
||||
@ -36,7 +36,6 @@ from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.db.chat import delete_old_default_personas
|
||||
@ -54,7 +53,7 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.dynamic_configs.port_configs import port_filesystem_to_postgres
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.search.search_nlp_models import warm_up_models
|
||||
from danswer.search.search_nlp_models import warm_up_encoders
|
||||
from danswer.server.danswer_api.ingestion import get_danswer_api_key
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
@ -82,6 +81,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -204,24 +204,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
||||
logger.info("Reranking step of search flow is enabled.")
|
||||
|
||||
if MODEL_SERVER_HOST:
|
||||
logger.info(
|
||||
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
||||
)
|
||||
else:
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
@ -237,19 +219,34 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
load_chat_yamls()
|
||||
|
||||
logger.info("Verifying Document Index(s) is/are available.")
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=db_embedding_model.index_name,
|
||||
secondary_index_name=secondary_db_embedding_model.index_name
|
||||
if secondary_db_embedding_model
|
||||
else None,
|
||||
)
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=db_embedding_model.model_dim,
|
||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||
if secondary_db_embedding_model
|
||||
else None,
|
||||
)
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
wait_time = 5
|
||||
for attempt in range(5):
|
||||
try:
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=db_embedding_model.model_dim,
|
||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||
if secondary_db_embedding_model
|
||||
else None,
|
||||
)
|
||||
break
|
||||
except Exception:
|
||||
logger.info(f"Waiting on Vespa, retrying in {wait_time} seconds...")
|
||||
time.sleep(wait_time)
|
||||
|
||||
logger.info(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
warm_up_encoders(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
||||
|
@ -28,3 +28,8 @@ class SearchType(str, Enum):
|
||||
class QueryFlow(str, Enum):
|
||||
SEARCH = "search"
|
||||
QUESTION_ANSWER = "question-answer"
|
||||
|
||||
|
||||
class EmbedTextType(str, Enum):
|
||||
QUERY = "query"
|
||||
PASSAGE = "passage"
|
||||
|
@ -8,10 +8,10 @@ from danswer.configs.chat_configs import HYBRID_ALPHA
|
||||
from danswer.configs.chat_configs import NUM_RERANKED_RESULTS
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.db.models import Persona
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
|
@ -5,7 +5,6 @@ from danswer.configs.chat_configs import DISABLE_LLM_CHUNK_FILTER
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_FILTER_EXTRACTION
|
||||
from danswer.configs.chat_configs import FAVOR_RECENT_DECAY_MULTIPLIER
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.db.models import User
|
||||
from danswer.search.enums import QueryFlow
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
@ -22,6 +21,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import FunctionCall
|
||||
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
@ -14,6 +14,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
@ -21,7 +22,6 @@ from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.search.search_nlp_models import EmbedTextType
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
@ -1,56 +1,38 @@
|
||||
import gc
|
||||
import os
|
||||
from enum import Enum
|
||||
import time
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
||||
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
||||
from danswer.search.enums import EmbedTextType
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_models.model_server_models import EmbedRequest
|
||||
from shared_models.model_server_models import EmbedResponse
|
||||
from shared_models.model_server_models import IntentRequest
|
||||
from shared_models.model_server_models import IntentResponse
|
||||
from shared_models.model_server_models import RerankRequest
|
||||
from shared_models.model_server_models import RerankResponse
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
logger = setup_logger()
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||
|
||||
|
||||
_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None)
|
||||
_EMBED_MODEL: tuple[Optional["SentenceTransformer"], str | None] = (None, None)
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
_INTENT_TOKENIZER: Optional["AutoTokenizer"] = None
|
||||
_INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None
|
||||
|
||||
|
||||
class EmbedTextType(str, Enum):
|
||||
QUERY = "query"
|
||||
PASSAGE = "passage"
|
||||
|
||||
|
||||
def clean_model_name(model_str: str) -> str:
|
||||
@ -84,89 +66,10 @@ def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
|
||||
return _TOKENIZER[0]
|
||||
|
||||
|
||||
def get_local_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> "SentenceTransformer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
global _EMBED_MODEL
|
||||
if (
|
||||
_EMBED_MODEL[0] is None
|
||||
or max_context_length != _EMBED_MODEL[0].max_seq_length
|
||||
or model_name != _EMBED_MODEL[1]
|
||||
):
|
||||
if _EMBED_MODEL[0] is not None:
|
||||
del _EMBED_MODEL
|
||||
gc.collect()
|
||||
|
||||
logger.info(f"Loading {model_name}")
|
||||
_EMBED_MODEL = (SentenceTransformer(model_name), model_name)
|
||||
_EMBED_MODEL[0].max_seq_length = max_context_length
|
||||
return _EMBED_MODEL[0]
|
||||
|
||||
|
||||
def get_local_reranking_model_ensemble(
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
) -> list["CrossEncoder"]:
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from sentence_transformers import CrossEncoder
|
||||
|
||||
global _RERANK_MODELS
|
||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||
del _RERANK_MODELS
|
||||
gc.collect()
|
||||
|
||||
_RERANK_MODELS = []
|
||||
for model_name in model_names:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
model.max_length = max_context_length
|
||||
_RERANK_MODELS.append(model)
|
||||
return _RERANK_MODELS
|
||||
|
||||
|
||||
def get_intent_model_tokenizer(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||
) -> "TFDistilBertForSequenceClassification":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
_INTENT_MODEL.max_seq_length = max_context_length
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def build_model_server_url(
|
||||
model_server_host: str | None,
|
||||
model_server_port: int | None,
|
||||
) -> str | None:
|
||||
if not model_server_host or model_server_port is None:
|
||||
return None
|
||||
|
||||
model_server_host: str,
|
||||
model_server_port: int,
|
||||
) -> str:
|
||||
model_server_url = f"{model_server_host}:{model_server_port}"
|
||||
|
||||
# use protocol if provided
|
||||
@ -184,8 +87,8 @@ class EmbeddingModel:
|
||||
query_prefix: str | None,
|
||||
passage_prefix: str | None,
|
||||
normalize: bool,
|
||||
server_host: str | None, # Changes depending on indexing or inference
|
||||
server_port: int | None,
|
||||
server_host: str, # Changes depending on indexing or inference
|
||||
server_port: int,
|
||||
# The following are globals are currently not configurable
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> None:
|
||||
@ -196,17 +99,7 @@ class EmbeddingModel:
|
||||
self.normalize = normalize
|
||||
|
||||
model_server_url = build_model_server_url(server_host, server_port)
|
||||
self.embed_server_endpoint = (
|
||||
f"{model_server_url}/encoder/bi-encoder-embed" if model_server_url else None
|
||||
)
|
||||
|
||||
def load_model(self) -> Optional["SentenceTransformer"]:
|
||||
if self.embed_server_endpoint:
|
||||
return None
|
||||
|
||||
return get_local_embedding_model(
|
||||
model_name=self.model_name, max_context_length=self.max_seq_length
|
||||
)
|
||||
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
||||
|
||||
def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]:
|
||||
if text_type == EmbedTextType.QUERY and self.query_prefix:
|
||||
@ -216,166 +109,67 @@ class EmbeddingModel:
|
||||
else:
|
||||
prefixed_texts = texts
|
||||
|
||||
if self.embed_server_endpoint:
|
||||
embed_request = EmbedRequest(
|
||||
texts=prefixed_texts,
|
||||
model_name=self.model_name,
|
||||
normalize_embeddings=self.normalize,
|
||||
)
|
||||
embed_request = EmbedRequest(
|
||||
texts=prefixed_texts,
|
||||
model_name=self.model_name,
|
||||
max_context_length=self.max_seq_length,
|
||||
normalize_embeddings=self.normalize,
|
||||
)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
response = requests.post(self.embed_server_endpoint, json=embed_request.dict())
|
||||
response.raise_for_status()
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f"Failed to get Embedding: {e}")
|
||||
raise
|
||||
|
||||
local_model = self.load_model()
|
||||
|
||||
if local_model is None:
|
||||
raise RuntimeError("Failed to load local Embedding Model")
|
||||
|
||||
return local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=self.normalize
|
||||
).tolist()
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
self.model_names = model_names
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.rerank_server_endpoint = (
|
||||
model_server_url + "/encoder/cross-encoder-scores"
|
||||
if model_server_url
|
||||
else None
|
||||
)
|
||||
|
||||
def load_model(self) -> list["CrossEncoder"] | None:
|
||||
if (
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW is False
|
||||
and ENABLE_RERANKING_ASYNC_FLOW is False
|
||||
):
|
||||
logger.warning(
|
||||
"Running rerankers but they are globally disabled."
|
||||
"Was this specified explicitly via an API?"
|
||||
)
|
||||
|
||||
if self.rerank_server_endpoint:
|
||||
return None
|
||||
|
||||
return get_local_reranking_model_ensemble(
|
||||
model_names=self.model_names, max_context_length=self.max_seq_length
|
||||
)
|
||||
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||
if self.rerank_server_endpoint:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
response = requests.post(
|
||||
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return RerankResponse(**response.json()).scores
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f"Failed to get Reranking Scores: {e}")
|
||||
raise
|
||||
|
||||
local_models = self.load_model()
|
||||
|
||||
if local_models is None:
|
||||
raise RuntimeError("Failed to load local Reranking Model Ensemble")
|
||||
|
||||
scores = [
|
||||
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
|
||||
for cross_encoder in local_models
|
||||
]
|
||||
|
||||
return scores
|
||||
return RerankResponse(**response.json()).scores
|
||||
|
||||
|
||||
class IntentModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.max_seq_length = max_seq_length
|
||||
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.intent_server_endpoint = (
|
||||
model_server_url + "/custom/intent-model" if model_server_url else None
|
||||
)
|
||||
|
||||
def load_model(self) -> Optional["SentenceTransformer"]:
|
||||
if self.intent_server_endpoint:
|
||||
return None
|
||||
|
||||
return get_local_intent_model(
|
||||
model_name=self.model_name, max_context_length=self.max_seq_length
|
||||
)
|
||||
self.intent_server_endpoint = model_server_url + "/custom/intent-model"
|
||||
|
||||
def predict(
|
||||
self,
|
||||
query: str,
|
||||
) -> list[float]:
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
import tensorflow as tf # type: ignore
|
||||
intent_request = IntentRequest(query=query)
|
||||
|
||||
if self.intent_server_endpoint:
|
||||
intent_request = IntentRequest(query=query)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.intent_server_endpoint, json=intent_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return IntentResponse(**response.json()).class_probs
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f"Failed to get Embedding: {e}")
|
||||
raise
|
||||
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
local_model = self.load_model()
|
||||
|
||||
if local_model is None:
|
||||
raise RuntimeError("Failed to load local Intent Model")
|
||||
|
||||
intent_model = get_local_intent_model()
|
||||
model_input = tokenizer(
|
||||
query, return_tensors="tf", truncation=True, padding=True
|
||||
response = requests.post(
|
||||
self.intent_server_endpoint, json=intent_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
predictions = intent_model(model_input)[0]
|
||||
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||
|
||||
return list(class_percentages.tolist()[0])
|
||||
return IntentResponse(**response.json()).class_probs
|
||||
|
||||
|
||||
def warm_up_models(
|
||||
def warm_up_encoders(
|
||||
model_name: str,
|
||||
normalize: bool,
|
||||
skip_cross_encoders: bool = True,
|
||||
indexer_only: bool = False,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
warm_up_str = (
|
||||
"Danswer is amazing! Check out our easy deployment guide at "
|
||||
@ -387,23 +181,23 @@ def warm_up_models(
|
||||
embed_model = EmbeddingModel(
|
||||
model_name=model_name,
|
||||
normalize=normalize,
|
||||
# These don't matter, if it's a remote model, this function shouldn't be called
|
||||
# Not a big deal if prefix is incorrect
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
server_host=None,
|
||||
server_port=None,
|
||||
server_host=model_server_host,
|
||||
server_port=model_server_port,
|
||||
)
|
||||
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
|
||||
if indexer_only:
|
||||
return
|
||||
|
||||
if not skip_cross_encoders:
|
||||
CrossEncoderEnsembleModel().predict(query=warm_up_str, passages=[warm_up_str])
|
||||
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
||||
# First time downloading the models it may take even longer, but just in case,
|
||||
# retry the whole server
|
||||
wait_time = 5
|
||||
for attempt in range(20):
|
||||
try:
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
return
|
||||
except Exception:
|
||||
logger.info(
|
||||
f"Failed to run test embedding, retrying in {wait_time} seconds..."
|
||||
)
|
||||
time.sleep(wait_time)
|
||||
raise Exception("Failed to run test embedding.")
|
||||
|
@ -21,3 +21,10 @@ def batch_generator(
|
||||
if pre_batch_yield:
|
||||
pre_batch_yield(batch)
|
||||
yield batch
|
||||
|
||||
|
||||
def batch_list(
|
||||
lst: list[T],
|
||||
batch_size: int,
|
||||
) -> list[list[T]]:
|
||||
return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
|
||||
|
1
backend/model_server/constants.py
Normal file
1
backend/model_server/constants.py
Normal file
@ -0,0 +1 @@
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
@ -1,19 +1,58 @@
|
||||
import numpy as np
|
||||
from fastapi import APIRouter
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import tensorflow as tf # type: ignore
|
||||
from fastapi import APIRouter
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import TFDistilBertForSequenceClassification
|
||||
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||
from shared_configs.nlp_model_configs import INTENT_MODEL_CONTEXT_SIZE
|
||||
from shared_configs.nlp_model_configs import INTENT_MODEL_VERSION
|
||||
|
||||
from danswer.search.search_nlp_models import get_intent_model_tokenizer
|
||||
from danswer.search.search_nlp_models import get_local_intent_model
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import IntentRequest
|
||||
from shared_models.model_server_models import IntentResponse
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
_INTENT_TOKENIZER: Optional[AutoTokenizer] = None
|
||||
_INTENT_MODEL: Optional[TFDistilBertForSequenceClassification] = None
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
|
||||
def get_intent_model_tokenizer(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
) -> "AutoTokenizer":
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_context_length: int = INTENT_MODEL_CONTEXT_SIZE,
|
||||
) -> TFDistilBertForSequenceClassification:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None or max_context_length != _INTENT_MODEL.max_seq_length:
|
||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||
model_name
|
||||
)
|
||||
_INTENT_MODEL.max_seq_length = max_context_length
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
MODEL_WARM_UP_STRING, return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def classify_intent(query: str) -> list[float]:
|
||||
import tensorflow as tf # type:ignore
|
||||
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
intent_model = get_local_intent_model()
|
||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||
@ -26,16 +65,11 @@ def classify_intent(query: str) -> list[float]:
|
||||
|
||||
|
||||
@router.post("/intent-model")
|
||||
def process_intent_request(
|
||||
async def process_intent_request(
|
||||
intent_request: IntentRequest,
|
||||
) -> IntentResponse:
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
class_percentages = classify_intent(intent_request.query)
|
||||
return IntentResponse(class_probs=class_percentages)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
"danswer", return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
||||
|
@ -1,34 +1,33 @@
|
||||
from typing import TYPE_CHECKING
|
||||
import gc
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import EmbedRequest
|
||||
from shared_models.model_server_models import EmbedResponse
|
||||
from shared_models.model_server_models import RerankRequest
|
||||
from shared_models.model_server_models import RerankResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.utils import simple_log_function_time
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
from shared_configs.model_server_models import RerankResponse
|
||||
from shared_configs.nlp_model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from shared_configs.nlp_model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
WARM_UP_STRING = "Danswer is amazing"
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {}
|
||||
_RERANK_MODELS: Optional[list["CrossEncoder"]] = None
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
max_context_length: int,
|
||||
) -> "SentenceTransformer":
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
@ -48,11 +47,44 @@ def get_embedding_model(
|
||||
return _GLOBAL_MODELS_DICT[model_name]
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def get_local_reranking_model_ensemble(
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
) -> list[CrossEncoder]:
|
||||
global _RERANK_MODELS
|
||||
if _RERANK_MODELS is None or max_context_length != _RERANK_MODELS[0].max_length:
|
||||
del _RERANK_MODELS
|
||||
gc.collect()
|
||||
|
||||
_RERANK_MODELS = []
|
||||
for model_name in model_names:
|
||||
logger.info(f"Loading {model_name}")
|
||||
model = CrossEncoder(model_name)
|
||||
model.max_length = max_context_length
|
||||
_RERANK_MODELS.append(model)
|
||||
return _RERANK_MODELS
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((MODEL_WARM_UP_STRING, MODEL_WARM_UP_STRING))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str], model_name: str, normalize_embeddings: bool
|
||||
texts: list[str],
|
||||
model_name: str,
|
||||
max_context_length: int,
|
||||
normalize_embeddings: bool,
|
||||
) -> list[list[float]]:
|
||||
model = get_embedding_model(model_name=model_name)
|
||||
model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
@ -61,7 +93,7 @@ def embed_text(
|
||||
return embeddings
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
@simple_log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
@ -72,13 +104,14 @@ def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
def process_embed_request(
|
||||
async def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
try:
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
max_context_length=embed_request.max_context_length,
|
||||
normalize_embeddings=embed_request.normalize_embeddings,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
@ -87,7 +120,11 @@ def process_embed_request(
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
async def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
"""Cross encoders can be purely black box from the app perspective"""
|
||||
if INDEXING_ONLY:
|
||||
raise RuntimeError("Indexing model server should not call intent endpoint")
|
||||
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
@ -95,13 +132,3 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
return RerankResponse(scores=sim_scores)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
|
@ -1,40 +1,61 @@
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer import __version__
|
||||
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_cross_encoders
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_ASYNC_FLOW
|
||||
from shared_configs.nlp_model_configs import ENABLE_RERANKING_REAL_TIME_FLOW
|
||||
from shared_configs.nlp_model_configs import INDEXING_ONLY
|
||||
from shared_configs.nlp_model_configs import MIN_THREADS_ML_MODELS
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
warm_up_intent_model()
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW or ENABLE_RERANKING_ASYNC_FLOW:
|
||||
warm_up_cross_encoders()
|
||||
else:
|
||||
logger.info("This model server should only run document indexing.")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def get_model_app() -> FastAPI:
|
||||
application = FastAPI(title="Danswer Model Server", version=__version__)
|
||||
application = FastAPI(
|
||||
title="Danswer Model Server", version=__version__, lifespan=lifespan
|
||||
)
|
||||
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
@application.on_event("startup")
|
||||
def startup_event() -> None:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
warm_up_cross_encoders()
|
||||
warm_up_intent_model()
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
41
backend/model_server/utils.py
Normal file
41
backend/model_server/utils.py
Normal file
@ -0,0 +1,41 @@
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
||||
|
||||
|
||||
def simple_log_function_time(
|
||||
func_name: str | None = None,
|
||||
debug_only: bool = False,
|
||||
include_args: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.info(final_log)
|
||||
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
return decorator
|
@ -54,19 +54,12 @@ requests-oauthlib==1.3.1
|
||||
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
|
||||
rfc3986==1.5.0
|
||||
rt==3.1.2
|
||||
# need to pin `safetensors` version, since the latest versions requires
|
||||
# building from source using Rust
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
starlette==0.36.3
|
||||
supervisor==4.2.5
|
||||
tensorflow==2.15.0
|
||||
tiktoken==0.4.0
|
||||
timeago==1.0.16
|
||||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
zulip==0.8.2
|
||||
|
@ -1,8 +1,8 @@
|
||||
fastapi==0.109.1
|
||||
fastapi==0.109.2
|
||||
pydantic==1.10.7
|
||||
safetensors==0.3.1
|
||||
sentence-transformers==2.2.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==2.6.1
|
||||
tensorflow==2.15.0
|
||||
torch==2.0.1
|
||||
transformers==4.36.2
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
|
@ -2,8 +2,10 @@ from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
# This already includes any prefixes, the text is just passed directly to the model
|
||||
texts: list[str]
|
||||
model_name: str
|
||||
max_context_length: int
|
||||
normalize_embeddings: bool
|
||||
|
||||
|
26
backend/shared_configs/nlp_model_configs.py
Normal file
26
backend/shared_configs/nlp_model_configs.py
Normal file
@ -0,0 +1,26 @@
|
||||
import os
|
||||
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
INTENT_MODEL_CONTEXT_SIZE = 256
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
||||
# Cross Encoder Settings
|
||||
ENABLE_RERANKING_ASYNC_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_ASYNC_FLOW", "").lower() == "true"
|
||||
)
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW = (
|
||||
os.environ.get("ENABLE_RERANKING_REAL_TIME_FLOW", "").lower() == "true"
|
||||
)
|
||||
# Only using one cross-encoder for now
|
||||
CROSS_ENCODER_MODEL_ENSEMBLE = ["mixedbread-ai/mxbai-rerank-xsmall-v1"]
|
||||
CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
|
||||
# This controls the minimum number of pytorch "threads" to allocate to the embedding
|
||||
# model. If torch finds more threads on its own, this value is not used.
|
||||
MIN_THREADS_ML_MODELS = int(os.environ.get("MIN_THREADS_ML_MODELS") or 1)
|
||||
|
||||
INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"
|
@ -67,7 +67,7 @@ services:
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-}
|
||||
- ENABLE_RERANKING_REAL_TIME_FLOW=${ENABLE_RERANKING_REAL_TIME_FLOW:-}
|
||||
- ENABLE_RERANKING_ASYNC_FLOW=${ENABLE_RERANKING_ASYNC_FLOW:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
# Leave this on pretty please? Nothing sensitive is collected!
|
||||
# https://docs.danswer.dev/more/telemetry
|
||||
@ -80,9 +80,7 @@ services:
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_nltk:/root/nltk_data/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@ -90,6 +88,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
background:
|
||||
image: danswer/danswer-backend:latest
|
||||
build:
|
||||
@ -137,10 +137,9 @@ services:
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- ASYM_QUERY_PREFIX=${ASYM_QUERY_PREFIX:-} # Needed by DanswerBot
|
||||
- ASYM_PASSAGE_PREFIX=${ASYM_PASSAGE_PREFIX:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-inference_model_server}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
# Indexing Configs
|
||||
- NUM_INDEXING_WORKERS=${NUM_INDEXING_WORKERS:-}
|
||||
- DISABLE_INDEX_UPDATE_ON_SWAP=${DISABLE_INDEX_UPDATE_ON_SWAP:-}
|
||||
@ -174,9 +173,7 @@ services:
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_nltk:/root/nltk_data/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@ -184,6 +181,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
web_server:
|
||||
image: danswer/danswer-web-server:latest
|
||||
build:
|
||||
@ -198,6 +197,63 @@ services:
|
||||
environment:
|
||||
- INTERNAL_URL=http://api_server:8080
|
||||
- WEB_DOMAIN=${WEB_DOMAIN:-}
|
||||
|
||||
|
||||
inference_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
indexing_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- INDEXING_ONLY=True
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
restart: always
|
||||
@ -208,6 +264,8 @@ services:
|
||||
- "5432:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
@ -222,6 +280,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
nginx:
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: always
|
||||
@ -250,32 +310,8 @@ services:
|
||||
command: >
|
||||
/bin/sh -c "dos2unix /etc/nginx/conf.d/run-nginx.sh
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.dev"
|
||||
# Run with --profile model-server to bring up the danswer-model-server container
|
||||
# Be sure to change MODEL_SERVER_HOST (see above) as well
|
||||
# ie. MODEL_SERVER_HOST="model_server" docker compose -f docker-compose.dev.yml -p danswer-stack --profile model-server up -d --build
|
||||
model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
profiles:
|
||||
- "model-server"
|
||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||
restart: always
|
||||
environment:
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
volumes:
|
||||
local_dynamic_storage:
|
||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||
|
@ -22,9 +22,7 @@ services:
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_nltk:/root/nltk_data/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@ -32,6 +30,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
background:
|
||||
image: danswer/danswer-backend:latest
|
||||
build:
|
||||
@ -51,9 +51,7 @@ services:
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_nltk:/root/nltk_data/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@ -61,6 +59,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
web_server:
|
||||
image: danswer/danswer-web-server:latest
|
||||
build:
|
||||
@ -81,6 +81,63 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
inference_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
indexing_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- INDEXING_ONLY=True
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
restart: always
|
||||
@ -94,6 +151,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
@ -108,6 +167,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
nginx:
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: always
|
||||
@ -137,30 +198,8 @@ services:
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template.no-letsencrypt"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
# Run with --profile model-server to bring up the danswer-model-server container
|
||||
model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
profiles:
|
||||
- "model-server"
|
||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||
restart: always
|
||||
environment:
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
volumes:
|
||||
local_dynamic_storage:
|
||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||
|
@ -22,9 +22,7 @@ services:
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_nltk:/root/nltk_data/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@ -32,6 +30,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
background:
|
||||
image: danswer/danswer-backend:latest
|
||||
build:
|
||||
@ -51,9 +51,7 @@ services:
|
||||
volumes:
|
||||
- local_dynamic_storage:/home/storage
|
||||
- file_connector_tmp_storage:/home/file_connector_storage
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_nltk:/root/nltk_data/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
logging:
|
||||
@ -61,6 +59,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
web_server:
|
||||
image: danswer/danswer-web-server:latest
|
||||
build:
|
||||
@ -94,6 +94,63 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
inference_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
indexing_model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
command: >
|
||||
/bin/sh -c "if [ \"${DISABLE_MODEL_SERVER:-false}\" = \"True\" ]; then
|
||||
echo 'Skipping service...';
|
||||
exit 0;
|
||||
else
|
||||
exec uvicorn model_server.main:app --host 0.0.0.0 --port 9000;
|
||||
fi"
|
||||
restart: on-failure
|
||||
environment:
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- INDEXING_ONLY=True
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
@ -108,6 +165,8 @@ services:
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
nginx:
|
||||
image: nginx:1.23.4-alpine
|
||||
restart: always
|
||||
@ -141,6 +200,8 @@ services:
|
||||
&& /etc/nginx/conf.d/run-nginx.sh app.conf.template"
|
||||
env_file:
|
||||
- .env.nginx
|
||||
|
||||
|
||||
# follows https://pentacent.medium.com/nginx-and-lets-encrypt-with-docker-in-less-than-5-minutes-b4b8a60d3a71
|
||||
certbot:
|
||||
image: certbot/certbot
|
||||
@ -154,30 +215,8 @@ services:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
|
||||
# Run with --profile model-server to bring up the danswer-model-server container
|
||||
model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
profiles:
|
||||
- "model-server"
|
||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||
restart: always
|
||||
environment:
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
logging:
|
||||
driver: json-file
|
||||
options:
|
||||
max-size: "50m"
|
||||
max-file: "6"
|
||||
|
||||
|
||||
volumes:
|
||||
local_dynamic_storage:
|
||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||
|
@ -43,9 +43,9 @@ data:
|
||||
ASYM_PASSAGE_PREFIX: ""
|
||||
ENABLE_RERANKING_REAL_TIME_FLOW: ""
|
||||
ENABLE_RERANKING_ASYNC_FLOW: ""
|
||||
MODEL_SERVER_HOST: ""
|
||||
MODEL_SERVER_HOST: "inference-model-server-service"
|
||||
MODEL_SERVER_PORT: ""
|
||||
INDEXING_MODEL_SERVER_HOST: ""
|
||||
INDEXING_MODEL_SERVER_HOST: "indexing-model-server-service"
|
||||
MIN_THREADS_ML_MODELS: ""
|
||||
# Indexing Configs
|
||||
NUM_INDEXING_WORKERS: ""
|
||||
|
@ -0,0 +1,59 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: indexing-model-server-service
|
||||
spec:
|
||||
selector:
|
||||
app: indexing-model-server
|
||||
ports:
|
||||
- name: indexing-model-server-port
|
||||
protocol: TCP
|
||||
port: 9000
|
||||
targetPort: 9000
|
||||
type: ClusterIP
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: indexing-model-server-deployment
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: indexing-model-server
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: indexing-model-server
|
||||
spec:
|
||||
containers:
|
||||
- name: indexing-model-server
|
||||
image: danswer/danswer-model-server:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
|
||||
ports:
|
||||
- containerPort: 9000
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
env:
|
||||
- name: INDEXING_ONLY
|
||||
value: "True"
|
||||
volumeMounts:
|
||||
- name: indexing-model-storage
|
||||
mountPath: /root/.cache
|
||||
volumes:
|
||||
- name: indexing-model-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: indexing-model-pvc
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: indexing-model-pvc
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 3Gi
|
@ -0,0 +1,56 @@
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: inference-model-server-service
|
||||
spec:
|
||||
selector:
|
||||
app: inference-model-server
|
||||
ports:
|
||||
- name: inference-model-server-port
|
||||
protocol: TCP
|
||||
port: 9000
|
||||
targetPort: 9000
|
||||
type: ClusterIP
|
||||
---
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: inference-model-server-deployment
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: inference-model-server
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: inference-model-server
|
||||
spec:
|
||||
containers:
|
||||
- name: inference-model-server
|
||||
image: danswer/danswer-model-server:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000" ]
|
||||
ports:
|
||||
- containerPort: 9000
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: inference-model-storage
|
||||
mountPath: /root/.cache
|
||||
volumes:
|
||||
- name: inference-model-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: inference-model-pvc
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: inference-model-pvc
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
resources:
|
||||
requests:
|
||||
storage: 3Gi
|
Loading…
x
Reference in New Issue
Block a user