mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 08:20:40 +02:00
497 lines
18 KiB
Python
497 lines
18 KiB
Python
import threading
|
|
import time
|
|
from collections.abc import Callable
|
|
from concurrent.futures import as_completed
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from functools import wraps
|
|
from typing import Any
|
|
|
|
import requests
|
|
from httpx import HTTPError
|
|
from requests import JSONDecodeError
|
|
from requests import RequestException
|
|
from requests import Response
|
|
from retry import retry
|
|
|
|
from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS
|
|
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
|
|
from onyx.configs.app_configs import SKIP_WARM_UP
|
|
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
|
|
from onyx.configs.model_configs import (
|
|
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
|
)
|
|
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
|
from onyx.db.models import SearchSettings
|
|
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
|
from onyx.natural_language_processing.exceptions import (
|
|
ModelServerRateLimitError,
|
|
)
|
|
from onyx.natural_language_processing.utils import get_tokenizer
|
|
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.configs import MODEL_SERVER_HOST
|
|
from shared_configs.configs import MODEL_SERVER_PORT
|
|
from shared_configs.enums import EmbeddingProvider
|
|
from shared_configs.enums import EmbedTextType
|
|
from shared_configs.enums import RerankerProvider
|
|
from shared_configs.model_server_models import ConnectorClassificationRequest
|
|
from shared_configs.model_server_models import ConnectorClassificationResponse
|
|
from shared_configs.model_server_models import Embedding
|
|
from shared_configs.model_server_models import EmbedRequest
|
|
from shared_configs.model_server_models import EmbedResponse
|
|
from shared_configs.model_server_models import IntentRequest
|
|
from shared_configs.model_server_models import IntentResponse
|
|
from shared_configs.model_server_models import RerankRequest
|
|
from shared_configs.model_server_models import RerankResponse
|
|
from shared_configs.utils import batch_list
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
WARM_UP_STRINGS = [
|
|
"Onyx is amazing!",
|
|
"Check out our easy deployment guide at",
|
|
"https://docs.onyx.app/quickstart",
|
|
]
|
|
|
|
|
|
def clean_model_name(model_str: str) -> str:
|
|
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
|
|
|
|
|
def build_model_server_url(
|
|
model_server_host: str,
|
|
model_server_port: int,
|
|
) -> str:
|
|
model_server_url = f"{model_server_host}:{model_server_port}"
|
|
|
|
# use protocol if provided
|
|
if "http" in model_server_url:
|
|
return model_server_url
|
|
|
|
# otherwise default to http
|
|
return f"http://{model_server_url}"
|
|
|
|
|
|
class EmbeddingModel:
|
|
def __init__(
|
|
self,
|
|
server_host: str, # Changes depending on indexing or inference
|
|
server_port: int,
|
|
model_name: str | None,
|
|
normalize: bool,
|
|
query_prefix: str | None,
|
|
passage_prefix: str | None,
|
|
api_key: str | None,
|
|
api_url: str | None,
|
|
provider_type: EmbeddingProvider | None,
|
|
retrim_content: bool = False,
|
|
callback: IndexingHeartbeatInterface | None = None,
|
|
api_version: str | None = None,
|
|
deployment_name: str | None = None,
|
|
) -> None:
|
|
self.api_key = api_key
|
|
self.provider_type = provider_type
|
|
self.query_prefix = query_prefix
|
|
self.passage_prefix = passage_prefix
|
|
self.normalize = normalize
|
|
self.model_name = model_name
|
|
self.retrim_content = retrim_content
|
|
self.api_url = api_url
|
|
self.api_version = api_version
|
|
self.deployment_name = deployment_name
|
|
self.tokenizer = get_tokenizer(
|
|
model_name=model_name, provider_type=provider_type
|
|
)
|
|
self.callback = callback
|
|
|
|
model_server_url = build_model_server_url(server_host, server_port)
|
|
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
|
|
|
|
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
|
|
def _make_request() -> Response:
|
|
response = requests.post(
|
|
self.embed_server_endpoint, json=embed_request.model_dump()
|
|
)
|
|
# signify that this is a rate limit error
|
|
if response.status_code == 429:
|
|
raise ModelServerRateLimitError(response.text)
|
|
|
|
response.raise_for_status()
|
|
return response
|
|
|
|
final_make_request_func = _make_request
|
|
|
|
# if the text type is a passage, add some default
|
|
# retries + handling for rate limiting
|
|
if embed_request.text_type == EmbedTextType.PASSAGE:
|
|
final_make_request_func = retry(
|
|
tries=3,
|
|
delay=5,
|
|
exceptions=(RequestException, ValueError, JSONDecodeError),
|
|
)(final_make_request_func)
|
|
# use 10 second delay as per Azure suggestion
|
|
final_make_request_func = retry(
|
|
tries=10, delay=10, exceptions=ModelServerRateLimitError
|
|
)(final_make_request_func)
|
|
|
|
response: Response | None = None
|
|
|
|
try:
|
|
response = final_make_request_func()
|
|
return EmbedResponse(**response.json())
|
|
except requests.HTTPError as e:
|
|
if not response:
|
|
raise HTTPError("HTTP error occurred - response is None.") from e
|
|
|
|
try:
|
|
error_detail = response.json().get("detail", str(e))
|
|
except Exception:
|
|
error_detail = response.text
|
|
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
|
|
except requests.RequestException as e:
|
|
raise HTTPError(f"Request failed: {str(e)}") from e
|
|
|
|
def _batch_encode_texts(
|
|
self,
|
|
texts: list[str],
|
|
text_type: EmbedTextType,
|
|
batch_size: int,
|
|
max_seq_length: int,
|
|
num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS,
|
|
) -> list[Embedding]:
|
|
text_batches = batch_list(texts, batch_size)
|
|
|
|
logger.debug(
|
|
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
|
)
|
|
|
|
embeddings: list[Embedding] = []
|
|
|
|
def process_batch(
|
|
batch_idx: int, text_batch: list[str]
|
|
) -> tuple[int, list[Embedding]]:
|
|
if self.callback:
|
|
if self.callback.should_stop():
|
|
raise RuntimeError("_batch_encode_texts detected stop signal")
|
|
|
|
embed_request = EmbedRequest(
|
|
model_name=self.model_name,
|
|
texts=text_batch,
|
|
api_version=self.api_version,
|
|
deployment_name=self.deployment_name,
|
|
max_context_length=max_seq_length,
|
|
normalize_embeddings=self.normalize,
|
|
api_key=self.api_key,
|
|
provider_type=self.provider_type,
|
|
text_type=text_type,
|
|
manual_query_prefix=self.query_prefix,
|
|
manual_passage_prefix=self.passage_prefix,
|
|
api_url=self.api_url,
|
|
)
|
|
|
|
start_time = time.time()
|
|
response = self._make_model_server_request(embed_request)
|
|
end_time = time.time()
|
|
|
|
processing_time = end_time - start_time
|
|
logger.info(
|
|
f"Batch {batch_idx} processing time: {processing_time:.2f} seconds"
|
|
)
|
|
|
|
return batch_idx, response.embeddings
|
|
|
|
# only multi thread if:
|
|
# 1. num_threads is greater than 1
|
|
# 2. we are using an API-based embedding model (provider_type is not None)
|
|
# 3. there are more than 1 batch (no point in threading if only 1)
|
|
if num_threads >= 1 and self.provider_type and len(text_batches) > 1:
|
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
|
future_to_batch = {
|
|
executor.submit(process_batch, idx, batch): idx
|
|
for idx, batch in enumerate(text_batches, start=1)
|
|
}
|
|
|
|
# Collect results in order
|
|
batch_results: list[tuple[int, list[Embedding]]] = []
|
|
for future in as_completed(future_to_batch):
|
|
try:
|
|
result = future.result()
|
|
batch_results.append(result)
|
|
if self.callback:
|
|
self.callback.progress("_batch_encode_texts", 1)
|
|
except Exception as e:
|
|
logger.exception("Embedding model failed to process batch")
|
|
raise e
|
|
|
|
# Sort by batch index and extend embeddings
|
|
batch_results.sort(key=lambda x: x[0])
|
|
for _, batch_embeddings in batch_results:
|
|
embeddings.extend(batch_embeddings)
|
|
else:
|
|
# Original sequential processing
|
|
for idx, text_batch in enumerate(text_batches, start=1):
|
|
_, batch_embeddings = process_batch(idx, text_batch)
|
|
embeddings.extend(batch_embeddings)
|
|
if self.callback:
|
|
self.callback.progress("_batch_encode_texts", 1)
|
|
|
|
return embeddings
|
|
|
|
def encode(
|
|
self,
|
|
texts: list[str],
|
|
text_type: EmbedTextType,
|
|
large_chunks_present: bool = False,
|
|
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
|
|
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
|
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
|
) -> list[Embedding]:
|
|
if not texts or not all(texts):
|
|
raise ValueError(f"Empty or missing text for embedding: {texts}")
|
|
|
|
if large_chunks_present:
|
|
max_seq_length *= LARGE_CHUNK_RATIO
|
|
|
|
if self.retrim_content:
|
|
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
|
|
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
|
|
# However this slight miscounting is very unlikely to have any material impact.
|
|
texts = [
|
|
tokenizer_trim_content(
|
|
content=text,
|
|
desired_length=max_seq_length,
|
|
tokenizer=self.tokenizer,
|
|
)
|
|
for text in texts
|
|
]
|
|
|
|
batch_size = (
|
|
api_embedding_batch_size
|
|
if self.provider_type
|
|
else local_embedding_batch_size
|
|
)
|
|
|
|
return self._batch_encode_texts(
|
|
texts=texts,
|
|
text_type=text_type,
|
|
batch_size=batch_size,
|
|
max_seq_length=max_seq_length,
|
|
)
|
|
|
|
@classmethod
|
|
def from_db_model(
|
|
cls,
|
|
search_settings: SearchSettings,
|
|
server_host: str, # Changes depending on indexing or inference
|
|
server_port: int,
|
|
retrim_content: bool = False,
|
|
) -> "EmbeddingModel":
|
|
return cls(
|
|
server_host=server_host,
|
|
server_port=server_port,
|
|
model_name=search_settings.model_name,
|
|
normalize=search_settings.normalize,
|
|
query_prefix=search_settings.query_prefix,
|
|
passage_prefix=search_settings.passage_prefix,
|
|
api_key=search_settings.api_key,
|
|
provider_type=search_settings.provider_type,
|
|
api_url=search_settings.api_url,
|
|
retrim_content=retrim_content,
|
|
api_version=search_settings.api_version,
|
|
deployment_name=search_settings.deployment_name,
|
|
)
|
|
|
|
|
|
class RerankingModel:
|
|
def __init__(
|
|
self,
|
|
model_name: str,
|
|
provider_type: RerankerProvider | None,
|
|
api_key: str | None,
|
|
api_url: str | None,
|
|
model_server_host: str = MODEL_SERVER_HOST,
|
|
model_server_port: int = MODEL_SERVER_PORT,
|
|
) -> None:
|
|
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
|
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
|
|
self.model_name = model_name
|
|
self.provider_type = provider_type
|
|
self.api_key = api_key
|
|
self.api_url = api_url
|
|
|
|
def predict(self, query: str, passages: list[str]) -> list[float]:
|
|
rerank_request = RerankRequest(
|
|
query=query,
|
|
documents=passages,
|
|
model_name=self.model_name,
|
|
provider_type=self.provider_type,
|
|
api_key=self.api_key,
|
|
api_url=self.api_url,
|
|
)
|
|
|
|
response = requests.post(
|
|
self.rerank_server_endpoint, json=rerank_request.model_dump()
|
|
)
|
|
response.raise_for_status()
|
|
|
|
return RerankResponse(**response.json()).scores
|
|
|
|
|
|
class QueryAnalysisModel:
|
|
def __init__(
|
|
self,
|
|
model_server_host: str = MODEL_SERVER_HOST,
|
|
model_server_port: int = MODEL_SERVER_PORT,
|
|
# Lean heavily towards not throwing out keywords
|
|
keyword_percent_threshold: float = 0.1,
|
|
# Lean towards semantic which is the default
|
|
semantic_percent_threshold: float = 0.4,
|
|
) -> None:
|
|
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
|
self.intent_server_endpoint = model_server_url + "/custom/query-analysis"
|
|
self.keyword_percent_threshold = keyword_percent_threshold
|
|
self.semantic_percent_threshold = semantic_percent_threshold
|
|
|
|
def predict(
|
|
self,
|
|
query: str,
|
|
) -> tuple[bool, list[str]]:
|
|
intent_request = IntentRequest(
|
|
query=query,
|
|
keyword_percent_threshold=self.keyword_percent_threshold,
|
|
semantic_percent_threshold=self.semantic_percent_threshold,
|
|
)
|
|
|
|
response = requests.post(
|
|
self.intent_server_endpoint, json=intent_request.model_dump()
|
|
)
|
|
response.raise_for_status()
|
|
|
|
response_model = IntentResponse(**response.json())
|
|
|
|
return response_model.is_keyword, response_model.keywords
|
|
|
|
|
|
class ConnectorClassificationModel:
|
|
def __init__(
|
|
self,
|
|
model_server_host: str = MODEL_SERVER_HOST,
|
|
model_server_port: int = MODEL_SERVER_PORT,
|
|
):
|
|
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
|
self.connector_classification_endpoint = (
|
|
model_server_url + "/custom/connector-classification"
|
|
)
|
|
|
|
def predict(
|
|
self,
|
|
query: str,
|
|
available_connectors: list[str],
|
|
) -> list[str]:
|
|
connector_classification_request = ConnectorClassificationRequest(
|
|
available_connectors=available_connectors,
|
|
query=query,
|
|
)
|
|
response = requests.post(
|
|
self.connector_classification_endpoint,
|
|
json=connector_classification_request.dict(),
|
|
)
|
|
response.raise_for_status()
|
|
|
|
response_model = ConnectorClassificationResponse(**response.json())
|
|
|
|
return response_model.connectors
|
|
|
|
|
|
def warm_up_retry(
|
|
func: Callable[..., Any],
|
|
tries: int = 20,
|
|
delay: int = 5,
|
|
*args: Any,
|
|
**kwargs: Any,
|
|
) -> Callable[..., Any]:
|
|
@wraps(func)
|
|
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
exceptions = []
|
|
for attempt in range(tries):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except Exception as e:
|
|
exceptions.append(e)
|
|
logger.info(
|
|
f"Attempt {attempt + 1}/{tries} failed; retrying in {delay} seconds..."
|
|
)
|
|
time.sleep(delay)
|
|
raise Exception(f"All retries failed: {exceptions}")
|
|
|
|
return wrapper
|
|
|
|
|
|
def warm_up_bi_encoder(
|
|
embedding_model: EmbeddingModel,
|
|
non_blocking: bool = False,
|
|
) -> None:
|
|
if SKIP_WARM_UP:
|
|
return
|
|
|
|
warm_up_str = " ".join(WARM_UP_STRINGS)
|
|
|
|
logger.debug(f"Warming up encoder model: {embedding_model.model_name}")
|
|
get_tokenizer(
|
|
model_name=embedding_model.model_name,
|
|
provider_type=embedding_model.provider_type,
|
|
).encode(warm_up_str)
|
|
|
|
def _warm_up() -> None:
|
|
try:
|
|
embedding_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
|
logger.debug(
|
|
f"Warm-up complete for encoder model: {embedding_model.model_name}"
|
|
)
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Warm-up request failed for encoder model {embedding_model.model_name}: {e}"
|
|
)
|
|
|
|
if non_blocking:
|
|
threading.Thread(target=_warm_up, daemon=True).start()
|
|
logger.debug(
|
|
f"Started non-blocking warm-up for encoder model: {embedding_model.model_name}"
|
|
)
|
|
else:
|
|
retry_encode = warm_up_retry(embedding_model.encode)
|
|
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
|
|
|
|
|
def warm_up_cross_encoder(
|
|
rerank_model_name: str,
|
|
non_blocking: bool = False,
|
|
) -> None:
|
|
logger.debug(f"Warming up reranking model: {rerank_model_name}")
|
|
|
|
reranking_model = RerankingModel(
|
|
model_name=rerank_model_name,
|
|
provider_type=None,
|
|
api_url=None,
|
|
api_key=None,
|
|
)
|
|
|
|
def _warm_up() -> None:
|
|
try:
|
|
reranking_model.predict(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
|
|
logger.debug(f"Warm-up complete for reranking model: {rerank_model_name}")
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Warm-up request failed for reranking model {rerank_model_name}: {e}"
|
|
)
|
|
|
|
if non_blocking:
|
|
threading.Thread(target=_warm_up, daemon=True).start()
|
|
logger.debug(
|
|
f"Started non-blocking warm-up for reranking model: {rerank_model_name}"
|
|
)
|
|
else:
|
|
retry_rerank = warm_up_retry(reranking_model.predict)
|
|
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
|