mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-06 18:00:29 +02:00
180 lines
5.8 KiB
Python
180 lines
5.8 KiB
Python
import os
|
|
from abc import ABC
|
|
from abc import abstractmethod
|
|
from copy import copy
|
|
|
|
from transformers import logging as transformer_logging # type:ignore
|
|
|
|
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
|
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
|
from onyx.context.search.models import InferenceChunk
|
|
from onyx.utils.logger import setup_logger
|
|
from shared_configs.enums import EmbeddingProvider
|
|
|
|
logger = setup_logger()
|
|
transformer_logging.set_verbosity_error()
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
|
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
|
|
|
|
|
|
class BaseTokenizer(ABC):
|
|
@abstractmethod
|
|
def encode(self, string: str) -> list[int]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def tokenize(self, string: str) -> list[str]:
|
|
pass
|
|
|
|
@abstractmethod
|
|
def decode(self, tokens: list[int]) -> str:
|
|
pass
|
|
|
|
|
|
class TiktokenTokenizer(BaseTokenizer):
|
|
_instances: dict[str, "TiktokenTokenizer"] = {}
|
|
|
|
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
|
|
if model_name not in cls._instances:
|
|
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
|
return cls._instances[model_name]
|
|
|
|
def __init__(self, model_name: str):
|
|
if not hasattr(self, "encoder"):
|
|
import tiktoken
|
|
|
|
self.encoder = tiktoken.encoding_for_model(model_name)
|
|
|
|
def encode(self, string: str) -> list[int]:
|
|
# this ignores special tokens that the model is trained on, see encode_ordinary for details
|
|
return self.encoder.encode_ordinary(string)
|
|
|
|
def tokenize(self, string: str) -> list[str]:
|
|
encoded = self.encode(string)
|
|
decoded = [self.encoder.decode([token]) for token in encoded]
|
|
|
|
if len(decoded) != len(encoded):
|
|
logger.warning(
|
|
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
|
|
)
|
|
|
|
return decoded
|
|
|
|
def decode(self, tokens: list[int]) -> str:
|
|
return self.encoder.decode(tokens)
|
|
|
|
|
|
class HuggingFaceTokenizer(BaseTokenizer):
|
|
def __init__(self, model_name: str):
|
|
from tokenizers import Tokenizer # type: ignore
|
|
|
|
self.encoder = Tokenizer.from_pretrained(model_name)
|
|
|
|
def encode(self, string: str) -> list[int]:
|
|
# this returns no special tokens
|
|
return self.encoder.encode(string, add_special_tokens=False).ids
|
|
|
|
def tokenize(self, string: str) -> list[str]:
|
|
return self.encoder.encode(string, add_special_tokens=False).tokens
|
|
|
|
def decode(self, tokens: list[int]) -> str:
|
|
return self.encoder.decode(tokens)
|
|
|
|
|
|
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
|
|
|
|
|
|
def _check_tokenizer_cache(
|
|
model_provider: EmbeddingProvider | None, model_name: str | None
|
|
) -> BaseTokenizer:
|
|
global _TOKENIZER_CACHE
|
|
id_tuple = (model_provider, model_name)
|
|
|
|
if id_tuple not in _TOKENIZER_CACHE:
|
|
tokenizer = None
|
|
|
|
if model_name:
|
|
tokenizer = _try_initialize_tokenizer(model_name, model_provider)
|
|
|
|
if not tokenizer:
|
|
logger.info(
|
|
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
|
)
|
|
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
|
|
|
_TOKENIZER_CACHE[id_tuple] = tokenizer
|
|
|
|
return _TOKENIZER_CACHE[id_tuple]
|
|
|
|
|
|
def _try_initialize_tokenizer(
|
|
model_name: str, model_provider: EmbeddingProvider | None
|
|
) -> BaseTokenizer | None:
|
|
tokenizer: BaseTokenizer | None = None
|
|
|
|
if model_provider is not None:
|
|
# Try using TiktokenTokenizer first if model_provider exists
|
|
try:
|
|
tokenizer = TiktokenTokenizer(model_name)
|
|
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
|
|
return tokenizer
|
|
except Exception as tiktoken_error:
|
|
logger.debug(
|
|
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
|
|
)
|
|
else:
|
|
# If no provider specified, try HuggingFaceTokenizer
|
|
try:
|
|
tokenizer = HuggingFaceTokenizer(model_name)
|
|
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
|
|
return tokenizer
|
|
except Exception as hf_error:
|
|
logger.warning(
|
|
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
|
|
)
|
|
|
|
# If both initializations fail, return None
|
|
return None
|
|
|
|
|
|
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
|
|
|
|
|
def get_tokenizer(
|
|
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
|
) -> BaseTokenizer:
|
|
if isinstance(provider_type, str):
|
|
try:
|
|
provider_type = EmbeddingProvider(provider_type)
|
|
except ValueError:
|
|
logger.debug(
|
|
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
|
|
)
|
|
return _DEFAULT_TOKENIZER
|
|
return _check_tokenizer_cache(provider_type, model_name)
|
|
|
|
|
|
def tokenizer_trim_content(
|
|
content: str, desired_length: int, tokenizer: BaseTokenizer
|
|
) -> str:
|
|
tokens = tokenizer.encode(content)
|
|
if len(tokens) > desired_length:
|
|
content = tokenizer.decode(tokens[:desired_length])
|
|
return content
|
|
|
|
|
|
def tokenizer_trim_chunks(
|
|
chunks: list[InferenceChunk],
|
|
tokenizer: BaseTokenizer,
|
|
max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
|
) -> list[InferenceChunk]:
|
|
new_chunks = copy(chunks)
|
|
for ind, chunk in enumerate(new_chunks):
|
|
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
|
|
if len(new_content) != len(chunk.content):
|
|
new_chunk = copy(chunk)
|
|
new_chunk.content = new_content
|
|
new_chunks[ind] = new_chunk
|
|
return new_chunks
|