Global Tokenizer Fix (#1825)

This commit is contained in:
Yuhong Sun
2024-07-14 11:37:10 -07:00
committed by GitHub
parent e7f81d1688
commit 017af052be
2 changed files with 10 additions and 11 deletions

View File

@@ -15,7 +15,7 @@ if TYPE_CHECKING:
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
"""Unclear if the wordpiece/sentencepiece tokenizer used is actually tokenizing anything as the [UNK] token
It splits up even foreign characters and unicode emojis without using UNK"""
tokenized_text = tokenizer.tokenize(text)
num_unk_tokens = len(
@@ -73,6 +73,7 @@ def recommend_search_flow(
non_stopword_percent = len(non_stopwords) / len(words)
# UNK tokens -> suggest Keyword (still may be valid QA)
# TODO do a better job with the classifier model and retire the heuristics
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
if not keyword:
heuristic_search_type = SearchType.KEYWORD

View File

@@ -40,25 +40,22 @@ def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
# NOTE: If None is used, it may not be using the "correct" tokenizer, for cases
# where this is more important, be sure to refresh with the actual model name
def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
# for cases where this is more important, be sure to refresh with the actual model name
# One case where it is not particularly important is in the document chunking flow,
# they're basically all using the sentencepiece tokenizer and whether it's cased or
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
# NOTE: doing a local import here to avoid reduce memory usage caused by
# processes importing this file despite not using any of this
from transformers import AutoTokenizer # type: ignore
global _TOKENIZER
if _TOKENIZER[0] is None or (
_TOKENIZER[1] is not None and _TOKENIZER[1] != model_name
):
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
if _TOKENIZER[0] is not None:
del _TOKENIZER
gc.collect()
if model_name is None:
# This could be inaccurate
model_name = DOCUMENT_ENCODER_MODEL
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
@@ -184,6 +181,7 @@ def warm_up_encoders(
"https://docs.danswer.dev/quickstart"
)
# May not be the exact same tokenizer used for the indexing flow
get_default_tokenizer(model_name=model_name)(warm_up_str)
embed_model = EmbeddingModel(