2024-12-13 09:56:10 -08:00

180 lines
5.8 KiB
Python

import os
from abc import ABC
from abc import abstractmethod
from copy import copy
from transformers import logging as transformer_logging # type:ignore
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
from onyx.context.search.models import InferenceChunk
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
os.environ["TRANSFORMERS_NO_ADVISORY_WARNINGS"] = "1"
class BaseTokenizer(ABC):
@abstractmethod
def encode(self, string: str) -> list[int]:
pass
@abstractmethod
def tokenize(self, string: str) -> list[str]:
pass
@abstractmethod
def decode(self, tokens: list[int]) -> str:
pass
class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {}
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if model_name not in cls._instances:
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[model_name]
def __init__(self, model_name: str):
if not hasattr(self, "encoder"):
import tiktoken
self.encoder = tiktoken.encoding_for_model(model_name)
def encode(self, string: str) -> list[int]:
# this ignores special tokens that the model is trained on, see encode_ordinary for details
return self.encoder.encode_ordinary(string)
def tokenize(self, string: str) -> list[str]:
encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]
if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)
return decoded
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
class HuggingFaceTokenizer(BaseTokenizer):
def __init__(self, model_name: str):
from tokenizers import Tokenizer # type: ignore
self.encoder = Tokenizer.from_pretrained(model_name)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
return self.encoder.encode(string, add_special_tokens=False).ids
def tokenize(self, string: str) -> list[str]:
return self.encoder.encode(string, add_special_tokens=False).tokens
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE
id_tuple = (model_provider, model_name)
if id_tuple not in _TOKENIZER_CACHE:
tokenizer = None
if model_name:
tokenizer = _try_initialize_tokenizer(model_name, model_provider)
if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
_TOKENIZER_CACHE[id_tuple] = tokenizer
return _TOKENIZER_CACHE[id_tuple]
def _try_initialize_tokenizer(
model_name: str, model_provider: EmbeddingProvider | None
) -> BaseTokenizer | None:
tokenizer: BaseTokenizer | None = None
if model_provider is not None:
# Try using TiktokenTokenizer first if model_provider exists
try:
tokenizer = TiktokenTokenizer(model_name)
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
return tokenizer
except Exception as tiktoken_error:
logger.debug(
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
)
else:
# If no provider specified, try HuggingFaceTokenizer
try:
tokenizer = HuggingFaceTokenizer(model_name)
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
return tokenizer
except Exception as hf_error:
logger.warning(
f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}"
)
# If both initializations fail, return None
return None
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
def tokenizer_trim_chunks(
chunks: list[InferenceChunk],
tokenizer: BaseTokenizer,
max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE,
) -> list[InferenceChunk]:
new_chunks = copy(chunks)
for ind, chunk in enumerate(new_chunks):
new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer)
if len(new_content) != len(chunk.content):
new_chunk = copy(chunk)
new_chunk.content = new_content
new_chunks[ind] = new_chunk
return new_chunks