mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-26 07:50:56 +02:00
Reenable OpenAI Tokenizer (#3062)
* k * clean up test embeddings * nit * minor update to ensure consistency * minor organizational update * minor updates --------- Co-authored-by: pablodanswer <pablo@danswer.ai>
This commit is contained in:
parent
2bbc5d5d07
commit
4fb65dcf73
@ -35,23 +35,31 @@ class BaseTokenizer(ABC):
|
||||
class TiktokenTokenizer(BaseTokenizer):
|
||||
_instances: dict[str, "TiktokenTokenizer"] = {}
|
||||
|
||||
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
|
||||
if encoding_name not in cls._instances:
|
||||
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||
return cls._instances[encoding_name]
|
||||
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
|
||||
if model_name not in cls._instances:
|
||||
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||
return cls._instances[model_name]
|
||||
|
||||
def __init__(self, encoding_name: str = "cl100k_base"):
|
||||
def __init__(self, model_name: str):
|
||||
if not hasattr(self, "encoder"):
|
||||
import tiktoken
|
||||
|
||||
self.encoder = tiktoken.get_encoding(encoding_name)
|
||||
self.encoder = tiktoken.encoding_for_model(model_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
# this ignores special tokens that the model is trained on, see encode_ordinary for details
|
||||
return self.encoder.encode_ordinary(string)
|
||||
|
||||
def tokenize(self, string: str) -> list[str]:
|
||||
return [self.encoder.decode([token]) for token in self.encode(string)]
|
||||
encoded = self.encode(string)
|
||||
decoded = [self.encoder.decode([token]) for token in encoded]
|
||||
|
||||
if len(decoded) != len(encoded):
|
||||
logger.warning(
|
||||
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
|
||||
)
|
||||
|
||||
return decoded
|
||||
|
||||
def decode(self, tokens: list[int]) -> str:
|
||||
return self.encoder.decode(tokens)
|
||||
@ -74,22 +82,35 @@ class HuggingFaceTokenizer(BaseTokenizer):
|
||||
return self.encoder.decode(tokens)
|
||||
|
||||
|
||||
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
|
||||
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
|
||||
|
||||
|
||||
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
def _check_tokenizer_cache(
|
||||
model_provider: EmbeddingProvider | None, model_name: str | None
|
||||
) -> BaseTokenizer:
|
||||
global _TOKENIZER_CACHE
|
||||
|
||||
if tokenizer_name not in _TOKENIZER_CACHE:
|
||||
if tokenizer_name == "openai":
|
||||
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
id_tuple = (model_provider, model_name)
|
||||
|
||||
if id_tuple not in _TOKENIZER_CACHE:
|
||||
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
|
||||
if model_name is None:
|
||||
raise ValueError(
|
||||
"model_name is required for OPENAI and AZURE embeddings"
|
||||
)
|
||||
|
||||
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
|
||||
return _TOKENIZER_CACHE[id_tuple]
|
||||
|
||||
try:
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
|
||||
if model_name is None:
|
||||
model_name = DOCUMENT_ENCODER_MODEL
|
||||
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
|
||||
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
|
||||
except Exception as primary_error:
|
||||
logger.error(
|
||||
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
|
||||
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
@ -98,7 +119,7 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
try:
|
||||
# Cache this tokenizer name to the default so we don't have to try to load it again
|
||||
# and fail again
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
|
||||
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
except Exception as fallback_error:
|
||||
@ -106,10 +127,10 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
|
||||
f"Failed to initialize tokenizer for {model_name} and fallback model"
|
||||
) from fallback_error
|
||||
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
return _TOKENIZER_CACHE[id_tuple]
|
||||
|
||||
|
||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
@ -118,11 +139,16 @@ _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||
def get_tokenizer(
|
||||
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
||||
) -> BaseTokenizer:
|
||||
# Currently all of the viable models use the same sentencepiece tokenizer
|
||||
# OpenAI uses a different one but currently it's not supported due to quality issues
|
||||
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
|
||||
# LLM tokenizers are specified by strings
|
||||
global _DEFAULT_TOKENIZER
|
||||
if provider_type is not None:
|
||||
if isinstance(provider_type, str):
|
||||
try:
|
||||
provider_type = EmbeddingProvider(provider_type)
|
||||
except ValueError:
|
||||
logger.debug(
|
||||
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
|
||||
)
|
||||
return _DEFAULT_TOKENIZER
|
||||
return _check_tokenizer_cache(provider_type, model_name)
|
||||
return _DEFAULT_TOKENIZER
|
||||
|
||||
|
||||
|
@ -11,6 +11,7 @@ import {
|
||||
LLM_PROVIDERS_ADMIN_URL,
|
||||
} from "../../configuration/llm/constants";
|
||||
import { mutate } from "swr";
|
||||
import { testEmbedding } from "../pages/utils";
|
||||
|
||||
export function ChangeCredentialsModal({
|
||||
provider,
|
||||
@ -112,16 +113,15 @@ export function ChangeCredentialsModal({
|
||||
const normalizedProviderType = provider.provider_type
|
||||
.toLowerCase()
|
||||
.split(" ")[0];
|
||||
|
||||
try {
|
||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: normalizedProviderType,
|
||||
api_key: apiKey,
|
||||
api_url: apiUrl,
|
||||
model_name: modelName,
|
||||
}),
|
||||
const testResponse = await testEmbedding({
|
||||
provider_type: normalizedProviderType,
|
||||
modelName,
|
||||
apiKey,
|
||||
apiUrl,
|
||||
apiVersion: null,
|
||||
deploymentName: null,
|
||||
});
|
||||
|
||||
if (!testResponse.ok) {
|
||||
|
@ -110,20 +110,27 @@ export function ProviderCreationModal({
|
||||
setErrorMsg("");
|
||||
try {
|
||||
const customConfig = Object.fromEntries(values.custom_config);
|
||||
const providerType = values.provider_type.toLowerCase().split(" ")[0];
|
||||
const isOpenAI = providerType === "openai";
|
||||
|
||||
const testModelName =
|
||||
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;
|
||||
|
||||
const testEmbeddingPayload = {
|
||||
provider_type: providerType,
|
||||
api_key: values.api_key,
|
||||
api_url: values.api_url,
|
||||
model_name: testModelName,
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
};
|
||||
|
||||
const initialResponse = await fetch(
|
||||
"/api/admin/embedding/test-embedding",
|
||||
{
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
||||
api_key: values.api_key,
|
||||
api_url: values.api_url,
|
||||
model_name: values.model_name,
|
||||
api_version: values.api_version,
|
||||
deployment_name: values.deployment_name,
|
||||
}),
|
||||
body: JSON.stringify(testEmbeddingPayload),
|
||||
}
|
||||
);
|
||||
|
||||
|
@ -8,3 +8,37 @@ export const deleteSearchSettings = async (search_settings_id: number) => {
|
||||
});
|
||||
return response;
|
||||
};
|
||||
|
||||
export const testEmbedding = async ({
|
||||
provider_type,
|
||||
modelName,
|
||||
apiKey,
|
||||
apiUrl,
|
||||
apiVersion,
|
||||
deploymentName,
|
||||
}: {
|
||||
provider_type: string;
|
||||
modelName: string;
|
||||
apiKey: string | null;
|
||||
apiUrl: string | null;
|
||||
apiVersion: string | null;
|
||||
deploymentName: string | null;
|
||||
}) => {
|
||||
const testModelName =
|
||||
provider_type === "openai" ? "text-embedding-3-small" : modelName;
|
||||
|
||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
provider_type: provider_type,
|
||||
api_key: apiKey,
|
||||
api_url: apiUrl,
|
||||
model_name: testModelName,
|
||||
api_version: apiVersion,
|
||||
deployment_name: deploymentName,
|
||||
}),
|
||||
});
|
||||
|
||||
return testResponse;
|
||||
};
|
||||
|
Loading…
x
Reference in New Issue
Block a user