Model Server (#695)

Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
This commit is contained in:
Yuhong Sun
2023-11-06 16:36:09 -08:00
committed by GitHub
parent fe938b6fc6
commit 7433dddac3
20 changed files with 614 additions and 85 deletions

View File

@@ -0,0 +1,36 @@
name: Build and Push Backend Images on Tagging
on:
push:
tags:
- '*'
jobs:
build-and-push:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
build-args: |
DANSWER_VERSION: ${{ github.ref_name }}

View File

@@ -45,6 +45,7 @@ RUN apt-get remove -y linux-libc-dev && \
# Set up application files
WORKDIR /app
COPY ./danswer /app/danswer
COPY ./shared_models /app/shared_models
COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf

View File

@@ -0,0 +1,27 @@
FROM python:3.11.4-slim-bookworm
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.2-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
WORKDIR /app
# Needed for model configs and defaults
COPY ./danswer/configs /app/danswer/configs
# Utils used by model server
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
# Version information
COPY ./danswer/__init__.py /app/danswer/__init__.py
# Shared implementations for running NLP models locally
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
# Request/Response models
COPY ./shared_models /app/shared_models
# Model Server main code
COPY ./model_server /app/model_server
ENV PYTHONPATH /app
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -13,6 +13,7 @@ from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.db.connector import fetch_connectors
@@ -290,7 +291,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
if __name__ == "__main__":
logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True)
if not MODEL_SERVER_HOST:
logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True)
logger.info("Starting Indexing Loop")
update_loop()

View File

@@ -3,6 +3,7 @@ import os
from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentIndexType
#####
# App Configs
#####
@@ -19,6 +20,7 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# Use this if you want to use Danswer as a search engine only without the LLM capabilities
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
#####
# Web Configs
#####
@@ -56,7 +58,6 @@ VALID_EMAIL_DOMAINS = (
if _VALID_EMAIL_DOMAINS_STR
else []
)
# OAuth Login Flow
# Used for both Google OAuth2 and OIDC flows
OAUTH_CLIENT_ID = (
@@ -200,12 +201,13 @@ MINI_CHUNK_SIZE = 150
#####
# Encoder Model Endpoint Configs (Currently unused, running the models in memory)
# Model Server Configs
#####
BI_ENCODER_HOST = "localhost"
BI_ENCODER_PORT = 9000
CROSS_ENCODER_HOST = "localhost"
CROSS_ENCODER_PORT = 9000
# If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
# requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
#####

View File

@@ -1,16 +1,14 @@
import numpy
from sentence_transformers import SentenceTransformer # type: ignore
from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
from danswer.search.models import Embedder
from danswer.search.search_nlp_models import get_default_embedding_model
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.timing import log_function_time
@@ -24,7 +22,7 @@ def encode_chunks(
) -> list[IndexChunk]:
embedded_chunks: list[IndexChunk] = []
if embedding_model is None:
embedding_model = get_default_embedding_model()
embedding_model = EmbeddingModel()
chunk_texts = []
chunk_mini_chunks_count = {}
@@ -43,15 +41,10 @@ def encode_chunks(
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
]
embeddings_np: list[numpy.ndarray] = []
embeddings: list[list[float]] = []
for text_batch in text_batches:
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
embeddings_np.extend(
embedding_model.encode(
text_batch, normalize_embeddings=NORMALIZE_EMBEDDINGS
)
)
embeddings: list[list[float]] = [embedding.tolist() for embedding in embeddings_np]
embeddings.extend(embedding_model.encode(text_batch))
embedding_ind_start = 0
for chunk_ind, chunk in enumerate(chunks):

View File

@@ -1,4 +1,5 @@
import nltk # type:ignore
import torch
import uvicorn
from fastapi import FastAPI
from fastapi import Request
@@ -7,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from danswer import __version__
from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate
@@ -17,6 +19,8 @@ from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import SECRET
@@ -72,7 +76,7 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
def get_application() -> FastAPI:
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1")
application = FastAPI(title="Danswer Backend", version=__version__)
application.include_router(backend_router)
application.include_router(chat_router)
application.include_router(event_processing_router)
@@ -176,11 +180,23 @@ def get_application() -> FastAPI:
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
logger.info("Warming up local NLP models.")
warm_up_models()
qa_model = get_default_qa_model()
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_models()
# This is for the LLM, most LLMs will not need warming up
qa_model.warm_up_model()
# It logs for itself
get_default_qa_model().warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)

View File

@@ -1,12 +1,9 @@
import numpy as np
import tensorflow as tf # type:ignore
from transformers import AutoTokenizer # type:ignore
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import get_default_intent_model
from danswer.search.search_nlp_models import get_default_intent_model_tokenizer
from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.search.search_nlp_models import IntentModel
from danswer.search.search_runner import remove_stop_words
from danswer.server.models import HelperResponse
from danswer.utils.logger import setup_logger
@@ -28,15 +25,11 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
@log_function_time()
def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
tokenizer = get_default_intent_model_tokenizer()
intent_model = get_default_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
keyword, semantic, qa = class_percentages.tolist()[0]
intent_model = IntentModel()
class_probs = intent_model.predict(query)
keyword = class_probs[0]
semantic = class_probs[1]
qa = class_probs[2]
# Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question
if qa > 20:

View File

@@ -1,15 +1,30 @@
import numpy as np
import requests
import tensorflow as tf # type: ignore
from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.configs.model_configs import SKIP_RERANKING
from danswer.utils.logger import setup_logger
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
logger = setup_logger()
_TOKENIZER: None | AutoTokenizer = None
@@ -26,39 +41,46 @@ def get_default_tokenizer() -> AutoTokenizer:
return _TOKENIZER
def get_default_embedding_model() -> SentenceTransformer:
def get_local_embedding_model(
model_name: str = DOCUMENT_ENCODER_MODEL,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> SentenceTransformer:
global _EMBED_MODEL
if _EMBED_MODEL is None:
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL)
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE
_EMBED_MODEL = SentenceTransformer(model_name)
_EMBED_MODEL.max_seq_length = max_context_length
return _EMBED_MODEL
def get_default_reranking_model_ensemble() -> list[CrossEncoder]:
def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
global _RERANK_MODELS
if _RERANK_MODELS is None:
_RERANK_MODELS = [
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
]
_RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names]
for model in _RERANK_MODELS:
model.max_length = CROSS_EMBED_CONTEXT_SIZE
model.max_length = max_context_length
return _RERANK_MODELS
def get_default_intent_model_tokenizer() -> AutoTokenizer:
def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTokenizer:
global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION)
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER
def get_default_intent_model() -> TFDistilBertForSequenceClassification:
def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL
if _INTENT_MODEL is None:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
INTENT_MODEL_VERSION
model_name
)
_INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE
_INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL
@@ -67,20 +89,183 @@ def warm_up_models(
) -> None:
warm_up_str = "Danswer is amazing"
get_default_tokenizer()(warm_up_str)
get_default_embedding_model().encode(warm_up_str)
get_local_embedding_model().encode(warm_up_str)
if indexer_only:
return
if not skip_cross_encoders:
cross_encoders = get_default_reranking_model_ensemble()
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((warm_up_str, warm_up_str))
for cross_encoder in cross_encoders
]
intent_tokenizer = get_default_intent_model_tokenizer()
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
warm_up_str, return_tensors="tf", truncation=True, padding=True
)
get_default_intent_model()(inputs)
get_local_intent_model()(inputs)
class EmbeddingModel:
def __init__(
self,
model_name: str = DOCUMENT_ENCODER_MODEL,
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_name = model_name
self.max_seq_length = max_seq_length
self.embed_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/encoder/bi-encoder-embed"
if model_server_host
else None
)
def load_model(self) -> SentenceTransformer | None:
if self.embed_server_endpoint:
return None
return get_local_embedding_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
def encode(
self, texts: list[str], normalize_embeddings: bool = NORMALIZE_EMBEDDINGS
) -> list[list[float]]:
if self.embed_server_endpoint:
embed_request = EmbedRequest(texts=texts)
try:
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
response.raise_for_status()
return EmbedResponse(**response.json()).embeddings
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Embedding Model")
return local_model.encode(
texts, normalize_embeddings=normalize_embeddings
).tolist()
class CrossEncoderEnsembleModel:
def __init__(
self,
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_names = model_names
self.max_seq_length = max_seq_length
self.rerank_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/encoder/cross-encoder-scores"
if model_server_host
else None
)
def load_model(self) -> list[CrossEncoder] | None:
if self.rerank_server_endpoint:
return None
return get_local_reranking_model_ensemble(
model_names=self.model_names, max_context_length=self.max_seq_length
)
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
if self.rerank_server_endpoint:
rerank_request = RerankRequest(query=query, documents=passages)
try:
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
except requests.RequestException as e:
logger.exception(f"Failed to get Reranking Scores: {e}")
raise
local_models = self.load_model()
if local_models is None:
raise RuntimeError("Failed to load local Reranking Model Ensemble")
scores = [
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
for cross_encoder in local_models
]
return scores
class IntentModel:
def __init__(
self,
model_name: str = INTENT_MODEL_VERSION,
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_name = model_name
self.max_seq_length = max_seq_length
self.intent_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/custom/intent-model"
if model_server_host
else None
)
def load_model(self) -> SentenceTransformer | None:
if self.intent_server_endpoint:
return None
return get_local_intent_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
def predict(
self,
query: str,
) -> list[float]:
if self.intent_server_endpoint:
intent_request = IntentRequest(query=query)
try:
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
)
response.raise_for_status()
return IntentResponse(**response.json()).class_probs
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
tokenizer = get_intent_model_tokenizer()
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Intent Model")
intent_model = get_local_intent_model()
model_input = tokenizer(
query, return_tensors="tf", truncation=True, padding=True
)
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])

View File

@@ -1,4 +1,5 @@
from collections.abc import Callable
from typing import cast
import numpy
from nltk.corpus import stopwords # type:ignore
@@ -29,8 +30,8 @@ from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.models import SearchType
from danswer.search.search_nlp_models import get_default_embedding_model
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble
from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import EmbeddingModel
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
from danswer.utils.logger import setup_logger
@@ -67,14 +68,11 @@ def embed_query(
prefix: str = ASYM_QUERY_PREFIX,
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[float]:
model = embedding_model or get_default_embedding_model()
model = embedding_model or EmbeddingModel()
prefixed_query = prefix + query
query_embedding = model.encode(
prefixed_query, normalize_embeddings=normalize_embeddings
)
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
[prefixed_query], normalize_embeddings=normalize_embeddings
)[0]
return query_embedding
@@ -104,6 +102,31 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
return search_docs
@log_function_time()
def doc_index_retrieval(
query: SearchQuery, document_index: DocumentIndex
) -> list[InferenceChunk]:
if query.search_type == SearchType.KEYWORD:
top_chunks = document_index.keyword_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
else:
raise RuntimeError("Invalid Search Flow")
return top_chunks
@log_function_time()
def semantic_reranking(
query: str,
@@ -112,13 +135,13 @@ def semantic_reranking(
model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX,
) -> list[InferenceChunk]:
cross_encoders = get_default_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore
for encoder in cross_encoders
]
cross_encoders = CrossEncoderEnsembleModel()
passages = [chunk.content for chunk in chunks]
sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
raw_sim_scores = sum(sim_scores) / len(sim_scores)
sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
cross_models_min = numpy.min(sim_scores)
@@ -270,23 +293,7 @@ def search_chunks(
]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
if query.search_type == SearchType.KEYWORD:
top_chunks = document_index.keyword_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
else:
raise RuntimeError("Invalid Search Flow")
top_chunks = doc_index_retrieval(query=query, document_index=document_index)
if not top_chunks:
logger.info(

View File

View File

@@ -0,0 +1,40 @@
import numpy as np
import tensorflow as tf # type:ignore
from fastapi import APIRouter
from danswer.search.search_nlp_models import get_intent_model_tokenizer
from danswer.search.search_nlp_models import get_local_intent_model
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
router = APIRouter(prefix="/custom")
@log_function_time()
def classify_intent(query: str) -> list[float]:
tokenizer = get_intent_model_tokenizer()
intent_model = get_local_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])
@router.post("/intent-model")
def process_intent_request(
intent_request: IntentRequest,
) -> IntentResponse:
class_percentages = classify_intent(intent_request.query)
return IntentResponse(class_probs=class_percentages)
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
"danswer", return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)

View File

@@ -0,0 +1,81 @@
from fastapi import APIRouter
from fastapi import HTTPException
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.search.search_nlp_models import get_local_embedding_model
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
logger = setup_logger()
WARM_UP_STRING = "Danswer is amazing"
router = APIRouter(prefix="/encoder")
@log_function_time()
def embed_text(
texts: list[str],
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[list[float]]:
model = get_local_embedding_model()
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
for encoder in cross_encoders
]
return sim_scores
@router.post("/bi-encoder-embed")
def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
try:
embeddings = embed_text(texts=embed_request.texts)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cross-encoder-scores")
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents
)
return RerankResponse(scores=sim_scores)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def warm_up_bi_encoder() -> None:
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
get_local_embedding_model().encode(WARM_UP_STRING)
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
for cross_encoder in cross_encoders
]

View File

@@ -0,0 +1,51 @@
import torch
import uvicorn
from fastapi import FastAPI
from danswer import __version__
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_bi_encoder
from model_server.encoders import warm_up_cross_encoders
logger = setup_logger()
def get_model_app() -> FastAPI:
application = FastAPI(title="Danswer Model Server", version=__version__)
application.include_router(encoders_router)
application.include_router(custom_models_router)
@application.on_event("startup")
def startup_event() -> None:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_bi_encoder()
warm_up_cross_encoders()
warm_up_intent_model()
return application
app = get_model_app()
if __name__ == "__main__":
logger.info(
f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
)
logger.info(f"Model Server Version: {__version__}")
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)

View File

@@ -0,0 +1,8 @@
fastapi==0.103.0
pydantic==1.10.7
safetensors==0.3.1
sentence-transformers==2.2.2
tensorflow==2.13.0
torch==2.0.1
transformers==4.30.1
uvicorn==0.21.1

View File

View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel
class EmbedRequest(BaseModel):
texts: list[str]
class EmbedResponse(BaseModel):
embeddings: list[list[float]]
class RerankRequest(BaseModel):
query: str
documents: list[str]
class RerankResponse(BaseModel):
scores: list[list[float]]
class IntentRequest(BaseModel):
query: str
class IntentResponse(BaseModel):
class_probs: list[float]

View File

@@ -42,6 +42,8 @@ services:
- SKIP_RERANKING=${SKIP_RERANKING:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
@@ -94,6 +96,8 @@ services:
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
@@ -157,6 +161,25 @@ services:
/bin/sh -c "sleep 10 &&
envsubst '$$\{DOMAIN\}' < /etc/nginx/conf.d/app.conf.template.dev > /etc/nginx/conf.d/app.conf &&
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
volumes:
local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@@ -101,6 +101,25 @@ services:
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
env_file:
- .env.nginx
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
volumes:
local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@@ -110,6 +110,25 @@ services:
- ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
volumes:
local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them