mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-26 20:08:38 +02:00
Openai encoding temp hotfix (#2094)
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
import time
|
||||
|
||||
import requests
|
||||
@@ -32,6 +33,25 @@ def clean_model_name(model_str: str) -> str:
|
||||
return model_str.replace("/", "_").replace("-", "_").replace(".", "_")
|
||||
|
||||
|
||||
_WHITELIST = set(
|
||||
" !\"#$%&'()*+,-./0123456789:;<=>?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\\]^_`abcdefghijklmnopqrstuvwxyz{|}~\n\t"
|
||||
)
|
||||
_INITIAL_FILTER = re.compile(
|
||||
"["
|
||||
"\U00000080-\U0000FFFF" # All Unicode characters beyond ASCII
|
||||
"\U00010000-\U0010FFFF" # All Unicode characters in supplementary planes
|
||||
"]+",
|
||||
flags=re.UNICODE,
|
||||
)
|
||||
|
||||
|
||||
def clean_openai_text(text: str) -> str:
|
||||
# First, remove all weird characters
|
||||
cleaned = _INITIAL_FILTER.sub("", text)
|
||||
# Then, keep only whitelisted characters
|
||||
return "".join(char for char in cleaned if char in _WHITELIST)
|
||||
|
||||
|
||||
def build_model_server_url(
|
||||
model_server_host: str,
|
||||
model_server_port: int,
|
||||
@@ -180,6 +200,8 @@ class EmbeddingModel:
|
||||
]
|
||||
|
||||
if self.provider_type:
|
||||
if self.provider_type == "openai":
|
||||
texts = [clean_openai_text(text) for text in texts]
|
||||
return self._encode_api_model(
|
||||
texts=texts, text_type=text_type, batch_size=api_embedding_batch_size
|
||||
)
|
||||
|
@@ -111,18 +111,12 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
|
||||
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
|
||||
if provider_type:
|
||||
if provider_type.lower() == "openai":
|
||||
# Used across ada and text-embedding-3 models
|
||||
return _check_tokenizer_cache("openai")
|
||||
# If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name
|
||||
# this means we are approximating the token count which may leave some performance on the table
|
||||
|
||||
if not model_name:
|
||||
raise ValueError("Need to provide a model_name or provider_type")
|
||||
|
||||
return _check_tokenizer_cache(model_name)
|
||||
global _DEFAULT_TOKENIZER
|
||||
return _DEFAULT_TOKENIZER
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
|
Reference in New Issue
Block a user