mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-29 01:10:58 +02:00
Reenable OpenAI Tokenizer (#3062)
* k * clean up test embeddings * nit * minor update to ensure consistency * minor organizational update * minor updates --------- Co-authored-by: pablodanswer <pablo@danswer.ai>
This commit is contained in:
parent
2bbc5d5d07
commit
4fb65dcf73
@ -35,23 +35,31 @@ class BaseTokenizer(ABC):
|
|||||||
class TiktokenTokenizer(BaseTokenizer):
|
class TiktokenTokenizer(BaseTokenizer):
|
||||||
_instances: dict[str, "TiktokenTokenizer"] = {}
|
_instances: dict[str, "TiktokenTokenizer"] = {}
|
||||||
|
|
||||||
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
|
def __new__(cls, model_name: str) -> "TiktokenTokenizer":
|
||||||
if encoding_name not in cls._instances:
|
if model_name not in cls._instances:
|
||||||
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||||
return cls._instances[encoding_name]
|
return cls._instances[model_name]
|
||||||
|
|
||||||
def __init__(self, encoding_name: str = "cl100k_base"):
|
def __init__(self, model_name: str):
|
||||||
if not hasattr(self, "encoder"):
|
if not hasattr(self, "encoder"):
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
self.encoder = tiktoken.get_encoding(encoding_name)
|
self.encoder = tiktoken.encoding_for_model(model_name)
|
||||||
|
|
||||||
def encode(self, string: str) -> list[int]:
|
def encode(self, string: str) -> list[int]:
|
||||||
# this returns no special tokens
|
# this ignores special tokens that the model is trained on, see encode_ordinary for details
|
||||||
return self.encoder.encode_ordinary(string)
|
return self.encoder.encode_ordinary(string)
|
||||||
|
|
||||||
def tokenize(self, string: str) -> list[str]:
|
def tokenize(self, string: str) -> list[str]:
|
||||||
return [self.encoder.decode([token]) for token in self.encode(string)]
|
encoded = self.encode(string)
|
||||||
|
decoded = [self.encoder.decode([token]) for token in encoded]
|
||||||
|
|
||||||
|
if len(decoded) != len(encoded):
|
||||||
|
logger.warning(
|
||||||
|
f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return decoded
|
||||||
|
|
||||||
def decode(self, tokens: list[int]) -> str:
|
def decode(self, tokens: list[int]) -> str:
|
||||||
return self.encoder.decode(tokens)
|
return self.encoder.decode(tokens)
|
||||||
@ -74,22 +82,35 @@ class HuggingFaceTokenizer(BaseTokenizer):
|
|||||||
return self.encoder.decode(tokens)
|
return self.encoder.decode(tokens)
|
||||||
|
|
||||||
|
|
||||||
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
|
_TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {}
|
||||||
|
|
||||||
|
|
||||||
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
def _check_tokenizer_cache(
|
||||||
|
model_provider: EmbeddingProvider | None, model_name: str | None
|
||||||
|
) -> BaseTokenizer:
|
||||||
global _TOKENIZER_CACHE
|
global _TOKENIZER_CACHE
|
||||||
|
|
||||||
if tokenizer_name not in _TOKENIZER_CACHE:
|
id_tuple = (model_provider, model_name)
|
||||||
if tokenizer_name == "openai":
|
|
||||||
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
|
if id_tuple not in _TOKENIZER_CACHE:
|
||||||
return _TOKENIZER_CACHE[tokenizer_name]
|
if model_provider in [EmbeddingProvider.OPENAI, EmbeddingProvider.AZURE]:
|
||||||
|
if model_name is None:
|
||||||
|
raise ValueError(
|
||||||
|
"model_name is required for OPENAI and AZURE embeddings"
|
||||||
|
)
|
||||||
|
|
||||||
|
_TOKENIZER_CACHE[id_tuple] = TiktokenTokenizer(model_name)
|
||||||
|
return _TOKENIZER_CACHE[id_tuple]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
|
if model_name is None:
|
||||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
|
model_name = DOCUMENT_ENCODER_MODEL
|
||||||
|
|
||||||
|
logger.debug(f"Initializing HuggingFaceTokenizer for: {model_name}")
|
||||||
|
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(model_name)
|
||||||
except Exception as primary_error:
|
except Exception as primary_error:
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
|
f"Error initializing HuggingFaceTokenizer for {model_name}: {primary_error}"
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||||
@ -98,7 +119,7 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
|||||||
try:
|
try:
|
||||||
# Cache this tokenizer name to the default so we don't have to try to load it again
|
# Cache this tokenizer name to the default so we don't have to try to load it again
|
||||||
# and fail again
|
# and fail again
|
||||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
|
_TOKENIZER_CACHE[id_tuple] = HuggingFaceTokenizer(
|
||||||
DOCUMENT_ENCODER_MODEL
|
DOCUMENT_ENCODER_MODEL
|
||||||
)
|
)
|
||||||
except Exception as fallback_error:
|
except Exception as fallback_error:
|
||||||
@ -106,10 +127,10 @@ def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
|||||||
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
||||||
)
|
)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
|
f"Failed to initialize tokenizer for {model_name} and fallback model"
|
||||||
) from fallback_error
|
) from fallback_error
|
||||||
|
|
||||||
return _TOKENIZER_CACHE[tokenizer_name]
|
return _TOKENIZER_CACHE[id_tuple]
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
_DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
||||||
@ -118,11 +139,16 @@ _DEFAULT_TOKENIZER: BaseTokenizer = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL)
|
|||||||
def get_tokenizer(
|
def get_tokenizer(
|
||||||
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
model_name: str | None, provider_type: EmbeddingProvider | str | None
|
||||||
) -> BaseTokenizer:
|
) -> BaseTokenizer:
|
||||||
# Currently all of the viable models use the same sentencepiece tokenizer
|
if provider_type is not None:
|
||||||
# OpenAI uses a different one but currently it's not supported due to quality issues
|
if isinstance(provider_type, str):
|
||||||
# the inconsistent chunking makes using the sentencepiece tokenizer default better for now
|
try:
|
||||||
# LLM tokenizers are specified by strings
|
provider_type = EmbeddingProvider(provider_type)
|
||||||
global _DEFAULT_TOKENIZER
|
except ValueError:
|
||||||
|
logger.debug(
|
||||||
|
f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer."
|
||||||
|
)
|
||||||
|
return _DEFAULT_TOKENIZER
|
||||||
|
return _check_tokenizer_cache(provider_type, model_name)
|
||||||
return _DEFAULT_TOKENIZER
|
return _DEFAULT_TOKENIZER
|
||||||
|
|
||||||
|
|
||||||
|
@ -11,6 +11,7 @@ import {
|
|||||||
LLM_PROVIDERS_ADMIN_URL,
|
LLM_PROVIDERS_ADMIN_URL,
|
||||||
} from "../../configuration/llm/constants";
|
} from "../../configuration/llm/constants";
|
||||||
import { mutate } from "swr";
|
import { mutate } from "swr";
|
||||||
|
import { testEmbedding } from "../pages/utils";
|
||||||
|
|
||||||
export function ChangeCredentialsModal({
|
export function ChangeCredentialsModal({
|
||||||
provider,
|
provider,
|
||||||
@ -112,16 +113,15 @@ export function ChangeCredentialsModal({
|
|||||||
const normalizedProviderType = provider.provider_type
|
const normalizedProviderType = provider.provider_type
|
||||||
.toLowerCase()
|
.toLowerCase()
|
||||||
.split(" ")[0];
|
.split(" ")[0];
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
const testResponse = await testEmbedding({
|
||||||
method: "POST",
|
|
||||||
headers: { "Content-Type": "application/json" },
|
|
||||||
body: JSON.stringify({
|
|
||||||
provider_type: normalizedProviderType,
|
provider_type: normalizedProviderType,
|
||||||
api_key: apiKey,
|
modelName,
|
||||||
api_url: apiUrl,
|
apiKey,
|
||||||
model_name: modelName,
|
apiUrl,
|
||||||
}),
|
apiVersion: null,
|
||||||
|
deploymentName: null,
|
||||||
});
|
});
|
||||||
|
|
||||||
if (!testResponse.ok) {
|
if (!testResponse.ok) {
|
||||||
|
@ -110,20 +110,27 @@ export function ProviderCreationModal({
|
|||||||
setErrorMsg("");
|
setErrorMsg("");
|
||||||
try {
|
try {
|
||||||
const customConfig = Object.fromEntries(values.custom_config);
|
const customConfig = Object.fromEntries(values.custom_config);
|
||||||
|
const providerType = values.provider_type.toLowerCase().split(" ")[0];
|
||||||
|
const isOpenAI = providerType === "openai";
|
||||||
|
|
||||||
|
const testModelName =
|
||||||
|
isOpenAI || isAzure ? "text-embedding-3-small" : values.model_name;
|
||||||
|
|
||||||
|
const testEmbeddingPayload = {
|
||||||
|
provider_type: providerType,
|
||||||
|
api_key: values.api_key,
|
||||||
|
api_url: values.api_url,
|
||||||
|
model_name: testModelName,
|
||||||
|
api_version: values.api_version,
|
||||||
|
deployment_name: values.deployment_name,
|
||||||
|
};
|
||||||
|
|
||||||
const initialResponse = await fetch(
|
const initialResponse = await fetch(
|
||||||
"/api/admin/embedding/test-embedding",
|
"/api/admin/embedding/test-embedding",
|
||||||
{
|
{
|
||||||
method: "POST",
|
method: "POST",
|
||||||
headers: { "Content-Type": "application/json" },
|
headers: { "Content-Type": "application/json" },
|
||||||
body: JSON.stringify({
|
body: JSON.stringify(testEmbeddingPayload),
|
||||||
provider_type: values.provider_type.toLowerCase().split(" ")[0],
|
|
||||||
api_key: values.api_key,
|
|
||||||
api_url: values.api_url,
|
|
||||||
model_name: values.model_name,
|
|
||||||
api_version: values.api_version,
|
|
||||||
deployment_name: values.deployment_name,
|
|
||||||
}),
|
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
|
@ -8,3 +8,37 @@ export const deleteSearchSettings = async (search_settings_id: number) => {
|
|||||||
});
|
});
|
||||||
return response;
|
return response;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export const testEmbedding = async ({
|
||||||
|
provider_type,
|
||||||
|
modelName,
|
||||||
|
apiKey,
|
||||||
|
apiUrl,
|
||||||
|
apiVersion,
|
||||||
|
deploymentName,
|
||||||
|
}: {
|
||||||
|
provider_type: string;
|
||||||
|
modelName: string;
|
||||||
|
apiKey: string | null;
|
||||||
|
apiUrl: string | null;
|
||||||
|
apiVersion: string | null;
|
||||||
|
deploymentName: string | null;
|
||||||
|
}) => {
|
||||||
|
const testModelName =
|
||||||
|
provider_type === "openai" ? "text-embedding-3-small" : modelName;
|
||||||
|
|
||||||
|
const testResponse = await fetch("/api/admin/embedding/test-embedding", {
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({
|
||||||
|
provider_type: provider_type,
|
||||||
|
api_key: apiKey,
|
||||||
|
api_url: apiUrl,
|
||||||
|
model_name: testModelName,
|
||||||
|
api_version: apiVersion,
|
||||||
|
deployment_name: deploymentName,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
return testResponse;
|
||||||
|
};
|
||||||
|
Loading…
x
Reference in New Issue
Block a user