Model Server (#695)

Provides the ability to pull out the NLP models into a separate model server which can then be hosted on a GPU instance if desired.
This commit is contained in:
Yuhong Sun
2023-11-06 16:36:09 -08:00
committed by GitHub
parent fe938b6fc6
commit 7433dddac3
20 changed files with 614 additions and 85 deletions

View File

@@ -0,0 +1,36 @@
name: Build and Push Backend Images on Tagging
on:
push:
tags:
- '*'
jobs:
build-and-push:
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v1
- name: Login to Docker Hub
uses: docker/login-action@v1
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Backend Image Docker Build and Push
uses: docker/build-push-action@v2
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64,linux/arm64
push: true
tags: |
danswer/danswer-model-server:${{ github.ref_name }}
danswer/danswer-model-server:latest
build-args: |
DANSWER_VERSION: ${{ github.ref_name }}

View File

@@ -45,6 +45,7 @@ RUN apt-get remove -y linux-libc-dev && \
# Set up application files # Set up application files
WORKDIR /app WORKDIR /app
COPY ./danswer /app/danswer COPY ./danswer /app/danswer
COPY ./shared_models /app/shared_models
COPY ./alembic /app/alembic COPY ./alembic /app/alembic
COPY ./alembic.ini /app/alembic.ini COPY ./alembic.ini /app/alembic.ini
COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf

View File

@@ -0,0 +1,27 @@
FROM python:3.11.4-slim-bookworm
# Default DANSWER_VERSION, typically overriden during builds by GitHub Actions.
ARG DANSWER_VERSION=0.2-dev
ENV DANSWER_VERSION=${DANSWER_VERSION}
COPY ./requirements/model_server.txt /tmp/requirements.txt
RUN pip install --no-cache-dir --upgrade -r /tmp/requirements.txt
WORKDIR /app
# Needed for model configs and defaults
COPY ./danswer/configs /app/danswer/configs
# Utils used by model server
COPY ./danswer/utils/logger.py /app/danswer/utils/logger.py
COPY ./danswer/utils/timing.py /app/danswer/utils/timing.py
# Version information
COPY ./danswer/__init__.py /app/danswer/__init__.py
# Shared implementations for running NLP models locally
COPY ./danswer/search/search_nlp_models.py /app/danswer/search/search_nlp_models.py
# Request/Response models
COPY ./shared_models /app/shared_models
# Model Server main code
COPY ./model_server /app/model_server
ENV PYTHONPATH /app
CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"]

View File

@@ -13,6 +13,7 @@ from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED from danswer.configs.app_configs import EXPERIMENTAL_SIMPLE_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.db.connector import fetch_connectors from danswer.db.connector import fetch_connectors
@@ -290,7 +291,8 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non
if __name__ == "__main__": if __name__ == "__main__":
logger.info("Warming up Embedding Model(s)") if not MODEL_SERVER_HOST:
warm_up_models(indexer_only=True) logger.info("Warming up Embedding Model(s)")
warm_up_models(indexer_only=True)
logger.info("Starting Indexing Loop") logger.info("Starting Indexing Loop")
update_loop() update_loop()

View File

@@ -3,6 +3,7 @@ import os
from danswer.configs.constants import AuthType from danswer.configs.constants import AuthType
from danswer.configs.constants import DocumentIndexType from danswer.configs.constants import DocumentIndexType
##### #####
# App Configs # App Configs
##### #####
@@ -19,6 +20,7 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = 86400 # 1 day
# Use this if you want to use Danswer as a search engine only without the LLM capabilities # Use this if you want to use Danswer as a search engine only without the LLM capabilities
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true" DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
##### #####
# Web Configs # Web Configs
##### #####
@@ -56,7 +58,6 @@ VALID_EMAIL_DOMAINS = (
if _VALID_EMAIL_DOMAINS_STR if _VALID_EMAIL_DOMAINS_STR
else [] else []
) )
# OAuth Login Flow # OAuth Login Flow
# Used for both Google OAuth2 and OIDC flows # Used for both Google OAuth2 and OIDC flows
OAUTH_CLIENT_ID = ( OAUTH_CLIENT_ID = (
@@ -200,12 +201,13 @@ MINI_CHUNK_SIZE = 150
##### #####
# Encoder Model Endpoint Configs (Currently unused, running the models in memory) # Model Server Configs
##### #####
BI_ENCODER_HOST = "localhost" # If MODEL_SERVER_HOST is set, the NLP models required for Danswer are offloaded to the server via
BI_ENCODER_PORT = 9000 # requests. Be sure to include the scheme in the MODEL_SERVER_HOST value.
CROSS_ENCODER_HOST = "localhost" MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or None
CROSS_ENCODER_PORT = 9000 MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
##### #####

View File

@@ -1,16 +1,14 @@
import numpy
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
from danswer.configs.app_configs import ENABLE_MINI_CHUNK from danswer.configs.app_configs import ENABLE_MINI_CHUNK
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.indexing.chunker import split_chunk_text_into_mini_chunks from danswer.indexing.chunker import split_chunk_text_into_mini_chunks
from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk from danswer.indexing.models import IndexChunk
from danswer.search.models import Embedder from danswer.search.models import Embedder
from danswer.search.search_nlp_models import get_default_embedding_model from danswer.search.search_nlp_models import EmbeddingModel
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
@@ -24,7 +22,7 @@ def encode_chunks(
) -> list[IndexChunk]: ) -> list[IndexChunk]:
embedded_chunks: list[IndexChunk] = [] embedded_chunks: list[IndexChunk] = []
if embedding_model is None: if embedding_model is None:
embedding_model = get_default_embedding_model() embedding_model = EmbeddingModel()
chunk_texts = [] chunk_texts = []
chunk_mini_chunks_count = {} chunk_mini_chunks_count = {}
@@ -43,15 +41,10 @@ def encode_chunks(
chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size) chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size)
] ]
embeddings_np: list[numpy.ndarray] = [] embeddings: list[list[float]] = []
for text_batch in text_batches: for text_batch in text_batches:
# Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss # Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss
embeddings_np.extend( embeddings.extend(embedding_model.encode(text_batch))
embedding_model.encode(
text_batch, normalize_embeddings=NORMALIZE_EMBEDDINGS
)
)
embeddings: list[list[float]] = [embedding.tolist() for embedding in embeddings_np]
embedding_ind_start = 0 embedding_ind_start = 0
for chunk_ind, chunk in enumerate(chunks): for chunk_ind, chunk in enumerate(chunks):

View File

@@ -1,4 +1,5 @@
import nltk # type:ignore import nltk # type:ignore
import torch
import uvicorn import uvicorn
from fastapi import FastAPI from fastapi import FastAPI
from fastapi import Request from fastapi import Request
@@ -7,6 +8,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.google import GoogleOAuth2
from danswer import __version__
from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserCreate
from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserRead
from danswer.auth.schemas import UserUpdate from danswer.auth.schemas import UserUpdate
@@ -17,6 +19,8 @@ from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import SECRET
@@ -72,7 +76,7 @@ def value_error_handler(_: Request, exc: ValueError) -> JSONResponse:
def get_application() -> FastAPI: def get_application() -> FastAPI:
application = FastAPI(title="Internal Search QA Backend", debug=True, version="0.1") application = FastAPI(title="Danswer Backend", version=__version__)
application.include_router(backend_router) application.include_router(backend_router)
application.include_router(chat_router) application.include_router(chat_router)
application.include_router(event_processing_router) application.include_router(event_processing_router)
@@ -176,11 +180,23 @@ def get_application() -> FastAPI:
logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"') logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"')
logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"') logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"')
logger.info("Warming up local NLP models.") if MODEL_SERVER_HOST:
warm_up_models() logger.info(
qa_model = get_default_qa_model() f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_models()
# This is for the LLM, most LLMs will not need warming up # This is for the LLM, most LLMs will not need warming up
qa_model.warm_up_model() # It logs for itself
get_default_qa_model().warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded") logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True) nltk.download("stopwords", quiet=True)

View File

@@ -1,12 +1,9 @@
import numpy as np
import tensorflow as tf # type:ignore
from transformers import AutoTokenizer # type:ignore from transformers import AutoTokenizer # type:ignore
from danswer.search.models import QueryFlow from danswer.search.models import QueryFlow
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.search.search_nlp_models import get_default_intent_model
from danswer.search.search_nlp_models import get_default_intent_model_tokenizer
from danswer.search.search_nlp_models import get_default_tokenizer from danswer.search.search_nlp_models import get_default_tokenizer
from danswer.search.search_nlp_models import IntentModel
from danswer.search.search_runner import remove_stop_words from danswer.search.search_runner import remove_stop_words
from danswer.server.models import HelperResponse from danswer.server.models import HelperResponse
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -28,15 +25,11 @@ def count_unk_tokens(text: str, tokenizer: AutoTokenizer) -> int:
@log_function_time() @log_function_time()
def query_intent(query: str) -> tuple[SearchType, QueryFlow]: def query_intent(query: str) -> tuple[SearchType, QueryFlow]:
tokenizer = get_default_intent_model_tokenizer() intent_model = IntentModel()
intent_model = get_default_intent_model() class_probs = intent_model.predict(query)
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True) keyword = class_probs[0]
semantic = class_probs[1]
predictions = intent_model(model_input)[0] qa = class_probs[2]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
keyword, semantic, qa = class_percentages.tolist()[0]
# Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question # Heavily bias towards QA, from user perspective, answering a statement is not as bad as not answering a question
if qa > 20: if qa > 20:

View File

@@ -1,15 +1,30 @@
import numpy as np
import requests
import tensorflow as tf # type: ignore
from sentence_transformers import CrossEncoder # type: ignore from sentence_transformers import CrossEncoder # type: ignore
from sentence_transformers import SentenceTransformer # type: ignore from sentence_transformers import SentenceTransformer # type: ignore
from transformers import AutoTokenizer # type: ignore from transformers import AutoTokenizer # type: ignore
from transformers import TFDistilBertForSequenceClassification # type: ignore from transformers import TFDistilBertForSequenceClassification # type: ignore
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import INTENT_MODEL_VERSION from danswer.configs.model_configs import INTENT_MODEL_VERSION
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE
from danswer.configs.model_configs import SKIP_RERANKING from danswer.configs.model_configs import SKIP_RERANKING
from danswer.utils.logger import setup_logger
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
logger = setup_logger()
_TOKENIZER: None | AutoTokenizer = None _TOKENIZER: None | AutoTokenizer = None
@@ -26,39 +41,46 @@ def get_default_tokenizer() -> AutoTokenizer:
return _TOKENIZER return _TOKENIZER
def get_default_embedding_model() -> SentenceTransformer: def get_local_embedding_model(
model_name: str = DOCUMENT_ENCODER_MODEL,
max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> SentenceTransformer:
global _EMBED_MODEL global _EMBED_MODEL
if _EMBED_MODEL is None: if _EMBED_MODEL is None:
_EMBED_MODEL = SentenceTransformer(DOCUMENT_ENCODER_MODEL) _EMBED_MODEL = SentenceTransformer(model_name)
_EMBED_MODEL.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE _EMBED_MODEL.max_seq_length = max_context_length
return _EMBED_MODEL return _EMBED_MODEL
def get_default_reranking_model_ensemble() -> list[CrossEncoder]: def get_local_reranking_model_ensemble(
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_context_length: int = CROSS_EMBED_CONTEXT_SIZE,
) -> list[CrossEncoder]:
global _RERANK_MODELS global _RERANK_MODELS
if _RERANK_MODELS is None: if _RERANK_MODELS is None:
_RERANK_MODELS = [ _RERANK_MODELS = [CrossEncoder(model_name) for model_name in model_names]
CrossEncoder(model_name) for model_name in CROSS_ENCODER_MODEL_ENSEMBLE
]
for model in _RERANK_MODELS: for model in _RERANK_MODELS:
model.max_length = CROSS_EMBED_CONTEXT_SIZE model.max_length = max_context_length
return _RERANK_MODELS return _RERANK_MODELS
def get_default_intent_model_tokenizer() -> AutoTokenizer: def get_intent_model_tokenizer(model_name: str = INTENT_MODEL_VERSION) -> AutoTokenizer:
global _INTENT_TOKENIZER global _INTENT_TOKENIZER
if _INTENT_TOKENIZER is None: if _INTENT_TOKENIZER is None:
_INTENT_TOKENIZER = AutoTokenizer.from_pretrained(INTENT_MODEL_VERSION) _INTENT_TOKENIZER = AutoTokenizer.from_pretrained(model_name)
return _INTENT_TOKENIZER return _INTENT_TOKENIZER
def get_default_intent_model() -> TFDistilBertForSequenceClassification: def get_local_intent_model(
model_name: str = INTENT_MODEL_VERSION,
max_context_length: int = QUERY_MAX_CONTEXT_SIZE,
) -> TFDistilBertForSequenceClassification:
global _INTENT_MODEL global _INTENT_MODEL
if _INTENT_MODEL is None: if _INTENT_MODEL is None:
_INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained( _INTENT_MODEL = TFDistilBertForSequenceClassification.from_pretrained(
INTENT_MODEL_VERSION model_name
) )
_INTENT_MODEL.max_seq_length = QUERY_MAX_CONTEXT_SIZE _INTENT_MODEL.max_seq_length = max_context_length
return _INTENT_MODEL return _INTENT_MODEL
@@ -67,20 +89,183 @@ def warm_up_models(
) -> None: ) -> None:
warm_up_str = "Danswer is amazing" warm_up_str = "Danswer is amazing"
get_default_tokenizer()(warm_up_str) get_default_tokenizer()(warm_up_str)
get_default_embedding_model().encode(warm_up_str) get_local_embedding_model().encode(warm_up_str)
if indexer_only: if indexer_only:
return return
if not skip_cross_encoders: if not skip_cross_encoders:
cross_encoders = get_default_reranking_model_ensemble() cross_encoders = get_local_reranking_model_ensemble()
[ [
cross_encoder.predict((warm_up_str, warm_up_str)) cross_encoder.predict((warm_up_str, warm_up_str))
for cross_encoder in cross_encoders for cross_encoder in cross_encoders
] ]
intent_tokenizer = get_default_intent_model_tokenizer() intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer( inputs = intent_tokenizer(
warm_up_str, return_tensors="tf", truncation=True, padding=True warm_up_str, return_tensors="tf", truncation=True, padding=True
) )
get_default_intent_model()(inputs) get_local_intent_model()(inputs)
class EmbeddingModel:
def __init__(
self,
model_name: str = DOCUMENT_ENCODER_MODEL,
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_name = model_name
self.max_seq_length = max_seq_length
self.embed_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/encoder/bi-encoder-embed"
if model_server_host
else None
)
def load_model(self) -> SentenceTransformer | None:
if self.embed_server_endpoint:
return None
return get_local_embedding_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
def encode(
self, texts: list[str], normalize_embeddings: bool = NORMALIZE_EMBEDDINGS
) -> list[list[float]]:
if self.embed_server_endpoint:
embed_request = EmbedRequest(texts=texts)
try:
response = requests.post(
self.embed_server_endpoint, json=embed_request.dict()
)
response.raise_for_status()
return EmbedResponse(**response.json()).embeddings
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Embedding Model")
return local_model.encode(
texts, normalize_embeddings=normalize_embeddings
).tolist()
class CrossEncoderEnsembleModel:
def __init__(
self,
model_names: list[str] = CROSS_ENCODER_MODEL_ENSEMBLE,
max_seq_length: int = CROSS_EMBED_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_names = model_names
self.max_seq_length = max_seq_length
self.rerank_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/encoder/cross-encoder-scores"
if model_server_host
else None
)
def load_model(self) -> list[CrossEncoder] | None:
if self.rerank_server_endpoint:
return None
return get_local_reranking_model_ensemble(
model_names=self.model_names, max_context_length=self.max_seq_length
)
def predict(self, query: str, passages: list[str]) -> list[list[float]]:
if self.rerank_server_endpoint:
rerank_request = RerankRequest(query=query, documents=passages)
try:
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.dict()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
except requests.RequestException as e:
logger.exception(f"Failed to get Reranking Scores: {e}")
raise
local_models = self.load_model()
if local_models is None:
raise RuntimeError("Failed to load local Reranking Model Ensemble")
scores = [
cross_encoder.predict([(query, passage) for passage in passages]).tolist() # type: ignore
for cross_encoder in local_models
]
return scores
class IntentModel:
def __init__(
self,
model_name: str = INTENT_MODEL_VERSION,
max_seq_length: int = QUERY_MAX_CONTEXT_SIZE,
model_server_host: str | None = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
self.model_name = model_name
self.max_seq_length = max_seq_length
self.intent_server_endpoint = (
f"http://{model_server_host}:{model_server_port}/custom/intent-model"
if model_server_host
else None
)
def load_model(self) -> SentenceTransformer | None:
if self.intent_server_endpoint:
return None
return get_local_intent_model(
model_name=self.model_name, max_context_length=self.max_seq_length
)
def predict(
self,
query: str,
) -> list[float]:
if self.intent_server_endpoint:
intent_request = IntentRequest(query=query)
try:
response = requests.post(
self.intent_server_endpoint, json=intent_request.dict()
)
response.raise_for_status()
return IntentResponse(**response.json()).class_probs
except requests.RequestException as e:
logger.exception(f"Failed to get Embedding: {e}")
raise
tokenizer = get_intent_model_tokenizer()
local_model = self.load_model()
if local_model is None:
raise RuntimeError("Failed to load local Intent Model")
intent_model = get_local_intent_model()
model_input = tokenizer(
query, return_tensors="tf", truncation=True, padding=True
)
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])

View File

@@ -1,4 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from typing import cast
import numpy import numpy
from nltk.corpus import stopwords # type:ignore from nltk.corpus import stopwords # type:ignore
@@ -29,8 +30,8 @@ from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery from danswer.search.models import SearchQuery
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.search.search_nlp_models import get_default_embedding_model from danswer.search.search_nlp_models import CrossEncoderEnsembleModel
from danswer.search.search_nlp_models import get_default_reranking_model_ensemble from danswer.search.search_nlp_models import EmbeddingModel
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc from danswer.server.models import SearchDoc
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -67,14 +68,11 @@ def embed_query(
prefix: str = ASYM_QUERY_PREFIX, prefix: str = ASYM_QUERY_PREFIX,
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS, normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[float]: ) -> list[float]:
model = embedding_model or get_default_embedding_model() model = embedding_model or EmbeddingModel()
prefixed_query = prefix + query prefixed_query = prefix + query
query_embedding = model.encode( query_embedding = model.encode(
prefixed_query, normalize_embeddings=normalize_embeddings [prefixed_query], normalize_embeddings=normalize_embeddings
) )[0]
if not isinstance(query_embedding, list):
query_embedding = query_embedding.tolist()
return query_embedding return query_embedding
@@ -104,6 +102,31 @@ def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc
return search_docs return search_docs
@log_function_time()
def doc_index_retrieval(
query: SearchQuery, document_index: DocumentIndex
) -> list[InferenceChunk]:
if query.search_type == SearchType.KEYWORD:
top_chunks = document_index.keyword_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
else:
raise RuntimeError("Invalid Search Flow")
return top_chunks
@log_function_time() @log_function_time()
def semantic_reranking( def semantic_reranking(
query: str, query: str,
@@ -112,13 +135,13 @@ def semantic_reranking(
model_min: int = CROSS_ENCODER_RANGE_MIN, model_min: int = CROSS_ENCODER_RANGE_MIN,
model_max: int = CROSS_ENCODER_RANGE_MAX, model_max: int = CROSS_ENCODER_RANGE_MAX,
) -> list[InferenceChunk]: ) -> list[InferenceChunk]:
cross_encoders = get_default_reranking_model_ensemble() cross_encoders = CrossEncoderEnsembleModel()
sim_scores = [ passages = [chunk.content for chunk in chunks]
encoder.predict([(query, chunk.content) for chunk in chunks]) # type: ignore sim_scores_floats = cross_encoders.predict(query=query, passages=passages)
for encoder in cross_encoders
]
raw_sim_scores = sum(sim_scores) / len(sim_scores) sim_scores = [numpy.array(scores) for scores in sim_scores_floats]
raw_sim_scores = cast(numpy.ndarray, sum(sim_scores) / len(sim_scores))
cross_models_min = numpy.min(sim_scores) cross_models_min = numpy.min(sim_scores)
@@ -270,23 +293,7 @@ def search_chunks(
] ]
logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}") logger.info(f"Top links from {search_flow} search: {', '.join(top_links)}")
if query.search_type == SearchType.KEYWORD: top_chunks = doc_index_retrieval(query=query, document_index=document_index)
top_chunks = document_index.keyword_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.SEMANTIC:
top_chunks = document_index.semantic_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
elif query.search_type == SearchType.HYBRID:
top_chunks = document_index.hybrid_retrieval(
query.query, query.filters, query.favor_recent, query.num_hits
)
else:
raise RuntimeError("Invalid Search Flow")
if not top_chunks: if not top_chunks:
logger.info( logger.info(

View File

View File

@@ -0,0 +1,40 @@
import numpy as np
import tensorflow as tf # type:ignore
from fastapi import APIRouter
from danswer.search.search_nlp_models import get_intent_model_tokenizer
from danswer.search.search_nlp_models import get_local_intent_model
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import IntentRequest
from shared_models.model_server_models import IntentResponse
router = APIRouter(prefix="/custom")
@log_function_time()
def classify_intent(query: str) -> list[float]:
tokenizer = get_intent_model_tokenizer()
intent_model = get_local_intent_model()
model_input = tokenizer(query, return_tensors="tf", truncation=True, padding=True)
predictions = intent_model(model_input)[0]
probabilities = tf.nn.softmax(predictions, axis=-1)
class_percentages = np.round(probabilities.numpy() * 100, 2)
return list(class_percentages.tolist()[0])
@router.post("/intent-model")
def process_intent_request(
intent_request: IntentRequest,
) -> IntentResponse:
class_percentages = classify_intent(intent_request.query)
return IntentResponse(class_probs=class_percentages)
def warm_up_intent_model() -> None:
intent_tokenizer = get_intent_model_tokenizer()
inputs = intent_tokenizer(
"danswer", return_tensors="tf", truncation=True, padding=True
)
get_local_intent_model()(inputs)

View File

@@ -0,0 +1,81 @@
from fastapi import APIRouter
from fastapi import HTTPException
from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.search.search_nlp_models import get_local_embedding_model
from danswer.search.search_nlp_models import get_local_reranking_model_ensemble
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from shared_models.model_server_models import EmbedRequest
from shared_models.model_server_models import EmbedResponse
from shared_models.model_server_models import RerankRequest
from shared_models.model_server_models import RerankResponse
logger = setup_logger()
WARM_UP_STRING = "Danswer is amazing"
router = APIRouter(prefix="/encoder")
@log_function_time()
def embed_text(
texts: list[str],
normalize_embeddings: bool = NORMALIZE_EMBEDDINGS,
) -> list[list[float]]:
model = get_local_embedding_model()
embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings)
if not isinstance(embeddings, list):
embeddings = embeddings.tolist()
return embeddings
@log_function_time()
def calc_sim_scores(query: str, docs: list[str]) -> list[list[float]]:
cross_encoders = get_local_reranking_model_ensemble()
sim_scores = [
encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
for encoder in cross_encoders
]
return sim_scores
@router.post("/bi-encoder-embed")
def process_embed_request(
embed_request: EmbedRequest,
) -> EmbedResponse:
try:
embeddings = embed_text(texts=embed_request.texts)
return EmbedResponse(embeddings=embeddings)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@router.post("/cross-encoder-scores")
def process_rerank_request(embed_request: RerankRequest) -> RerankResponse:
try:
sim_scores = calc_sim_scores(
query=embed_request.query, docs=embed_request.documents
)
return RerankResponse(scores=sim_scores)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def warm_up_bi_encoder() -> None:
logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}")
get_local_embedding_model().encode(WARM_UP_STRING)
def warm_up_cross_encoders() -> None:
logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}")
cross_encoders = get_local_reranking_model_ensemble()
[
cross_encoder.predict((WARM_UP_STRING, WARM_UP_STRING))
for cross_encoder in cross_encoders
]

View File

@@ -0,0 +1,51 @@
import torch
import uvicorn
from fastapi import FastAPI
from danswer import __version__
from danswer.configs.app_configs import MODEL_SERVER_ALLOWED_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
from danswer.configs.model_configs import MIN_THREADS_ML_MODELS
from danswer.utils.logger import setup_logger
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.encoders import warm_up_bi_encoder
from model_server.encoders import warm_up_cross_encoders
logger = setup_logger()
def get_model_app() -> FastAPI:
application = FastAPI(title="Danswer Model Server", version=__version__)
application.include_router(encoders_router)
application.include_router(custom_models_router)
@application.on_event("startup")
def startup_event() -> None:
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads()))
logger.info(f"Torch Threads: {torch.get_num_threads()}")
warm_up_bi_encoder()
warm_up_cross_encoders()
warm_up_intent_model()
return application
app = get_model_app()
if __name__ == "__main__":
logger.info(
f"Starting Danswer Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/"
)
logger.info(f"Model Server Version: {__version__}")
uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT)

View File

@@ -0,0 +1,8 @@
fastapi==0.103.0
pydantic==1.10.7
safetensors==0.3.1
sentence-transformers==2.2.2
tensorflow==2.13.0
torch==2.0.1
transformers==4.30.1
uvicorn==0.21.1

View File

View File

@@ -0,0 +1,26 @@
from pydantic import BaseModel
class EmbedRequest(BaseModel):
texts: list[str]
class EmbedResponse(BaseModel):
embeddings: list[list[float]]
class RerankRequest(BaseModel):
query: str
documents: list[str]
class RerankResponse(BaseModel):
scores: list[list[float]]
class IntentRequest(BaseModel):
query: str
class IntentResponse(BaseModel):
class_probs: list[float]

View File

@@ -42,6 +42,8 @@ services:
- SKIP_RERANKING=${SKIP_RERANKING:-} - SKIP_RERANKING=${SKIP_RERANKING:-}
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Set to debug to get more fine-grained logs # Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info} - LOG_LEVEL=${LOG_LEVEL:-info}
volumes: volumes:
@@ -94,6 +96,8 @@ services:
- QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-} - QA_PROMPT_OVERRIDE=${QA_PROMPT_OVERRIDE:-}
- EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-} - EDIT_KEYWORD_QUERY=${EDIT_KEYWORD_QUERY:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-} - MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
- MODEL_SERVER_HOST=${MODEL_SERVER_HOST:-}
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
# Set to debug to get more fine-grained logs # Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info} - LOG_LEVEL=${LOG_LEVEL:-info}
volumes: volumes:
@@ -157,6 +161,25 @@ services:
/bin/sh -c "sleep 10 && /bin/sh -c "sleep 10 &&
envsubst '$$\{DOMAIN\}' < /etc/nginx/conf.d/app.conf.template.dev > /etc/nginx/conf.d/app.conf && envsubst '$$\{DOMAIN\}' < /etc/nginx/conf.d/app.conf.template.dev > /etc/nginx/conf.d/app.conf &&
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\"" while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
volumes: volumes:
local_dynamic_storage: local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@@ -101,6 +101,25 @@ services:
while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\"" while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\""
env_file: env_file:
- .env.nginx - .env.nginx
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
volumes: volumes:
local_dynamic_storage: local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them

View File

@@ -110,6 +110,25 @@ services:
- ../data/certbot/conf:/etc/letsencrypt - ../data/certbot/conf:/etc/letsencrypt
- ../data/certbot/www:/var/www/certbot - ../data/certbot/www:/var/www/certbot
entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'" entrypoint: "/bin/sh -c 'trap exit TERM; while :; do certbot renew; sleep 12h & wait $${!}; done;'"
# Run with --profile model-server to bring up the danswer-model-server container
model_server:
image: danswer/danswer-model-server:latest
build:
context: ../../backend
dockerfile: Dockerfile.model_server
profiles:
- "model-server"
command: uvicorn model_server.main:app --host 0.0.0.0 --port 9000
restart: always
environment:
- DOCUMENT_ENCODER_MODEL=${DOCUMENT_ENCODER_MODEL:-}
- NORMALIZE_EMBEDDINGS=${NORMALIZE_EMBEDDINGS:-}
- MIN_THREADS_ML_MODELS=${MIN_THREADS_ML_MODELS:-}
# Set to debug to get more fine-grained logs
- LOG_LEVEL=${LOG_LEVEL:-info}
volumes:
- model_cache_torch:/root/.cache/torch/
- model_cache_huggingface:/root/.cache/huggingface/
volumes: volumes:
local_dynamic_storage: local_dynamic_storage:
file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them file_connector_tmp_storage: # used to store files uploaded by the user temporarily while we are indexing them