Reenable OpenAI Tokenizer (#3062)

* k

* clean up test embeddings

* nit

* minor update to ensure consistency

* minor organizational update

* minor updates

---------

Co-authored-by: pablodanswer <pablo@danswer.ai>
This commit is contained in:
Yuhong Sun 2024-11-08 14:54:15 -08:00 committed by GitHub
parent 2bbc5d5d07
commit 4fb65dcf73
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 109 additions and 42 deletions

View File

@ -35,23 +35,31 @@ class BaseTokenizer(ABC):
class TiktokenTokenizer(BaseTokenizer): class TiktokenTokenizer(BaseTokenizer):
_instances: dict[str, "TiktokenTokenizer"] = {} _instances: dict[str, "TiktokenTokenizer"] = {}
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer": def __new__(cls, model_name: str) -> "TiktokenTokenizer":
if encoding_name not in cls._instances: if model_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls) cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name] return cls._instances[model_name]
def __init__(self, encoding_name: str = "cl100k_base"): def __init__(self, model_name: str):
if not hasattr(self, "encoder"): if not hasattr(self, "encoder"):
import tiktoken import tiktoken
self.encoder = tiktoken.get_encoding(encoding_name) self.encoder = tiktoken.encoding_for_model(model_name)
def encode(self, string: str) -> list[int]: def encode(self, string: str) -> list[int]:
# this returns no special tokens # this ignores special tokens that the model is trained on, see encode_ordinary for details
return self.encoder.encode_ordinary(string) return self.encoder.encode_ordinary(string)
def tokenize(self, string: str) -> list[str]: def tokenize(self, string: str) -> list[str]:
return [self.encoder.decode([token]) for token in self.encode(string)] encoded = self.encode(string)
decoded = [self.encoder.decode([token]) for token in encoded]
if len(decoded) != len(encoded):
logger.warning(
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
)
return decoded
def decode(self, tokens: list[int]) -> str: def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens) return self.encoder.decode(tokens)
@ -74,22 +82,35 @@ class HuggingFaceTokenizer(BaseTokenizer):
return self.encoder.decode(tokens) return self.encoder.decode(tokens)
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {} _TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: def _check_tokenizer_cache(
model_provider: EmbeddingProvider | None, model_name: str | None
) -> BaseTokenizer:
global _TOKENIZER_CACHE global _TOKENIZER_CACHE
if tokenizer_name not in _TOKENIZER_CACHE: id_tuple = (model_provider, model_name)
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base") if id_tuple not in _TOKENIZER_CACHE:
return _TOKENIZER_CACHE[tokenizer_name] if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
if model_name is None:
raise ValueError(
"model_name is required for OPENAI and AZURE embeddings"
)
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
return _TOKENIZER_CACHE[id_tuple]
try: try:
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}") if model_name is None:
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name) model_name = DOCUMENT_ENCODER_MODEL
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
except Exception as primary_error: except Exception as primary_error:
logger.error( logger.error(
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}" f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
) )
logger.warning( logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
@ -98,7 +119,7 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
try: try:
# Cache this tokenizer name to the default so we don't have to try to load it again # Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again # and fail again
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer( _TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL DOCUMENT_ENCODER_MODEL
) )
except Exception as fallback_error: except Exception as fallback_error:
@ -106,10 +127,10 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}" f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
) )
raise ValueError( raise ValueError(
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model" f"Failed to initialize tokenizer for {model_name} and fallback model"
) from fallback_error ) from fallback_error
return _TOKENIZER_CACHE[tokenizer_name] return _TOKENIZER_CACHE[id_tuple]
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
@ -118,11 +139,16 @@ _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
def get_tokenizer( def get_tokenizer(
model_name: str | None, provider_type: EmbeddingProvider | str | None model_name: str | None, provider_type: EmbeddingProvider | str | None
) -> BaseTokenizer: ) -> BaseTokenizer:
# Currently all of the viable models use the same sentencepiece tokenizer if provider_type is not None:
# OpenAI uses a different one but currently it's not supported due to quality issues if isinstance(provider_type, str):
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now try:
# LLM tokenizers are specified by strings provider_type = EmbeddingProvider(provider_type)
global _DEFAULT_TOKENIZER except ValueError:
logger.debug(
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
)
return _DEFAULT_TOKENIZER
return _check_tokenizer_cache(provider_type, model_name)
return _DEFAULT_TOKENIZER return _DEFAULT_TOKENIZER

View File

@ -11,6 +11,7 @@ import {
LLM_PROVIDERS_ADMIN_URL, LLM_PROVIDERS_ADMIN_URL,
} from "../../configuration/llm/constants"; } from "../../configuration/llm/constants";
import { mutate } from "swr"; import { mutate } from "swr";
import { testEmbedding } from "../pages/utils";
export function ChangeCredentialsModal({ export function ChangeCredentialsModal({
provider, provider,
@ -112,16 +113,15 @@ export function ChangeCredentialsModal({
const normalizedProviderType = provider.provider_type const normalizedProviderType = provider.provider_type
.toLowerCase() .toLowerCase()
.split(" ")[0]; .split(" ")[0];
try { try {
const testResponse = await fetch("/api/admin/embedding/test-embedding", { const testResponse = await testEmbedding({
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: normalizedProviderType, provider_type: normalizedProviderType,
api_key: apiKey, modelName,
api_url: apiUrl, apiKey,
model_name: modelName, apiUrl,
}), apiVersion: null,
deploymentName: null,
}); });
if (!testResponse.ok) { if (!testResponse.ok) {

View File

@ -110,20 +110,27 @@ export function ProviderCreationModal({
setErrorMsg(""); setErrorMsg("");
try { try {
const customConfig = Object.fromEntries(values.custom_config); const customConfig = Object.fromEntries(values.custom_config);
const providerType = values.provider_type.toLowerCase().split(" ")[0];
const isOpenAI = providerType === "openai";
const testModelName =
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;
const testEmbeddingPayload = {
provider_type: providerType,
api_key: values.api_key,
api_url: values.api_url,
model_name: testModelName,
api_version: values.api_version,
deployment_name: values.deployment_name,
};
const initialResponse = await fetch( const initialResponse = await fetch(
"/api/admin/embedding/test-embedding", "/api/admin/embedding/test-embedding",
{ {
method: "POST", method: "POST",
headers: { "Content-Type": "application/json" }, headers: { "Content-Type": "application/json" },
body: JSON.stringify({ body: JSON.stringify(testEmbeddingPayload),
provider_type: values.provider_type.toLowerCase().split(" ")[0],
api_key: values.api_key,
api_url: values.api_url,
model_name: values.model_name,
api_version: values.api_version,
deployment_name: values.deployment_name,
}),
} }
); );

View File

@ -8,3 +8,37 @@ export const deleteSearchSettings = async (search_settings_id: number) => {
}); });
return response; return response;
}; };
export const testEmbedding = async ({
provider_type,
modelName,
apiKey,
apiUrl,
apiVersion,
deploymentName,
}: {
provider_type: string;
modelName: string;
apiKey: string | null;
apiUrl: string | null;
apiVersion: string | null;
deploymentName: string | null;
}) => {
const testModelName =
provider_type === "openai" ? "text-embedding-3-small" : modelName;
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
provider_type: provider_type,
api_key: apiKey,
api_url: apiUrl,
model_name: testModelName,
api_version: apiVersion,
deployment_name: deploymentName,
}),
});
return testResponse;
};