Trim Chunks if LLM tokenizer differs from Embedding tokenizer (#1143)

This commit is contained in:
Yuhong Sun
2024-02-28 13:01:32 -08:00
committed by GitHub
parent cd198ba368
commit c7d228e292
8 changed files with 158 additions and 130 deletions

View File

@@ -9,6 +9,7 @@ from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage from langchain.schema.messages import SystemMessage
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from tiktoken.core import Encoding
from danswer.chat.models import CitationInfo from danswer.chat.models import CitationInfo
from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import DanswerAnswerPiece
@@ -17,6 +18,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import STOP_STREAM_PAT from danswer.configs.chat_configs import STOP_STREAM_PAT
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import IGNORE_FOR_QA
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.chat import get_chat_messages_by_session from danswer.db.chat import get_chat_messages_by_session
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
@@ -24,8 +26,10 @@ from danswer.db.models import Persona
from danswer.db.models import Prompt from danswer.db.models import Prompt
from danswer.indexing.models import InferenceChunk from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_default_llm_version
from danswer.llm.utils import get_max_input_tokens from danswer.llm.utils import get_max_input_tokens
from danswer.llm.utils import tokenizer_trim_content
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
from danswer.prompts.chat_prompts import CITATION_REMINDER from danswer.prompts.chat_prompts import CITATION_REMINDER
@@ -42,6 +46,9 @@ from danswer.prompts.token_counts import (
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.utils.logger import setup_logger
logger = setup_logger()
# Maps connector enum string to a more natural language representation for the LLM # Maps connector enum string to a more natural language representation for the LLM
# If not on the list, uses the original but slightly cleaned up, see below # If not on the list, uses the original but slightly cleaned up, see below
@@ -270,6 +277,7 @@ def get_chunks_for_qa(
chunks: list[InferenceChunk], chunks: list[InferenceChunk],
llm_chunk_selection: list[bool], llm_chunk_selection: list[bool],
token_limit: int | None, token_limit: int | None,
llm_tokenizer: Encoding | None = None,
batch_offset: int = 0, batch_offset: int = 0,
) -> list[int]: ) -> list[int]:
""" """
@@ -282,6 +290,7 @@ def get_chunks_for_qa(
there's no way to know which chunks were included in the prior batches without recounting atm, there's no way to know which chunks were included in the prior batches without recounting atm,
this is somewhat slow as it requires tokenizing all the chunks again this is somewhat slow as it requires tokenizing all the chunks again
""" """
token_leeway = 50
batch_index = 0 batch_index = 0
latest_batch_indices: list[int] = [] latest_batch_indices: list[int] = []
token_count = 0 token_count = 0
@@ -296,8 +305,19 @@ def get_chunks_for_qa(
# We calculate it live in case the user uses a different LLM + tokenizer # We calculate it live in case the user uses a different LLM + tokenizer
chunk_token = check_number_of_tokens(chunk.content) chunk_token = check_number_of_tokens(chunk.content)
if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway:
logger.warning(
"Found more tokens in chunk than expected, "
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
)
chunk.content = tokenizer_trim_content(
content=chunk.content,
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
tokenizer=llm_tokenizer or get_default_llm_tokenizer(),
)
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk # 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
token_count += chunk_token + 50 token_count += chunk_token + token_leeway
# Always use at least 1 chunk # Always use at least 1 chunk
if ( if (

View File

@@ -26,7 +26,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import DISABLED_GEN_AI_MSG
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_db_search_doc
from danswer.db.chat import create_new_chat_message from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_chat_message from danswer.db.chat import get_chat_message
@@ -160,7 +160,7 @@ def stream_chat_message_objects(
db_session: Session, db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM # Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE, default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
# For flow with search, don't include as many chunks as possible since we need to leave space # For flow with search, don't include as many chunks as possible since we need to leave space
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
@@ -468,6 +468,7 @@ def stream_chat_message_objects(
chunks=top_chunks, chunks=top_chunks,
llm_chunk_selection=llm_chunk_selection, llm_chunk_selection=llm_chunk_selection,
token_limit=chunk_token_limit, token_limit=chunk_token_limit,
llm_tokenizer=llm_tokenizer,
) )
llm_chunks = [top_chunks[i] for i in llm_chunks_indices] llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks] llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]

View File

@@ -3,7 +3,6 @@ import os
##### #####
# Embedding/Reranking Model Configs # Embedding/Reranking Model Configs
##### #####
CHUNK_SIZE = 512
# Important considerations when choosing models # Important considerations when choosing models
# Max tokens count needs to be high considering use case (at least 512) # Max tokens count needs to be high considering use case (at least 512)
# Models used must be MIT or Apache license # Models used must be MIT or Apache license

View File

@@ -7,7 +7,7 @@ from danswer.configs.app_configs import CHUNK_OVERLAP
from danswer.configs.app_configs import MINI_CHUNK_SIZE from danswer.configs.app_configs import MINI_CHUNK_SIZE
from danswer.configs.constants import SECTION_SEPARATOR from danswer.configs.constants import SECTION_SEPARATOR
from danswer.configs.constants import TITLE_SEPARATOR from danswer.configs.constants import TITLE_SEPARATOR
from danswer.configs.model_configs import CHUNK_SIZE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.connectors.models import Document from danswer.connectors.models import Document
from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import DocAwareChunk
from danswer.search.search_nlp_models import get_default_tokenizer from danswer.search.search_nlp_models import get_default_tokenizer
@@ -37,7 +37,7 @@ def chunk_large_section(
document: Document, document: Document,
start_chunk_id: int, start_chunk_id: int,
tokenizer: "AutoTokenizer", tokenizer: "AutoTokenizer",
chunk_size: int = CHUNK_SIZE, chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
chunk_overlap: int = CHUNK_OVERLAP, chunk_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE, blurb_size: int = BLURB_SIZE,
) -> list[DocAwareChunk]: ) -> list[DocAwareChunk]:
@@ -67,7 +67,7 @@ def chunk_large_section(
def chunk_document( def chunk_document(
document: Document, document: Document,
chunk_tok_size: int = CHUNK_SIZE, chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
subsection_overlap: int = CHUNK_OVERLAP, subsection_overlap: int = CHUNK_OVERLAP,
blurb_size: int = BLURB_SIZE, blurb_size: int = BLURB_SIZE,
) -> list[DocAwareChunk]: ) -> list[DocAwareChunk]:

View File

@@ -1,3 +1,5 @@
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any from typing import Any
from typing import cast from typing import cast
@@ -130,8 +132,124 @@ def include_router_with_global_prefix_prepended(
application.include_router(router, **final_kwargs) application.include_router(router, **final_kwargs)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator:
engine = get_sqlalchemy_engine()
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
# Will throw exception if an issue is found
verify_auth()
# Danswer APIs key
api_key = get_danswer_api_key()
logger.info(f"Danswer API Key: {api_key}")
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
logger.info("Both OAuth Client ID and Secret are configured.")
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
base, fast = get_default_llm_version()
logger.info(f"Using LLM Model Version: {base}")
if base != fast:
logger.info(f"Using Fast LLM Model Version: {fast}")
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
# Any additional model configs logged here
get_default_llm().log_model_configs()
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
# Break bad state for thrashing indexes
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
embedding_model_id=db_embedding_model.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"')
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
yield
def get_application() -> FastAPI: def get_application() -> FastAPI:
application = FastAPI(title="Danswer Backend", version=__version__) application = FastAPI(
title="Danswer Backend", version=__version__, lifespan=lifespan
)
include_router_with_global_prefix_prepended(application, chat_router) include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, query_router)
@@ -220,121 +338,6 @@ def get_application() -> FastAPI:
application.add_exception_handler(ValueError, value_error_handler) application.add_exception_handler(ValueError, value_error_handler)
@application.on_event("startup")
def startup_event() -> None:
engine = get_sqlalchemy_engine()
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
# Will throw exception if an issue is found
verify_auth()
# Danswer APIs key
api_key = get_danswer_api_key()
logger.info(f"Danswer API Key: {api_key}")
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
logger.info("Both OAuth Client ID and Secret are configured.")
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
base, fast = get_default_llm_version()
logger.info(f"Using LLM Model Version: {base}")
if base != fast:
logger.info(f"Using Fast LLM Model Version: {fast}")
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
# Any additional model configs logged here
get_default_llm().log_model_configs()
if MULTILINGUAL_QUERY_EXPANSION:
logger.info(
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
# Break bad state for thrashing indexes
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
embedding_model_id=db_embedding_model.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
)
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls()
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
optional_telemetry(
record_type=RecordType.VERSION, data={"version": __version__}
)
application.add_middleware( application.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], # Change this to the list of allowed origins if needed allow_origins=["*"], # Change this to the list of allowed origins if needed

View File

@@ -18,7 +18,7 @@ from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.chat_configs import QA_TIMEOUT
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.model_configs import CHUNK_SIZE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.chat import create_chat_session from danswer.db.chat import create_chat_session
from danswer.db.chat import create_new_chat_message from danswer.db.chat import create_new_chat_message
from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_or_create_root_message
@@ -63,7 +63,7 @@ def stream_answer_objects(
db_session: Session, db_session: Session,
# Needed to translate persona num_chunks to tokens to the LLM # Needed to translate persona num_chunks to tokens to the LLM
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
default_chunk_size: int = CHUNK_SIZE, default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
timeout: int = QA_TIMEOUT, timeout: int = QA_TIMEOUT,
bypass_acl: bool = False, bypass_acl: bool = False,
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]

View File

@@ -1,5 +1,4 @@
import gc import gc
import logging
import os import os
from enum import Enum from enum import Enum
from typing import Optional from typing import Optional
@@ -7,6 +6,7 @@ from typing import TYPE_CHECKING
import numpy as np import numpy as np
import requests import requests
from transformers import logging as transformer_logging # type:ignore
from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_HOST
from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.app_configs import MODEL_SERVER_PORT
@@ -26,10 +26,10 @@ from shared_models.model_server_models import RerankResponse
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
logger = setup_logger() logger = setup_logger()
# Remove useless info about layer initialization transformer_logging.set_verbosity_error()
logging.getLogger("transformers").setLevel(logging.ERROR)
if TYPE_CHECKING: if TYPE_CHECKING:

View File

@@ -18,7 +18,9 @@ FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
def log_function_time( def log_function_time(
func_name: str | None = None, print_only: bool = False func_name: str | None = None,
print_only: bool = False,
debug_only: bool = False,
) -> Callable[[F], F]: ) -> Callable[[F], F]:
def decorator(func: F) -> F: def decorator(func: F) -> F:
@wraps(func) @wraps(func)
@@ -28,7 +30,10 @@ def log_function_time(
result = func(*args, **kwargs) result = func(*args, **kwargs)
elapsed_time_str = str(time.time() - start_time) elapsed_time_str = str(time.time() - start_time)
log_name = func_name or func.__name__ log_name = func_name or func.__name__
logger.info(f"{log_name} took {elapsed_time_str} seconds") if debug_only:
logger.debug(f"{log_name} took {elapsed_time_str} seconds")
else:
logger.info(f"{log_name} took {elapsed_time_str} seconds")
if not print_only: if not print_only:
optional_telemetry( optional_telemetry(