mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-18 11:34:12 +02:00
Trim Chunks if LLM tokenizer differs from Embedding tokenizer (#1143)
This commit is contained in:
@@ -9,6 +9,7 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from sqlalchemy.orm import Session
|
||||
from tiktoken.core import Encoding
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
@@ -17,6 +18,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.models import ChatMessage
|
||||
@@ -24,8 +26,10 @@ from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.llm.utils import get_default_llm_version
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.llm.utils import tokenizer_trim_content
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
@@ -42,6 +46,9 @@ from danswer.prompts.token_counts import (
|
||||
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Maps connector enum string to a more natural language representation for the LLM
|
||||
# If not on the list, uses the original but slightly cleaned up, see below
|
||||
@@ -270,6 +277,7 @@ def get_chunks_for_qa(
|
||||
chunks: list[InferenceChunk],
|
||||
llm_chunk_selection: list[bool],
|
||||
token_limit: int | None,
|
||||
llm_tokenizer: Encoding | None = None,
|
||||
batch_offset: int = 0,
|
||||
) -> list[int]:
|
||||
"""
|
||||
@@ -282,6 +290,7 @@ def get_chunks_for_qa(
|
||||
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||
this is somewhat slow as it requires tokenizing all the chunks again
|
||||
"""
|
||||
token_leeway = 50
|
||||
batch_index = 0
|
||||
latest_batch_indices: list[int] = []
|
||||
token_count = 0
|
||||
@@ -296,8 +305,19 @@ def get_chunks_for_qa(
|
||||
|
||||
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||
chunk_token = check_number_of_tokens(chunk.content)
|
||||
if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway:
|
||||
logger.warning(
|
||||
"Found more tokens in chunk than expected, "
|
||||
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||
)
|
||||
chunk.content = tokenizer_trim_content(
|
||||
content=chunk.content,
|
||||
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
tokenizer=llm_tokenizer or get_default_llm_tokenizer(),
|
||||
)
|
||||
|
||||
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
|
||||
token_count += chunk_token + 50
|
||||
token_count += chunk_token + token_leeway
|
||||
|
||||
# Always use at least 1 chunk
|
||||
if (
|
||||
|
@@ -26,7 +26,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import CHUNK_SIZE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_chat_message
|
||||
@@ -160,7 +160,7 @@ def stream_chat_message_objects(
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
default_chunk_size: int = CHUNK_SIZE,
|
||||
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
@@ -468,6 +468,7 @@ def stream_chat_message_objects(
|
||||
chunks=top_chunks,
|
||||
llm_chunk_selection=llm_chunk_selection,
|
||||
token_limit=chunk_token_limit,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
|
||||
|
@@ -3,7 +3,6 @@ import os
|
||||
#####
|
||||
# Embedding/Reranking Model Configs
|
||||
#####
|
||||
CHUNK_SIZE = 512
|
||||
# Important considerations when choosing models
|
||||
# Max tokens count needs to be high considering use case (at least 512)
|
||||
# Models used must be MIT or Apache license
|
||||
|
@@ -7,7 +7,7 @@ from danswer.configs.app_configs import CHUNK_OVERLAP
|
||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||
from danswer.configs.constants import SECTION_SEPARATOR
|
||||
from danswer.configs.constants import TITLE_SEPARATOR
|
||||
from danswer.configs.model_configs import CHUNK_SIZE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||
@@ -37,7 +37,7 @@ def chunk_large_section(
|
||||
document: Document,
|
||||
start_chunk_id: int,
|
||||
tokenizer: "AutoTokenizer",
|
||||
chunk_size: int = CHUNK_SIZE,
|
||||
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
chunk_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
) -> list[DocAwareChunk]:
|
||||
@@ -67,7 +67,7 @@ def chunk_large_section(
|
||||
|
||||
def chunk_document(
|
||||
document: Document,
|
||||
chunk_tok_size: int = CHUNK_SIZE,
|
||||
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
subsection_overlap: int = CHUNK_OVERLAP,
|
||||
blurb_size: int = BLURB_SIZE,
|
||||
) -> list[DocAwareChunk]:
|
||||
|
@@ -1,3 +1,5 @@
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -130,8 +132,124 @@ def include_router_with_global_prefix_prepended(
|
||||
application.include_router(router, **final_kwargs)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "verify_auth_setting"
|
||||
)
|
||||
# Will throw exception if an issue is found
|
||||
verify_auth()
|
||||
|
||||
# Danswer APIs key
|
||||
api_key = get_danswer_api_key()
|
||||
logger.info(f"Danswer API Key: {api_key}")
|
||||
|
||||
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
||||
logger.info("Both OAuth Client ID and Secret are configured.")
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.info("Generative AI Q&A disabled")
|
||||
else:
|
||||
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
||||
base, fast = get_default_llm_version()
|
||||
logger.info(f"Using LLM Model Version: {base}")
|
||||
if base != fast:
|
||||
logger.info(f"Using Fast LLM Model Version: {fast}")
|
||||
if GEN_AI_API_ENDPOINT:
|
||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||
|
||||
# Any additional model configs logged here
|
||||
get_default_llm().log_model_configs()
|
||||
|
||||
if MULTILINGUAL_QUERY_EXPANSION:
|
||||
logger.info(
|
||||
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
||||
)
|
||||
|
||||
with Session(engine) as db_session:
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
embedding_model_id=db_embedding_model.id, db_session=db_session
|
||||
)
|
||||
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# Expire all old embedding models indexing attempts, technically redundant
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
|
||||
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
|
||||
logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"')
|
||||
logger.info(
|
||||
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
|
||||
)
|
||||
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
||||
logger.info("Reranking step of search flow is enabled.")
|
||||
|
||||
if MODEL_SERVER_HOST:
|
||||
logger.info(
|
||||
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
||||
)
|
||||
else:
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
logger.info("Verifying default connector/credential exist.")
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.info("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
|
||||
logger.info("Verifying Document Index(s) is/are available.")
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=db_embedding_model.index_name,
|
||||
secondary_index_name=secondary_db_embedding_model.index_name
|
||||
if secondary_db_embedding_model
|
||||
else None,
|
||||
)
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=db_embedding_model.model_dim,
|
||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||
if secondary_db_embedding_model
|
||||
else None,
|
||||
)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
||||
yield
|
||||
|
||||
|
||||
def get_application() -> FastAPI:
|
||||
application = FastAPI(title="Danswer Backend", version=__version__)
|
||||
application = FastAPI(
|
||||
title="Danswer Backend", version=__version__, lifespan=lifespan
|
||||
)
|
||||
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
@@ -220,121 +338,6 @@ def get_application() -> FastAPI:
|
||||
|
||||
application.add_exception_handler(ValueError, value_error_handler)
|
||||
|
||||
@application.on_event("startup")
|
||||
def startup_event() -> None:
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"danswer.auth.users", "verify_auth_setting"
|
||||
)
|
||||
# Will throw exception if an issue is found
|
||||
verify_auth()
|
||||
|
||||
# Danswer APIs key
|
||||
api_key = get_danswer_api_key()
|
||||
logger.info(f"Danswer API Key: {api_key}")
|
||||
|
||||
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
||||
logger.info("Both OAuth Client ID and Secret are configured.")
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.info("Generative AI Q&A disabled")
|
||||
else:
|
||||
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
||||
base, fast = get_default_llm_version()
|
||||
logger.info(f"Using LLM Model Version: {base}")
|
||||
if base != fast:
|
||||
logger.info(f"Using Fast LLM Model Version: {fast}")
|
||||
if GEN_AI_API_ENDPOINT:
|
||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||
|
||||
# Any additional model configs logged here
|
||||
get_default_llm().log_model_configs()
|
||||
|
||||
if MULTILINGUAL_QUERY_EXPANSION:
|
||||
logger.info(
|
||||
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
||||
)
|
||||
|
||||
with Session(engine) as db_session:
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
embedding_model_id=db_embedding_model.id, db_session=db_session
|
||||
)
|
||||
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# Expire all old embedding models indexing attempts, technically redundant
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
|
||||
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
|
||||
logger.info(
|
||||
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
|
||||
)
|
||||
logger.info(
|
||||
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
|
||||
)
|
||||
|
||||
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
||||
logger.info("Reranking step of search flow is enabled.")
|
||||
|
||||
if MODEL_SERVER_HOST:
|
||||
logger.info(
|
||||
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
||||
)
|
||||
else:
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||
)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
logger.info("GPU is available")
|
||||
else:
|
||||
logger.info("GPU is not available")
|
||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords", quiet=True)
|
||||
nltk.download("wordnet", quiet=True)
|
||||
nltk.download("punkt", quiet=True)
|
||||
|
||||
logger.info("Verifying default connector/credential exist.")
|
||||
create_initial_public_credential(db_session)
|
||||
create_initial_default_connector(db_session)
|
||||
associate_default_cc_pair(db_session)
|
||||
|
||||
logger.info("Loading default Prompts and Personas")
|
||||
delete_old_default_personas(db_session)
|
||||
load_chat_yamls()
|
||||
|
||||
logger.info("Verifying Document Index(s) is/are available.")
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=db_embedding_model.index_name,
|
||||
secondary_index_name=secondary_db_embedding_model.index_name
|
||||
if secondary_db_embedding_model
|
||||
else None,
|
||||
)
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=db_embedding_model.model_dim,
|
||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||
if secondary_db_embedding_model
|
||||
else None,
|
||||
)
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.VERSION, data={"version": __version__}
|
||||
)
|
||||
|
||||
application.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Change this to the list of allowed origins if needed
|
||||
|
@@ -18,7 +18,7 @@ from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import CHUNK_SIZE
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
@@ -63,7 +63,7 @@ def stream_answer_objects(
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
default_chunk_size: int = CHUNK_SIZE,
|
||||
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
bypass_acl: bool = False,
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
|
@@ -1,5 +1,4 @@
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
@@ -7,6 +6,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||
@@ -26,10 +26,10 @@ from shared_models.model_server_models import RerankResponse
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||
|
||||
logger = setup_logger()
|
||||
# Remove useless info about layer initialization
|
||||
logging.getLogger("transformers").setLevel(logging.ERROR)
|
||||
transformer_logging.set_verbosity_error()
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@@ -18,7 +18,9 @@ FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
||||
|
||||
|
||||
def log_function_time(
|
||||
func_name: str | None = None, print_only: bool = False
|
||||
func_name: str | None = None,
|
||||
print_only: bool = False,
|
||||
debug_only: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
@wraps(func)
|
||||
@@ -28,7 +30,10 @@ def log_function_time(
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
logger.info(f"{log_name} took {elapsed_time_str} seconds")
|
||||
if debug_only:
|
||||
logger.debug(f"{log_name} took {elapsed_time_str} seconds")
|
||||
else:
|
||||
logger.info(f"{log_name} took {elapsed_time_str} seconds")
|
||||
|
||||
if not print_only:
|
||||
optional_telemetry(
|
||||
|
Reference in New Issue
Block a user