welcome to onyx

This commit is contained in:
pablodanswer
2024-12-13 09:48:43 -08:00
parent 54dcbfa288
commit 21ec5ed795
813 changed files with 7021 additions and 6824 deletions

View File

@ -0,0 +1,4 @@
class ModelServerRateLimitError(Exception):
"""
Exception raised for rate limiting errors from the model server.
"""

View File

@ -0,0 +1,440 @@
import threading
import time
from collections.abc import Callable
from functools import wraps
from typing import Any
import requests
from httpx import HTTPError
from requests import JSONDecodeError
from requests import RequestException
from requests import Response
from retry import retry
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS
from onyx.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
)
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.db.models import SearchSettings
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.natural_language_processing.exceptions import (
ModelServerRateLimitError,
)
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
from shared_configs.model_server_models import RerankResponse
from shared_configs.utils import batch_list
logger = setup_logger()
WARM_UP_STRINGS = [
"Onyx is amazing!",
"Check out our easy deployment guide at",
"https://docs.onyx.app/quickstart",
]
def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
def build_model_server_url(
model_server_host: str,
model_server_port: int,
) -> str:
model_server_url = f"{model_server_host}:{model_server_port}"
# use protocol if provided
if "http" in model_server_url:
return model_server_url
# otherwise default to http
return f"http://{model_server_url}"
class EmbeddingModel:
def __init__(
self,
server_host: str, # Changes depending on indexing or inference
server_port: int,
model_name: str | None,
normalize: bool,
query_prefix: str | None,
passage_prefix: str | None,
api_key: str | None,
api_url: str | None,
provider_type: EmbeddingProvider | None,
retrim_content: bool = False,
callback: IndexingHeartbeatInterface | None = None,
api_version: str | None = None,
deployment_name: str | None = None,
) -> None:
self.api_key = api_key
self.provider_type = provider_type
self.query_prefix = query_prefix
self.passage_prefix = passage_prefix
self.normalize = normalize
self.model_name = model_name
self.retrim_content = retrim_content
self.api_url = api_url
self.api_version = api_version
self.deployment_name = deployment_name
self.tokenizer = get_tokenizer(
model_name=model_name, provider_type=provider_type
)
self.callback = callback
model_server_url = build_model_server_url(server_host, server_port)
self.embed_server_endpoint = f"{model_server_url}/encoder/bi-encoder-embed"
def _make_model_server_request(self, embed_request: EmbedRequest) -> EmbedResponse:
def _make_request() -> Response:
response = requests.post(
self.embed_server_endpoint, json=embed_request.model_dump()
)
# signify that this is a rate limit error
if response.status_code == 429:
raise ModelServerRateLimitError(response.text)
response.raise_for_status()
return response
final_make_request_func = _make_request
# if the text type is a passage, add some default
# retries + handling for rate limiting
if embed_request.text_type == EmbedTextType.PASSAGE:
final_make_request_func = retry(
tries=3,
delay=5,
exceptions=(RequestException, ValueError, JSONDecodeError),
)(final_make_request_func)
# use 10 second delay as per Azure suggestion
final_make_request_func = retry(
tries=10, delay=10, exceptions=ModelServerRateLimitError
)(final_make_request_func)
try:
response = final_make_request_func()
return EmbedResponse(**response.json())
except requests.HTTPError as e:
try:
error_detail = response.json().get("detail", str(e))
except Exception:
error_detail = response.text
raise HTTPError(f"HTTP error occurred: {error_detail}") from e
except requests.RequestException as e:
raise HTTPError(f"Request failed: {str(e)}") from e
def _batch_encode_texts(
self,
texts: list[str],
text_type: EmbedTextType,
batch_size: int,
max_seq_length: int,
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
)
embeddings: list[Embedding] = []
for idx, text_batch in enumerate(text_batches, start=1):
if self.callback:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
api_version=self.api_version,
deployment_name=self.deployment_name,
max_context_length=max_seq_length,
normalize_embeddings=self.normalize,
api_key=self.api_key,
provider_type=self.provider_type,
text_type=text_type,
manual_query_prefix=self.query_prefix,
manual_passage_prefix=self.passage_prefix,
api_url=self.api_url,
)
response = self._make_model_server_request(embed_request)
embeddings.extend(response.embeddings)
if self.callback:
self.callback.progress("_batch_encode_texts", 1)
return embeddings
def encode(
self,
texts: list[str],
text_type: EmbedTextType,
large_chunks_present: bool = False,
local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS,
api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> list[Embedding]:
if not texts or not all(texts):
raise ValueError(f"Empty or missing text for embedding: {texts}")
if large_chunks_present:
max_seq_length *= LARGE_CHUNK_RATIO
if self.retrim_content:
# This is applied during indexing as a catchall for overly long titles (or other uncapped fields)
# Note that this uses just the default tokenizer which may also lead to very minor miscountings
# However this slight miscounting is very unlikely to have any material impact.
texts = [
tokenizer_trim_content(
content=text,
desired_length=max_seq_length,
tokenizer=self.tokenizer,
)
for text in texts
]
batch_size = (
api_embedding_batch_size
if self.provider_type
else local_embedding_batch_size
)
return self._batch_encode_texts(
texts=texts,
text_type=text_type,
batch_size=batch_size,
max_seq_length=max_seq_length,
)
@classmethod
def from_db_model(
cls,
search_settings: SearchSettings,
server_host: str, # Changes depending on indexing or inference
server_port: int,
retrim_content: bool = False,
) -> "EmbeddingModel":
return cls(
server_host=server_host,
server_port=server_port,
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
api_key=search_settings.api_key,
provider_type=search_settings.provider_type,
api_url=search_settings.api_url,
retrim_content=retrim_content,
api_version=search_settings.api_version,
deployment_name=search_settings.deployment_name,
)
class RerankingModel:
def __init__(
self,
model_name: str,
provider_type: RerankerProvider | None,
api_key: str | None,
api_url: str | None,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.rerank_server_endpoint = model_server_url + "/encoder/cross-encoder-scores"
self.model_name = model_name
self.provider_type = provider_type
self.api_key = api_key
self.api_url = api_url
def predict(self, query: str, passages: list[str]) -> list[float]:
rerank_request = RerankRequest(
query=query,
documents=passages,
model_name=self.model_name,
provider_type=self.provider_type,
api_key=self.api_key,
api_url=self.api_url,
)
response = requests.post(
self.rerank_server_endpoint, json=rerank_request.model_dump()
)
response.raise_for_status()
return RerankResponse(**response.json()).scores
class QueryAnalysisModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
# Lean heavily towards not throwing out keywords
keyword_percent_threshold: float = 0.1,
# Lean towards semantic which is the default
semantic_percent_threshold: float = 0.4,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.intent_server_endpoint = model_server_url + "/custom/query-analysis"
self.keyword_percent_threshold = keyword_percent_threshold
self.semantic_percent_threshold = semantic_percent_threshold
def predict(
self,
query: str,
) -> tuple[bool, list[str]]:
intent_request = IntentRequest(
query=query,
keyword_percent_threshold=self.keyword_percent_threshold,
semantic_percent_threshold=self.semantic_percent_threshold,
)
response = requests.post(
self.intent_server_endpoint, json=intent_request.model_dump()
)
response.raise_for_status()
response_model = IntentResponse(**response.json())
return response_model.is_keyword, response_model.keywords
class ConnectorClassificationModel:
def __init__(
self,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
):
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.connector_classification_endpoint = (
model_server_url + "/custom/connector-classification"
)
def predict(
self,
query: str,
available_connectors: list[str],
) -> list[str]:
connector_classification_request = ConnectorClassificationRequest(
available_connectors=available_connectors,
query=query,
)
response = requests.post(
self.connector_classification_endpoint,
json=connector_classification_request.dict(),
)
response.raise_for_status()
response_model = ConnectorClassificationResponse(**response.json())
return response_model.connectors
def warm_up_retry(
func: Callable[..., Any],
tries: int = 20,
delay: int = 5,
*args: Any,
**kwargs: Any,
) -> Callable[..., Any]:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any:
exceptions = []
for attempt in range(tries):
try:
return func(*args, **kwargs)
except Exception as e:
exceptions.append(e)
logger.info(
f"Attempt {attempt + 1}/{tries} failed; retrying in {delay} seconds..."
)
time.sleep(delay)
raise Exception(f"All retries failed: {exceptions}")
return wrapper
def warm_up_bi_encoder(
embedding_model: EmbeddingModel,
non_blocking: bool = False,
) -> None:
warm_up_str = " ".join(WARM_UP_STRINGS)
logger.debug(f"Warming up encoder model: {embedding_model.model_name}")
get_tokenizer(
model_name=embedding_model.model_name,
provider_type=embedding_model.provider_type,
).encode(warm_up_str)
def _warm_up() -> None:
try:
embedding_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
logger.debug(
f"Warm-up complete for encoder model: {embedding_model.model_name}"
)
except Exception as e:
logger.warning(
f"Warm-up request failed for encoder model {embedding_model.model_name}: {e}"
)
if non_blocking:
threading.Thread(target=_warm_up, daemon=True).start()
logger.debug(
f"Started non-blocking warm-up for encoder model: {embedding_model.model_name}"
)
else:
retry_encode = warm_up_retry(embedding_model.encode)
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
def warm_up_cross_encoder(
rerank_model_name: str,
non_blocking: bool = False,
) -> None:
logger.debug(f"Warming up reranking model: {rerank_model_name}")
reranking_model = RerankingModel(
model_name=rerank_model_name,
provider_type=None,
api_url=None,
api_key=None,
)
def _warm_up() -> None:
try:
reranking_model.predict(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])
logger.debug(f"Warm-up complete for reranking model: {rerank_model_name}")
except Exception as e:
logger.warning(
f"Warm-up request failed for reranking model {rerank_model_name}: {e}"
)
if non_blocking:
threading.Thread(target=_warm_up, daemon=True).start()
logger.debug(
f"Started non-blocking warm-up for reranking model: {rerank_model_name}"
)
else:
retry_rerank = warm_up_retry(reranking_model.predict)
retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:])

View File

@ -0,0 +1,179 @@
import os
from abc import ABC
from abc import abstractmethod
from copy import copy
from transformers import logging as transformer_logging # type:ignore
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
from onyx.context.search.models import InferenceChunk
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
class BaseTokenizer(ABC):
@abstractmethod
def encode(self, string: str) -> list[int]:
pass
@abstractmethod
def tokenize(self, string: str) -> list[str]:
pass
@abstractmethod
def decode(self, tokens: list[int]) -> str:
pass
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if model_name not in cls._instances:
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[model_name]
def __init__(self, model_name: str):
if not hasattr(self, "encoder"):
import tiktoken
self.encoder = tiktoken.encoding_for_model(model_name)
def encode(self, string: str) -> list[int]:
# this ignores special tokens that the model is trained on, see encode_ordinary for details
return self.encoder.encode_ordinary(string)
def tokenize(self, string: str) -> list[str]:
encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]
if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)
return decoded
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
class HuggingFaceTokenizer(BaseTokenizer):
def __init__(self, model_name: str):
from tokenizers import Tokenizer # type: ignore
self.encoder = Tokenizer.from_pretrained(model_name)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
return self.encoder.encode(string, add_special_tokens=False).ids
def tokenize(self, string: str) -> list[str]:
return self.encoder.encode(string, add_special_tokens=False).tokens
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE
id_tuple = (model_provider, model_name)
if id_tuple not in _TOKENIZER_CACHE:
tokenizer = None
if model_name:
tokenizer = _try_initialize_tokenizer(model_name, model_provider)
if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
_TOKENIZER_CACHE[id_tuple] = tokenizer
return _TOKENIZER_CACHE[id_tuple]
def _try_initialize_tokenizer(
model_name: str, model_provider: EmbeddingProvider | None
) -> BaseTokenizer | None:
tokenizer: BaseTokenizer | None = None
if model_provider is not None:
# Try using TiktokenTokenizer first if model_provider exists
try:
tokenizer = TiktokenTokenizer(model_name)
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
return tokenizer
except Exception as tiktoken_error:
logger.debug(
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
)
else:
# If no provider specified, try HuggingFaceTokenizer
try:
tokenizer = HuggingFaceTokenizer(model_name)
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
return tokenizer
except Exception as hf_error:
logger.warning(
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
)
# If both initializations fail, return None
return None
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk],
tokenizer: BaseTokenizer,
max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> list[InferenceChunk]:
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks