diff --git a/backend/danswer/chat/chat_utils.py b/backend/danswer/chat/chat_utils.py index d837e0654621..26533d704b4a 100644 --- a/backend/danswer/chat/chat_utils.py +++ b/backend/danswer/chat/chat_utils.py @@ -9,6 +9,7 @@ from langchain.schema.messages import BaseMessage from langchain.schema.messages import HumanMessage from langchain.schema.messages import SystemMessage from sqlalchemy.orm import Session +from tiktoken.core import Encoding from danswer.chat.models import CitationInfo from danswer.chat.models import DanswerAnswerPiece @@ -17,6 +18,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import STOP_STREAM_PAT from danswer.configs.constants import DocumentSource from danswer.configs.constants import IGNORE_FOR_QA +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS from danswer.db.chat import get_chat_messages_by_session from danswer.db.models import ChatMessage @@ -24,8 +26,10 @@ from danswer.db.models import Persona from danswer.db.models import Prompt from danswer.indexing.models import InferenceChunk from danswer.llm.utils import check_number_of_tokens +from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import get_default_llm_version from danswer.llm.utils import get_max_input_tokens +from danswer.llm.utils import tokenizer_trim_content from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT from danswer.prompts.chat_prompts import CHAT_USER_PROMPT from danswer.prompts.chat_prompts import CITATION_REMINDER @@ -42,6 +46,9 @@ from danswer.prompts.token_counts import ( from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT +from danswer.utils.logger import setup_logger + +logger = setup_logger() # Maps connector enum string to a more natural language representation for the LLM # If not on the list, uses the original but slightly cleaned up, see below @@ -270,6 +277,7 @@ def get_chunks_for_qa( chunks: list[InferenceChunk], llm_chunk_selection: list[bool], token_limit: int | None, + llm_tokenizer: Encoding | None = None, batch_offset: int = 0, ) -> list[int]: """ @@ -282,6 +290,7 @@ def get_chunks_for_qa( there's no way to know which chunks were included in the prior batches without recounting atm, this is somewhat slow as it requires tokenizing all the chunks again """ + token_leeway = 50 batch_index = 0 latest_batch_indices: list[int] = [] token_count = 0 @@ -296,8 +305,19 @@ def get_chunks_for_qa( # We calculate it live in case the user uses a different LLM + tokenizer chunk_token = check_number_of_tokens(chunk.content) + if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway: + logger.warning( + "Found more tokens in chunk than expected, " + "likely mismatch between embedding and LLM tokenizers. Trimming content..." + ) + chunk.content = tokenizer_trim_content( + content=chunk.content, + desired_length=DOC_EMBEDDING_CONTEXT_SIZE, + tokenizer=llm_tokenizer or get_default_llm_tokenizer(), + ) + # 50 for an approximate/slight overestimate for # tokens for metadata for the chunk - token_count += chunk_token + 50 + token_count += chunk_token + token_leeway # Always use at least 1 chunk if ( diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 6a47f9c144b7..479feb2579c0 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -26,7 +26,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.constants import DISABLED_GEN_AI_MSG from danswer.configs.constants import MessageType -from danswer.configs.model_configs import CHUNK_SIZE +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_db_search_doc from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_chat_message @@ -160,7 +160,7 @@ def stream_chat_message_objects( db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = CHUNK_SIZE, + default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, # For flow with search, don't include as many chunks as possible since we need to leave space # for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE, @@ -468,6 +468,7 @@ def stream_chat_message_objects( chunks=top_chunks, llm_chunk_selection=llm_chunk_selection, token_limit=chunk_token_limit, + llm_tokenizer=llm_tokenizer, ) llm_chunks = [top_chunks[i] for i in llm_chunks_indices] llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks] diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 1f5a4efef81a..f6cd71f31db8 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -3,7 +3,6 @@ import os ##### # Embedding/Reranking Model Configs ##### -CHUNK_SIZE = 512 # Important considerations when choosing models # Max tokens count needs to be high considering use case (at least 512) # Models used must be MIT or Apache license diff --git a/backend/danswer/indexing/chunker.py b/backend/danswer/indexing/chunker.py index c30162327adc..9be9348b9f9b 100644 --- a/backend/danswer/indexing/chunker.py +++ b/backend/danswer/indexing/chunker.py @@ -7,7 +7,7 @@ from danswer.configs.app_configs import CHUNK_OVERLAP from danswer.configs.app_configs import MINI_CHUNK_SIZE from danswer.configs.constants import SECTION_SEPARATOR from danswer.configs.constants import TITLE_SEPARATOR -from danswer.configs.model_configs import CHUNK_SIZE +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.connectors.models import Document from danswer.indexing.models import DocAwareChunk from danswer.search.search_nlp_models import get_default_tokenizer @@ -37,7 +37,7 @@ def chunk_large_section( document: Document, start_chunk_id: int, tokenizer: "AutoTokenizer", - chunk_size: int = CHUNK_SIZE, + chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, chunk_overlap: int = CHUNK_OVERLAP, blurb_size: int = BLURB_SIZE, ) -> list[DocAwareChunk]: @@ -67,7 +67,7 @@ def chunk_large_section( def chunk_document( document: Document, - chunk_tok_size: int = CHUNK_SIZE, + chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE, subsection_overlap: int = CHUNK_OVERLAP, blurb_size: int = BLURB_SIZE, ) -> list[DocAwareChunk]: diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 11337fe30861..ad7bb14b5ce9 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,3 +1,5 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager from typing import Any from typing import cast @@ -130,8 +132,124 @@ def include_router_with_global_prefix_prepended( application.include_router(router, **final_kwargs) +@asynccontextmanager +async def lifespan(app: FastAPI) -> AsyncGenerator: + engine = get_sqlalchemy_engine() + + verify_auth = fetch_versioned_implementation( + "danswer.auth.users", "verify_auth_setting" + ) + # Will throw exception if an issue is found + verify_auth() + + # Danswer APIs key + api_key = get_danswer_api_key() + logger.info(f"Danswer API Key: {api_key}") + + if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET: + logger.info("Both OAuth Client ID and Secret are configured.") + + if DISABLE_GENERATIVE_AI: + logger.info("Generative AI Q&A disabled") + else: + logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}") + base, fast = get_default_llm_version() + logger.info(f"Using LLM Model Version: {base}") + if base != fast: + logger.info(f"Using Fast LLM Model Version: {fast}") + if GEN_AI_API_ENDPOINT: + logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") + + # Any additional model configs logged here + get_default_llm().log_model_configs() + + if MULTILINGUAL_QUERY_EXPANSION: + logger.info( + f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" + ) + + with Session(engine) as db_session: + db_embedding_model = get_current_db_embedding_model(db_session) + secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) + + # Break bad state for thrashing indexes + if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP: + expire_index_attempts( + embedding_model_id=db_embedding_model.id, db_session=db_session + ) + + for cc_pair in get_connector_credential_pairs(db_session): + resync_cc_pair(cc_pair, db_session=db_session) + + # Expire all old embedding models indexing attempts, technically redundant + cancel_indexing_attempts_past_model(db_session) + + logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"') + if db_embedding_model.query_prefix or db_embedding_model.passage_prefix: + logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"') + logger.info( + f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"' + ) + + if ENABLE_RERANKING_REAL_TIME_FLOW: + logger.info("Reranking step of search flow is enabled.") + + if MODEL_SERVER_HOST: + logger.info( + f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" + ) + else: + logger.info("Warming up local NLP models.") + warm_up_models( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW, + ) + + if torch.cuda.is_available(): + logger.info("GPU is available") + else: + logger.info("GPU is not available") + logger.info(f"Torch Threads: {torch.get_num_threads()}") + + logger.info("Verifying query preprocessing (NLTK) data is downloaded") + nltk.download("stopwords", quiet=True) + nltk.download("wordnet", quiet=True) + nltk.download("punkt", quiet=True) + + logger.info("Verifying default connector/credential exist.") + create_initial_public_credential(db_session) + create_initial_default_connector(db_session) + associate_default_cc_pair(db_session) + + logger.info("Loading default Prompts and Personas") + delete_old_default_personas(db_session) + load_chat_yamls() + + logger.info("Verifying Document Index(s) is/are available.") + + document_index = get_default_document_index( + primary_index_name=db_embedding_model.index_name, + secondary_index_name=secondary_db_embedding_model.index_name + if secondary_db_embedding_model + else None, + ) + document_index.ensure_indices_exist( + index_embedding_dim=db_embedding_model.model_dim, + secondary_index_embedding_dim=secondary_db_embedding_model.model_dim + if secondary_db_embedding_model + else None, + ) + + optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) + + yield + + def get_application() -> FastAPI: - application = FastAPI(title="Danswer Backend", version=__version__) + application = FastAPI( + title="Danswer Backend", version=__version__, lifespan=lifespan + ) include_router_with_global_prefix_prepended(application, chat_router) include_router_with_global_prefix_prepended(application, query_router) @@ -220,121 +338,6 @@ def get_application() -> FastAPI: application.add_exception_handler(ValueError, value_error_handler) - @application.on_event("startup") - def startup_event() -> None: - engine = get_sqlalchemy_engine() - - verify_auth = fetch_versioned_implementation( - "danswer.auth.users", "verify_auth_setting" - ) - # Will throw exception if an issue is found - verify_auth() - - # Danswer APIs key - api_key = get_danswer_api_key() - logger.info(f"Danswer API Key: {api_key}") - - if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET: - logger.info("Both OAuth Client ID and Secret are configured.") - - if DISABLE_GENERATIVE_AI: - logger.info("Generative AI Q&A disabled") - else: - logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}") - base, fast = get_default_llm_version() - logger.info(f"Using LLM Model Version: {base}") - if base != fast: - logger.info(f"Using Fast LLM Model Version: {fast}") - if GEN_AI_API_ENDPOINT: - logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") - - # Any additional model configs logged here - get_default_llm().log_model_configs() - - if MULTILINGUAL_QUERY_EXPANSION: - logger.info( - f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" - ) - - with Session(engine) as db_session: - db_embedding_model = get_current_db_embedding_model(db_session) - secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) - - # Break bad state for thrashing indexes - if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP: - expire_index_attempts( - embedding_model_id=db_embedding_model.id, db_session=db_session - ) - - for cc_pair in get_connector_credential_pairs(db_session): - resync_cc_pair(cc_pair, db_session=db_session) - - # Expire all old embedding models indexing attempts, technically redundant - cancel_indexing_attempts_past_model(db_session) - - logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"') - if db_embedding_model.query_prefix or db_embedding_model.passage_prefix: - logger.info( - f'Query embedding prefix: "{db_embedding_model.query_prefix}"' - ) - logger.info( - f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"' - ) - - if ENABLE_RERANKING_REAL_TIME_FLOW: - logger.info("Reranking step of search flow is enabled.") - - if MODEL_SERVER_HOST: - logger.info( - f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" - ) - else: - logger.info("Warming up local NLP models.") - warm_up_models( - model_name=db_embedding_model.model_name, - normalize=db_embedding_model.normalize, - skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW, - ) - - if torch.cuda.is_available(): - logger.info("GPU is available") - else: - logger.info("GPU is not available") - logger.info(f"Torch Threads: {torch.get_num_threads()}") - - logger.info("Verifying query preprocessing (NLTK) data is downloaded") - nltk.download("stopwords", quiet=True) - nltk.download("wordnet", quiet=True) - nltk.download("punkt", quiet=True) - - logger.info("Verifying default connector/credential exist.") - create_initial_public_credential(db_session) - create_initial_default_connector(db_session) - associate_default_cc_pair(db_session) - - logger.info("Loading default Prompts and Personas") - delete_old_default_personas(db_session) - load_chat_yamls() - - logger.info("Verifying Document Index(s) is/are available.") - - document_index = get_default_document_index( - primary_index_name=db_embedding_model.index_name, - secondary_index_name=secondary_db_embedding_model.index_name - if secondary_db_embedding_model - else None, - ) - document_index.ensure_indices_exist( - index_embedding_dim=db_embedding_model.model_dim, - secondary_index_embedding_dim=secondary_db_embedding_model.model_dim - if secondary_db_embedding_model - else None, - ) - - optional_telemetry( - record_type=RecordType.VERSION, data={"version": __version__} - ) - application.add_middleware( CORSMiddleware, allow_origins=["*"], # Change this to the list of allowed origins if needed diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 535b19af01bc..03292eec15f5 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -18,7 +18,7 @@ from danswer.chat.models import StreamingError from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import QA_TIMEOUT from danswer.configs.constants import MessageType -from danswer.configs.model_configs import CHUNK_SIZE +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.db.chat import create_chat_session from danswer.db.chat import create_new_chat_message from danswer.db.chat import get_or_create_root_message @@ -63,7 +63,7 @@ def stream_answer_objects( db_session: Session, # Needed to translate persona num_chunks to tokens to the LLM default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT, - default_chunk_size: int = CHUNK_SIZE, + default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE, timeout: int = QA_TIMEOUT, bypass_acl: bool = False, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index 8b1f3b59feb2..bc5a6fac42d0 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -1,5 +1,4 @@ import gc -import logging import os from enum import Enum from typing import Optional @@ -7,6 +6,7 @@ from typing import TYPE_CHECKING import numpy as np import requests +from transformers import logging as transformer_logging # type:ignore from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_PORT @@ -26,10 +26,10 @@ from shared_models.model_server_models import RerankResponse os.environ["TOKENIZERS_PARALLELISM"] = "false" +os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" logger = setup_logger() -# Remove useless info about layer initialization -logging.getLogger("transformers").setLevel(logging.ERROR) +transformer_logging.set_verbosity_error() if TYPE_CHECKING: diff --git a/backend/danswer/utils/timing.py b/backend/danswer/utils/timing.py index 664656aa799a..a98cc7e351e3 100644 --- a/backend/danswer/utils/timing.py +++ b/backend/danswer/utils/timing.py @@ -18,7 +18,9 @@ FG = TypeVar("FG", bound=Callable[..., Generator | Iterator]) def log_function_time( - func_name: str | None = None, print_only: bool = False + func_name: str | None = None, + print_only: bool = False, + debug_only: bool = False, ) -> Callable[[F], F]: def decorator(func: F) -> F: @wraps(func) @@ -28,7 +30,10 @@ def log_function_time( result = func(*args, **kwargs) elapsed_time_str = str(time.time() - start_time) log_name = func_name or func.__name__ - logger.info(f"{log_name} took {elapsed_time_str} seconds") + if debug_only: + logger.debug(f"{log_name} took {elapsed_time_str} seconds") + else: + logger.info(f"{log_name} took {elapsed_time_str} seconds") if not print_only: optional_telemetry(