Fixed tokenizer logic (#1986)

This commit is contained in:
hagen-danswer 2024-07-31 09:59:45 -07:00 committed by GitHub
parent d619602a6f
commit 5307d38472
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 77 additions and 38 deletions

View File

@ -32,10 +32,18 @@ class BaseTokenizer(ABC):
class TiktokenTokenizer(BaseTokenizer):
def __init__(self, encoding_name: str = "cl100k_base"):
import tiktoken
_instances: dict[str, "TiktokenTokenizer"] = {}
self.encoder = tiktoken.get_encoding(encoding_name)
def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer":
if encoding_name not in cls._instances:
cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls)
return cls._instances[encoding_name]
def __init__(self, encoding_name: str = "cl100k_base"):
if not hasattr(self, "encoder"):
import tiktoken
self.encoder = tiktoken.get_encoding(encoding_name)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
@ -68,40 +76,52 @@ class HuggingFaceTokenizer(BaseTokenizer):
_TOKENIZER_CACHE: dict[str, BaseTokenizer] = {}
def _get_cached_tokenizer(
model_name: str | None = None, provider_type: str | None = None
) -> BaseTokenizer:
def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer:
global _TOKENIZER_CACHE
if provider_type:
if not _TOKENIZER_CACHE.get(provider_type):
if provider_type.lower() == "openai":
_TOKENIZER_CACHE[provider_type] = TiktokenTokenizer()
elif provider_type.lower() == "cohere":
_TOKENIZER_CACHE[provider_type] = HuggingFaceTokenizer(
"Cohere/command-nightly"
if tokenizer_name not in _TOKENIZER_CACHE:
if tokenizer_name == "openai":
_TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base")
return _TOKENIZER_CACHE[tokenizer_name]
try:
logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}")
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name)
except Exception as primary_error:
logger.error(
f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}"
)
logger.warning(
f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}"
)
try:
# Cache this tokenizer name to the default so we don't have to try to load it again
# and fail again
_TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(
DOCUMENT_ENCODER_MODEL
)
else:
_TOKENIZER_CACHE[
provider_type
] = TiktokenTokenizer() # Default to OpenAI tokenizer
return _TOKENIZER_CACHE[provider_type]
except Exception as fallback_error:
logger.error(
f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}"
)
raise ValueError(
f"Failed to initialize tokenizer for {tokenizer_name} and fallback model"
) from fallback_error
if model_name:
if not _TOKENIZER_CACHE.get(model_name):
_TOKENIZER_CACHE[model_name] = HuggingFaceTokenizer(model_name)
return _TOKENIZER_CACHE[model_name]
raise ValueError("Need to provide a model_name or provider_type")
return _TOKENIZER_CACHE[tokenizer_name]
def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer:
if provider_type is None and model_name is None:
model_name = DOCUMENT_ENCODER_MODEL
return _get_cached_tokenizer(
model_name=model_name,
provider_type=provider_type,
)
if provider_type:
if provider_type.lower() == "openai":
# Used across ada and text-embedding-3 models
return _check_tokenizer_cache("openai")
# If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name
if not model_name:
raise ValueError("Need to provide a model_name or provider_type")
return _check_tokenizer_cache(model_name)
def tokenizer_trim_content(

View File

@ -184,6 +184,7 @@ def create_credential(env_name: str) -> int:
body = {
"credential_json": {},
"admin_public": True,
"source": DocumentSource.FILE,
}
response = requests.post(url, headers=GENERAL_HEADERS, json=body)
if response.status_code == 200:

View File

@ -1,3 +1,4 @@
import csv
import os
import tempfile
import time
@ -20,7 +21,12 @@ def unzip_and_get_file_paths(zip_file_path: str) -> list[str]:
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(persistent_dir)
return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()]
file_paths = []
for root, _, files in os.walk(persistent_dir):
for file in sorted(files):
file_paths.append(os.path.join(root, file))
return file_paths
def create_temp_zip_from_files(file_paths: list[str]) -> str:
@ -48,6 +54,7 @@ def upload_test_files(zip_file_path: str, env_name: str) -> None:
def manage_file_upload(zip_file_path: str, env_name: str) -> None:
unzipped_file_paths = unzip_and_get_file_paths(zip_file_path)
total_file_count = len(unzipped_file_paths)
problem_file_list: list[str] = []
while True:
doc_count, ongoing_index_attempts = check_indexing_status(env_name)
@ -58,11 +65,16 @@ def manage_file_upload(zip_file_path: str, env_name: str) -> None:
)
elif not doc_count:
print("No docs indexed, waiting for indexing to start")
upload_test_files(zip_file_path, env_name)
elif doc_count < total_file_count:
temp_zip_file_path = create_temp_zip_from_files(unzipped_file_paths)
upload_test_files(temp_zip_file_path, env_name)
os.unlink(temp_zip_file_path)
elif (doc_count + len(problem_file_list)) < total_file_count:
print(f"No ongooing indexing attempts but only {doc_count} docs indexed")
remaining_files = unzipped_file_paths[doc_count:]
print(f"Grabbed last {len(remaining_files)} docs to try agian")
remaining_files = unzipped_file_paths[doc_count + len(problem_file_list) :]
problem_file_list.append(remaining_files.pop(0))
print(
f"Removing first doc and grabbed last {len(remaining_files)} docs to try agian"
)
temp_zip_file_path = create_temp_zip_from_files(remaining_files)
upload_test_files(temp_zip_file_path, env_name)
os.unlink(temp_zip_file_path)
@ -72,6 +84,13 @@ def manage_file_upload(zip_file_path: str, env_name: str) -> None:
time.sleep(10)
problem_file_csv_path = os.path.join(current_dir, "problem_files.csv")
with open(problem_file_csv_path, "w", newline="") as csvfile:
csvwriter = csv.writer(csvfile)
csvwriter.writerow(["Problematic File Paths"])
for problem_file in problem_file_list:
csvwriter.writerow([problem_file])
for file in unzipped_file_paths:
os.unlink(file)

View File

@ -111,15 +111,14 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]:
def _process_question(question_data: dict, config: dict, question_number: int) -> dict:
print(f"On question number {question_number}")
query = question_data["question"]
print(f"query: {query}")
context_data_list, answer = get_answer_from_query(
query=query,
only_retrieve_docs=config["only_retrieve_docs"],
env_name=config["env_name"],
)
print(f"On question number {question_number}")
print(f"query: {query}")
if not context_data_list:
print("No answer or context found")