mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
Improved tokenizer fallback (#3132)
* silence warning * improved fallback logic * k * minor cosmetic update * minor logic update * nit
This commit is contained in:
@@ -89,67 +89,70 @@ def _check_tokenizer_cache(
|
||||
model_provider: EmbeddingProvider | None, model_name: str | None
|
||||
) -> BaseTokenizer:
|
||||
global _TOKENIZER_CACHE
|
||||
|
||||
id_tuple = (model_provider, model_name)
|
||||
|
||||
if id_tuple not in _TOKENIZER_CACHE:
|
||||
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
|
||||
if model_name is None:
|
||||
raise ValueError(
|
||||
"model_name is required for OPENAI and AZURE embeddings"
|
||||
)
|
||||
tokenizer = None
|
||||
|
||||
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
|
||||
return _TOKENIZER_CACHE[id_tuple]
|
||||
if model_name:
|
||||
tokenizer = _try_initialize_tokenizer(model_name, model_provider)
|
||||
|
||||
try:
|
||||
if model_name is None:
|
||||
model_name = DOCUMENT_ENCODER_MODEL
|
||||
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
|
||||
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
|
||||
except Exception as primary_error:
|
||||
logger.error(
|
||||
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
|
||||
)
|
||||
logger.warning(
|
||||
if not tokenizer:
|
||||
logger.info(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
tokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
try:
|
||||
# Cache this tokenizer name to the default so we don't have to try to load it again
|
||||
# and fail again
|
||||
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize tokenizer for {model_name} and fallback model"
|
||||
) from fallback_error
|
||||
_TOKENIZER_CACHE[id_tuple] = tokenizer
|
||||
|
||||
return _TOKENIZER_CACHE[id_tuple]
|
||||
|
||||
|
||||
def _try_initialize_tokenizer(
|
||||
model_name: str, model_provider: EmbeddingProvider | None
|
||||
) -> BaseTokenizer | None:
|
||||
tokenizer: BaseTokenizer | None = None
|
||||
|
||||
if model_provider is not None:
|
||||
# Try using TiktokenTokenizer first if model_provider exists
|
||||
try:
|
||||
tokenizer = TiktokenTokenizer(model_name)
|
||||
logger.info(f"Initialized TiktokenTokenizer for: {model_name}")
|
||||
return tokenizer
|
||||
except Exception as tiktoken_error:
|
||||
logger.debug(
|
||||
f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}"
|
||||
)
|
||||
else:
|
||||
# If no provider specified, try HuggingFaceTokenizer
|
||||
try:
|
||||
tokenizer = HuggingFaceTokenizer(model_name)
|
||||
logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}")
|
||||
return tokenizer
|
||||
except Exception as hf_error:
|
||||
logger.warning(
|
||||
f"Error initializing HuggingFaceTokenizer for {model_name}: {hf_error}"
|
||||
)
|
||||
|
||||
# If both initializations fail, return None
|
||||
return None
|
||||
|
||||
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
|
||||
def get_tokenizer(
|
||||
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
||||
) -> BaseTokenizer:
|
||||
if provider_type is not None:
|
||||
if isinstance(provider_type, str):
|
||||
try:
|
||||
provider_type = EmbeddingProvider(provider_type)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
|
||||
)
|
||||
return _DEFAULT_TOKENIZER
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
return _DEFAULT_TOKENIZER
|
||||
if isinstance(provider_type, str):
|
||||
try:
|
||||
provider_type = EmbeddingProvider(provider_type)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
|
||||
)
|
||||
return _DEFAULT_TOKENIZER
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
|
Reference in New Issue
Block a user