Trim Chunks if LLM tokenizer differs from Embedding tokenizer (#1143)

This commit is contained in:
Yuhong Sun
2024-02-28 13:01:32 -08:00
committed by GitHub
parent cd198ba368
commit c7d228e292
8 changed files with 158 additions and 130 deletions

View File

@@ -9,6 +9,7 @@ from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session
from tiktoken.core import Encoding
from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece
@@ -17,6 +18,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage
@@ -24,8 +26,10 @@ from danswer.db.models import Persona
from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER
@@ -42,6 +46,9 @@ from danswer.prompts.token_counts import (
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.utils.logger import setup_logger
logger = setup_logger()
# Maps connector enum string to a more natural language representation for the LLM
# If not on the list, uses the original but slightly cleaned up, see below
@@ -270,6 +277,7 @@ def get_chunks_for_qa(
chunks: list[InferenceChunk],
llm_chunk_selection: list[bool],
token_limit: int | None,
llm_tokenizer: Encoding | None = None,
batch_offset: int = 0,
) -> list[int]:
"""
@@ -282,6 +290,7 @@ def get_chunks_for_qa(
there's no way to know which chunks were included in the prior batches without recounting atm,
this is somewhat slow as it requires tokenizing all the chunks again
"""
token_leeway = 50
batch_index = 0
latest_batch_indices: list[int] = []
token_count = 0
@@ -296,8 +305,19 @@ def get_chunks_for_qa(
# We calculate it live in case the user uses a different LLM + tokenizer
chunk_token = check_number_of_tokens(chunk.content)
if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway:
logger.warning(
"Found more tokens in chunk than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
chunk.content = tokenizer_trim_content(
content=chunk.content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer or get_default_llm_tokenizer(),
)
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
token_count += chunk_token + 50
token_count += chunk_token + token_leeway
# Always use at least 1 chunk
if (

View File

@@ -26,7 +26,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message
@@ -160,7 +160,7 @@ def stream_chat_message_objects(
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE,
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
# For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
@@ -468,6 +468,7 @@ def stream_chat_message_objects(
chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection,
token_limit=chunk_token_limit,
llm_tokenizer=llm_tokenizer,
)
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]

View File

@@ -3,7 +3,6 @@ import os
#####
# Embedding/Reranking Model Configs
#####
CHUNK_SIZE = 512
# Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license

View File

@@ -7,7 +7,7 @@ from danswer.configs.app_configs import CHUNK_OVERLAP
from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.constants import SECTION_SEPARATOR
from danswer.configs.constants import TITLE_SEPARATOR
from danswer.configs.model_configs import CHUNK_SIZE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.connectors.models import Document
from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer
@@ -37,7 +37,7 @@ def chunk_large_section(
document: Document,
start_chunk_id: int,
tokenizer: "AutoTokenizer",
chunk_size: int = CHUNK_SIZE,
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
) -> list[DocAwareChunk]:
@@ -67,7 +67,7 @@ def chunk_large_section(
def chunk_document(
document: Document,
chunk_tok_size: int = CHUNK_SIZE,
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE,
) -> list[DocAwareChunk]:

View File

@@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from typing import cast
@@ -130,8 +132,124 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
engine = get_sqlalchemy_engine()
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
# Will throw exception if an issue is found
verify_auth()
# Danswer APIs key
api_key = get_danswer_api_key()
logger.info(f"Danswer API Key: {api_key}")
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
logger.info("Both OAuth Client ID and Secret are configured.")
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
base, fast = get_default_llm_version()
logger.info(f"Using LLM Model Version: {base}")
if base != fast:
logger.info(f"Using Fast LLM Model Version: {fast}")
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
# Any additional model configs logged here
get_default_llm().log_model_configs()
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
# Break bad state for thrashing indexes
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
embedding_model_id=db_embedding_model.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"')
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
def get_application() -> FastAPI:
application = FastAPI(title="Danswer Backend", version=__version__)
application = FastAPI(
title="Danswer Backend", version=__version__, lifespan=lifespan
)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router)
@@ -220,121 +338,6 @@ def get_application() -> FastAPI:
application.add_exception_handler(ValueError, value_error_handler)
@application.on_event("startup")
def startup_event() -> None:
engine = get_sqlalchemy_engine()
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
# Will throw exception if an issue is found
verify_auth()
# Danswer APIs key
api_key = get_danswer_api_key()
logger.info(f"Danswer API Key: {api_key}")
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
logger.info("Both OAuth Client ID and Secret are configured.")
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
base, fast = get_default_llm_version()
logger.info(f"Using LLM Model Version: {base}")
if base != fast:
logger.info(f"Using Fast LLM Model Version: {fast}")
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
# Any additional model configs logged here
get_default_llm().log_model_configs()
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
# Break bad state for thrashing indexes
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
embedding_model_id=db_embedding_model.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
)
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
optional_telemetry(
record_type=RecordType.VERSION, data={"version": __version__}
)
application.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Change this to the list of allowed origins if needed

View File

@@ -18,7 +18,7 @@ from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_or_create_root_message
@@ -63,7 +63,7 @@ def stream_answer_objects(
db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE,
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
timeout: int = QA_TIMEOUT,
bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]

View File

@@ -1,5 +1,4 @@
import gc
import logging
import os
from enum import Enum
from typing import Optional
@@ -7,6 +6,7 @@ from typing import TYPE_CHECKING
import numpy as np
import requests
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT
@@ -26,10 +26,10 @@ from shared_models.model_server_models import RerankResponse
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
logger = setup_logger()
# Remove useless info about layer initialization
logging.getLogger("transformers").setLevel(logging.ERROR)
transformer_logging.set_verbosity_error()
if TYPE_CHECKING:

View File

@@ -18,7 +18,9 @@ FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
def log_function_time(
func_name: str | None = None, print_only: bool = False
func_name: str | None = None,
print_only: bool = False,
debug_only: bool = False,
) -> Callable[[F], F]:
def decorator(func: F) -> F:
@wraps(func)
@@ -28,7 +30,10 @@ def log_function_time(
result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__
logger.info(f"{log_name} took {elapsed_time_str} seconds")
if debug_only:
logger.debug(f"{log_name} took {elapsed_time_str} seconds")
else:
logger.info(f"{log_name} took {elapsed_time_str} seconds")
if not print_only:
optional_telemetry(