mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-07-28 13:53:28 +02:00
Model Server (#695)
Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
This commit is contained in:
36
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
36
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
name: Build and Push Backend Images on Tagging
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- '*'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
build-and-push:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v2
|
||||||
|
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v1
|
||||||
|
|
||||||
|
- name: Login to Docker Hub
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USERNAME }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: Backend Image Docker Build and Push
|
||||||
|
uses: docker/build-push-action@v2
|
||||||
|
with:
|
||||||
|
context: ./backend
|
||||||
|
file: ./backend/Dockerfile.model_server
|
||||||
|
platforms: linux/amd64,linux/arm64
|
||||||
|
push: true
|
||||||
|
tags: |
|
||||||
|
danswer/danswer-model-server:${{ github.ref_name }}
|
||||||
|
danswer/danswer-model-server:latest
|
||||||
|
build-args: |
|
||||||
|
DANSWER_VERSION: ${{ github.ref_name }}
|
@@ -45,6 +45,7 @@ RUN apt-get remove -y linux-libc-dev && \
|
|||||||
# Set up application files
|
# Set up application files
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY ./danswer /app/danswer
|
COPY ./danswer /app/danswer
|
||||||
|
COPY ./shared_models /app/shared_models
|
||||||
COPY ./alembic /app/alembic
|
COPY ./alembic /app/alembic
|
||||||
COPY ./alembic.ini /app/alembic.ini
|
COPY ./alembic.ini /app/alembic.ini
|
||||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||||
|
27
backend/Dockerfile.model_server
Normal file
27
backend/Dockerfile.model_server
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
FROM python:3.11.4-slim-bookworm
|
||||||
|
|
||||||
|
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||||
|
ARG DANSWER_VERSION=0.2-dev
|
||||||
|
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||||
|
|
||||||
|
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||||
|
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
# Needed for model configs and defaults
|
||||||
|
COPY ./danswer/configs /app/danswer/configs
|
||||||
|
# Utils used by model server
|
||||||
|
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
|
||||||
|
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
|
||||||
|
# Version information
|
||||||
|
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||||
|
# Shared implementations for running NLP models locally
|
||||||
|
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
|
||||||
|
# Request/Response models
|
||||||
|
COPY ./shared_models /app/shared_models
|
||||||
|
# Model Server main code
|
||||||
|
COPY ./model_server /app/model_server
|
||||||
|
|
||||||
|
ENV PYTHONPATH /app
|
||||||
|
|
||||||
|
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]
|
@@ -13,6 +13,7 @@ from danswer.background.indexing.job_client import SimpleJob
|
|||||||
from danswer.background.indexing.job_client import SimpleJobClient
|
from danswer.background.indexing.job_client import SimpleJobClient
|
||||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||||
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
|
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||||
from danswer.db.connector import fetch_connectors
|
from danswer.db.connector import fetch_connectors
|
||||||
@@ -290,7 +291,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
logger.info("Warming up Embedding Model(s)")
|
if not MODEL_SERVER_HOST:
|
||||||
warm_up_models(indexer_only=True)
|
logger.info("Warming up Embedding Model(s)")
|
||||||
|
warm_up_models(indexer_only=True)
|
||||||
logger.info("Starting Indexing Loop")
|
logger.info("Starting Indexing Loop")
|
||||||
update_loop()
|
update_loop()
|
||||||
|
@@ -3,6 +3,7 @@ import os
|
|||||||
from danswer.configs.constants import AuthType
|
from danswer.configs.constants import AuthType
|
||||||
from danswer.configs.constants import DocumentIndexType
|
from danswer.configs.constants import DocumentIndexType
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# App Configs
|
# App Configs
|
||||||
#####
|
#####
|
||||||
@@ -19,6 +20,7 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
|||||||
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
|
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
|
||||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Web Configs
|
# Web Configs
|
||||||
#####
|
#####
|
||||||
@@ -56,7 +58,6 @@ VALID_EMAIL_DOMAINS = (
|
|||||||
if _VALID_EMAIL_DOMAINS_STR
|
if _VALID_EMAIL_DOMAINS_STR
|
||||||
else []
|
else []
|
||||||
)
|
)
|
||||||
|
|
||||||
# OAuth Login Flow
|
# OAuth Login Flow
|
||||||
# Used for both Google OAuth2 and OIDC flows
|
# Used for both Google OAuth2 and OIDC flows
|
||||||
OAUTH_CLIENT_ID = (
|
OAUTH_CLIENT_ID = (
|
||||||
@@ -200,12 +201,13 @@ MINI_CHUNK_SIZE = 150
|
|||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
|
# Model Server Configs
|
||||||
#####
|
#####
|
||||||
BI_ENCODER_HOST = "localhost"
|
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
|
||||||
BI_ENCODER_PORT = 9000
|
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
|
||||||
CROSS_ENCODER_HOST = "localhost"
|
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
|
||||||
CROSS_ENCODER_PORT = 9000
|
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||||
|
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||||
|
|
||||||
|
|
||||||
#####
|
#####
|
||||||
|
@@ -1,16 +1,14 @@
|
|||||||
import numpy
|
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
|
|
||||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
|
||||||
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||||
from danswer.indexing.models import ChunkEmbedding
|
from danswer.indexing.models import ChunkEmbedding
|
||||||
from danswer.indexing.models import DocAwareChunk
|
from danswer.indexing.models import DocAwareChunk
|
||||||
from danswer.indexing.models import IndexChunk
|
from danswer.indexing.models import IndexChunk
|
||||||
from danswer.search.models import Embedder
|
from danswer.search.models import Embedder
|
||||||
from danswer.search.search_nlp_models import get_default_embedding_model
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
|
|
||||||
|
|
||||||
@@ -24,7 +22,7 @@ def encode_chunks(
|
|||||||
) -> list[IndexChunk]:
|
) -> list[IndexChunk]:
|
||||||
embedded_chunks: list[IndexChunk] = []
|
embedded_chunks: list[IndexChunk] = []
|
||||||
if embedding_model is None:
|
if embedding_model is None:
|
||||||
embedding_model = get_default_embedding_model()
|
embedding_model = EmbeddingModel()
|
||||||
|
|
||||||
chunk_texts = []
|
chunk_texts = []
|
||||||
chunk_mini_chunks_count = {}
|
chunk_mini_chunks_count = {}
|
||||||
@@ -43,15 +41,10 @@ def encode_chunks(
|
|||||||
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
|
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
|
||||||
]
|
]
|
||||||
|
|
||||||
embeddings_np: list[numpy.ndarray] = []
|
embeddings: list[list[float]] = []
|
||||||
for text_batch in text_batches:
|
for text_batch in text_batches:
|
||||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
|
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
|
||||||
embeddings_np.extend(
|
embeddings.extend(embedding_model.encode(text_batch))
|
||||||
embedding_model.encode(
|
|
||||||
text_batch, normalize_embeddings=NORMALIZE_EMBEDDINGS
|
|
||||||
)
|
|
||||||
)
|
|
||||||
embeddings: list[list[float]] = [embedding.tolist() for embedding in embeddings_np]
|
|
||||||
|
|
||||||
embedding_ind_start = 0
|
embedding_ind_start = 0
|
||||||
for chunk_ind, chunk in enumerate(chunks):
|
for chunk_ind, chunk in enumerate(chunks):
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
import nltk # type:ignore
|
import nltk # type:ignore
|
||||||
|
import torch
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi import Request
|
from fastapi import Request
|
||||||
@@ -7,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from httpx_oauth.clients.google import GoogleOAuth2
|
from httpx_oauth.clients.google import GoogleOAuth2
|
||||||
|
|
||||||
|
from danswer import __version__
|
||||||
from danswer.auth.schemas import UserCreate
|
from danswer.auth.schemas import UserCreate
|
||||||
from danswer.auth.schemas import UserRead
|
from danswer.auth.schemas import UserRead
|
||||||
from danswer.auth.schemas import UserUpdate
|
from danswer.auth.schemas import UserUpdate
|
||||||
@@ -17,6 +19,8 @@ from danswer.configs.app_configs import APP_HOST
|
|||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
from danswer.configs.app_configs import AUTH_TYPE
|
from danswer.configs.app_configs import AUTH_TYPE
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||||
from danswer.configs.app_configs import SECRET
|
from danswer.configs.app_configs import SECRET
|
||||||
@@ -72,7 +76,7 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
|
|||||||
|
|
||||||
|
|
||||||
def get_application() -> FastAPI:
|
def get_application() -> FastAPI:
|
||||||
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
|
application = FastAPI(title="Danswer Backend", version=__version__)
|
||||||
application.include_router(backend_router)
|
application.include_router(backend_router)
|
||||||
application.include_router(chat_router)
|
application.include_router(chat_router)
|
||||||
application.include_router(event_processing_router)
|
application.include_router(event_processing_router)
|
||||||
@@ -176,11 +180,23 @@ def get_application() -> FastAPI:
|
|||||||
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
|
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
|
||||||
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
|
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
|
||||||
|
|
||||||
logger.info("Warming up local NLP models.")
|
if MODEL_SERVER_HOST:
|
||||||
warm_up_models()
|
logger.info(
|
||||||
qa_model = get_default_qa_model()
|
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Warming up local NLP models.")
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.info("GPU is available")
|
||||||
|
else:
|
||||||
|
logger.info("GPU is not available")
|
||||||
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||||
|
|
||||||
|
warm_up_models()
|
||||||
|
|
||||||
# This is for the LLM, most LLMs will not need warming up
|
# This is for the LLM, most LLMs will not need warming up
|
||||||
qa_model.warm_up_model()
|
# It logs for itself
|
||||||
|
get_default_qa_model().warm_up_model()
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
nltk.download("stopwords", quiet=True)
|
nltk.download("stopwords", quiet=True)
|
||||||
|
@@ -1,12 +1,9 @@
|
|||||||
import numpy as np
|
|
||||||
import tensorflow as tf # type:ignore
|
|
||||||
from transformers import AutoTokenizer # type:ignore
|
from transformers import AutoTokenizer # type:ignore
|
||||||
|
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.search_nlp_models import get_default_intent_model
|
|
||||||
from danswer.search.search_nlp_models import get_default_intent_model_tokenizer
|
|
||||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||||
|
from danswer.search.search_nlp_models import IntentModel
|
||||||
from danswer.search.search_runner import remove_stop_words
|
from danswer.search.search_runner import remove_stop_words
|
||||||
from danswer.server.models import HelperResponse
|
from danswer.server.models import HelperResponse
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -28,15 +25,11 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
|
|||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
||||||
tokenizer = get_default_intent_model_tokenizer()
|
intent_model = IntentModel()
|
||||||
intent_model = get_default_intent_model()
|
class_probs = intent_model.predict(query)
|
||||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
keyword = class_probs[0]
|
||||||
|
semantic = class_probs[1]
|
||||||
predictions = intent_model(model_input)[0]
|
qa = class_probs[2]
|
||||||
probabilities = tf.nn.softmax(predictions, axis=-1)
|
|
||||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
|
||||||
|
|
||||||
keyword, semantic, qa = class_percentages.tolist()[0]
|
|
||||||
|
|
||||||
# Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question
|
# Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question
|
||||||
if qa > 20:
|
if qa > 20:
|
||||||
|
@@ -1,15 +1,30 @@
|
|||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
|
import tensorflow as tf # type: ignore
|
||||||
from sentence_transformers import CrossEncoder # type: ignore
|
from sentence_transformers import CrossEncoder # type: ignore
|
||||||
from sentence_transformers import SentenceTransformer # type: ignore
|
from sentence_transformers import SentenceTransformer # type: ignore
|
||||||
from transformers import AutoTokenizer # type: ignore
|
from transformers import AutoTokenizer # type: ignore
|
||||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||||
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
||||||
from danswer.configs.model_configs import SKIP_RERANKING
|
from danswer.configs.model_configs import SKIP_RERANKING
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from shared_models.model_server_models import EmbedRequest
|
||||||
|
from shared_models.model_server_models import EmbedResponse
|
||||||
|
from shared_models.model_server_models import IntentRequest
|
||||||
|
from shared_models.model_server_models import IntentResponse
|
||||||
|
from shared_models.model_server_models import RerankRequest
|
||||||
|
from shared_models.model_server_models import RerankResponse
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
_TOKENIZER: None | AutoTokenizer = None
|
_TOKENIZER: None | AutoTokenizer = None
|
||||||
@@ -26,39 +41,46 @@ def get_default_tokenizer() -> AutoTokenizer:
|
|||||||
return _TOKENIZER
|
return _TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
def get_default_embedding_model() -> SentenceTransformer:
|
def get_local_embedding_model(
|
||||||
|
model_name: str = DOCUMENT_ENCODER_MODEL,
|
||||||
|
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
|
) -> SentenceTransformer:
|
||||||
global _EMBED_MODEL
|
global _EMBED_MODEL
|
||||||
if _EMBED_MODEL is None:
|
if _EMBED_MODEL is None:
|
||||||
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
|
_EMBED_MODEL = SentenceTransformer(model_name)
|
||||||
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
|
_EMBED_MODEL.max_seq_length = max_context_length
|
||||||
return _EMBED_MODEL
|
return _EMBED_MODEL
|
||||||
|
|
||||||
|
|
||||||
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
|
def get_local_reranking_model_ensemble(
|
||||||
|
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||||
|
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||||
|
) -> list[CrossEncoder]:
|
||||||
global _RERANK_MODELS
|
global _RERANK_MODELS
|
||||||
if _RERANK_MODELS is None:
|
if _RERANK_MODELS is None:
|
||||||
_RERANK_MODELS = [
|
_RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names]
|
||||||
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
|
|
||||||
]
|
|
||||||
for model in _RERANK_MODELS:
|
for model in _RERANK_MODELS:
|
||||||
model.max_length = CROSS_EMBED_CONTEXT_SIZE
|
model.max_length = max_context_length
|
||||||
return _RERANK_MODELS
|
return _RERANK_MODELS
|
||||||
|
|
||||||
|
|
||||||
def get_default_intent_model_tokenizer() -> AutoTokenizer:
|
def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTokenizer:
|
||||||
global _INTENT_TOKENIZER
|
global _INTENT_TOKENIZER
|
||||||
if _INTENT_TOKENIZER is None:
|
if _INTENT_TOKENIZER is None:
|
||||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION)
|
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
||||||
return _INTENT_TOKENIZER
|
return _INTENT_TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
def get_default_intent_model() -> TFDistilBertForSequenceClassification:
|
def get_local_intent_model(
|
||||||
|
model_name: str = INTENT_MODEL_VERSION,
|
||||||
|
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||||
|
) -> TFDistilBertForSequenceClassification:
|
||||||
global _INTENT_MODEL
|
global _INTENT_MODEL
|
||||||
if _INTENT_MODEL is None:
|
if _INTENT_MODEL is None:
|
||||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||||
INTENT_MODEL_VERSION
|
model_name
|
||||||
)
|
)
|
||||||
_INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE
|
_INTENT_MODEL.max_seq_length = max_context_length
|
||||||
return _INTENT_MODEL
|
return _INTENT_MODEL
|
||||||
|
|
||||||
|
|
||||||
@@ -67,20 +89,183 @@ def warm_up_models(
|
|||||||
) -> None:
|
) -> None:
|
||||||
warm_up_str = "Danswer is amazing"
|
warm_up_str = "Danswer is amazing"
|
||||||
get_default_tokenizer()(warm_up_str)
|
get_default_tokenizer()(warm_up_str)
|
||||||
get_default_embedding_model().encode(warm_up_str)
|
get_local_embedding_model().encode(warm_up_str)
|
||||||
|
|
||||||
if indexer_only:
|
if indexer_only:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not skip_cross_encoders:
|
if not skip_cross_encoders:
|
||||||
cross_encoders = get_default_reranking_model_ensemble()
|
cross_encoders = get_local_reranking_model_ensemble()
|
||||||
[
|
[
|
||||||
cross_encoder.predict((warm_up_str, warm_up_str))
|
cross_encoder.predict((warm_up_str, warm_up_str))
|
||||||
for cross_encoder in cross_encoders
|
for cross_encoder in cross_encoders
|
||||||
]
|
]
|
||||||
|
|
||||||
intent_tokenizer = get_default_intent_model_tokenizer()
|
intent_tokenizer = get_intent_model_tokenizer()
|
||||||
inputs = intent_tokenizer(
|
inputs = intent_tokenizer(
|
||||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||||
)
|
)
|
||||||
get_default_intent_model()(inputs)
|
get_local_intent_model()(inputs)
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = DOCUMENT_ENCODER_MODEL,
|
||||||
|
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
|
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||||
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
|
) -> None:
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
self.embed_server_endpoint = (
|
||||||
|
f"http://{model_server_host}:{model_server_port}/encoder/bi-encoder-embed"
|
||||||
|
if model_server_host
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self) -> SentenceTransformer | None:
|
||||||
|
if self.embed_server_endpoint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return get_local_embedding_model(
|
||||||
|
model_name=self.model_name, max_context_length=self.max_seq_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def encode(
|
||||||
|
self, texts: list[str], normalize_embeddings: bool = NORMALIZE_EMBEDDINGS
|
||||||
|
) -> list[list[float]]:
|
||||||
|
if self.embed_server_endpoint:
|
||||||
|
embed_request = EmbedRequest(texts=texts)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.embed_server_endpoint, json=embed_request.dict()
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return EmbedResponse(**response.json()).embeddings
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.exception(f"Failed to get Embedding: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
local_model = self.load_model()
|
||||||
|
|
||||||
|
if local_model is None:
|
||||||
|
raise RuntimeError("Failed to load local Embedding Model")
|
||||||
|
|
||||||
|
return local_model.encode(
|
||||||
|
texts, normalize_embeddings=normalize_embeddings
|
||||||
|
).tolist()
|
||||||
|
|
||||||
|
|
||||||
|
class CrossEncoderEnsembleModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||||
|
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||||
|
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||||
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
|
) -> None:
|
||||||
|
self.model_names = model_names
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
self.rerank_server_endpoint = (
|
||||||
|
f"http://{model_server_host}:{model_server_port}/encoder/cross-encoder-scores"
|
||||||
|
if model_server_host
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self) -> list[CrossEncoder] | None:
|
||||||
|
if self.rerank_server_endpoint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return get_local_reranking_model_ensemble(
|
||||||
|
model_names=self.model_names, max_context_length=self.max_seq_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||||
|
if self.rerank_server_endpoint:
|
||||||
|
rerank_request = RerankRequest(query=query, documents=passages)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return RerankResponse(**response.json()).scores
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.exception(f"Failed to get Reranking Scores: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
local_models = self.load_model()
|
||||||
|
|
||||||
|
if local_models is None:
|
||||||
|
raise RuntimeError("Failed to load local Reranking Model Ensemble")
|
||||||
|
|
||||||
|
scores = [
|
||||||
|
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
|
||||||
|
for cross_encoder in local_models
|
||||||
|
]
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
class IntentModel:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = INTENT_MODEL_VERSION,
|
||||||
|
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||||
|
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||||
|
model_server_port: int = MODEL_SERVER_PORT,
|
||||||
|
) -> None:
|
||||||
|
self.model_name = model_name
|
||||||
|
self.max_seq_length = max_seq_length
|
||||||
|
self.intent_server_endpoint = (
|
||||||
|
f"http://{model_server_host}:{model_server_port}/custom/intent-model"
|
||||||
|
if model_server_host
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self) -> SentenceTransformer | None:
|
||||||
|
if self.intent_server_endpoint:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return get_local_intent_model(
|
||||||
|
model_name=self.model_name, max_context_length=self.max_seq_length
|
||||||
|
)
|
||||||
|
|
||||||
|
def predict(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
) -> list[float]:
|
||||||
|
if self.intent_server_endpoint:
|
||||||
|
intent_request = IntentRequest(query=query)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
self.intent_server_endpoint, json=intent_request.dict()
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
return IntentResponse(**response.json()).class_probs
|
||||||
|
except requests.RequestException as e:
|
||||||
|
logger.exception(f"Failed to get Embedding: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
tokenizer = get_intent_model_tokenizer()
|
||||||
|
local_model = self.load_model()
|
||||||
|
|
||||||
|
if local_model is None:
|
||||||
|
raise RuntimeError("Failed to load local Intent Model")
|
||||||
|
|
||||||
|
intent_model = get_local_intent_model()
|
||||||
|
model_input = tokenizer(
|
||||||
|
query, return_tensors="tf", truncation=True, padding=True
|
||||||
|
)
|
||||||
|
|
||||||
|
predictions = intent_model(model_input)[0]
|
||||||
|
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||||
|
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||||
|
|
||||||
|
return list(class_percentages.tolist()[0])
|
||||||
|
@@ -1,4 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
from nltk.corpus import stopwords # type:ignore
|
from nltk.corpus import stopwords # type:ignore
|
||||||
@@ -29,8 +30,8 @@ from danswer.search.models import RerankMetricsContainer
|
|||||||
from danswer.search.models import RetrievalMetricsContainer
|
from danswer.search.models import RetrievalMetricsContainer
|
||||||
from danswer.search.models import SearchQuery
|
from danswer.search.models import SearchQuery
|
||||||
from danswer.search.models import SearchType
|
from danswer.search.models import SearchType
|
||||||
from danswer.search.search_nlp_models import get_default_embedding_model
|
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||||
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble
|
from danswer.search.search_nlp_models import EmbeddingModel
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import SearchDoc
|
from danswer.server.models import SearchDoc
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -67,14 +68,11 @@ def embed_query(
|
|||||||
prefix: str = ASYM_QUERY_PREFIX,
|
prefix: str = ASYM_QUERY_PREFIX,
|
||||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||||
) -> list[float]:
|
) -> list[float]:
|
||||||
model = embedding_model or get_default_embedding_model()
|
model = embedding_model or EmbeddingModel()
|
||||||
prefixed_query = prefix + query
|
prefixed_query = prefix + query
|
||||||
query_embedding = model.encode(
|
query_embedding = model.encode(
|
||||||
prefixed_query, normalize_embeddings=normalize_embeddings
|
[prefixed_query], normalize_embeddings=normalize_embeddings
|
||||||
)
|
)[0]
|
||||||
|
|
||||||
if not isinstance(query_embedding, list):
|
|
||||||
query_embedding = query_embedding.tolist()
|
|
||||||
|
|
||||||
return query_embedding
|
return query_embedding
|
||||||
|
|
||||||
@@ -104,6 +102,31 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
|
|||||||
return search_docs
|
return search_docs
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def doc_index_retrieval(
|
||||||
|
query: SearchQuery, document_index: DocumentIndex
|
||||||
|
) -> list[InferenceChunk]:
|
||||||
|
if query.search_type == SearchType.KEYWORD:
|
||||||
|
top_chunks = document_index.keyword_retrieval(
|
||||||
|
query.query, query.filters, query.favor_recent, query.num_hits
|
||||||
|
)
|
||||||
|
|
||||||
|
elif query.search_type == SearchType.SEMANTIC:
|
||||||
|
top_chunks = document_index.semantic_retrieval(
|
||||||
|
query.query, query.filters, query.favor_recent, query.num_hits
|
||||||
|
)
|
||||||
|
|
||||||
|
elif query.search_type == SearchType.HYBRID:
|
||||||
|
top_chunks = document_index.hybrid_retrieval(
|
||||||
|
query.query, query.filters, query.favor_recent, query.num_hits
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Invalid Search Flow")
|
||||||
|
|
||||||
|
return top_chunks
|
||||||
|
|
||||||
|
|
||||||
@log_function_time()
|
@log_function_time()
|
||||||
def semantic_reranking(
|
def semantic_reranking(
|
||||||
query: str,
|
query: str,
|
||||||
@@ -112,13 +135,13 @@ def semantic_reranking(
|
|||||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||||
) -> list[InferenceChunk]:
|
) -> list[InferenceChunk]:
|
||||||
cross_encoders = get_default_reranking_model_ensemble()
|
cross_encoders = CrossEncoderEnsembleModel()
|
||||||
sim_scores = [
|
passages = [chunk.content for chunk in chunks]
|
||||||
encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||||
for encoder in cross_encoders
|
|
||||||
]
|
|
||||||
|
|
||||||
raw_sim_scores = sum(sim_scores) / len(sim_scores)
|
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||||
|
|
||||||
|
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
||||||
|
|
||||||
cross_models_min = numpy.min(sim_scores)
|
cross_models_min = numpy.min(sim_scores)
|
||||||
|
|
||||||
@@ -270,23 +293,7 @@ def search_chunks(
|
|||||||
]
|
]
|
||||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||||
|
|
||||||
if query.search_type == SearchType.KEYWORD:
|
top_chunks = doc_index_retrieval(query=query, document_index=document_index)
|
||||||
top_chunks = document_index.keyword_retrieval(
|
|
||||||
query.query, query.filters, query.favor_recent, query.num_hits
|
|
||||||
)
|
|
||||||
|
|
||||||
elif query.search_type == SearchType.SEMANTIC:
|
|
||||||
top_chunks = document_index.semantic_retrieval(
|
|
||||||
query.query, query.filters, query.favor_recent, query.num_hits
|
|
||||||
)
|
|
||||||
|
|
||||||
elif query.search_type == SearchType.HYBRID:
|
|
||||||
top_chunks = document_index.hybrid_retrieval(
|
|
||||||
query.query, query.filters, query.favor_recent, query.num_hits
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
|
||||||
raise RuntimeError("Invalid Search Flow")
|
|
||||||
|
|
||||||
if not top_chunks:
|
if not top_chunks:
|
||||||
logger.info(
|
logger.info(
|
||||||
|
0
backend/model_server/__init__.py
Normal file
0
backend/model_server/__init__.py
Normal file
40
backend/model_server/custom_models.py
Normal file
40
backend/model_server/custom_models.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf # type:ignore
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from danswer.search.search_nlp_models import get_intent_model_tokenizer
|
||||||
|
from danswer.search.search_nlp_models import get_local_intent_model
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
from shared_models.model_server_models import IntentRequest
|
||||||
|
from shared_models.model_server_models import IntentResponse
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/custom")
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def classify_intent(query: str) -> list[float]:
|
||||||
|
tokenizer = get_intent_model_tokenizer()
|
||||||
|
intent_model = get_local_intent_model()
|
||||||
|
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||||
|
|
||||||
|
predictions = intent_model(model_input)[0]
|
||||||
|
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||||
|
|
||||||
|
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||||
|
return list(class_percentages.tolist()[0])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/intent-model")
|
||||||
|
def process_intent_request(
|
||||||
|
intent_request: IntentRequest,
|
||||||
|
) -> IntentResponse:
|
||||||
|
class_percentages = classify_intent(intent_request.query)
|
||||||
|
return IntentResponse(class_probs=class_percentages)
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_intent_model() -> None:
|
||||||
|
intent_tokenizer = get_intent_model_tokenizer()
|
||||||
|
inputs = intent_tokenizer(
|
||||||
|
"danswer", return_tensors="tf", truncation=True, padding=True
|
||||||
|
)
|
||||||
|
get_local_intent_model()(inputs)
|
81
backend/model_server/encoders.py
Normal file
81
backend/model_server/encoders.py
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
from fastapi import APIRouter
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||||
|
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||||
|
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||||
|
from danswer.search.search_nlp_models import get_local_embedding_model
|
||||||
|
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
from shared_models.model_server_models import EmbedRequest
|
||||||
|
from shared_models.model_server_models import EmbedResponse
|
||||||
|
from shared_models.model_server_models import RerankRequest
|
||||||
|
from shared_models.model_server_models import RerankResponse
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
WARM_UP_STRING = "Danswer is amazing"
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/encoder")
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def embed_text(
|
||||||
|
texts: list[str],
|
||||||
|
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||||
|
) -> list[list[float]]:
|
||||||
|
model = get_local_embedding_model()
|
||||||
|
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||||
|
|
||||||
|
if not isinstance(embeddings, list):
|
||||||
|
embeddings = embeddings.tolist()
|
||||||
|
|
||||||
|
return embeddings
|
||||||
|
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||||
|
cross_encoders = get_local_reranking_model_ensemble()
|
||||||
|
sim_scores = [
|
||||||
|
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||||
|
for encoder in cross_encoders
|
||||||
|
]
|
||||||
|
return sim_scores
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/bi-encoder-embed")
|
||||||
|
def process_embed_request(
|
||||||
|
embed_request: EmbedRequest,
|
||||||
|
) -> EmbedResponse:
|
||||||
|
try:
|
||||||
|
embeddings = embed_text(texts=embed_request.texts)
|
||||||
|
return EmbedResponse(embeddings=embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cross-encoder-scores")
|
||||||
|
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||||
|
try:
|
||||||
|
sim_scores = calc_sim_scores(
|
||||||
|
query=embed_request.query, docs=embed_request.documents
|
||||||
|
)
|
||||||
|
return RerankResponse(scores=sim_scores)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_bi_encoder() -> None:
|
||||||
|
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
|
||||||
|
get_local_embedding_model().encode(WARM_UP_STRING)
|
||||||
|
|
||||||
|
|
||||||
|
def warm_up_cross_encoders() -> None:
|
||||||
|
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||||
|
|
||||||
|
cross_encoders = get_local_reranking_model_ensemble()
|
||||||
|
[
|
||||||
|
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
|
||||||
|
for cross_encoder in cross_encoders
|
||||||
|
]
|
51
backend/model_server/main.py
Normal file
51
backend/model_server/main.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
import torch
|
||||||
|
import uvicorn
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
from danswer import __version__
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
|
||||||
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
|
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from model_server.custom_models import router as custom_models_router
|
||||||
|
from model_server.custom_models import warm_up_intent_model
|
||||||
|
from model_server.encoders import router as encoders_router
|
||||||
|
from model_server.encoders import warm_up_bi_encoder
|
||||||
|
from model_server.encoders import warm_up_cross_encoders
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_app() -> FastAPI:
|
||||||
|
application = FastAPI(title="Danswer Model Server", version=__version__)
|
||||||
|
|
||||||
|
application.include_router(encoders_router)
|
||||||
|
application.include_router(custom_models_router)
|
||||||
|
|
||||||
|
@application.on_event("startup")
|
||||||
|
def startup_event() -> None:
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.info("GPU is available")
|
||||||
|
else:
|
||||||
|
logger.info("GPU is not available")
|
||||||
|
|
||||||
|
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||||
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||||
|
|
||||||
|
warm_up_bi_encoder()
|
||||||
|
warm_up_cross_encoders()
|
||||||
|
warm_up_intent_model()
|
||||||
|
|
||||||
|
return application
|
||||||
|
|
||||||
|
|
||||||
|
app = get_model_app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
logger.info(
|
||||||
|
f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
|
||||||
|
)
|
||||||
|
logger.info(f"Model Server Version: {__version__}")
|
||||||
|
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)
|
8
backend/requirements/model_server.txt
Normal file
8
backend/requirements/model_server.txt
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
fastapi==0.103.0
|
||||||
|
pydantic==1.10.7
|
||||||
|
safetensors==0.3.1
|
||||||
|
sentence-transformers==2.2.2
|
||||||
|
tensorflow==2.13.0
|
||||||
|
torch==2.0.1
|
||||||
|
transformers==4.30.1
|
||||||
|
uvicorn==0.21.1
|
0
backend/shared_models/__init__.py
Normal file
0
backend/shared_models/__init__.py
Normal file
26
backend/shared_models/model_server_models.py
Normal file
26
backend/shared_models/model_server_models.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedRequest(BaseModel):
|
||||||
|
texts: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbedResponse(BaseModel):
|
||||||
|
embeddings: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
class RerankRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
documents: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class RerankResponse(BaseModel):
|
||||||
|
scores: list[list[float]]
|
||||||
|
|
||||||
|
|
||||||
|
class IntentRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
|
||||||
|
|
||||||
|
class IntentResponse(BaseModel):
|
||||||
|
class_probs: list[float]
|
@@ -42,6 +42,8 @@ services:
|
|||||||
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
||||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||||
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||||
|
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||||
# Set to debug to get more fine-grained logs
|
# Set to debug to get more fine-grained logs
|
||||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
volumes:
|
volumes:
|
||||||
@@ -94,6 +96,8 @@ services:
|
|||||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||||
|
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||||
# Set to debug to get more fine-grained logs
|
# Set to debug to get more fine-grained logs
|
||||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
volumes:
|
volumes:
|
||||||
@@ -157,6 +161,25 @@ services:
|
|||||||
/bin/sh -c "sleep 10 &&
|
/bin/sh -c "sleep 10 &&
|
||||||
envsubst '$$\{DOMAIN\}' < /etc/nginx/conf.d/app.conf.template.dev > /etc/nginx/conf.d/app.conf &&
|
envsubst '$$\{DOMAIN\}' < /etc/nginx/conf.d/app.conf.template.dev > /etc/nginx/conf.d/app.conf &&
|
||||||
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
|
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
|
||||||
|
# Run with --profile model-server to bring up the danswer-model-server container
|
||||||
|
model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
profiles:
|
||||||
|
- "model-server"
|
||||||
|
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||||
|
restart: always
|
||||||
|
environment:
|
||||||
|
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||||
|
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
volumes:
|
volumes:
|
||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||||
|
@@ -101,6 +101,25 @@ services:
|
|||||||
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
|
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
|
||||||
env_file:
|
env_file:
|
||||||
- .env.nginx
|
- .env.nginx
|
||||||
|
# Run with --profile model-server to bring up the danswer-model-server container
|
||||||
|
model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
profiles:
|
||||||
|
- "model-server"
|
||||||
|
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||||
|
restart: always
|
||||||
|
environment:
|
||||||
|
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||||
|
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
volumes:
|
volumes:
|
||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||||
|
@@ -110,6 +110,25 @@ services:
|
|||||||
- ../data/certbot/conf:/etc/letsencrypt
|
- ../data/certbot/conf:/etc/letsencrypt
|
||||||
- ../data/certbot/www:/var/www/certbot
|
- ../data/certbot/www:/var/www/certbot
|
||||||
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
|
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
|
||||||
|
# Run with --profile model-server to bring up the danswer-model-server container
|
||||||
|
model_server:
|
||||||
|
image: danswer/danswer-model-server:latest
|
||||||
|
build:
|
||||||
|
context: ../../backend
|
||||||
|
dockerfile: Dockerfile.model_server
|
||||||
|
profiles:
|
||||||
|
- "model-server"
|
||||||
|
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||||
|
restart: always
|
||||||
|
environment:
|
||||||
|
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||||
|
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||||
|
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||||
|
# Set to debug to get more fine-grained logs
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||||
|
volumes:
|
||||||
|
- model_cache_torch:/root/.cache/torch/
|
||||||
|
- model_cache_huggingface:/root/.cache/huggingface/
|
||||||
volumes:
|
volumes:
|
||||||
local_dynamic_storage:
|
local_dynamic_storage:
|
||||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||||
|
Reference in New Issue
Block a user