Files
danswer/backend/danswer/natural_language_processing/search_nlp_models.py

423 lines
14 KiB
Python

import re
import threading
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
import requests
from httpx import HTTPError
from retry import retry
from danswer.configs.app_configs import LARGE_CHUNK_RATIO
from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from danswer.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import SearchSettings
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
WARM_UP_STRINGS = [
"Danswer is amazing!",
"Check out our easy deployment guide at",
"https://docs.danswer.dev/quickstart",
]
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
_WHITELIST = set(
" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\n\t"
)
_INITIAL_FILTER = re.compile(
"["
"\U00000080-\U0000FFFF" # All Unicode characters beyond ASCII
"\U00010000-\U0010FFFF" # All Unicode characters in supplementary planes
"]+",
flags=re.UNICODE,
)
def clean_openai_text(text: str) -> str:
# First, remove all weird characters
cleaned = _INITIAL_FILTER.sub("", text)
# Then, keep only whitelisted characters
return "".join(char for char in cleaned if char in _WHITELIST)
def build_model_server_url(
model_server_host: str,
model_server_port: int,
) -> str:
model_server_url = f"{model_server_host}:{model_server_port}"
# use protocol if provided
if "http" in model_server_url:
return model_server_url
# otherwise default to http
return f"http://{model_server_url}"
class EmbeddingModel:
def __init__(
self,
server_host: str, # Changes depending on indexing or inference
server_port: int,
model_name: str | None,
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
self.retrim_content = retrim_content
self.api_url = api_url
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> EmbedResponse:
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
)
try:
response.raise_for_status()
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
return EmbedResponse(**response.json())
# only perform retries for the non-realtime embedding of passages (e.g. for indexing)
if embed_request.text_type == EmbedTextType.PASSAGE:
return retry(tries=3, delay=5)(_make_request)()
else:
return _make_request()
def _batch_encode_texts(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
)
embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
max_context_length=max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
api_url=self.api_url,
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
return embeddings
def encode(
self,
texts: list[str],
text_type: EmbedTextType,
large_chunks_present: bool = False,
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> list[Embedding]:
if not texts or not all(texts):
raise ValueError(f"Empty or missing text for embedding: {texts}")
if large_chunks_present:
max_seq_length *= LARGE_CHUNK_RATIO
if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
# However this slight miscounting is very unlikely to have any material impact.
texts = [
tokenizer_trim_content(
content=text,
desired_length=max_seq_length,
tokenizer=self.tokenizer,
)
for text in texts
]
if self.provider_type == EmbeddingProvider.OPENAI:
# If the provider is openai, we need to clean the text
# as a temporary workaround for the openai API
texts = [clean_openai_text(text) for text in texts]
batch_size = (
api_embedding_batch_size
if self.provider_type
else local_embedding_batch_size
)
return self._batch_encode_texts(
texts=texts,
text_type=text_type,
batch_size=batch_size,
max_seq_length=max_seq_length,
)
@classmethod
def from_db_model(
cls,
search_settings: SearchSettings,
server_host: str, # Changes depending on indexing or inference
server_port: int,
retrim_content: bool = False,
) -> "EmbeddingModel":
return cls(
server_host=server_host,
server_port=server_port,
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
api_key=search_settings.api_key,
provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content,
)
class RerankingModel:
def __init__(
self,
model_name: str,
provider_type: RerankerProvider | None,
api_key: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
query=query,
documents=passages,
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
)
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.model_dump()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
class QueryAnalysisModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
# Lean heavily towards not throwing out keywords
keyword_percent_threshold: float = 0.1,
# Lean towards semantic which is the default
semantic_percent_threshold: float = 0.4,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.intent_server_endpoint = model_server_url + "/custom/query-analysis"
self.keyword_percent_threshold = keyword_percent_threshold
self.semantic_percent_threshold = semantic_percent_threshold
def predict(
self,
query: str,
) -> tuple[bool, list[str]]:
intent_request = IntentRequest(
query=query,
keyword_percent_threshold=self.keyword_percent_threshold,
semantic_percent_threshold=self.semantic_percent_threshold,
)
response = requests.post(
self.intent_server_endpoint, json=intent_request.model_dump()
)
response.raise_for_status()
response_model = IntentResponse(**response.json())
return response_model.is_keyword, response_model.keywords
class ConnectorClassificationModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
):
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.connector_classification_endpoint = (
model_server_url + "/custom/connector-classification"
)
def predict(
self,
query: str,
available_connectors: list[str],
) -> list[str]:
connector_classification_request = ConnectorClassificationRequest(
available_connectors=available_connectors,
query=query,
)
response = requests.post(
self.connector_classification_endpoint,
json=connector_classification_request.dict(),
)
response.raise_for_status()
response_model = ConnectorClassificationResponse(**response.json())
return response_model.connectors
def warm_up_retry(
func: Callable[..., Any],
tries: int = 20,
delay: int = 5,
*args: Any,
**kwargs: Any,
) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
exceptions = []
for attempt in range(tries):
try:
return func(*args, **kwargs)
except Exception as e:
exceptions.append(e)
logger.exception(
f"Attempt {attempt + 1} failed; retrying in {delay} seconds..."
)
time.sleep(delay)
raise Exception(f"All retries failed: {exceptions}")
return wrapper
def warm_up_bi_encoder(
embedding_model: EmbeddingModel,
non_blocking: bool = False,
) -> None:
warm_up_str = " ".join(WARM_UP_STRINGS)
logger.debug(f"Warming up encoder model: {embedding_model.model_name}")
get_tokenizer(
model_name=embedding_model.model_name,
provider_type=embedding_model.provider_type,
).encode(warm_up_str)
def _warm_up() -> None:
try:
embedding_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
logger.debug(
f"Warm-up complete for encoder model: {embedding_model.model_name}"
)
except Exception as e:
logger.warning(
f"Warm-up request failed for encoder model {embedding_model.model_name}: {e}"
)
if non_blocking:
threading.Thread(target=_warm_up, daemon=True).start()
logger.debug(
f"Started non-blocking warm-up for encoder model: {embedding_model.model_name}"
)
else:
retry_encode = warm_up_retry(embedding_model.encode)
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
def warm_up_cross_encoder(
rerank_model_name: str,
non_blocking: bool = False,
) -> None:
logger.debug(f"Warming up reranking model: {rerank_model_name}")
reranking_model = RerankingModel(
model_name=rerank_model_name,
provider_type=None,
api_key=None,
)
def _warm_up() -> None:
try:
reranking_model.predict(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
logger.debug(f"Warm-up complete for reranking model: {rerank_model_name}")
except Exception as e:
logger.warning(
f"Warm-up request failed for reranking model {rerank_model_name}: {e}"
)
if non_blocking:
threading.Thread(target=_warm_up, daemon=True).start()
logger.debug(
f"Started non-blocking warm-up for reranking model: {rerank_model_name}"
)
else:
retry_rerank = warm_up_retry(reranking_model.predict)
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])