Improved tokenizer fallback (#3132)

* silence warning

* improved fallback logic

* k

* minor cosmetic update

* minor logic update

* nit
This commit is contained in:
pablodanswer
2024-11-14 20:13:29 -08:00
committed by GitHub
parent ddff7ecc3f
commit 24be13c015

View File

@@ -89,67 +89,70 @@ def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE
id_tuple = (model_provider, model_name)
if id_tuple not in _TOKENIZER_CACHE:
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)
tokenizer = None
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]
if model_name:
tokenizer = _try_initialize_tokenizer(model_name, model_provider)
try:
if model_name is None:
model_name = DOCUMENT_ENCODER_MODEL
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
)
logger.warning(
if not tokenizer:
logger.info(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
except Exception as fallback_error:
logger.error(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error
_TOKENIZER_CACHE[id_tuple] = tokenizer
return _TOKENIZER_CACHE[id_tuple]
def _try_initialize_tokenizer(
model_name: str, model_provider: EmbeddingProvider | None
) -> BaseTokenizer | None:
tokenizer: BaseTokenizer | None = None
if model_provider is not None:
# Try using TiktokenTokenizer first if model_provider exists
try:
tokenizer = TiktokenTokenizer(model_name)
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
return tokenizer
except Exception as tiktoken_error:
logger.debug(
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
)
else:
# If no provider specified, try HuggingFaceTokenizer
try:
tokenizer = HuggingFaceTokenizer(model_name)
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
return tokenizer
except Exception as hf_error:
logger.warning(
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
)
# If both initializations fail, return None
return None
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer:
if provider_type is not None:
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER
if isinstance(provider_type, str):
try:
provider_type = EmbeddingProvider(provider_type)
except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
def tokenizer_trim_content(