mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Trim Chunks if LLM tokenizer differs from Embedding tokenizer (#1143)
This commit is contained in:
@@ -9,6 +9,7 @@ from langchain.schema.messages import BaseMessage
|
|||||||
from langchain.schema.messages import HumanMessage
|
from langchain.schema.messages import HumanMessage
|
||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
from tiktoken.core import Encoding
|
||||||
|
|
||||||
from danswer.chat.models import CitationInfo
|
from danswer.chat.models import CitationInfo
|
||||||
from danswer.chat.models import DanswerAnswerPiece
|
from danswer.chat.models import DanswerAnswerPiece
|
||||||
@@ -17,6 +18,7 @@ from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
|||||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||||
from danswer.configs.constants import DocumentSource
|
from danswer.configs.constants import DocumentSource
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||||
from danswer.db.chat import get_chat_messages_by_session
|
from danswer.db.chat import get_chat_messages_by_session
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
@@ -24,8 +26,10 @@ from danswer.db.models import Persona
|
|||||||
from danswer.db.models import Prompt
|
from danswer.db.models import Prompt
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.llm.utils import check_number_of_tokens
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
|
from danswer.llm.utils import get_default_llm_tokenizer
|
||||||
from danswer.llm.utils import get_default_llm_version
|
from danswer.llm.utils import get_default_llm_version
|
||||||
from danswer.llm.utils import get_max_input_tokens
|
from danswer.llm.utils import get_max_input_tokens
|
||||||
|
from danswer.llm.utils import tokenizer_trim_content
|
||||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||||
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
from danswer.prompts.chat_prompts import CHAT_USER_PROMPT
|
||||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||||
@@ -42,6 +46,9 @@ from danswer.prompts.token_counts import (
|
|||||||
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
# Maps connector enum string to a more natural language representation for the LLM
|
# Maps connector enum string to a more natural language representation for the LLM
|
||||||
# If not on the list, uses the original but slightly cleaned up, see below
|
# If not on the list, uses the original but slightly cleaned up, see below
|
||||||
@@ -270,6 +277,7 @@ def get_chunks_for_qa(
|
|||||||
chunks: list[InferenceChunk],
|
chunks: list[InferenceChunk],
|
||||||
llm_chunk_selection: list[bool],
|
llm_chunk_selection: list[bool],
|
||||||
token_limit: int | None,
|
token_limit: int | None,
|
||||||
|
llm_tokenizer: Encoding | None = None,
|
||||||
batch_offset: int = 0,
|
batch_offset: int = 0,
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
@@ -282,6 +290,7 @@ def get_chunks_for_qa(
|
|||||||
there's no way to know which chunks were included in the prior batches without recounting atm,
|
there's no way to know which chunks were included in the prior batches without recounting atm,
|
||||||
this is somewhat slow as it requires tokenizing all the chunks again
|
this is somewhat slow as it requires tokenizing all the chunks again
|
||||||
"""
|
"""
|
||||||
|
token_leeway = 50
|
||||||
batch_index = 0
|
batch_index = 0
|
||||||
latest_batch_indices: list[int] = []
|
latest_batch_indices: list[int] = []
|
||||||
token_count = 0
|
token_count = 0
|
||||||
@@ -296,8 +305,19 @@ def get_chunks_for_qa(
|
|||||||
|
|
||||||
# We calculate it live in case the user uses a different LLM + tokenizer
|
# We calculate it live in case the user uses a different LLM + tokenizer
|
||||||
chunk_token = check_number_of_tokens(chunk.content)
|
chunk_token = check_number_of_tokens(chunk.content)
|
||||||
|
if chunk_token > DOC_EMBEDDING_CONTEXT_SIZE + token_leeway:
|
||||||
|
logger.warning(
|
||||||
|
"Found more tokens in chunk than expected, "
|
||||||
|
"likely mismatch between embedding and LLM tokenizers. Trimming content..."
|
||||||
|
)
|
||||||
|
chunk.content = tokenizer_trim_content(
|
||||||
|
content=chunk.content,
|
||||||
|
desired_length=DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
|
tokenizer=llm_tokenizer or get_default_llm_tokenizer(),
|
||||||
|
)
|
||||||
|
|
||||||
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
|
# 50 for an approximate/slight overestimate for # tokens for metadata for the chunk
|
||||||
token_count += chunk_token + 50
|
token_count += chunk_token + token_leeway
|
||||||
|
|
||||||
# Always use at least 1 chunk
|
# Always use at least 1 chunk
|
||||||
if (
|
if (
|
||||||
|
@@ -26,7 +26,7 @@ from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
|||||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||||
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
from danswer.configs.constants import DISABLED_GEN_AI_MSG
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.configs.model_configs import CHUNK_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.db.chat import create_db_search_doc
|
from danswer.db.chat import create_db_search_doc
|
||||||
from danswer.db.chat import create_new_chat_message
|
from danswer.db.chat import create_new_chat_message
|
||||||
from danswer.db.chat import get_chat_message
|
from danswer.db.chat import get_chat_message
|
||||||
@@ -160,7 +160,7 @@ def stream_chat_message_objects(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
# Needed to translate persona num_chunks to tokens to the LLM
|
# Needed to translate persona num_chunks to tokens to the LLM
|
||||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||||
default_chunk_size: int = CHUNK_SIZE,
|
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
# For flow with search, don't include as many chunks as possible since we need to leave space
|
# For flow with search, don't include as many chunks as possible since we need to leave space
|
||||||
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
# for the chat history, for smaller models, we likely won't get MAX_CHUNKS_FED_TO_CHAT chunks
|
||||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||||
@@ -468,6 +468,7 @@ def stream_chat_message_objects(
|
|||||||
chunks=top_chunks,
|
chunks=top_chunks,
|
||||||
llm_chunk_selection=llm_chunk_selection,
|
llm_chunk_selection=llm_chunk_selection,
|
||||||
token_limit=chunk_token_limit,
|
token_limit=chunk_token_limit,
|
||||||
|
llm_tokenizer=llm_tokenizer,
|
||||||
)
|
)
|
||||||
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
llm_chunks = [top_chunks[i] for i in llm_chunks_indices]
|
||||||
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
|
llm_docs = [llm_doc_from_inference_chunk(chunk) for chunk in llm_chunks]
|
||||||
|
@@ -3,7 +3,6 @@ import os
|
|||||||
#####
|
#####
|
||||||
# Embedding/Reranking Model Configs
|
# Embedding/Reranking Model Configs
|
||||||
#####
|
#####
|
||||||
CHUNK_SIZE = 512
|
|
||||||
# Important considerations when choosing models
|
# Important considerations when choosing models
|
||||||
# Max tokens count needs to be high considering use case (at least 512)
|
# Max tokens count needs to be high considering use case (at least 512)
|
||||||
# Models used must be MIT or Apache license
|
# Models used must be MIT or Apache license
|
||||||
|
@@ -7,7 +7,7 @@ from danswer.configs.app_configs import CHUNK_OVERLAP
|
|||||||
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
from danswer.configs.app_configs import MINI_CHUNK_SIZE
|
||||||
from danswer.configs.constants import SECTION_SEPARATOR
|
from danswer.configs.constants import SECTION_SEPARATOR
|
||||||
from danswer.configs.constants import TITLE_SEPARATOR
|
from danswer.configs.constants import TITLE_SEPARATOR
|
||||||
from danswer.configs.model_configs import CHUNK_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.indexing.models import DocAwareChunk
|
from danswer.indexing.models import DocAwareChunk
|
||||||
from danswer.search.search_nlp_models import get_default_tokenizer
|
from danswer.search.search_nlp_models import get_default_tokenizer
|
||||||
@@ -37,7 +37,7 @@ def chunk_large_section(
|
|||||||
document: Document,
|
document: Document,
|
||||||
start_chunk_id: int,
|
start_chunk_id: int,
|
||||||
tokenizer: "AutoTokenizer",
|
tokenizer: "AutoTokenizer",
|
||||||
chunk_size: int = CHUNK_SIZE,
|
chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
chunk_overlap: int = CHUNK_OVERLAP,
|
chunk_overlap: int = CHUNK_OVERLAP,
|
||||||
blurb_size: int = BLURB_SIZE,
|
blurb_size: int = BLURB_SIZE,
|
||||||
) -> list[DocAwareChunk]:
|
) -> list[DocAwareChunk]:
|
||||||
@@ -67,7 +67,7 @@ def chunk_large_section(
|
|||||||
|
|
||||||
def chunk_document(
|
def chunk_document(
|
||||||
document: Document,
|
document: Document,
|
||||||
chunk_tok_size: int = CHUNK_SIZE,
|
chunk_tok_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
subsection_overlap: int = CHUNK_OVERLAP,
|
subsection_overlap: int = CHUNK_OVERLAP,
|
||||||
blurb_size: int = BLURB_SIZE,
|
blurb_size: int = BLURB_SIZE,
|
||||||
) -> list[DocAwareChunk]:
|
) -> list[DocAwareChunk]:
|
||||||
|
@@ -1,3 +1,5 @@
|
|||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
|
||||||
@@ -130,8 +132,124 @@ def include_router_with_global_prefix_prepended(
|
|||||||
application.include_router(router, **final_kwargs)
|
application.include_router(router, **final_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||||
|
engine = get_sqlalchemy_engine()
|
||||||
|
|
||||||
|
verify_auth = fetch_versioned_implementation(
|
||||||
|
"danswer.auth.users", "verify_auth_setting"
|
||||||
|
)
|
||||||
|
# Will throw exception if an issue is found
|
||||||
|
verify_auth()
|
||||||
|
|
||||||
|
# Danswer APIs key
|
||||||
|
api_key = get_danswer_api_key()
|
||||||
|
logger.info(f"Danswer API Key: {api_key}")
|
||||||
|
|
||||||
|
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
||||||
|
logger.info("Both OAuth Client ID and Secret are configured.")
|
||||||
|
|
||||||
|
if DISABLE_GENERATIVE_AI:
|
||||||
|
logger.info("Generative AI Q&A disabled")
|
||||||
|
else:
|
||||||
|
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
||||||
|
base, fast = get_default_llm_version()
|
||||||
|
logger.info(f"Using LLM Model Version: {base}")
|
||||||
|
if base != fast:
|
||||||
|
logger.info(f"Using Fast LLM Model Version: {fast}")
|
||||||
|
if GEN_AI_API_ENDPOINT:
|
||||||
|
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||||
|
|
||||||
|
# Any additional model configs logged here
|
||||||
|
get_default_llm().log_model_configs()
|
||||||
|
|
||||||
|
if MULTILINGUAL_QUERY_EXPANSION:
|
||||||
|
logger.info(
|
||||||
|
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
||||||
|
)
|
||||||
|
|
||||||
|
with Session(engine) as db_session:
|
||||||
|
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||||
|
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||||
|
|
||||||
|
# Break bad state for thrashing indexes
|
||||||
|
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||||
|
expire_index_attempts(
|
||||||
|
embedding_model_id=db_embedding_model.id, db_session=db_session
|
||||||
|
)
|
||||||
|
|
||||||
|
for cc_pair in get_connector_credential_pairs(db_session):
|
||||||
|
resync_cc_pair(cc_pair, db_session=db_session)
|
||||||
|
|
||||||
|
# Expire all old embedding models indexing attempts, technically redundant
|
||||||
|
cancel_indexing_attempts_past_model(db_session)
|
||||||
|
|
||||||
|
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
|
||||||
|
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
|
||||||
|
logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"')
|
||||||
|
logger.info(
|
||||||
|
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
|
||||||
|
)
|
||||||
|
|
||||||
|
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
||||||
|
logger.info("Reranking step of search flow is enabled.")
|
||||||
|
|
||||||
|
if MODEL_SERVER_HOST:
|
||||||
|
logger.info(
|
||||||
|
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info("Warming up local NLP models.")
|
||||||
|
warm_up_models(
|
||||||
|
model_name=db_embedding_model.model_name,
|
||||||
|
normalize=db_embedding_model.normalize,
|
||||||
|
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
|
||||||
|
)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
logger.info("GPU is available")
|
||||||
|
else:
|
||||||
|
logger.info("GPU is not available")
|
||||||
|
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
||||||
|
|
||||||
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
|
nltk.download("stopwords", quiet=True)
|
||||||
|
nltk.download("wordnet", quiet=True)
|
||||||
|
nltk.download("punkt", quiet=True)
|
||||||
|
|
||||||
|
logger.info("Verifying default connector/credential exist.")
|
||||||
|
create_initial_public_credential(db_session)
|
||||||
|
create_initial_default_connector(db_session)
|
||||||
|
associate_default_cc_pair(db_session)
|
||||||
|
|
||||||
|
logger.info("Loading default Prompts and Personas")
|
||||||
|
delete_old_default_personas(db_session)
|
||||||
|
load_chat_yamls()
|
||||||
|
|
||||||
|
logger.info("Verifying Document Index(s) is/are available.")
|
||||||
|
|
||||||
|
document_index = get_default_document_index(
|
||||||
|
primary_index_name=db_embedding_model.index_name,
|
||||||
|
secondary_index_name=secondary_db_embedding_model.index_name
|
||||||
|
if secondary_db_embedding_model
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
document_index.ensure_indices_exist(
|
||||||
|
index_embedding_dim=db_embedding_model.model_dim,
|
||||||
|
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||||
|
if secondary_db_embedding_model
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
def get_application() -> FastAPI:
|
def get_application() -> FastAPI:
|
||||||
application = FastAPI(title="Danswer Backend", version=__version__)
|
application = FastAPI(
|
||||||
|
title="Danswer Backend", version=__version__, lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
include_router_with_global_prefix_prepended(application, chat_router)
|
include_router_with_global_prefix_prepended(application, chat_router)
|
||||||
include_router_with_global_prefix_prepended(application, query_router)
|
include_router_with_global_prefix_prepended(application, query_router)
|
||||||
@@ -220,121 +338,6 @@ def get_application() -> FastAPI:
|
|||||||
|
|
||||||
application.add_exception_handler(ValueError, value_error_handler)
|
application.add_exception_handler(ValueError, value_error_handler)
|
||||||
|
|
||||||
@application.on_event("startup")
|
|
||||||
def startup_event() -> None:
|
|
||||||
engine = get_sqlalchemy_engine()
|
|
||||||
|
|
||||||
verify_auth = fetch_versioned_implementation(
|
|
||||||
"danswer.auth.users", "verify_auth_setting"
|
|
||||||
)
|
|
||||||
# Will throw exception if an issue is found
|
|
||||||
verify_auth()
|
|
||||||
|
|
||||||
# Danswer APIs key
|
|
||||||
api_key = get_danswer_api_key()
|
|
||||||
logger.info(f"Danswer API Key: {api_key}")
|
|
||||||
|
|
||||||
if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET:
|
|
||||||
logger.info("Both OAuth Client ID and Secret are configured.")
|
|
||||||
|
|
||||||
if DISABLE_GENERATIVE_AI:
|
|
||||||
logger.info("Generative AI Q&A disabled")
|
|
||||||
else:
|
|
||||||
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
|
||||||
base, fast = get_default_llm_version()
|
|
||||||
logger.info(f"Using LLM Model Version: {base}")
|
|
||||||
if base != fast:
|
|
||||||
logger.info(f"Using Fast LLM Model Version: {fast}")
|
|
||||||
if GEN_AI_API_ENDPOINT:
|
|
||||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
|
||||||
|
|
||||||
# Any additional model configs logged here
|
|
||||||
get_default_llm().log_model_configs()
|
|
||||||
|
|
||||||
if MULTILINGUAL_QUERY_EXPANSION:
|
|
||||||
logger.info(
|
|
||||||
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
|
|
||||||
)
|
|
||||||
|
|
||||||
with Session(engine) as db_session:
|
|
||||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
|
||||||
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
|
||||||
|
|
||||||
# Break bad state for thrashing indexes
|
|
||||||
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
|
|
||||||
expire_index_attempts(
|
|
||||||
embedding_model_id=db_embedding_model.id, db_session=db_session
|
|
||||||
)
|
|
||||||
|
|
||||||
for cc_pair in get_connector_credential_pairs(db_session):
|
|
||||||
resync_cc_pair(cc_pair, db_session=db_session)
|
|
||||||
|
|
||||||
# Expire all old embedding models indexing attempts, technically redundant
|
|
||||||
cancel_indexing_attempts_past_model(db_session)
|
|
||||||
|
|
||||||
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
|
|
||||||
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
|
|
||||||
logger.info(
|
|
||||||
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
|
|
||||||
)
|
|
||||||
|
|
||||||
if ENABLE_RERANKING_REAL_TIME_FLOW:
|
|
||||||
logger.info("Reranking step of search flow is enabled.")
|
|
||||||
|
|
||||||
if MODEL_SERVER_HOST:
|
|
||||||
logger.info(
|
|
||||||
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.info("Warming up local NLP models.")
|
|
||||||
warm_up_models(
|
|
||||||
model_name=db_embedding_model.model_name,
|
|
||||||
normalize=db_embedding_model.normalize,
|
|
||||||
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
|
|
||||||
)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
logger.info("GPU is available")
|
|
||||||
else:
|
|
||||||
logger.info("GPU is not available")
|
|
||||||
logger.info(f"Torch Threads: {torch.get_num_threads()}")
|
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
|
||||||
nltk.download("stopwords", quiet=True)
|
|
||||||
nltk.download("wordnet", quiet=True)
|
|
||||||
nltk.download("punkt", quiet=True)
|
|
||||||
|
|
||||||
logger.info("Verifying default connector/credential exist.")
|
|
||||||
create_initial_public_credential(db_session)
|
|
||||||
create_initial_default_connector(db_session)
|
|
||||||
associate_default_cc_pair(db_session)
|
|
||||||
|
|
||||||
logger.info("Loading default Prompts and Personas")
|
|
||||||
delete_old_default_personas(db_session)
|
|
||||||
load_chat_yamls()
|
|
||||||
|
|
||||||
logger.info("Verifying Document Index(s) is/are available.")
|
|
||||||
|
|
||||||
document_index = get_default_document_index(
|
|
||||||
primary_index_name=db_embedding_model.index_name,
|
|
||||||
secondary_index_name=secondary_db_embedding_model.index_name
|
|
||||||
if secondary_db_embedding_model
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
document_index.ensure_indices_exist(
|
|
||||||
index_embedding_dim=db_embedding_model.model_dim,
|
|
||||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
|
||||||
if secondary_db_embedding_model
|
|
||||||
else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
optional_telemetry(
|
|
||||||
record_type=RecordType.VERSION, data={"version": __version__}
|
|
||||||
)
|
|
||||||
|
|
||||||
application.add_middleware(
|
application.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"], # Change this to the list of allowed origins if needed
|
allow_origins=["*"], # Change this to the list of allowed origins if needed
|
||||||
|
@@ -18,7 +18,7 @@ from danswer.chat.models import StreamingError
|
|||||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
from danswer.configs.model_configs import CHUNK_SIZE
|
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||||
from danswer.db.chat import create_chat_session
|
from danswer.db.chat import create_chat_session
|
||||||
from danswer.db.chat import create_new_chat_message
|
from danswer.db.chat import create_new_chat_message
|
||||||
from danswer.db.chat import get_or_create_root_message
|
from danswer.db.chat import get_or_create_root_message
|
||||||
@@ -63,7 +63,7 @@ def stream_answer_objects(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
# Needed to translate persona num_chunks to tokens to the LLM
|
# Needed to translate persona num_chunks to tokens to the LLM
|
||||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||||
default_chunk_size: int = CHUNK_SIZE,
|
default_chunk_size: int = DOC_EMBEDDING_CONTEXT_SIZE,
|
||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
bypass_acl: bool = False,
|
bypass_acl: bool = False,
|
||||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||||
|
@@ -1,5 +1,4 @@
|
|||||||
import gc
|
import gc
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
@@ -7,6 +6,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
|
from transformers import logging as transformer_logging # type:ignore
|
||||||
|
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
from danswer.configs.app_configs import MODEL_SERVER_HOST
|
||||||
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
from danswer.configs.app_configs import MODEL_SERVER_PORT
|
||||||
@@ -26,10 +26,10 @@ from shared_models.model_server_models import RerankResponse
|
|||||||
|
|
||||||
|
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1"
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
# Remove useless info about layer initialization
|
transformer_logging.set_verbosity_error()
|
||||||
logging.getLogger("transformers").setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@@ -18,7 +18,9 @@ FG = TypeVar("FG", bound=Callable[..., Generator | Iterator])
|
|||||||
|
|
||||||
|
|
||||||
def log_function_time(
|
def log_function_time(
|
||||||
func_name: str | None = None, print_only: bool = False
|
func_name: str | None = None,
|
||||||
|
print_only: bool = False,
|
||||||
|
debug_only: bool = False,
|
||||||
) -> Callable[[F], F]:
|
) -> Callable[[F], F]:
|
||||||
def decorator(func: F) -> F:
|
def decorator(func: F) -> F:
|
||||||
@wraps(func)
|
@wraps(func)
|
||||||
@@ -28,7 +30,10 @@ def log_function_time(
|
|||||||
result = func(*args, **kwargs)
|
result = func(*args, **kwargs)
|
||||||
elapsed_time_str = str(time.time() - start_time)
|
elapsed_time_str = str(time.time() - start_time)
|
||||||
log_name = func_name or func.__name__
|
log_name = func_name or func.__name__
|
||||||
logger.info(f"{log_name} took {elapsed_time_str} seconds")
|
if debug_only:
|
||||||
|
logger.debug(f"{log_name} took {elapsed_time_str} seconds")
|
||||||
|
else:
|
||||||
|
logger.info(f"{log_name} took {elapsed_time_str} seconds")
|
||||||
|
|
||||||
if not print_only:
|
if not print_only:
|
||||||
optional_telemetry(
|
optional_telemetry(
|
||||||
|
Reference in New Issue
Block a user