From b23008289192a5eb71aaf4ba8ce82c6f0a9fc0df Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Fri, 9 Aug 2024 08:17:31 -0700 Subject: [PATCH] Openai encoding temp hotfix (#2094) --- .../search_nlp_models.py | 22 +++++++++++++++++++ .../natural_language_processing/utils.py | 16 +++++--------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/backend/danswer/natural_language_processing/search_nlp_models.py b/backend/danswer/natural_language_processing/search_nlp_models.py index 6757fada2ae0..db2f1181e0a8 100644 --- a/backend/danswer/natural_language_processing/search_nlp_models.py +++ b/backend/danswer/natural_language_processing/search_nlp_models.py @@ -1,3 +1,4 @@ +import re import time import requests @@ -32,6 +33,25 @@ def clean_model_name(model_str: str) -> str: return model_str.replace("/", "_").replace("-", "_").replace(".", "_") +_WHITELIST = set( + " !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\n\t" +) +_INITIAL_FILTER = re.compile( + "[" + "\U00000080-\U0000FFFF" # All Unicode characters beyond ASCII + "\U00010000-\U0010FFFF" # All Unicode characters in supplementary planes + "]+", + flags=re.UNICODE, +) + + +def clean_openai_text(text: str) -> str: + # First, remove all weird characters + cleaned = _INITIAL_FILTER.sub("", text) + # Then, keep only whitelisted characters + return "".join(char for char in cleaned if char in _WHITELIST) + + def build_model_server_url( model_server_host: str, model_server_port: int, @@ -180,6 +200,8 @@ class EmbeddingModel: ] if self.provider_type: + if self.provider_type == "openai": + texts = [clean_openai_text(text) for text in texts] return self._encode_api_model( texts=texts, text_type=text_type, batch_size=api_embedding_batch_size ) diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index 30726033fa23..02d599ffcda1 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -111,18 +111,12 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: return _TOKENIZER_CACHE[tokenizer_name] +_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) + + def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer: - if provider_type: - if provider_type.lower() == "openai": - # Used across ada and text-embedding-3 models - return _check_tokenizer_cache("openai") - # If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name - # this means we are approximating the token count which may leave some performance on the table - - if not model_name: - raise ValueError("Need to provide a model_name or provider_type") - - return _check_tokenizer_cache(model_name) + global _DEFAULT_TOKENIZER + return _DEFAULT_TOKENIZER def tokenizer_trim_content(