From 5307d38472c220b043ecc484394616d033ca39e5 Mon Sep 17 00:00:00 2001 From: hagen-danswer Date: Wed, 31 Jul 2024 09:59:45 -0700 Subject: [PATCH] Fixed tokenizer logic (#1986) --- .../natural_language_processing/utils.py | 80 ++++++++++++------- .../regression/answer_quality/api_utils.py | 1 + .../answer_quality/file_uploader.py | 29 +++++-- .../tests/regression/answer_quality/run_qa.py | 5 +- 4 files changed, 77 insertions(+), 38 deletions(-) diff --git a/backend/danswer/natural_language_processing/utils.py b/backend/danswer/natural_language_processing/utils.py index 6e0acc3b7..beef56833 100644 --- a/backend/danswer/natural_language_processing/utils.py +++ b/backend/danswer/natural_language_processing/utils.py @@ -32,10 +32,18 @@ class BaseTokenizer(ABC): class TiktokenTokenizer(BaseTokenizer): - def __init__(self, encoding_name: str = "cl100k_base"): - import tiktoken + _instances: dict[str, "TiktokenTokenizer"] = {} - self.encoder = tiktoken.get_encoding(encoding_name) + def __new__(cls, encoding_name: str = "cl100k_base") -> "TiktokenTokenizer": + if encoding_name not in cls._instances: + cls._instances[encoding_name] = super(TiktokenTokenizer, cls).__new__(cls) + return cls._instances[encoding_name] + + def __init__(self, encoding_name: str = "cl100k_base"): + if not hasattr(self, "encoder"): + import tiktoken + + self.encoder = tiktoken.get_encoding(encoding_name) def encode(self, string: str) -> list[int]: # this returns no special tokens @@ -68,40 +76,52 @@ class HuggingFaceTokenizer(BaseTokenizer): _TOKENIZER_CACHE: dict[str, BaseTokenizer] = {} -def _get_cached_tokenizer( - model_name: str | None = None, provider_type: str | None = None -) -> BaseTokenizer: +def _check_tokenizer_cache(tokenizer_name: str) -> BaseTokenizer: global _TOKENIZER_CACHE - if provider_type: - if not _TOKENIZER_CACHE.get(provider_type): - if provider_type.lower() == "openai": - _TOKENIZER_CACHE[provider_type] = TiktokenTokenizer() - elif provider_type.lower() == "cohere": - _TOKENIZER_CACHE[provider_type] = HuggingFaceTokenizer( - "Cohere/command-nightly" + if tokenizer_name not in _TOKENIZER_CACHE: + if tokenizer_name == "openai": + _TOKENIZER_CACHE[tokenizer_name] = TiktokenTokenizer("cl100k_base") + return _TOKENIZER_CACHE[tokenizer_name] + try: + logger.debug(f"Initializing HuggingFaceTokenizer for: {tokenizer_name}") + _TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer(tokenizer_name) + except Exception as primary_error: + logger.error( + f"Error initializing HuggingFaceTokenizer for {tokenizer_name}: {primary_error}" + ) + logger.warning( + f"Falling back to default embedding model: {DOCUMENT_ENCODER_MODEL}" + ) + + try: + # Cache this tokenizer name to the default so we don't have to try to load it again + # and fail again + _TOKENIZER_CACHE[tokenizer_name] = HuggingFaceTokenizer( + DOCUMENT_ENCODER_MODEL ) - else: - _TOKENIZER_CACHE[ - provider_type - ] = TiktokenTokenizer() # Default to OpenAI tokenizer - return _TOKENIZER_CACHE[provider_type] + except Exception as fallback_error: + logger.error( + f"Error initializing fallback HuggingFaceTokenizer: {fallback_error}" + ) + raise ValueError( + f"Failed to initialize tokenizer for {tokenizer_name} and fallback model" + ) from fallback_error - if model_name: - if not _TOKENIZER_CACHE.get(model_name): - _TOKENIZER_CACHE[model_name] = HuggingFaceTokenizer(model_name) - return _TOKENIZER_CACHE[model_name] - - raise ValueError("Need to provide a model_name or provider_type") + return _TOKENIZER_CACHE[tokenizer_name] def get_tokenizer(model_name: str | None, provider_type: str | None) -> BaseTokenizer: - if provider_type is None and model_name is None: - model_name = DOCUMENT_ENCODER_MODEL - return _get_cached_tokenizer( - model_name=model_name, - provider_type=provider_type, - ) + if provider_type: + if provider_type.lower() == "openai": + # Used across ada and text-embedding-3 models + return _check_tokenizer_cache("openai") + + # If we are given a cloud provider_type that isn't OpenAI, we default to trying to use the model_name + if not model_name: + raise ValueError("Need to provide a model_name or provider_type") + + return _check_tokenizer_cache(model_name) def tokenizer_trim_content( diff --git a/backend/tests/regression/answer_quality/api_utils.py b/backend/tests/regression/answer_quality/api_utils.py index 86b4748ae..19b61315a 100644 --- a/backend/tests/regression/answer_quality/api_utils.py +++ b/backend/tests/regression/answer_quality/api_utils.py @@ -184,6 +184,7 @@ def create_credential(env_name: str) -> int: body = { "credential_json": {}, "admin_public": True, + "source": DocumentSource.FILE, } response = requests.post(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: diff --git a/backend/tests/regression/answer_quality/file_uploader.py b/backend/tests/regression/answer_quality/file_uploader.py index 0b423c5cd..bab78629d 100644 --- a/backend/tests/regression/answer_quality/file_uploader.py +++ b/backend/tests/regression/answer_quality/file_uploader.py @@ -1,3 +1,4 @@ +import csv import os import tempfile import time @@ -20,7 +21,12 @@ def unzip_and_get_file_paths(zip_file_path: str) -> list[str]: with zipfile.ZipFile(zip_file_path, "r") as zip_ref: zip_ref.extractall(persistent_dir) - return [str(path) for path in Path(persistent_dir).rglob("*") if path.is_file()] + file_paths = [] + for root, _, files in os.walk(persistent_dir): + for file in sorted(files): + file_paths.append(os.path.join(root, file)) + + return file_paths def create_temp_zip_from_files(file_paths: list[str]) -> str: @@ -48,6 +54,7 @@ def upload_test_files(zip_file_path: str, env_name: str) -> None: def manage_file_upload(zip_file_path: str, env_name: str) -> None: unzipped_file_paths = unzip_and_get_file_paths(zip_file_path) total_file_count = len(unzipped_file_paths) + problem_file_list: list[str] = [] while True: doc_count, ongoing_index_attempts = check_indexing_status(env_name) @@ -58,11 +65,16 @@ def manage_file_upload(zip_file_path: str, env_name: str) -> None: ) elif not doc_count: print("No docs indexed, waiting for indexing to start") - upload_test_files(zip_file_path, env_name) - elif doc_count < total_file_count: + temp_zip_file_path = create_temp_zip_from_files(unzipped_file_paths) + upload_test_files(temp_zip_file_path, env_name) + os.unlink(temp_zip_file_path) + elif (doc_count + len(problem_file_list)) < total_file_count: print(f"No ongooing indexing attempts but only {doc_count} docs indexed") - remaining_files = unzipped_file_paths[doc_count:] - print(f"Grabbed last {len(remaining_files)} docs to try agian") + remaining_files = unzipped_file_paths[doc_count + len(problem_file_list) :] + problem_file_list.append(remaining_files.pop(0)) + print( + f"Removing first doc and grabbed last {len(remaining_files)} docs to try agian" + ) temp_zip_file_path = create_temp_zip_from_files(remaining_files) upload_test_files(temp_zip_file_path, env_name) os.unlink(temp_zip_file_path) @@ -72,6 +84,13 @@ def manage_file_upload(zip_file_path: str, env_name: str) -> None: time.sleep(10) + problem_file_csv_path = os.path.join(current_dir, "problem_files.csv") + with open(problem_file_csv_path, "w", newline="") as csvfile: + csvwriter = csv.writer(csvfile) + csvwriter.writerow(["Problematic File Paths"]) + for problem_file in problem_file_list: + csvwriter.writerow([problem_file]) + for file in unzipped_file_paths: os.unlink(file) diff --git a/backend/tests/regression/answer_quality/run_qa.py b/backend/tests/regression/answer_quality/run_qa.py index 44eaf4bd6..5de034b37 100644 --- a/backend/tests/regression/answer_quality/run_qa.py +++ b/backend/tests/regression/answer_quality/run_qa.py @@ -111,15 +111,14 @@ def _initialize_files(config: dict) -> tuple[str, list[dict]]: def _process_question(question_data: dict, config: dict, question_number: int) -> dict: - print(f"On question number {question_number}") - query = question_data["question"] - print(f"query: {query}") context_data_list, answer = get_answer_from_query( query=query, only_retrieve_docs=config["only_retrieve_docs"], env_name=config["env_name"], ) + print(f"On question number {question_number}") + print(f"query: {query}") if not context_data_list: print("No answer or context found")