mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 13:15:18 +02:00
Global Tokenizer Fix (#1825)
This commit is contained in:
@@ -15,7 +15,7 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
|
|
||||||
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
|
def count_unk_tokens(text: str, tokenizer: "AutoTokenizer") -> int:
|
||||||
"""Unclear if the wordpiece tokenizer used is actually tokenizing anything as the [UNK] token
|
"""Unclear if the wordpiece/sentencepiece tokenizer used is actually tokenizing anything as the [UNK] token
|
||||||
It splits up even foreign characters and unicode emojis without using UNK"""
|
It splits up even foreign characters and unicode emojis without using UNK"""
|
||||||
tokenized_text = tokenizer.tokenize(text)
|
tokenized_text = tokenizer.tokenize(text)
|
||||||
num_unk_tokens = len(
|
num_unk_tokens = len(
|
||||||
@@ -73,6 +73,7 @@ def recommend_search_flow(
|
|||||||
non_stopword_percent = len(non_stopwords) / len(words)
|
non_stopword_percent = len(non_stopwords) / len(words)
|
||||||
|
|
||||||
# UNK tokens -> suggest Keyword (still may be valid QA)
|
# UNK tokens -> suggest Keyword (still may be valid QA)
|
||||||
|
# TODO do a better job with the classifier model and retire the heuristics
|
||||||
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
|
if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0:
|
||||||
if not keyword:
|
if not keyword:
|
||||||
heuristic_search_type = SearchType.KEYWORD
|
heuristic_search_type = SearchType.KEYWORD
|
||||||
|
@@ -40,25 +40,22 @@ def clean_model_name(model_str: str) -> str:
|
|||||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||||
|
|
||||||
|
|
||||||
# NOTE: If None is used, it may not be using the "correct" tokenizer, for cases
|
# NOTE: If no model_name is specified, it may not be using the "correct" tokenizer
|
||||||
# where this is more important, be sure to refresh with the actual model name
|
# for cases where this is more important, be sure to refresh with the actual model name
|
||||||
def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer":
|
# One case where it is not particularly important is in the document chunking flow,
|
||||||
|
# they're basically all using the sentencepiece tokenizer and whether it's cased or
|
||||||
|
# uncased does not really matter, they'll all generally end up with the same chunk lengths.
|
||||||
|
def get_default_tokenizer(model_name: str = DOCUMENT_ENCODER_MODEL) -> "AutoTokenizer":
|
||||||
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
# NOTE: doing a local import here to avoid reduce memory usage caused by
|
||||||
# processes importing this file despite not using any of this
|
# processes importing this file despite not using any of this
|
||||||
from transformers import AutoTokenizer # type: ignore
|
from transformers import AutoTokenizer # type: ignore
|
||||||
|
|
||||||
global _TOKENIZER
|
global _TOKENIZER
|
||||||
if _TOKENIZER[0] is None or (
|
if _TOKENIZER[0] is None or _TOKENIZER[1] != model_name:
|
||||||
_TOKENIZER[1] is not None and _TOKENIZER[1] != model_name
|
|
||||||
):
|
|
||||||
if _TOKENIZER[0] is not None:
|
if _TOKENIZER[0] is not None:
|
||||||
del _TOKENIZER
|
del _TOKENIZER
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
if model_name is None:
|
|
||||||
# This could be inaccurate
|
|
||||||
model_name = DOCUMENT_ENCODER_MODEL
|
|
||||||
|
|
||||||
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
_TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name)
|
||||||
|
|
||||||
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast:
|
||||||
@@ -184,6 +181,7 @@ def warm_up_encoders(
|
|||||||
"https://docs.danswer.dev/quickstart"
|
"https://docs.danswer.dev/quickstart"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# May not be the exact same tokenizer used for the indexing flow
|
||||||
get_default_tokenizer(model_name=model_name)(warm_up_str)
|
get_default_tokenizer(model_name=model_name)(warm_up_str)
|
||||||
|
|
||||||
embed_model = EmbeddingModel(
|
embed_model = EmbeddingModel(
|
||||||
|
Reference in New Issue
Block a user