Openai encoding temp hotfix (#2094)

This commit is contained in:
hagen-danswer
2024-08-09 08:17:31 -07:00
committed by GitHub
parent 8cd1eda8b1
commit b230082891
2 changed files with 27 additions and 11 deletions

View File

@@ -1,3 +1,4 @@
import re
import time
import requests
@@ -32,6 +33,25 @@ def clean_model_name(model_str: str) -> str:
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
_WHITELIST = set(
" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\n\t"
)
_INITIAL_FILTER = re.compile(
"["
"\U00000080-\U0000FFFF" # All Unicode characters beyond ASCII
"\U00010000-\U0010FFFF" # All Unicode characters in supplementary planes
"]+",
flags=re.UNICODE,
)
def clean_openai_text(text: str) -> str:
# First, remove all weird characters
cleaned = _INITIAL_FILTER.sub("", text)
# Then, keep only whitelisted characters
return "".join(char for char in cleaned if char in _WHITELIST)
def build_model_server_url(
model_server_host: str,
model_server_port: int,
@@ -180,6 +200,8 @@ class EmbeddingModel:
]
if self.provider_type:
if self.provider_type == "openai":
texts = [clean_openai_text(text) for text in texts]
return self._encode_api_model(
texts=texts, text_type=text_type, batch_size=api_embedding_batch_size
)

View File

@@ -111,18 +111,12 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
return _TOKENIZER_CACHE[tokenizer_name]
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
if provider_type:
if provider_type.lower() == "openai":
# Used across ada and text-embedding-3 models
return _check_tokenizer_cache("openai")
# If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name
# this means we are approximating the token count which may leave some performance on the table
if not model_name:
raise ValueError("Need to provide a model_name or provider_type")
return _check_tokenizer_cache(model_name)
global _DEFAULT_TOKENIZER
return _DEFAULT_TOKENIZER
def tokenizer_trim_content(