mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 19:43:26 +02:00
Model Server (#695)
Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
This commit is contained in:
36
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
36
.github/workflows/docker-build-push-model-server-container-on-tag.yml
vendored
Normal file
@@ -0,0 +1,36 @@
|
||||
name: Build and Push Backend Images on Tagging
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
jobs:
|
||||
build-and-push:
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v2
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v1
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v1
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
uses: docker/build-push-action@v2
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
platforms: linux/amd64,linux/arm64
|
||||
push: true
|
||||
tags: |
|
||||
danswer/danswer-model-server:${{ github.ref_name }}
|
||||
danswer/danswer-model-server:latest
|
||||
build-args: |
|
||||
DANSWER_VERSION: ${{ github.ref_name }}
|
@@ -45,6 +45,7 @@ RUN apt-get remove -y linux-libc-dev && \
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
COPY ./danswer /app/danswer
|
||||
COPY ./shared_models /app/shared_models
|
||||
COPY ./alembic /app/alembic
|
||||
COPY ./alembic.ini /app/alembic.ini
|
||||
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf
|
||||
|
27
backend/Dockerfile.model_server
Normal file
27
backend/Dockerfile.model_server
Normal file
@@ -0,0 +1,27 @@
|
||||
FROM python:3.11.4-slim-bookworm
|
||||
|
||||
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
|
||||
ARG DANSWER_VERSION=0.2-dev
|
||||
ENV DANSWER_VERSION=${DANSWER_VERSION}
|
||||
|
||||
COPY ./requirements/model_server.txt /tmp/requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
|
||||
|
||||
WORKDIR /app
|
||||
# Needed for model configs and defaults
|
||||
COPY ./danswer/configs /app/danswer/configs
|
||||
# Utils used by model server
|
||||
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
|
||||
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
|
||||
# Version information
|
||||
COPY ./danswer/__init__.py /app/danswer/__init__.py
|
||||
# Shared implementations for running NLP models locally
|
||||
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
|
||||
# Request/Response models
|
||||
COPY ./shared_models /app/shared_models
|
||||
# Model Server main code
|
||||
COPY ./model_server /app/model_server
|
||||
|
||||
ENV PYTHONPATH /app
|
||||
|
||||
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]
|
@@ -13,6 +13,7 @@ from danswer.background.indexing.job_client import SimpleJob
|
||||
from danswer.background.indexing.job_client import SimpleJobClient
|
||||
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
|
||||
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.db.connector import fetch_connectors
|
||||
@@ -290,7 +291,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info("Warming up Embedding Model(s)")
|
||||
warm_up_models(indexer_only=True)
|
||||
if not MODEL_SERVER_HOST:
|
||||
logger.info("Warming up Embedding Model(s)")
|
||||
warm_up_models(indexer_only=True)
|
||||
logger.info("Starting Indexing Loop")
|
||||
update_loop()
|
||||
|
@@ -3,6 +3,7 @@ import os
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import DocumentIndexType
|
||||
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
#####
|
||||
@@ -19,6 +20,7 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
|
||||
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
#####
|
||||
@@ -56,7 +58,6 @@ VALID_EMAIL_DOMAINS = (
|
||||
if _VALID_EMAIL_DOMAINS_STR
|
||||
else []
|
||||
)
|
||||
|
||||
# OAuth Login Flow
|
||||
# Used for both Google OAuth2 and OIDC flows
|
||||
OAUTH_CLIENT_ID = (
|
||||
@@ -200,12 +201,13 @@ MINI_CHUNK_SIZE = 150
|
||||
|
||||
|
||||
#####
|
||||
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
|
||||
# Model Server Configs
|
||||
#####
|
||||
BI_ENCODER_HOST = "localhost"
|
||||
BI_ENCODER_PORT = 9000
|
||||
CROSS_ENCODER_HOST = "localhost"
|
||||
CROSS_ENCODER_PORT = 9000
|
||||
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
|
||||
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
|
||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
|
||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||
|
||||
|
||||
#####
|
||||
|
@@ -1,16 +1,14 @@
|
||||
import numpy
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.models import Embedder
|
||||
from danswer.search.search_nlp_models import get_default_embedding_model
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -24,7 +22,7 @@ def encode_chunks(
|
||||
) -> list[IndexChunk]:
|
||||
embedded_chunks: list[IndexChunk] = []
|
||||
if embedding_model is None:
|
||||
embedding_model = get_default_embedding_model()
|
||||
embedding_model = EmbeddingModel()
|
||||
|
||||
chunk_texts = []
|
||||
chunk_mini_chunks_count = {}
|
||||
@@ -43,15 +41,10 @@ def encode_chunks(
|
||||
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
|
||||
]
|
||||
|
||||
embeddings_np: list[numpy.ndarray] = []
|
||||
embeddings: list[list[float]] = []
|
||||
for text_batch in text_batches:
|
||||
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
|
||||
embeddings_np.extend(
|
||||
embedding_model.encode(
|
||||
text_batch, normalize_embeddings=NORMALIZE_EMBEDDINGS
|
||||
)
|
||||
)
|
||||
embeddings: list[list[float]] = [embedding.tolist() for embedding in embeddings_np]
|
||||
embeddings.extend(embedding_model.encode(text_batch))
|
||||
|
||||
embedding_ind_start = 0
|
||||
for chunk_ind, chunk in enumerate(chunks):
|
||||
|
@@ -1,4 +1,5 @@
|
||||
import nltk # type:ignore
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
@@ -7,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
|
||||
from danswer import __version__
|
||||
from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
@@ -17,6 +19,8 @@ from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import SECRET
|
||||
@@ -72,7 +76,7 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
|
||||
application = FastAPI(title="Danswer Backend", version=__version__)
|
||||
application.include_router(backend_router)
|
||||
application.include_router(chat_router)
|
||||
application.include_router(event_processing_router)
|
||||
@@ -176,11 +180,23 @@ def get_application() -> FastAPI:
|
||||
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
|
||||
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
|
||||
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
qa_model = get_default_qa_model()
|
||||
if MODEL_SERVER_HOST:
|
||||
logger.info(
|
||||
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
||||
)
|
||||
else:
|
||||
logger.info("Warming up local NLP models.")
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
warm_up_models()
|
||||
|
||||
# This is for the LLM, most LLMs will not need warming up
|
||||
qa_model.warm_up_model()
|
||||
# It logs for itself
|
||||
get_default_qa_model().warm_up_model()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords", quiet=True)
|
||||
|
@@ -1,12 +1,9 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf # type:ignore
|
||||
from transformers import AutoTokenizer # type:ignore
|
||||
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import get_default_intent_model
|
||||
from danswer.search.search_nlp_models import get_default_intent_model_tokenizer
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
from danswer.search.search_nlp_models import IntentModel
|
||||
from danswer.search.search_runner import remove_stop_words
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -28,15 +25,11 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
|
||||
|
||||
@log_function_time()
|
||||
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
|
||||
tokenizer = get_default_intent_model_tokenizer()
|
||||
intent_model = get_default_intent_model()
|
||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||
|
||||
predictions = intent_model(model_input)[0]
|
||||
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||
|
||||
keyword, semantic, qa = class_percentages.tolist()[0]
|
||||
intent_model = IntentModel()
|
||||
class_probs = intent_model.predict(query)
|
||||
keyword = class_probs[0]
|
||||
semantic = class_probs[1]
|
||||
qa = class_probs[2]
|
||||
|
||||
# Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question
|
||||
if qa > 20:
|
||||
|
@@ -1,15 +1,30 @@
|
||||
import numpy as np
|
||||
import requests
|
||||
import tensorflow as tf # type: ignore
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
from sentence_transformers import SentenceTransformer # type: ignore
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import TFDistilBertForSequenceClassification # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import INTENT_MODEL_VERSION
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import SKIP_RERANKING
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_models.model_server_models import EmbedRequest
|
||||
from shared_models.model_server_models import EmbedResponse
|
||||
from shared_models.model_server_models import IntentRequest
|
||||
from shared_models.model_server_models import IntentResponse
|
||||
from shared_models.model_server_models import RerankRequest
|
||||
from shared_models.model_server_models import RerankResponse
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_TOKENIZER: None | AutoTokenizer = None
|
||||
@@ -26,39 +41,46 @@ def get_default_tokenizer() -> AutoTokenizer:
|
||||
return _TOKENIZER
|
||||
|
||||
|
||||
def get_default_embedding_model() -> SentenceTransformer:
|
||||
def get_local_embedding_model(
|
||||
model_name: str = DOCUMENT_ENCODER_MODEL,
|
||||
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
) -> SentenceTransformer:
|
||||
global _EMBED_MODEL
|
||||
if _EMBED_MODEL is None:
|
||||
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
|
||||
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
|
||||
_EMBED_MODEL = SentenceTransformer(model_name)
|
||||
_EMBED_MODEL.max_seq_length = max_context_length
|
||||
return _EMBED_MODEL
|
||||
|
||||
|
||||
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
|
||||
def get_local_reranking_model_ensemble(
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
) -> list[CrossEncoder]:
|
||||
global _RERANK_MODELS
|
||||
if _RERANK_MODELS is None:
|
||||
_RERANK_MODELS = [
|
||||
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
]
|
||||
_RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names]
|
||||
for model in _RERANK_MODELS:
|
||||
model.max_length = CROSS_EMBED_CONTEXT_SIZE
|
||||
model.max_length = max_context_length
|
||||
return _RERANK_MODELS
|
||||
|
||||
|
||||
def get_default_intent_model_tokenizer() -> AutoTokenizer:
|
||||
def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTokenizer:
|
||||
global _INTENT_TOKENIZER
|
||||
if _INTENT_TOKENIZER is None:
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION)
|
||||
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
|
||||
return _INTENT_TOKENIZER
|
||||
|
||||
|
||||
def get_default_intent_model() -> TFDistilBertForSequenceClassification:
|
||||
def get_local_intent_model(
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||
) -> TFDistilBertForSequenceClassification:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None:
|
||||
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
|
||||
INTENT_MODEL_VERSION
|
||||
model_name
|
||||
)
|
||||
_INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE
|
||||
_INTENT_MODEL.max_seq_length = max_context_length
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
@@ -67,20 +89,183 @@ def warm_up_models(
|
||||
) -> None:
|
||||
warm_up_str = "Danswer is amazing"
|
||||
get_default_tokenizer()(warm_up_str)
|
||||
get_default_embedding_model().encode(warm_up_str)
|
||||
get_local_embedding_model().encode(warm_up_str)
|
||||
|
||||
if indexer_only:
|
||||
return
|
||||
|
||||
if not skip_cross_encoders:
|
||||
cross_encoders = get_default_reranking_model_ensemble()
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((warm_up_str, warm_up_str))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
||||
|
||||
intent_tokenizer = get_default_intent_model_tokenizer()
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
warm_up_str, return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_default_intent_model()(inputs)
|
||||
get_local_intent_model()(inputs)
|
||||
|
||||
|
||||
class EmbeddingModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = DOCUMENT_ENCODER_MODEL,
|
||||
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.max_seq_length = max_seq_length
|
||||
self.embed_server_endpoint = (
|
||||
f"http://{model_server_host}:{model_server_port}/encoder/bi-encoder-embed"
|
||||
if model_server_host
|
||||
else None
|
||||
)
|
||||
|
||||
def load_model(self) -> SentenceTransformer | None:
|
||||
if self.embed_server_endpoint:
|
||||
return None
|
||||
|
||||
return get_local_embedding_model(
|
||||
model_name=self.model_name, max_context_length=self.max_seq_length
|
||||
)
|
||||
|
||||
def encode(
|
||||
self, texts: list[str], normalize_embeddings: bool = NORMALIZE_EMBEDDINGS
|
||||
) -> list[list[float]]:
|
||||
if self.embed_server_endpoint:
|
||||
embed_request = EmbedRequest(texts=texts)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.embed_server_endpoint, json=embed_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return EmbedResponse(**response.json()).embeddings
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f"Failed to get Embedding: {e}")
|
||||
raise
|
||||
|
||||
local_model = self.load_model()
|
||||
|
||||
if local_model is None:
|
||||
raise RuntimeError("Failed to load local Embedding Model")
|
||||
|
||||
return local_model.encode(
|
||||
texts, normalize_embeddings=normalize_embeddings
|
||||
).tolist()
|
||||
|
||||
|
||||
class CrossEncoderEnsembleModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
|
||||
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
self.model_names = model_names
|
||||
self.max_seq_length = max_seq_length
|
||||
self.rerank_server_endpoint = (
|
||||
f"http://{model_server_host}:{model_server_port}/encoder/cross-encoder-scores"
|
||||
if model_server_host
|
||||
else None
|
||||
)
|
||||
|
||||
def load_model(self) -> list[CrossEncoder] | None:
|
||||
if self.rerank_server_endpoint:
|
||||
return None
|
||||
|
||||
return get_local_reranking_model_ensemble(
|
||||
model_names=self.model_names, max_context_length=self.max_seq_length
|
||||
)
|
||||
|
||||
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
|
||||
if self.rerank_server_endpoint:
|
||||
rerank_request = RerankRequest(query=query, documents=passages)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.rerank_server_endpoint, json=rerank_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return RerankResponse(**response.json()).scores
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f"Failed to get Reranking Scores: {e}")
|
||||
raise
|
||||
|
||||
local_models = self.load_model()
|
||||
|
||||
if local_models is None:
|
||||
raise RuntimeError("Failed to load local Reranking Model Ensemble")
|
||||
|
||||
scores = [
|
||||
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
|
||||
for cross_encoder in local_models
|
||||
]
|
||||
|
||||
return scores
|
||||
|
||||
|
||||
class IntentModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = INTENT_MODEL_VERSION,
|
||||
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
|
||||
model_server_host: str | None = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
self.model_name = model_name
|
||||
self.max_seq_length = max_seq_length
|
||||
self.intent_server_endpoint = (
|
||||
f"http://{model_server_host}:{model_server_port}/custom/intent-model"
|
||||
if model_server_host
|
||||
else None
|
||||
)
|
||||
|
||||
def load_model(self) -> SentenceTransformer | None:
|
||||
if self.intent_server_endpoint:
|
||||
return None
|
||||
|
||||
return get_local_intent_model(
|
||||
model_name=self.model_name, max_context_length=self.max_seq_length
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
query: str,
|
||||
) -> list[float]:
|
||||
if self.intent_server_endpoint:
|
||||
intent_request = IntentRequest(query=query)
|
||||
|
||||
try:
|
||||
response = requests.post(
|
||||
self.intent_server_endpoint, json=intent_request.dict()
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
return IntentResponse(**response.json()).class_probs
|
||||
except requests.RequestException as e:
|
||||
logger.exception(f"Failed to get Embedding: {e}")
|
||||
raise
|
||||
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
local_model = self.load_model()
|
||||
|
||||
if local_model is None:
|
||||
raise RuntimeError("Failed to load local Intent Model")
|
||||
|
||||
intent_model = get_local_intent_model()
|
||||
model_input = tokenizer(
|
||||
query, return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
|
||||
predictions = intent_model(model_input)[0]
|
||||
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||
|
||||
return list(class_percentages.tolist()[0])
|
||||
|
@@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
import numpy
|
||||
from nltk.corpus import stopwords # type:ignore
|
||||
@@ -29,8 +30,8 @@ from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.search_nlp_models import get_default_embedding_model
|
||||
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble
|
||||
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
|
||||
from danswer.search.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -67,14 +68,11 @@ def embed_query(
|
||||
prefix: str = ASYM_QUERY_PREFIX,
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
) -> list[float]:
|
||||
model = embedding_model or get_default_embedding_model()
|
||||
model = embedding_model or EmbeddingModel()
|
||||
prefixed_query = prefix + query
|
||||
query_embedding = model.encode(
|
||||
prefixed_query, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
|
||||
if not isinstance(query_embedding, list):
|
||||
query_embedding = query_embedding.tolist()
|
||||
[prefixed_query], normalize_embeddings=normalize_embeddings
|
||||
)[0]
|
||||
|
||||
return query_embedding
|
||||
|
||||
@@ -104,6 +102,31 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
|
||||
return search_docs
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery, document_index: DocumentIndex
|
||||
) -> list[InferenceChunk]:
|
||||
if query.search_type == SearchType.KEYWORD:
|
||||
top_chunks = document_index.keyword_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
elif query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError("Invalid Search Flow")
|
||||
|
||||
return top_chunks
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def semantic_reranking(
|
||||
query: str,
|
||||
@@ -112,13 +135,13 @@ def semantic_reranking(
|
||||
model_min: int = CROSS_ENCODER_RANGE_MIN,
|
||||
model_max: int = CROSS_ENCODER_RANGE_MAX,
|
||||
) -> list[InferenceChunk]:
|
||||
cross_encoders = get_default_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
|
||||
for encoder in cross_encoders
|
||||
]
|
||||
cross_encoders = CrossEncoderEnsembleModel()
|
||||
passages = [chunk.content for chunk in chunks]
|
||||
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
|
||||
|
||||
raw_sim_scores = sum(sim_scores) / len(sim_scores)
|
||||
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
|
||||
|
||||
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
|
||||
|
||||
cross_models_min = numpy.min(sim_scores)
|
||||
|
||||
@@ -270,23 +293,7 @@ def search_chunks(
|
||||
]
|
||||
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
|
||||
|
||||
if query.search_type == SearchType.KEYWORD:
|
||||
top_chunks = document_index.keyword_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
elif query.search_type == SearchType.SEMANTIC:
|
||||
top_chunks = document_index.semantic_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
elif query.search_type == SearchType.HYBRID:
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query.query, query.filters, query.favor_recent, query.num_hits
|
||||
)
|
||||
|
||||
else:
|
||||
raise RuntimeError("Invalid Search Flow")
|
||||
top_chunks = doc_index_retrieval(query=query, document_index=document_index)
|
||||
|
||||
if not top_chunks:
|
||||
logger.info(
|
||||
|
0
backend/model_server/__init__.py
Normal file
0
backend/model_server/__init__.py
Normal file
40
backend/model_server/custom_models.py
Normal file
40
backend/model_server/custom_models.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import numpy as np
|
||||
import tensorflow as tf # type:ignore
|
||||
from fastapi import APIRouter
|
||||
|
||||
from danswer.search.search_nlp_models import get_intent_model_tokenizer
|
||||
from danswer.search.search_nlp_models import get_local_intent_model
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import IntentRequest
|
||||
from shared_models.model_server_models import IntentResponse
|
||||
|
||||
router = APIRouter(prefix="/custom")
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def classify_intent(query: str) -> list[float]:
|
||||
tokenizer = get_intent_model_tokenizer()
|
||||
intent_model = get_local_intent_model()
|
||||
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
|
||||
|
||||
predictions = intent_model(model_input)[0]
|
||||
probabilities = tf.nn.softmax(predictions, axis=-1)
|
||||
|
||||
class_percentages = np.round(probabilities.numpy() * 100, 2)
|
||||
return list(class_percentages.tolist()[0])
|
||||
|
||||
|
||||
@router.post("/intent-model")
|
||||
def process_intent_request(
|
||||
intent_request: IntentRequest,
|
||||
) -> IntentResponse:
|
||||
class_percentages = classify_intent(intent_request.query)
|
||||
return IntentResponse(class_probs=class_percentages)
|
||||
|
||||
|
||||
def warm_up_intent_model() -> None:
|
||||
intent_tokenizer = get_intent_model_tokenizer()
|
||||
inputs = intent_tokenizer(
|
||||
"danswer", return_tensors="tf", truncation=True, padding=True
|
||||
)
|
||||
get_local_intent_model()(inputs)
|
81
backend/model_server/encoders.py
Normal file
81
backend/model_server/encoders.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
|
||||
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.search.search_nlp_models import get_local_embedding_model
|
||||
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_models.model_server_models import EmbedRequest
|
||||
from shared_models.model_server_models import EmbedResponse
|
||||
from shared_models.model_server_models import RerankRequest
|
||||
from shared_models.model_server_models import RerankResponse
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
WARM_UP_STRING = "Danswer is amazing"
|
||||
|
||||
router = APIRouter(prefix="/encoder")
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
|
||||
) -> list[list[float]]:
|
||||
model = get_local_embedding_model()
|
||||
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
|
||||
|
||||
if not isinstance(embeddings, list):
|
||||
embeddings = embeddings.tolist()
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
sim_scores = [
|
||||
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
for encoder in cross_encoders
|
||||
]
|
||||
return sim_scores
|
||||
|
||||
|
||||
@router.post("/bi-encoder-embed")
|
||||
def process_embed_request(
|
||||
embed_request: EmbedRequest,
|
||||
) -> EmbedResponse:
|
||||
try:
|
||||
embeddings = embed_text(texts=embed_request.texts)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/cross-encoder-scores")
|
||||
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
|
||||
try:
|
||||
sim_scores = calc_sim_scores(
|
||||
query=embed_request.query, docs=embed_request.documents
|
||||
)
|
||||
return RerankResponse(scores=sim_scores)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def warm_up_bi_encoder() -> None:
|
||||
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
|
||||
get_local_embedding_model().encode(WARM_UP_STRING)
|
||||
|
||||
|
||||
def warm_up_cross_encoders() -> None:
|
||||
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
|
||||
|
||||
cross_encoders = get_local_reranking_model_ensemble()
|
||||
[
|
||||
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
|
||||
for cross_encoder in cross_encoders
|
||||
]
|
51
backend/model_server/main.py
Normal file
51
backend/model_server/main.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
|
||||
from danswer import __version__
|
||||
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
|
||||
from danswer.utils.logger import setup_logger
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.encoders import warm_up_bi_encoder
|
||||
from model_server.encoders import warm_up_cross_encoders
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_model_app() -> FastAPI:
|
||||
application = FastAPI(title="Danswer Model Server", version=__version__)
|
||||
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
@application.on_event("startup")
|
||||
def startup_event() -> None:
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
|
||||
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
warm_up_bi_encoder()
|
||||
warm_up_cross_encoders()
|
||||
warm_up_intent_model()
|
||||
|
||||
return application
|
||||
|
||||
|
||||
app = get_model_app()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logger.info(
|
||||
f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
|
||||
)
|
||||
logger.info(f"Model Server Version: {__version__}")
|
||||
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)
|
8
backend/requirements/model_server.txt
Normal file
8
backend/requirements/model_server.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
fastapi==0.103.0
|
||||
pydantic==1.10.7
|
||||
safetensors==0.3.1
|
||||
sentence-transformers==2.2.2
|
||||
tensorflow==2.13.0
|
||||
torch==2.0.1
|
||||
transformers==4.30.1
|
||||
uvicorn==0.21.1
|
0
backend/shared_models/__init__.py
Normal file
0
backend/shared_models/__init__.py
Normal file
26
backend/shared_models/model_server_models.py
Normal file
26
backend/shared_models/model_server_models.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class EmbedRequest(BaseModel):
|
||||
texts: list[str]
|
||||
|
||||
|
||||
class EmbedResponse(BaseModel):
|
||||
embeddings: list[list[float]]
|
||||
|
||||
|
||||
class RerankRequest(BaseModel):
|
||||
query: str
|
||||
documents: list[str]
|
||||
|
||||
|
||||
class RerankResponse(BaseModel):
|
||||
scores: list[list[float]]
|
||||
|
||||
|
||||
class IntentRequest(BaseModel):
|
||||
query: str
|
||||
|
||||
|
||||
class IntentResponse(BaseModel):
|
||||
class_probs: list[float]
|
@@ -42,6 +42,8 @@ services:
|
||||
- SKIP_RERANKING=${SKIP_RERANKING:-}
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
@@ -94,6 +96,8 @@ services:
|
||||
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
|
||||
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
@@ -157,6 +161,25 @@ services:
|
||||
/bin/sh -c "sleep 10 &&
|
||||
envsubst '$$\{DOMAIN\}' < /etc/nginx/conf.d/app.conf.template.dev > /etc/nginx/conf.d/app.conf &&
|
||||
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
|
||||
# Run with --profile model-server to bring up the danswer-model-server container
|
||||
model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
profiles:
|
||||
- "model-server"
|
||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||
restart: always
|
||||
environment:
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
volumes:
|
||||
local_dynamic_storage:
|
||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||
|
@@ -101,6 +101,25 @@ services:
|
||||
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
|
||||
env_file:
|
||||
- .env.nginx
|
||||
# Run with --profile model-server to bring up the danswer-model-server container
|
||||
model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
profiles:
|
||||
- "model-server"
|
||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||
restart: always
|
||||
environment:
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
volumes:
|
||||
local_dynamic_storage:
|
||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||
|
@@ -110,6 +110,25 @@ services:
|
||||
- ../data/certbot/conf:/etc/letsencrypt
|
||||
- ../data/certbot/www:/var/www/certbot
|
||||
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
|
||||
# Run with --profile model-server to bring up the danswer-model-server container
|
||||
model_server:
|
||||
image: danswer/danswer-model-server:latest
|
||||
build:
|
||||
context: ../../backend
|
||||
dockerfile: Dockerfile.model_server
|
||||
profiles:
|
||||
- "model-server"
|
||||
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
|
||||
restart: always
|
||||
environment:
|
||||
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
|
||||
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
|
||||
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
|
||||
# Set to debug to get more fine-grained logs
|
||||
- LOG_LEVEL=${LOG_LEVEL:-info}
|
||||
volumes:
|
||||
- model_cache_torch:/root/.cache/torch/
|
||||
- model_cache_huggingface:/root/.cache/huggingface/
|
||||
volumes:
|
||||
local_dynamic_storage:
|
||||
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them
|
||||
|
Reference in New Issue
Block a user