mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-08 03:48:14 +02:00
Fixed tokenizer logic (#1986)
This commit is contained in:
parent
d619602a6f
commit
5307d38472
@ -32,10 +32,18 @@ class BaseTokenizer(ABC):
|
||||
|
||||
|
||||
class TiktokenTokenizer(BaseTokenizer):
|
||||
def __init__(self, encoding_name: str = "cl100k_base"):
|
||||
import tiktoken
|
||||
_instances: dict[str, "TiktokenTokenizer"] = {}
|
||||
|
||||
self.encoder = tiktoken.get_encoding(encoding_name)
|
||||
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
|
||||
if encoding_name not in cls._instances:
|
||||
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
|
||||
return cls._instances[encoding_name]
|
||||
|
||||
def __init__(self, encoding_name: str = "cl100k_base"):
|
||||
if not hasattr(self, "encoder"):
|
||||
import tiktoken
|
||||
|
||||
self.encoder = tiktoken.get_encoding(encoding_name)
|
||||
|
||||
def encode(self, string: str) -> list[int]:
|
||||
# this returns no special tokens
|
||||
@ -68,40 +76,52 @@ class HuggingFaceTokenizer(BaseTokenizer):
|
||||
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
|
||||
|
||||
|
||||
def _get_cached_tokenizer(
|
||||
model_name: str | None = None, provider_type: str | None = None
|
||||
) -> BaseTokenizer:
|
||||
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
|
||||
global _TOKENIZER_CACHE
|
||||
|
||||
if provider_type:
|
||||
if not _TOKENIZER_CACHE.get(provider_type):
|
||||
if provider_type.lower() == "openai":
|
||||
_TOKENIZER_CACHE[provider_type] = TiktokenTokenizer()
|
||||
elif provider_type.lower() == "cohere":
|
||||
_TOKENIZER_CACHE[provider_type] = HuggingFaceTokenizer(
|
||||
"Cohere/command-nightly"
|
||||
if tokenizer_name not in _TOKENIZER_CACHE:
|
||||
if tokenizer_name == "openai":
|
||||
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
try:
|
||||
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
|
||||
except Exception as primary_error:
|
||||
logger.error(
|
||||
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
|
||||
)
|
||||
logger.warning(
|
||||
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Cache this tokenizer name to the default so we don't have to try to load it again
|
||||
# and fail again
|
||||
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
)
|
||||
else:
|
||||
_TOKENIZER_CACHE[
|
||||
provider_type
|
||||
] = TiktokenTokenizer() # Default to OpenAI tokenizer
|
||||
return _TOKENIZER_CACHE[provider_type]
|
||||
except Exception as fallback_error:
|
||||
logger.error(
|
||||
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
|
||||
) from fallback_error
|
||||
|
||||
if model_name:
|
||||
if not _TOKENIZER_CACHE.get(model_name):
|
||||
_TOKENIZER_CACHE[model_name] = HuggingFaceTokenizer(model_name)
|
||||
return _TOKENIZER_CACHE[model_name]
|
||||
|
||||
raise ValueError("Need to provide a model_name or provider_type")
|
||||
return _TOKENIZER_CACHE[tokenizer_name]
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
|
||||
if provider_type is None and model_name is None:
|
||||
model_name = DOCUMENT_ENCODER_MODEL
|
||||
return _get_cached_tokenizer(
|
||||
model_name=model_name,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
if provider_type:
|
||||
if provider_type.lower() == "openai":
|
||||
# Used across ada and text-embedding-3 models
|
||||
return _check_tokenizer_cache("openai")
|
||||
|
||||
# If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name
|
||||
if not model_name:
|
||||
raise ValueError("Need to provide a model_name or provider_type")
|
||||
|
||||
return _check_tokenizer_cache(model_name)
|
||||
|
||||
|
||||
def tokenizer_trim_content(
|
||||
|
@ -184,6 +184,7 @@ def create_credential(env_name: str) -> int:
|
||||
body = {
|
||||
"credential_json": {},
|
||||
"admin_public": True,
|
||||
"source": DocumentSource.FILE,
|
||||
}
|
||||
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
|
||||
if response.status_code == 200:
|
||||
|
@ -1,3 +1,4 @@
|
||||
import csv
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
@ -20,7 +21,12 @@ def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
|
||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||||
zip_ref.extractall(persistent_dir)
|
||||
|
||||
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
|
||||
file_paths = []
|
||||
for root, _, files in os.walk(persistent_dir):
|
||||
for file in sorted(files):
|
||||
file_paths.append(os.path.join(root, file))
|
||||
|
||||
return file_paths
|
||||
|
||||
|
||||
def create_temp_zip_from_files(file_paths: list[str]) -> str:
|
||||
@ -48,6 +54,7 @@ def upload_test_files(zip_file_path: str, env_name: str) -> None:
|
||||
def manage_file_upload(zip_file_path: str, env_name: str) -> None:
|
||||
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
|
||||
total_file_count = len(unzipped_file_paths)
|
||||
problem_file_list: list[str] = []
|
||||
|
||||
while True:
|
||||
doc_count, ongoing_index_attempts = check_indexing_status(env_name)
|
||||
@ -58,11 +65,16 @@ def manage_file_upload(zip_file_path: str, env_name: str) -> None:
|
||||
)
|
||||
elif not doc_count:
|
||||
print("No docs indexed, waiting for indexing to start")
|
||||
upload_test_files(zip_file_path, env_name)
|
||||
elif doc_count < total_file_count:
|
||||
temp_zip_file_path = create_temp_zip_from_files(unzipped_file_paths)
|
||||
upload_test_files(temp_zip_file_path, env_name)
|
||||
os.unlink(temp_zip_file_path)
|
||||
elif (doc_count + len(problem_file_list)) < total_file_count:
|
||||
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
|
||||
remaining_files = unzipped_file_paths[doc_count:]
|
||||
print(f"Grabbed last {len(remaining_files)} docs to try agian")
|
||||
remaining_files = unzipped_file_paths[doc_count + len(problem_file_list) :]
|
||||
problem_file_list.append(remaining_files.pop(0))
|
||||
print(
|
||||
f"Removing first doc and grabbed last {len(remaining_files)} docs to try agian"
|
||||
)
|
||||
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
|
||||
upload_test_files(temp_zip_file_path, env_name)
|
||||
os.unlink(temp_zip_file_path)
|
||||
@ -72,6 +84,13 @@ def manage_file_upload(zip_file_path: str, env_name: str) -> None:
|
||||
|
||||
time.sleep(10)
|
||||
|
||||
problem_file_csv_path = os.path.join(current_dir, "problem_files.csv")
|
||||
with open(problem_file_csv_path, "w", newline="") as csvfile:
|
||||
csvwriter = csv.writer(csvfile)
|
||||
csvwriter.writerow(["Problematic File Paths"])
|
||||
for problem_file in problem_file_list:
|
||||
csvwriter.writerow([problem_file])
|
||||
|
||||
for file in unzipped_file_paths:
|
||||
os.unlink(file)
|
||||
|
||||
|
@ -111,15 +111,14 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
|
||||
|
||||
|
||||
def _process_question(question_data: dict, config: dict, question_number: int) -> dict:
|
||||
print(f"On question number {question_number}")
|
||||
|
||||
query = question_data["question"]
|
||||
print(f"query: {query}")
|
||||
context_data_list, answer = get_answer_from_query(
|
||||
query=query,
|
||||
only_retrieve_docs=config["only_retrieve_docs"],
|
||||
env_name=config["env_name"],
|
||||
)
|
||||
print(f"On question number {question_number}")
|
||||
print(f"query: {query}")
|
||||
|
||||
if not context_data_list:
|
||||
print("No answer or context found")
|
||||
|
Loading…
x
Reference in New Issue
Block a user