mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 20:38:32 +02:00
Global Tokenizer Fix (#1825)
This commit is contained in:
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
|
||||
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||
"""Unclear if the wordpiece/sentencepiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||
It splits up even foreign characters and unicode emojis without using UNK"""
|
||||
tokenized_text = tokenizer.tokenize(text)
|
||||
num_unk_tokens = len(
|
||||
@@ -73,6 +73,7 @@ def recommend_search_flow(
|
||||
non_stopword_percent = len(non_stopwords) / len(words)
|
||||
|
||||
# UNK tokens -> suggest Keyword (still may be valid QA)
|
||||
# TODO do a better job with the classifier model and retire the heuristics
|
||||
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
|
||||
if not keyword:
|
||||
heuristic_search_type = SearchType.KEYWORD
|
||||
|
@@ -40,25 +40,22 @@ def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
# NOTE: If None is used, it may not be using the "correct" tokenizer, for cases
|
||||
# where this is more important, be sure to refresh with the actual model name
|
||||
def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
|
||||
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||
# for cases where this is more important, be sure to refresh with the actual model name
|
||||
# One case where it is not particularly important is in the document chunking flow,
|
||||
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||
# processes importing this file despite not using any of this
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
|
||||
global _TOKENIZER
|
||||
if _TOKENIZER[0] is None or (
|
||||
_TOKENIZER[1] is not None and _TOKENIZER[1] != model_name
|
||||
):
|
||||
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||
if _TOKENIZER[0] is not None:
|
||||
del _TOKENIZER
|
||||
gc.collect()
|
||||
|
||||
if model_name is None:
|
||||
# This could be inaccurate
|
||||
model_name = DOCUMENT_ENCODER_MODEL
|
||||
|
||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||
|
||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||
@@ -184,6 +181,7 @@ def warm_up_encoders(
|
||||
"https://docs.danswer.dev/quickstart"
|
||||
)
|
||||
|
||||
# May not be the exact same tokenizer used for the indexing flow
|
||||
get_default_tokenizer(model_name=model_name)(warm_up_str)
|
||||
|
||||
embed_model = EmbeddingModel(
|
||||
|
Reference in New Issue
Block a user