From 4b45164496135fb9965b00637e99f098b0b7f69b Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 28 Jan 2024 23:14:20 -0800 Subject: [PATCH] Background Index Attempt Creation (#1010) --- .../versions/dbaa756c2ccf_embedding_models.py | 43 ++++ .../background/indexing/run_indexing.py | 111 +++++++--- backend/danswer/background/update.py | 176 +++++++++++++--- backend/danswer/chat/process_message.py | 7 +- backend/danswer/configs/model_configs.py | 5 +- backend/danswer/danswerbot/slack/listener.py | 14 +- .../danswer/db/connector_credential_pair.py | 53 +++++ backend/danswer/db/embedding_model.py | 32 ++- backend/danswer/db/index_attempt.py | 91 +++++++- backend/danswer/db/models.py | 6 +- .../document_index/document_index_utils.py | 46 +--- backend/danswer/document_index/interfaces.py | 4 +- .../vespa/app_config/schemas/danswer_chunk.sd | 4 +- .../vespa/app_config/validation-overrides.xml | 5 + backend/danswer/document_index/vespa/index.py | 25 ++- backend/danswer/indexing/embedder.py | 199 ++++++++++++------ backend/danswer/indexing/indexing_pipeline.py | 21 +- backend/danswer/main.py | 62 +++--- .../one_shot_answer/answer_question.py | 7 +- backend/danswer/search/danswer_helper.py | 3 +- backend/danswer/search/models.py | 7 - backend/danswer/search/search_nlp_models.py | 121 ++++++++--- backend/danswer/search/search_runner.py | 87 +++++--- .../danswer/server/danswer_api/ingestion.py | 35 ++- backend/danswer/server/documents/connector.py | 31 ++- backend/danswer/server/documents/document.py | 10 +- backend/danswer/server/gpts/api.py | 13 +- .../danswer/server/manage/secondary_index.py | 92 +++++++- backend/danswer/server/models.py | 4 + .../server/query_and_chat/query_backend.py | 15 +- backend/model_server/encoders.py | 50 +++-- backend/model_server/main.py | 2 - backend/shared_models/model_server_models.py | 2 + .../regression/search_quality/eval_search.py | 7 +- .../docker_compose/docker-compose.dev.yml | 2 +- 35 files changed, 1022 insertions(+), 370 deletions(-) create mode 100644 backend/danswer/document_index/vespa/app_config/validation-overrides.xml diff --git a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py index 42135664578..592614176ca 100644 --- a/backend/alembic/versions/dbaa756c2ccf_embedding_models.py +++ b/backend/alembic/versions/dbaa756c2ccf_embedding_models.py @@ -7,7 +7,13 @@ Create Date: 2024-01-25 17:12:31.813160 """ from alembic import op import sqlalchemy as sa +from sqlalchemy import table, column, String, Integer, Boolean +from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL +from danswer.configs.model_configs import DOC_EMBEDDING_DIM +from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS +from danswer.configs.model_configs import ASYM_QUERY_PREFIX +from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.db.models import IndexModelStatus # revision identifiers, used by Alembic. @@ -26,6 +32,7 @@ def upgrade() -> None: sa.Column("normalize", sa.Boolean(), nullable=False), sa.Column("query_prefix", sa.String(), nullable=False), sa.Column("passage_prefix", sa.String(), nullable=False), + sa.Column("index_name", sa.String(), nullable=False), sa.Column( "status", sa.Enum(IndexModelStatus, native=False), @@ -33,10 +40,46 @@ def upgrade() -> None: ), sa.PrimaryKeyConstraint("id"), ) + EmbeddingModel = table( + "embedding_model", + column("id", Integer), + column("model_name", String), + column("model_dim", Integer), + column("normalize", Boolean), + column("query_prefix", String), + column("passage_prefix", String), + column("index_name", String), + column( + "status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False) + ), + ) + op.bulk_insert( + EmbeddingModel, + [ + { + "model_name": DOCUMENT_ENCODER_MODEL, + "model_dim": DOC_EMBEDDING_DIM, + "normalize": NORMALIZE_EMBEDDINGS, + "query_prefix": ASYM_QUERY_PREFIX, + "passage_prefix": ASYM_PASSAGE_PREFIX, + "index_name": "danswer_chunk", + "status": IndexModelStatus.PRESENT, + } + ], + ) op.add_column( "index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True), ) + op.execute( + "UPDATE index_attempt SET embedding_model_id=1 WHERE embedding_model_id IS NULL" + ) + op.alter_column( + "index_attempt", + "embedding_model_id", + existing_type=sa.Integer(), + nullable=False, + ) op.create_foreign_key( "index_attempt__embedding_model_fk", "index_attempt", diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index 496e6a15572..ebf752b0157 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -26,6 +26,9 @@ from danswer.db.index_attempt import mark_attempt_succeeded from danswer.db.index_attempt import update_docs_indexed from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus +from danswer.db.models import IndexModelStatus +from danswer.document_index.factory import get_default_document_index +from danswer.indexing.embedder import DefaultIndexingEmbedder from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.utils.logger import IndexAttemptSingleton from danswer.utils.logger import setup_logger @@ -93,17 +96,40 @@ def _run_indexing( """ start_time = time.time() - # mark as started + db_embedding_model = index_attempt.embedding_model + index_name = db_embedding_model.index_name + + # Only update cc-pair status for primary index jobs + # Secondary index syncs at the end when swapping + is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT + + # Mark as started mark_attempt_in_progress(index_attempt, db_session) - update_connector_credential_pair( - db_session=db_session, - connector_id=index_attempt.connector.id, - credential_id=index_attempt.credential.id, - attempt_status=IndexingStatus.IN_PROGRESS, + if is_primary: + update_connector_credential_pair( + db_session=db_session, + connector_id=index_attempt.connector.id, + credential_id=index_attempt.credential.id, + attempt_status=IndexingStatus.IN_PROGRESS, + ) + + # Indexing is only done into one index at a time + document_index = get_default_document_index( + primary_index_name=index_name, secondary_index_name=None ) - # TODO UPDATE THIS FOR SECONDARY INDEXING - indexing_pipeline = build_indexing_pipeline() + embedding_model = DefaultIndexingEmbedder( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + query_prefix=db_embedding_model.query_prefix, + passage_prefix=db_embedding_model.passage_prefix, + ) + + indexing_pipeline = build_indexing_pipeline( + embedder=embedding_model, + document_index=document_index, + ignore_time_skip=(db_embedding_model.status == IndexModelStatus.FUTURE), + ) db_connector = index_attempt.connector db_credential = index_attempt.credential @@ -139,12 +165,22 @@ def _run_indexing( try: for doc_batch in doc_batch_generator: - # check if connector is disabled mid run and stop if so + # Check if connector is disabled mid run and stop if so unless it's the secondary + # index being built. We want to populate it even for paused connectors + # Often paused connectors are sources that aren't updated frequently but the + # contents still need to be initially pulled. db_session.refresh(db_connector) - if db_connector.disabled: + if ( + db_connector.disabled + and db_embedding_model.status != IndexModelStatus.FUTURE + ): # let the `except` block handle this raise RuntimeError("Connector was disabled mid run") + db_session.refresh(index_attempt) + if index_attempt.status != IndexingStatus.IN_PROGRESS: + raise RuntimeError("Index Attempt was canceled") + logger.debug( f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}" ) @@ -176,14 +212,15 @@ def _run_indexing( ) run_end_dt = window_end - update_connector_credential_pair( - db_session=db_session, - connector_id=db_connector.id, - credential_id=db_credential.id, - attempt_status=IndexingStatus.IN_PROGRESS, - net_docs=net_doc_change, - run_dt=run_end_dt, - ) + if is_primary: + update_connector_credential_pair( + db_session=db_session, + connector_id=db_connector.id, + credential_id=db_credential.id, + attempt_status=IndexingStatus.IN_PROGRESS, + net_docs=net_doc_change, + run_dt=run_end_dt, + ) except Exception as e: logger.info( f"Connector run ran into exception after elapsed time: {time.time() - start_time} seconds" @@ -195,15 +232,20 @@ def _run_indexing( # # NOTE: if the connector is manually disabled, we should mark it as a failure regardless # to give better clarity in the UI, as the next run will never happen. - if ind == 0 or db_connector.disabled: + if ( + ind == 0 + or db_connector.disabled + or index_attempt.status != IndexingStatus.IN_PROGRESS + ): mark_attempt_failed(index_attempt, db_session, failure_reason=str(e)) - update_connector_credential_pair( - db_session=db_session, - connector_id=index_attempt.connector.id, - credential_id=index_attempt.credential.id, - attempt_status=IndexingStatus.FAILED, - net_docs=net_doc_change, - ) + if is_primary: + update_connector_credential_pair( + db_session=db_session, + connector_id=index_attempt.connector.id, + credential_id=index_attempt.credential.id, + attempt_status=IndexingStatus.FAILED, + net_docs=net_doc_change, + ) raise e # break => similar to success case. As mentioned above, if the next run fails for the same @@ -211,14 +253,15 @@ def _run_indexing( break mark_attempt_succeeded(index_attempt, db_session) - update_connector_credential_pair( - db_session=db_session, - connector_id=db_connector.id, - credential_id=db_credential.id, - attempt_status=IndexingStatus.SUCCESS, - net_docs=net_doc_change, - run_dt=run_end_dt, - ) + if is_primary: + update_connector_credential_pair( + db_session=db_session, + connector_id=db_connector.id, + credential_id=db_credential.id, + attempt_status=IndexingStatus.SUCCESS, + net_docs=net_doc_change, + run_dt=run_end_dt, + ) logger.info( f"Indexed or refreshed {document_count} total documents for a total of {chunk_count} indexed chunks" diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 20c2c33a253..8414d32d302 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -19,10 +19,16 @@ from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.model_configs import MIN_THREADS_ML_MODELS from danswer.db.connector import fetch_connectors +from danswer.db.connector_credential_pair import get_connector_credential_pairs from danswer.db.connector_credential_pair import mark_all_in_progress_cc_pairs_failed +from danswer.db.connector_credential_pair import resync_cc_pair from danswer.db.connector_credential_pair import update_connector_credential_pair +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model +from danswer.db.embedding_model import update_embedding_model_status from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts @@ -30,8 +36,10 @@ from danswer.db.index_attempt import get_last_attempt from danswer.db.index_attempt import get_not_started_index_attempts from danswer.db.index_attempt import mark_attempt_failed from danswer.db.models import Connector +from danswer.db.models import EmbeddingModel from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus +from danswer.db.models import IndexModelStatus from danswer.utils.logger import setup_logger logger = setup_logger() @@ -56,8 +64,17 @@ def _get_num_threads() -> int: def _should_create_new_indexing( - connector: Connector, last_index: IndexAttempt | None, db_session: Session + connector: Connector, + last_index: IndexAttempt | None, + model: EmbeddingModel, + db_session: Session, ) -> bool: + # When switching over models, always index at least once + if model.status == IndexModelStatus.FUTURE and not last_index: + if connector.id == 0: # Ingestion API + return False + return True + if connector.refresh_freq is None: return False if not last_index: @@ -66,6 +83,7 @@ def _should_create_new_indexing( # Only one scheduled job per connector at a time # Can schedule another one if the current one is already running however # Because the currently running one will not be until the latest time + # Note, this last index is for the given embedding model if last_index.status == IndexingStatus.NOT_STARTED: return False @@ -101,6 +119,7 @@ def _mark_run_failed( if ( index_attempt.connector_id is not None and index_attempt.credential_id is not None + and index_attempt.embedding_model.status == IndexModelStatus.PRESENT ): update_connector_credential_pair( db_session=db_session, @@ -120,7 +139,7 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None: 3. There is not already an ongoing indexing attempt for this pair """ with Session(get_sqlalchemy_engine()) as db_session: - ongoing_pairs: set[tuple[int | None, int | None]] = set() + ongoing: set[tuple[int | None, int | None, int]] = set() for attempt_id in existing_jobs: attempt = get_index_attempt( db_session=db_session, index_attempt_id=attempt_id @@ -131,28 +150,50 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None: "indexing jobs" ) continue - ongoing_pairs.add((attempt.connector_id, attempt.credential_id)) - - enabled_connectors = fetch_connectors(db_session, disabled_status=False) - for connector in enabled_connectors: - for association in connector.credentials: - credential = association.credential - - # check if there is an ongoing indexing attempt for this connector + credential pair - if (connector.id, credential.id) in ongoing_pairs: - continue - - last_attempt = get_last_attempt(connector.id, credential.id, db_session) - if not _should_create_new_indexing(connector, last_attempt, db_session): - continue - create_index_attempt(connector.id, credential.id, db_session) - - update_connector_credential_pair( - db_session=db_session, - connector_id=connector.id, - credential_id=credential.id, - attempt_status=IndexingStatus.NOT_STARTED, + ongoing.add( + ( + attempt.connector_id, + attempt.credential_id, + attempt.embedding_model_id, ) + ) + + embedding_models = [get_current_db_embedding_model(db_session)] + secondary_embedding_model = get_secondary_db_embedding_model(db_session) + if secondary_embedding_model is not None: + embedding_models.append(secondary_embedding_model) + + all_connectors = fetch_connectors(db_session) + for connector in all_connectors: + for association in connector.credentials: + for model in embedding_models: + credential = association.credential + + # Check if there is an ongoing indexing attempt for this connector + credential pair + if (connector.id, credential.id, model.id) in ongoing: + continue + + last_attempt = get_last_attempt( + connector.id, credential.id, model.id, db_session + ) + if not _should_create_new_indexing( + connector, last_attempt, model, db_session + ): + continue + + create_index_attempt( + connector.id, credential.id, model.id, db_session + ) + + # CC-Pair will have the status that it should for the primary index + # Will be re-sync-ed once the indices are swapped + if model.status == IndexModelStatus.PRESENT: + update_connector_credential_pair( + db_session=db_session, + connector_id=connector.id, + credential_id=credential.id, + attempt_status=IndexingStatus.NOT_STARTED, + ) def cleanup_indexing_jobs( @@ -233,6 +274,7 @@ def cleanup_indexing_jobs( def kickoff_indexing_jobs( existing_jobs: dict[int, Future | SimpleJob], client: Client | SimpleJobClient, + secondary_client: Client | SimpleJobClient, ) -> dict[int, Future | SimpleJob]: existing_jobs_copy = existing_jobs.copy() engine = get_sqlalchemy_engine() @@ -241,7 +283,7 @@ def kickoff_indexing_jobs( # Also (rarely) don't include for jobs that started but haven't updated the indexing tables yet with Session(engine) as db_session: new_indexing_attempts = [ - attempt + (attempt, attempt.embedding_model) for attempt in get_not_started_index_attempts(db_session) if attempt.id not in existing_jobs ] @@ -251,7 +293,12 @@ def kickoff_indexing_jobs( if not new_indexing_attempts: return existing_jobs - for attempt in new_indexing_attempts: + for attempt, embedding_model in new_indexing_attempts: + use_secondary_index = ( + embedding_model.status == IndexModelStatus.FUTURE + if embedding_model is not None + else False + ) if attempt.connector is None: logger.warning( f"Skipping index attempt as Connector has been deleted: {attempt}" @@ -271,12 +318,20 @@ def kickoff_indexing_jobs( ) continue - run = client.submit( - run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False - ) + if use_secondary_index: + run = secondary_client.submit( + run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False + ) + else: + run = client.submit( + run_indexing_entrypoint, attempt.id, _get_num_threads(), pure=False + ) + if run: + secondary_str = "(secondary index) " if use_secondary_index else "" logger.info( - f"Kicked off indexing attempt for connector: '{attempt.connector.name}', " + f"Kicked off {secondary_str}" + f"indexing attempt for connector: '{attempt.connector.name}', " f"with config: '{attempt.connector.connector_specific_config}', and " f"with credentials: '{attempt.credential_id}'" ) @@ -285,10 +340,50 @@ def kickoff_indexing_jobs( return existing_jobs_copy +def check_index_swap(db_session: Session) -> None: + """Get count of cc-pairs and count of index_attempts for the new model grouped by + connector + credential, if it's the same, then assume new index is done building. + This does not take into consideration if the attempt failed or not""" + # Default CC-pair created for Ingestion API unused here + all_cc_pairs = get_connector_credential_pairs(db_session) + cc_pair_count = len(all_cc_pairs) - 1 + embedding_model = get_secondary_db_embedding_model(db_session) + + if not embedding_model: + return + + unique_cc_indexings = count_unique_cc_pairs_with_index_attempts( + embedding_model_id=embedding_model.id, db_session=db_session + ) + + if unique_cc_indexings > cc_pair_count: + raise RuntimeError("More unique indexings than cc pairs, should not occur") + + if cc_pair_count == unique_cc_indexings: + # Swap indices + now_old_embedding_model = get_current_db_embedding_model(db_session) + update_embedding_model_status( + embedding_model=now_old_embedding_model, + new_status=IndexModelStatus.PAST, + db_session=db_session, + ) + + update_embedding_model_status( + embedding_model=embedding_model, + new_status=IndexModelStatus.PRESENT, + db_session=db_session, + ) + + # Recount aggregates + for cc_pair in all_cc_pairs: + resync_cc_pair(cc_pair, db_session=db_session) + + def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> None: - client: Client | SimpleJobClient + client_primary: Client | SimpleJobClient + client_secondary: Client | SimpleJobClient if DASK_JOB_CLIENT_ENABLED: - cluster = LocalCluster( + cluster_primary = LocalCluster( n_workers=num_workers, threads_per_worker=1, # there are warning about high memory usage + "Event loop unresponsive" @@ -297,11 +392,18 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non # the event loop silence_logs=logging.ERROR, ) - client = Client(cluster) + cluster_secondary = LocalCluster( + n_workers=num_workers, + threads_per_worker=1, + silence_logs=logging.ERROR, + ) + client_primary = Client(cluster_primary) + client_secondary = Client(cluster_secondary) if LOG_LEVEL.lower() == "debug": - client.register_worker_plugin(ResourceLogger()) + client_primary.register_worker_plugin(ResourceLogger()) else: - client = SimpleJobClient(n_workers=num_workers) + client_primary = SimpleJobClient(n_workers=num_workers) + client_secondary = SimpleJobClient(n_workers=num_workers) existing_jobs: dict[int, Future | SimpleJob] = {} engine = get_sqlalchemy_engine() @@ -324,10 +426,14 @@ def update_loop(delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS) -> Non ) try: + with Session(get_sqlalchemy_engine()) as db_session: + check_index_swap(db_session) existing_jobs = cleanup_indexing_jobs(existing_jobs=existing_jobs) create_indexing_jobs(existing_jobs=existing_jobs) existing_jobs = kickoff_indexing_jobs( - existing_jobs=existing_jobs, client=client + existing_jobs=existing_jobs, + client=client_primary, + secondary_client=client_secondary, ) except Exception as e: logger.exception(f"Failed to run update due to {e}") diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index f1283482e10..4bde92cbd0d 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -32,11 +32,11 @@ from danswer.db.chat import get_doc_query_identifiers_from_model from danswer.db.chat import get_or_create_root_message from danswer.db.chat import translate_db_message_to_chat_message_detail from danswer.db.chat import translate_db_search_doc_to_server_search_doc +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.models import ChatMessage from danswer.db.models import Persona from danswer.db.models import SearchDoc as DbSearchDoc from danswer.db.models import User -from danswer.document_index.document_index_utils import get_index_name from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk from danswer.llm.exceptions import GenAIDisabledException @@ -196,8 +196,10 @@ def stream_chat_message( llm = None llm_tokenizer = get_default_llm_token_encode() + + embedding_model = get_current_db_embedding_model(db_session) document_index = get_default_document_index( - primary_index_name=get_index_name(db_session), secondary_index_name=None + primary_index_name=embedding_model.index_name, secondary_index_name=None ) # Every chat Session begins with an empty root message @@ -308,6 +310,7 @@ def stream_chat_message( documents_generator = full_chunk_search_generator( search_query=retrieval_request, document_index=document_index, + db_session=db_session, ) time_cutoff = retrieval_request.filters.time_cutoff recency_bias_multiplier = retrieval_request.recency_bias_multiplier diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 0edfa1d304e..b1b2725e398 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -11,7 +11,10 @@ CHUNK_SIZE = 512 # https://huggingface.co/DOCUMENT_ENCODER_MODEL # The useable models configured as below must be SentenceTransformer compatible DOCUMENT_ENCODER_MODEL = ( - os.environ.get("DOCUMENT_ENCODER_MODEL") or "thenlper/gte-small" + # This is not a good model anymore, but this default needs to be kept for not breaking existing + # deployments, will eventually be retired/swapped for a different default model + os.environ.get("DOCUMENT_ENCODER_MODEL") + or "thenlper/gte-small" ) # If the below is changed, Vespa deployment must also be changed DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 384) diff --git a/backend/danswer/danswerbot/slack/listener.py b/backend/danswer/danswerbot/slack/listener.py index bcdb9a5503d..4f8404bac72 100644 --- a/backend/danswer/danswerbot/slack/listener.py +++ b/backend/danswer/danswerbot/slack/listener.py @@ -1,9 +1,9 @@ -import nltk import time from threading import Event from typing import Any from typing import cast +import nltk # type: ignore from slack_sdk import WebClient from slack_sdk.socket_mode import SocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest @@ -38,6 +38,7 @@ from danswer.danswerbot.slack.utils import get_danswer_bot_app_id from danswer.danswerbot.slack.utils import read_slack_thread from danswer.danswerbot.slack.utils import remove_danswer_bot_tag from danswer.danswerbot.slack.utils import respond_in_thread +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.one_shot_answer.models import ThreadMessage @@ -354,14 +355,21 @@ def _initialize_socket_client(socket_client: SocketModeClient) -> None: # NOTE: we are using Web Sockets so that you can run this from within a firewalled VPC # without issue. if __name__ == "__main__": - warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW) + with Session(get_sqlalchemy_engine()) as db_session: + embedding_model = get_current_db_embedding_model(db_session) + + warm_up_models( + model_name=embedding_model.model_name, + normalize=embedding_model.normalize, + skip_cross_encoders=not ENABLE_RERANKING_ASYNC_FLOW, + ) slack_bot_tokens: SlackBotTokens | None = None socket_client: SocketModeClient | None = None logger.info("Verifying query preprocessing (NLTK) data is downloaded") nltk.download("stopwords", quiet=True) - nltk.download('punkt', quiet=True) + nltk.download("punkt", quiet=True) while True: try: diff --git a/backend/danswer/db/connector_credential_pair.py b/backend/danswer/db/connector_credential_pair.py index 3a21ad31fb3..dd8da7856e7 100644 --- a/backend/danswer/db/connector_credential_pair.py +++ b/backend/danswer/db/connector_credential_pair.py @@ -2,6 +2,7 @@ from datetime import datetime from fastapi import HTTPException from sqlalchemy import delete +from sqlalchemy import desc from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session @@ -9,7 +10,10 @@ from sqlalchemy.orm import Session from danswer.db.connector import fetch_connector_by_id from danswer.db.credentials import fetch_credential_by_id from danswer.db.models import ConnectorCredentialPair +from danswer.db.models import EmbeddingModel +from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus +from danswer.db.models import IndexModelStatus from danswer.db.models import User from danswer.server.models import StatusResponse from danswer.utils.logger import setup_logger @@ -233,3 +237,52 @@ def remove_credential_from_connector( message=f"Connector already does not have Credential {credential_id}", data=connector_id, ) + + +def resync_cc_pair( + cc_pair: ConnectorCredentialPair, + db_session: Session, +) -> None: + def find_latest_index_attempt( + connector_id: int, + credential_id: int, + only_include_success: bool, + db_session: Session, + ) -> IndexAttempt | None: + query = ( + db_session.query(IndexAttempt) + .join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id) + .filter( + IndexAttempt.connector_id == connector_id, + IndexAttempt.credential_id == credential_id, + EmbeddingModel.status == IndexModelStatus.PRESENT, + ) + ) + + if only_include_success: + query = query.filter(IndexAttempt.status == IndexingStatus.SUCCESS) + + latest_index_attempt = query.order_by(desc(IndexAttempt.time_updated)).first() + + return latest_index_attempt + + last_success = find_latest_index_attempt( + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + only_include_success=True, + db_session=db_session, + ) + + cc_pair.last_successful_index_time = ( + last_success.time_updated if last_success else None + ) + + last_run = find_latest_index_attempt( + connector_id=cc_pair.connector_id, + credential_id=cc_pair.credential_id, + only_include_success=False, + db_session=db_session, + ) + cc_pair.last_attempt_status = last_run.status if last_run else None + + db_session.commit() diff --git a/backend/danswer/db/embedding_model.py b/backend/danswer/db/embedding_model.py index abacc9c8556..79076331069 100644 --- a/backend/danswer/db/embedding_model.py +++ b/backend/danswer/db/embedding_model.py @@ -4,6 +4,7 @@ from sqlalchemy.orm import Session from danswer.db.models import EmbeddingModel from danswer.db.models import IndexModelStatus from danswer.indexing.models import EmbeddingModelDetail +from danswer.search.search_nlp_models import clean_model_name from danswer.utils.logger import setup_logger logger = setup_logger() @@ -21,6 +22,9 @@ def create_embedding_model( query_prefix=model_details.query_prefix, passage_prefix=model_details.passage_prefix, status=status, + # Every single embedding model except the initial one from migrations has this name + # The initial one from migration is called "danswer_chunk" + index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}", ) db_session.add(embedding_model) @@ -29,15 +33,35 @@ def create_embedding_model( return embedding_model -def get_latest_embedding_model_by_status( - status: IndexModelStatus, db_session: Session -) -> EmbeddingModel | None: +def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel: query = ( select(EmbeddingModel) - .where(EmbeddingModel.status == status) + .where(EmbeddingModel.status == IndexModelStatus.PRESENT) + .order_by(EmbeddingModel.id.desc()) + ) + result = db_session.execute(query) + latest_model = result.scalars().first() + + if not latest_model: + raise RuntimeError("No embedding model selected, DB is not in a valid state") + + return latest_model + + +def get_secondary_db_embedding_model(db_session: Session) -> EmbeddingModel | None: + query = ( + select(EmbeddingModel) + .where(EmbeddingModel.status == IndexModelStatus.FUTURE) .order_by(EmbeddingModel.id.desc()) ) result = db_session.execute(query) latest_model = result.scalars().first() return latest_model + + +def update_embedding_model_status( + embedding_model: EmbeddingModel, new_status: IndexModelStatus, db_session: Session +) -> None: + embedding_model.status = new_status + db_session.commit() diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 04f3c58a6ba..e478b17b3ad 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -7,11 +7,14 @@ from sqlalchemy import desc from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import select +from sqlalchemy import update from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session +from danswer.db.models import EmbeddingModel from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus +from danswer.db.models import IndexModelStatus from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.utils.logger import setup_logger from danswer.utils.telemetry import optional_telemetry @@ -30,11 +33,13 @@ def get_index_attempt( def create_index_attempt( connector_id: int, credential_id: int, + embedding_model_id: int | None, db_session: Session, ) -> int: new_attempt = IndexAttempt( connector_id=connector_id, credential_id=credential_id, + embedding_model_id=embedding_model_id, status=IndexingStatus.NOT_STARTED, ) db_session.add(new_attempt) @@ -115,11 +120,14 @@ def update_docs_indexed( def get_last_attempt( connector_id: int, credential_id: int, + embedding_model_id: int | None, db_session: Session, ) -> IndexAttempt | None: - stmt = select(IndexAttempt) - stmt = stmt.where(IndexAttempt.connector_id == connector_id) - stmt = stmt.where(IndexAttempt.credential_id == credential_id) + stmt = select(IndexAttempt).where( + IndexAttempt.connector_id == connector_id, + IndexAttempt.credential_id == credential_id, + IndexAttempt.embedding_model_id == embedding_model_id, + ) # Note, the below is using time_created instead of time_updated stmt = stmt.order_by(desc(IndexAttempt.time_created)) @@ -128,13 +136,19 @@ def get_last_attempt( def get_latest_index_attempts( connector_credential_pair_identifiers: list[ConnectorCredentialPairIdentifier], + secondary_index: bool, db_session: Session, ) -> Sequence[IndexAttempt]: ids_stmt = select( IndexAttempt.connector_id, IndexAttempt.credential_id, func.max(IndexAttempt.time_created).label("max_time_created"), - ) + ).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id) + + if secondary_index: + ids_stmt = ids_stmt.where(EmbeddingModel.status == IndexModelStatus.FUTURE) + else: + ids_stmt = ids_stmt.where(EmbeddingModel.status == IndexModelStatus.PRESENT) where_stmts: list[ColumnElement] = [] for connector_credential_pair_identifier in connector_credential_pair_identifiers: @@ -162,12 +176,14 @@ def get_latest_index_attempts( ) .where(IndexAttempt.time_created == ids_subqery.c.max_time_created) ) + return db_session.execute(stmt).scalars().all() def get_index_attempts_for_cc_pair( db_session: Session, cc_pair_identifier: ConnectorCredentialPairIdentifier, + only_current: bool = True, disinclude_finished: bool = False, ) -> Sequence[IndexAttempt]: stmt = select(IndexAttempt).where( @@ -182,6 +198,10 @@ def get_index_attempts_for_cc_pair( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ) ) + if only_current: + stmt = stmt.join(EmbeddingModel).where( + EmbeddingModel.status == IndexModelStatus.PRESENT + ) stmt = stmt.order_by(IndexAttempt.time_created.desc()) return db_session.execute(stmt).scalars().all() @@ -197,3 +217,66 @@ def delete_index_attempts( IndexAttempt.credential_id == credential_id, ) db_session.execute(stmt) + + +def expire_index_attempts( + embedding_model_id: int, + db_session: Session, +) -> None: + update_query = ( + update(IndexAttempt) + .where(IndexAttempt.embedding_model_id == embedding_model_id) + .where(IndexAttempt.status != IndexingStatus.SUCCESS) + .values(status=IndexingStatus.FAILED, error_msg="Embedding model swapped") + ) + db_session.execute(update_query) + db_session.commit() + + +def cancel_indexing_attempts_for_connector( + connector_id: int, + db_session: Session, + include_secondary_index: bool = False, +) -> None: + subquery = select(EmbeddingModel.id).where( + EmbeddingModel.status != IndexModelStatus.FUTURE + ) + + stmt = delete(IndexAttempt).where( + IndexAttempt.connector_id == connector_id, + IndexAttempt.status == IndexingStatus.NOT_STARTED, + ) + + if not include_secondary_index: + stmt = stmt.where( + or_( + IndexAttempt.embedding_model_id.is_(None), + IndexAttempt.embedding_model_id.in_(subquery), + ) + ) + + db_session.execute(stmt) + + db_session.commit() + + +def count_unique_cc_pairs_with_index_attempts( + embedding_model_id: int | None, + db_session: Session, +) -> int: + unique_pairs_count = ( + db_session.query(IndexAttempt.connector_id, IndexAttempt.credential_id) + .filter( + IndexAttempt.embedding_model_id == embedding_model_id, + # Should not be able to hang since indexing jobs expire after a limit + # It will then be marked failed, and the next cycle it will be in a completed state + or_( + IndexAttempt.status == IndexingStatus.SUCCESS, + IndexAttempt.status == IndexingStatus.FAILED, + ), + ) + .distinct() + .count() + ) + + return unique_pairs_count diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index fcc46682dab..50867c1c2e2 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -379,6 +379,7 @@ class EmbeddingModel(Base): query_prefix: Mapped[str] = mapped_column(String) passage_prefix: Mapped[str] = mapped_column(String) status: Mapped[IndexModelStatus] = mapped_column(Enum(IndexModelStatus)) + index_name: Mapped[str] = mapped_column(String) index_attempts: Mapped[List["IndexAttempt"]] = relationship( "IndexAttempt", back_populates="embedding_model" @@ -419,15 +420,16 @@ class IndexAttempt(Base): nullable=True, ) status: Mapped[IndexingStatus] = mapped_column(Enum(IndexingStatus)) + # The two below may be slightly out of sync if user switches Embedding Model new_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) total_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) error_msg: Mapped[str | None] = mapped_column( Text, default=None ) # only filled if status = "failed" # Nullable because in the past, we didn't allow swapping out embedding models live - embedding_model_id: Mapped[int | None] = mapped_column( + embedding_model_id: Mapped[int] = mapped_column( ForeignKey("embedding_model.id"), - nullable=True, + nullable=False, ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), diff --git a/backend/danswer/document_index/document_index_utils.py b/backend/danswer/document_index/document_index_utils.py index dd369b49920..51e6433cbc4 100644 --- a/backend/danswer/document_index/document_index_utils.py +++ b/backend/danswer/document_index/document_index_utils.py @@ -3,56 +3,24 @@ import uuid from sqlalchemy.orm import Session -from danswer.db.embedding_model import get_latest_embedding_model_by_status -from danswer.db.models import IndexModelStatus +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.indexing.models import IndexChunk from danswer.indexing.models import InferenceChunk DEFAULT_BATCH_SIZE = 30 - - -def clean_model_name(model_str: str) -> str: - return model_str.replace("/", "_").replace("-", "_").replace(".", "_") - - -def get_index_name( - db_session: Session, - secondary_index: bool = False, -) -> str: - if secondary_index: - model = get_latest_embedding_model_by_status( - status=IndexModelStatus.FUTURE, db_session=db_session - ) - if model is None: - raise RuntimeError("No secondary index being built") - return f"danswer_chunk_{clean_model_name(model.model_name)}" - - model = get_latest_embedding_model_by_status( - status=IndexModelStatus.PRESENT, db_session=db_session - ) - if not model: - return "danswer_chunk" - return f"danswer_chunk_{clean_model_name(model.model_name)}" +DEFAULT_INDEX_NAME = "danswer_chunk" def get_both_index_names(db_session: Session) -> tuple[str, str | None]: - model = get_latest_embedding_model_by_status( - status=IndexModelStatus.PRESENT, db_session=db_session - ) - curr_index = ( - "danswer_chunk" - if not model - else f"danswer_chunk_{clean_model_name(model.model_name)}" - ) + model = get_current_db_embedding_model(db_session) - model_new = get_latest_embedding_model_by_status( - status=IndexModelStatus.FUTURE, db_session=db_session - ) + model_new = get_secondary_db_embedding_model(db_session) if not model_new: - return curr_index, None + return model.index_name, None - return curr_index, f"danswer_chunk_{clean_model_name(model_new.model_name)}" + return model.index_name, model_new.index_name def translate_boost_count_to_multiplier(boost: int) -> float: diff --git a/backend/danswer/document_index/interfaces.py b/backend/danswer/document_index/interfaces.py index 51fe0366dd7..e528504aaec 100644 --- a/backend/danswer/document_index/interfaces.py +++ b/backend/danswer/document_index/interfaces.py @@ -117,7 +117,8 @@ class VectorCapable(abc.ABC): @abc.abstractmethod def semantic_retrieval( self, - query: str, + query: str, # Needed for matching purposes + query_embedding: list[float], filters: IndexFilters, time_decay_multiplier: float, num_to_retrieve: int, @@ -131,6 +132,7 @@ class HybridCapable(abc.ABC): def hybrid_retrieval( self, query: str, + query_embedding: list[float], filters: IndexFilters, time_decay_multiplier: float, num_to_retrieve: int, diff --git a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd index 2e545530d1a..355918506a0 100644 --- a/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd +++ b/backend/danswer/document_index/vespa/app_config/schemas/danswer_chunk.sd @@ -145,7 +145,7 @@ schema DANSWER_CHUNK_NAME { match-features: recency_bias } - rank-profile hybrid_search inherits default, default_rank { + rank-profile hybrid_searchVARIABLE_DIM inherits default, default_rank { inputs { query(query_embedding) tensor(x[VARIABLE_DIM]) } @@ -227,7 +227,7 @@ schema DANSWER_CHUNK_NAME { match-features: recency_bias document_boost bm25(content) } - rank-profile semantic_search inherits default, default_rank { + rank-profile semantic_searchVARIABLE_DIM inherits default, default_rank { inputs { query(query_embedding) tensor(x[VARIABLE_DIM]) } diff --git a/backend/danswer/document_index/vespa/app_config/validation-overrides.xml b/backend/danswer/document_index/vespa/app_config/validation-overrides.xml new file mode 100644 index 00000000000..58bb2a0ce71 --- /dev/null +++ b/backend/danswer/document_index/vespa/app_config/validation-overrides.xml @@ -0,0 +1,5 @@ + + schema-removal + diff --git a/backend/danswer/document_index/vespa/index.py b/backend/danswer/document_index/vespa/index.py index 192af61b7e6..6cb84dd4c78 100644 --- a/backend/danswer/document_index/vespa/index.py +++ b/backend/danswer/document_index/vespa/index.py @@ -64,7 +64,6 @@ from danswer.document_index.vespa.utils import remove_invalid_unicode_chars from danswer.indexing.models import DocMetadataAwareIndexChunk from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters -from danswer.search.search_runner import embed_query from danswer.search.search_runner import query_processing from danswer.search.search_runner import remove_stop_words_and_punctuation from danswer.utils.batching import batch_generator @@ -76,6 +75,7 @@ logger = setup_logger() VESPA_DIM_REPLACEMENT_PAT = "VARIABLE_DIM" DANSWER_CHUNK_REPLACEMENT_PAT = "DANSWER_CHUNK_NAME" DOCUMENT_REPLACEMENT_PAT = "DOCUMENT_REPLACEMENT" +DATE_REPLACEMENT = "DATE_REPLACEMENT" VESPA_CONFIG_SERVER_URL = f"http://{VESPA_HOST}:{VESPA_TENANT_PORT}" VESPA_APP_CONTAINER_URL = f"http://{VESPA_HOST}:{VESPA_PORT}" VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2" @@ -660,6 +660,7 @@ class VespaIndex(DocumentIndex): ) schema_file = os.path.join(vespa_schema_path, "schemas", "danswer_chunk.sd") services_file = os.path.join(vespa_schema_path, "services.xml") + overrides_file = os.path.join(vespa_schema_path, "validation-overrides.xml") with open(services_file, "r") as services_f: services_template = services_f.read() @@ -669,8 +670,20 @@ class VespaIndex(DocumentIndex): doc_lines = _create_document_xml_lines(schema_names) services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines) + with open(overrides_file, "r") as overrides_f: + overrides_template = overrides_f.read() + + # Vespa requires an override to erase data including the indices we're no longer using + # It also has a 30 day cap from current so we set it to 7 dynamically + now = datetime.now() + date_in_7_days = now + timedelta(days=7) + formatted_date = date_in_7_days.strftime("%Y-%m-%d") + + overrides = overrides_template.replace(DATE_REPLACEMENT, formatted_date) + zip_dict = { "services.xml": services.encode("utf-8"), + "validation-overrides.xml": overrides.encode("utf-8"), } with open(schema_file, "r") as schema_f: @@ -887,6 +900,7 @@ class VespaIndex(DocumentIndex): def semantic_retrieval( self, query: str, + query_embedding: list[float], filters: IndexFilters, time_decay_multiplier: float, num_to_retrieve: int = NUM_RETURNED_HITS, @@ -906,8 +920,6 @@ class VespaIndex(DocumentIndex): + f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))' ) - query_embedding = embed_query(query) - query_keywords = ( " ".join(remove_stop_words_and_punctuation(query)) if edit_keyword_query @@ -921,7 +933,7 @@ class VespaIndex(DocumentIndex): "input.query(decay_factor)": str(DOC_TIME_DECAY * time_decay_multiplier), "hits": num_to_retrieve, "offset": offset, - "ranking.profile": "semantic_search", + "ranking.profile": f"hybrid_search{len(query_embedding)}", "timeout": _VESPA_TIMEOUT, } @@ -930,6 +942,7 @@ class VespaIndex(DocumentIndex): def hybrid_retrieval( self, query: str, + query_embedding: list[float], filters: IndexFilters, time_decay_multiplier: float, num_to_retrieve: int, @@ -951,8 +964,6 @@ class VespaIndex(DocumentIndex): + f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))' ) - query_embedding = embed_query(query) - query_keywords = ( " ".join(remove_stop_words_and_punctuation(query)) if edit_keyword_query @@ -972,7 +983,7 @@ class VespaIndex(DocumentIndex): else TITLE_CONTENT_RATIO, "hits": num_to_retrieve, "offset": offset, - "ranking.profile": "hybrid_search", + "ranking.profile": f"hybrid_search{len(query_embedding)}", "timeout": _VESPA_TIMEOUT, } diff --git a/backend/danswer/indexing/embedder.py b/backend/danswer/indexing/embedder.py index 6801fcbbdce..3be10f5b41c 100644 --- a/backend/danswer/indexing/embedder.py +++ b/backend/danswer/indexing/embedder.py @@ -1,96 +1,157 @@ -from typing import Optional -from typing import TYPE_CHECKING +from abc import ABC +from abc import abstractmethod + +from sqlalchemy.orm import Session from danswer.configs.app_configs import ENABLE_MINI_CHUNK -from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX +from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model +from danswer.db.models import EmbeddingModel as DbEmbeddingModel +from danswer.db.models import IndexModelStatus from danswer.indexing.chunker import split_chunk_text_into_mini_chunks from danswer.indexing.models import ChunkEmbedding from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import IndexChunk -from danswer.search.models import Embedder from danswer.search.search_nlp_models import EmbeddingModel +from danswer.search.search_nlp_models import EmbedTextType from danswer.utils.logger import setup_logger -if TYPE_CHECKING: - from sentence_transformers import SentenceTransformer # type: ignore - logger = setup_logger() -def embed_chunks( - chunks: list[DocAwareChunk], - embedding_model: Optional["SentenceTransformer"] = None, - batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, - enable_mini_chunk: bool = ENABLE_MINI_CHUNK, - passage_prefix: str = ASYM_PASSAGE_PREFIX, -) -> list[IndexChunk]: - # Cache the Title embeddings to only have to do it once - title_embed_dict: dict[str, list[float]] = {} +class IndexingEmbedder(ABC): + def __init__( + self, + model_name: str, + normalize: bool, + query_prefix: str | None, + passage_prefix: str | None, + ): + self.model_name = model_name + self.normalize = normalize + self.query_prefix = query_prefix + self.passage_prefix = passage_prefix - embedded_chunks: list[IndexChunk] = [] - if embedding_model is None: - embedding_model = EmbeddingModel() + @abstractmethod + def embed_chunks(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]: + raise NotImplementedError - chunk_texts = [] - chunk_mini_chunks_count = {} - for chunk_ind, chunk in enumerate(chunks): - chunk_texts.append(passage_prefix + chunk.content) - mini_chunk_texts = ( - split_chunk_text_into_mini_chunks(chunk.content) - if enable_mini_chunk - else [] + +class DefaultIndexingEmbedder(IndexingEmbedder): + def __init__( + self, + model_name: str, + normalize: bool, + query_prefix: str | None, + passage_prefix: str | None, + ): + super().__init__(model_name, normalize, query_prefix, passage_prefix) + self.max_seq_length = DOC_EMBEDDING_CONTEXT_SIZE # Currently not customizable + + self.embedding_model = EmbeddingModel( + model_name=model_name, + query_prefix=query_prefix, + passage_prefix=passage_prefix, + normalize=normalize, + # The below are globally set, this flow always uses the indexing one + server_host=INDEXING_MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, ) - prefixed_mini_chunk_texts = [passage_prefix + text for text in mini_chunk_texts] - chunk_texts.extend(prefixed_mini_chunk_texts) - chunk_mini_chunks_count[chunk_ind] = 1 + len(prefixed_mini_chunk_texts) - text_batches = [ - chunk_texts[i : i + batch_size] for i in range(0, len(chunk_texts), batch_size) - ] + def embed_chunks( + self, + chunks: list[DocAwareChunk], + batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, + enable_mini_chunk: bool = ENABLE_MINI_CHUNK, + ) -> list[IndexChunk]: + # Cache the Title embeddings to only have to do it once + title_embed_dict: dict[str, list[float]] = {} + embedded_chunks: list[IndexChunk] = [] - embeddings: list[list[float]] = [] - len_text_batches = len(text_batches) - for idx, text_batch in enumerate(text_batches, start=1): - logger.debug(f"Embedding text batch {idx} of {len_text_batches}") - # Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss - embeddings.extend(embedding_model.encode(text_batch)) + chunk_texts = [] + chunk_mini_chunks_count = {} + for chunk_ind, chunk in enumerate(chunks): + chunk_texts.append(chunk.content) + mini_chunk_texts = ( + split_chunk_text_into_mini_chunks(chunk.content) + if enable_mini_chunk + else [] + ) + chunk_texts.extend(mini_chunk_texts) + chunk_mini_chunks_count[chunk_ind] = 1 + len(mini_chunk_texts) - # Replace line above with the line below for easy debugging of indexing flow, skipping the actual model - # embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))]) - - embedding_ind_start = 0 - for chunk_ind, chunk in enumerate(chunks): - num_embeddings = chunk_mini_chunks_count[chunk_ind] - chunk_embeddings = embeddings[ - embedding_ind_start : embedding_ind_start + num_embeddings + text_batches = [ + chunk_texts[i : i + batch_size] + for i in range(0, len(chunk_texts), batch_size) ] - title = chunk.source_document.get_title_for_document_index() + embeddings: list[list[float]] = [] + len_text_batches = len(text_batches) + for idx, text_batch in enumerate(text_batches, start=1): + logger.debug(f"Embedding text batch {idx} of {len_text_batches}") + # Normalize embeddings is only configured via model_configs.py, be sure to use right value for the set loss + embeddings.extend( + self.embedding_model.encode(text_batch, text_type=EmbedTextType.PASSAGE) + ) - title_embedding = None - if title: - if title in title_embed_dict: - title_embedding = title_embed_dict[title] - else: - title_embedding = embedding_model.encode([title])[0] - title_embed_dict[title] = title_embedding + # Replace line above with the line below for easy debugging of indexing flow, skipping the actual model + # embeddings.extend([[0.0] * 384 for _ in range(len(text_batch))]) - new_embedded_chunk = IndexChunk( - **{k: getattr(chunk, k) for k in chunk.__dataclass_fields__}, - embeddings=ChunkEmbedding( - full_embedding=chunk_embeddings[0], - mini_chunk_embeddings=chunk_embeddings[1:], - ), - title_embedding=title_embedding, - ) - embedded_chunks.append(new_embedded_chunk) - embedding_ind_start += num_embeddings + embedding_ind_start = 0 + for chunk_ind, chunk in enumerate(chunks): + num_embeddings = chunk_mini_chunks_count[chunk_ind] + chunk_embeddings = embeddings[ + embedding_ind_start : embedding_ind_start + num_embeddings + ] - return embedded_chunks + title = chunk.source_document.get_title_for_document_index() + + title_embedding = None + if title: + if title in title_embed_dict: + # Using cached value for speedup + title_embedding = title_embed_dict[title] + else: + title_embedding = self.embedding_model.encode( + [title], text_type=EmbedTextType.PASSAGE + )[0] + title_embed_dict[title] = title_embedding + + new_embedded_chunk = IndexChunk( + **{k: getattr(chunk, k) for k in chunk.__dataclass_fields__}, + embeddings=ChunkEmbedding( + full_embedding=chunk_embeddings[0], + mini_chunk_embeddings=chunk_embeddings[1:], + ), + title_embedding=title_embedding, + ) + embedded_chunks.append(new_embedded_chunk) + embedding_ind_start += num_embeddings + + return embedded_chunks -class DefaultEmbedder(Embedder): - def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]: - return embed_chunks(chunks) +def get_embedding_model_from_db_embedding_model( + db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT +) -> IndexingEmbedder: + db_embedding_model: DbEmbeddingModel | None + if index_model_status == IndexModelStatus.PRESENT: + db_embedding_model = get_current_db_embedding_model(db_session) + elif index_model_status == IndexModelStatus.FUTURE: + db_embedding_model = get_secondary_db_embedding_model(db_session) + if not db_embedding_model: + raise RuntimeError("No secondary index configured") + else: + raise RuntimeError("Not supporting embedding model rollbacks") + + return DefaultIndexingEmbedder( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + query_prefix=db_embedding_model.query_prefix, + passage_prefix=db_embedding_model.passage_prefix, + ) diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 1b6891a3202..fa51995638e 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -19,16 +19,13 @@ from danswer.db.document_set import fetch_document_sets_for_documents from danswer.db.engine import get_sqlalchemy_engine from danswer.db.tag import create_or_add_document_tag from danswer.db.tag import create_or_add_document_tag_list -from danswer.document_index.document_index_utils import get_index_name -from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import DocumentMetadata from danswer.indexing.chunker import Chunker from danswer.indexing.chunker import DefaultChunker -from danswer.indexing.embedder import DefaultEmbedder +from danswer.indexing.embedder import IndexingEmbedder from danswer.indexing.models import DocAwareChunk from danswer.indexing.models import DocMetadataAwareIndexChunk -from danswer.search.models import Embedder from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -95,7 +92,7 @@ def upsert_documents_in_db( def index_doc_batch( *, chunker: Chunker, - embedder: Embedder, + embedder: IndexingEmbedder, document_index: DocumentIndex, documents: list[Document], index_attempt_metadata: IndexAttemptMetadata, @@ -152,7 +149,7 @@ def index_doc_batch( ) logger.debug("Starting embedding") - chunks_with_embeddings = embedder.embed(chunks=chunks) + chunks_with_embeddings = embedder.embed_chunks(chunks=chunks) # Attach the latest status from Postgres (source of truth for access) to each # chunk. This access status will be attached to each chunk in the document index @@ -213,22 +210,14 @@ def index_doc_batch( def build_indexing_pipeline( *, + embedder: IndexingEmbedder, + document_index: DocumentIndex, chunker: Chunker | None = None, - embedder: Embedder | None = None, - document_index: DocumentIndex | None = None, ignore_time_skip: bool = False, ) -> IndexingPipelineProtocol: """Builds a pipline which takes in a list (batch) of docs and indexes them.""" chunker = chunker or DefaultChunker() - embedder = embedder or DefaultEmbedder() - - if not document_index: - with Session(get_sqlalchemy_engine()) as db_session: - document_index = get_default_document_index( - primary_index_name=get_index_name(db_session), secondary_index_name=None - ) - return partial( index_doc_batch, chunker=chunker, diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 295dc3173c2..87d137aebdc 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -33,10 +33,6 @@ from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.constants import AuthType -from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX -from danswer.configs.model_configs import ASYM_QUERY_PREFIX -from danswer.configs.model_configs import DOC_EMBEDDING_DIM -from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_API_ENDPOINT @@ -45,10 +41,9 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.db.connector import create_initial_default_connector from danswer.db.connector_credential_pair import associate_default_cc_pair from danswer.db.credentials import create_initial_public_credential -from danswer.db.embedding_model import get_latest_embedding_model_by_status +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine -from danswer.db.models import IndexModelStatus -from danswer.document_index.document_index_utils import clean_model_name from danswer.document_index.factory import get_default_document_index from danswer.llm.factory import get_default_llm from danswer.search.search_nlp_models import warm_up_models @@ -65,6 +60,7 @@ from danswer.server.features.prompt.api import basic_router as prompt_router from danswer.server.gpts.api import router as gpts_router from danswer.server.manage.administrative import router as admin_router from danswer.server.manage.get_state import router as state_router +from danswer.server.manage.secondary_index import router as secondary_index_router from danswer.server.manage.slack_bot import router as slack_bot_management_router from danswer.server.manage.users import router as user_router from danswer.server.query_and_chat.chat_backend import router as chat_router @@ -134,6 +130,7 @@ def get_application() -> FastAPI: include_router_with_global_prefix_prepended(application, credential_router) include_router_with_global_prefix_prepended(application, cc_pair_router) include_router_with_global_prefix_prepended(application, document_set_router) + include_router_with_global_prefix_prepended(application, secondary_index_router) include_router_with_global_prefix_prepended( application, slack_bot_management_router ) @@ -245,13 +242,19 @@ def get_application() -> FastAPI: f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" ) + with Session(get_sqlalchemy_engine()) as db_session: + db_embedding_model = get_current_db_embedding_model(db_session) + secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) + if ENABLE_RERANKING_REAL_TIME_FLOW: logger.info("Reranking step of search flow is enabled.") - logger.info(f'Using Embedding model: "{DOCUMENT_ENCODER_MODEL}"') - if ASYM_QUERY_PREFIX or ASYM_PASSAGE_PREFIX: - logger.info(f'Query embedding prefix: "{ASYM_QUERY_PREFIX}"') - logger.info(f'Passage embedding prefix: "{ASYM_PASSAGE_PREFIX}"') + logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"') + if db_embedding_model.query_prefix or db_embedding_model.passage_prefix: + logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"') + logger.info( + f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"' + ) if MODEL_SERVER_HOST: logger.info( @@ -259,7 +262,11 @@ def get_application() -> FastAPI: ) else: logger.info("Warming up local NLP models.") - warm_up_models(skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW) + warm_up_models( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW, + ) if torch.cuda.is_available(): logger.info("GPU is available") @@ -273,7 +280,7 @@ def get_application() -> FastAPI: nltk.download("punkt", quiet=True) logger.info("Verifying default connector/credential exist.") - with Session(get_sqlalchemy_engine(), expire_on_commit=False) as db_session: + with Session(get_sqlalchemy_engine()) as db_session: create_initial_public_credential(db_session) create_initial_default_connector(db_session) associate_default_cc_pair(db_session) @@ -282,32 +289,17 @@ def get_application() -> FastAPI: load_chat_yamls() logger.info("Verifying Document Index(s) is/are available.") - primary_embedding_model = get_latest_embedding_model_by_status( - status=IndexModelStatus.PRESENT, db_session=db_session - ) - secondary_embedding_model = get_latest_embedding_model_by_status( - status=IndexModelStatus.FUTURE, db_session=db_session - ) - primary_index = ( - f"danswer_chunk_{clean_model_name(primary_embedding_model.model_name)}" - if primary_embedding_model - else "danswer_chunk" - ) - second_index = ( - f"danswer_chunk_{clean_model_name(secondary_embedding_model.model_name)}" - if secondary_embedding_model - else None - ) document_index = get_default_document_index( - primary_index_name=primary_index, secondary_index_name=second_index + primary_index_name=db_embedding_model.index_name, + secondary_index_name=secondary_db_embedding_model.index_name + if secondary_db_embedding_model + else None, ) document_index.ensure_indices_exist( - index_embedding_dim=primary_embedding_model.model_dim - if primary_embedding_model - else DOC_EMBEDDING_DIM, - secondary_index_embedding_dim=secondary_embedding_model.model_dim - if secondary_embedding_model + index_embedding_dim=db_embedding_model.model_dim, + secondary_index_embedding_dim=secondary_db_embedding_model.model_dim + if secondary_db_embedding_model else None, ) diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 4593b4f6518..04f95fbdbb2 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -21,8 +21,8 @@ from danswer.db.chat import get_or_create_root_message from danswer.db.chat import get_persona_by_id from danswer.db.chat import get_prompt_by_id from danswer.db.chat import translate_db_message_to_chat_message_detail +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.models import User -from danswer.document_index.document_index_utils import get_index_name from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk from danswer.llm.utils import get_default_llm_token_encode @@ -91,8 +91,10 @@ def stream_answer_objects( llm_tokenizer = get_default_llm_token_encode() + embedding_model = get_current_db_embedding_model(db_session) + document_index = get_default_document_index( - primary_index_name=get_index_name(db_session), secondary_index_name=None + primary_index_name=embedding_model.index_name, secondary_index_name=None ) # Create a chat session which will just store the root message, the query, and the AI response @@ -124,6 +126,7 @@ def stream_answer_objects( documents_generator = full_chunk_search_generator( search_query=retrieval_request, document_index=document_index, + db_session=db_session, retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) diff --git a/backend/danswer/search/danswer_helper.py b/backend/danswer/search/danswer_helper.py index 047ca78d47b..e3de6f92339 100644 --- a/backend/danswer/search/danswer_helper.py +++ b/backend/danswer/search/danswer_helper.py @@ -57,6 +57,7 @@ def query_intent(query: str) -> tuple[SearchType, QueryFlow]: def recommend_search_flow( query: str, + model_name: str, keyword: bool = False, max_percent_stopwords: float = 0.30, # ~Every third word max, ie "effects of caffeine" still viable keyword search ) -> HelperResponse: @@ -69,7 +70,7 @@ def recommend_search_flow( non_stopword_percent = len(non_stopwords) / len(words) # UNK tokens -> suggest Keyword (still may be valid QA) - if count_unk_tokens(query, get_default_tokenizer()) > 0: + if count_unk_tokens(query, get_default_tokenizer(model_name=model_name)) > 0: if not keyword: heuristic_search_type = SearchType.KEYWORD message = "Unknown tokens in query." diff --git a/backend/danswer/search/models.py b/backend/danswer/search/models.py index 2ee08736f4c..186c451aad1 100644 --- a/backend/danswer/search/models.py +++ b/backend/danswer/search/models.py @@ -9,8 +9,6 @@ from danswer.configs.chat_configs import NUM_RERANKED_RESULTS from danswer.configs.chat_configs import NUM_RETURNED_HITS from danswer.configs.constants import DocumentSource from danswer.configs.model_configs import ENABLE_RERANKING_REAL_TIME_FLOW -from danswer.indexing.models import DocAwareChunk -from danswer.indexing.models import IndexChunk MAX_METRICS_CONTENT = ( 200 # Just need enough characters to identify where in the doc the chunk is @@ -43,11 +41,6 @@ class QueryFlow(str, Enum): QUESTION_ANSWER = "question-answer" -class Embedder: - def embed(self, chunks: list[DocAwareChunk]) -> list[IndexChunk]: - raise NotImplementedError - - class Tag(BaseModel): tag_key: str tag_value: str diff --git a/backend/danswer/search/search_nlp_models.py b/backend/danswer/search/search_nlp_models.py index a37a7a1021a..81425518084 100644 --- a/backend/danswer/search/search_nlp_models.py +++ b/backend/danswer/search/search_nlp_models.py @@ -1,13 +1,13 @@ +import gc import logging import os +from enum import Enum from typing import Optional from typing import TYPE_CHECKING import numpy as np import requests -from danswer.configs.app_configs import CURRENT_PROCESS_IS_AN_INDEXING_JOB -from danswer.configs.app_configs import INDEXING_MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_HOST from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.model_configs import CROSS_EMBED_CONTEXT_SIZE @@ -15,7 +15,6 @@ from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import INTENT_MODEL_VERSION -from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import QUERY_MAX_CONTEXT_SIZE from danswer.utils.logger import setup_logger from shared_models.model_server_models import EmbedRequest @@ -37,28 +36,51 @@ if TYPE_CHECKING: from transformers import TFDistilBertForSequenceClassification # type: ignore -_TOKENIZER: Optional["AutoTokenizer"] = None -_EMBED_MODEL: Optional["SentenceTransformer"] = None +_TOKENIZER: tuple[Optional["AutoTokenizer"], str | None] = (None, None) +_EMBED_MODEL: tuple[Optional["SentenceTransformer"], str | None] = (None, None) _RERANK_MODELS: Optional[list["CrossEncoder"]] = None _INTENT_TOKENIZER: Optional["AutoTokenizer"] = None _INTENT_MODEL: Optional["TFDistilBertForSequenceClassification"] = None -def get_default_tokenizer() -> "AutoTokenizer": +class EmbedTextType(str, Enum): + QUERY = "query" + PASSAGE = "passage" + + +def clean_model_name(model_str: str) -> str: + return model_str.replace("/", "_").replace("-", "_").replace(".", "_") + + +# NOTE: If None is used, it may not be using the "correct" tokenizer, for cases +# where this is more important, be sure to refresh with the actual model name +def get_default_tokenizer(model_name: str | None = None) -> "AutoTokenizer": # NOTE: doing a local import here to avoid reduce memory usage caused by # processes importing this file despite not using any of this from transformers import AutoTokenizer # type: ignore global _TOKENIZER - if _TOKENIZER is None: - _TOKENIZER = AutoTokenizer.from_pretrained(DOCUMENT_ENCODER_MODEL) - if hasattr(_TOKENIZER, "is_fast") and _TOKENIZER.is_fast: + if _TOKENIZER[0] is None or ( + _TOKENIZER[1] is not None and _TOKENIZER[1] != model_name + ): + if _TOKENIZER[0] is not None: + del _TOKENIZER + gc.collect() + + if model_name is None: + # This could be inaccurate + model_name = DOCUMENT_ENCODER_MODEL + + _TOKENIZER = (AutoTokenizer.from_pretrained(model_name), model_name) + + if hasattr(_TOKENIZER[0], "is_fast") and _TOKENIZER[0].is_fast: os.environ["TOKENIZERS_PARALLELISM"] = "false" - return _TOKENIZER + + return _TOKENIZER[0] def get_local_embedding_model( - model_name: str = DOCUMENT_ENCODER_MODEL, + model_name: str, max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE, ) -> "SentenceTransformer": # NOTE: doing a local import here to avoid reduce memory usage caused by @@ -66,11 +88,19 @@ def get_local_embedding_model( from sentence_transformers import SentenceTransformer # type: ignore global _EMBED_MODEL - if _EMBED_MODEL is None or max_context_length != _EMBED_MODEL.max_seq_length: + if ( + _EMBED_MODEL[0] is None + or max_context_length != _EMBED_MODEL[0].max_seq_length + or model_name != _EMBED_MODEL[1] + ): + if _EMBED_MODEL[0] is not None: + del _EMBED_MODEL + gc.collect() + logger.info(f"Loading {model_name}") - _EMBED_MODEL = SentenceTransformer(model_name) - _EMBED_MODEL.max_seq_length = max_context_length - return _EMBED_MODEL + _EMBED_MODEL = (SentenceTransformer(model_name), model_name) + _EMBED_MODEL[0].max_seq_length = max_context_length + return _EMBED_MODEL[0] def get_local_reranking_model_ensemble( @@ -142,25 +172,24 @@ def build_model_server_url( class EmbeddingModel: def __init__( self, - model_name: str = DOCUMENT_ENCODER_MODEL, + model_name: str, + query_prefix: str | None, + passage_prefix: str | None, + normalize: bool, + server_host: str | None, # Changes depending on indexing or inference + server_port: int | None, + # The following are globals are currently not configurable max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, - model_server_host: str | None = MODEL_SERVER_HOST, - indexing_model_server_host: str | None = INDEXING_MODEL_SERVER_HOST, - model_server_port: int = MODEL_SERVER_PORT, - is_indexing: bool = CURRENT_PROCESS_IS_AN_INDEXING_JOB, ) -> None: self.model_name = model_name self.max_seq_length = max_seq_length + self.query_prefix = query_prefix + self.passage_prefix = passage_prefix + self.normalize = normalize - used_model_server_host = ( - indexing_model_server_host if is_indexing else model_server_host - ) - - model_server_url = build_model_server_url( - used_model_server_host, model_server_port - ) + model_server_url = build_model_server_url(server_host, server_port) self.embed_server_endpoint = ( - model_server_url + "/encoder/bi-encoder-embed" if model_server_url else None + f"{model_server_url}/encoder/bi-encoder-embed" if model_server_url else None ) def load_model(self) -> Optional["SentenceTransformer"]: @@ -171,11 +200,20 @@ class EmbeddingModel: model_name=self.model_name, max_context_length=self.max_seq_length ) - def encode( - self, texts: list[str], normalize_embeddings: bool = NORMALIZE_EMBEDDINGS - ) -> list[list[float]]: + def encode(self, texts: list[str], text_type: EmbedTextType) -> list[list[float]]: + if text_type == EmbedTextType.QUERY and self.query_prefix: + prefixed_texts = [self.query_prefix + text for text in texts] + elif text_type == EmbedTextType.PASSAGE and self.passage_prefix: + prefixed_texts = [self.passage_prefix + text for text in texts] + else: + prefixed_texts = texts + if self.embed_server_endpoint: - embed_request = EmbedRequest(texts=texts) + embed_request = EmbedRequest( + texts=prefixed_texts, + model_name=self.model_name, + normalize_embeddings=self.normalize, + ) try: response = requests.post( @@ -194,7 +232,7 @@ class EmbeddingModel: raise RuntimeError("Failed to load local Embedding Model") return local_model.encode( - texts, normalize_embeddings=normalize_embeddings + prefixed_texts, normalize_embeddings=self.normalize ).tolist() @@ -317,6 +355,8 @@ class IntentModel: def warm_up_models( + model_name: str, + normalize: bool, skip_cross_encoders: bool = False, indexer_only: bool = False, ) -> None: @@ -324,9 +364,20 @@ def warm_up_models( "Danswer is amazing! Check out our easy deployment guide at " "https://docs.danswer.dev/quickstart" ) - get_default_tokenizer()(warm_up_str) - EmbeddingModel().encode(texts=[warm_up_str]) + get_default_tokenizer(model_name=model_name)(warm_up_str) + + embed_model = EmbeddingModel( + model_name=model_name, + normalize=normalize, + # These don't matter, if it's a remote model, this function shouldn't be called + query_prefix=None, + passage_prefix=None, + server_host=None, + server_port=None, + ) + + embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) if indexer_only: return diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 54e69c31e4b..5cbd8b17bd3 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -7,16 +7,19 @@ import numpy from nltk.corpus import stopwords # type:ignore from nltk.stem import WordNetLemmatizer # type:ignore from nltk.tokenize import word_tokenize # type:ignore +from sqlalchemy.orm import Session from danswer.chat.models import LlmDoc +from danswer.configs.app_configs import MODEL_SERVER_HOST +from danswer.configs.app_configs import MODEL_SERVER_PORT from danswer.configs.chat_configs import HYBRID_ALPHA from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION from danswer.configs.chat_configs import NUM_RERANKED_RESULTS -from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.document_index.document_index_utils import ( translate_boost_count_to_multiplier, ) @@ -32,6 +35,7 @@ from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_nlp_models import CrossEncoderEnsembleModel from danswer.search.search_nlp_models import EmbeddingModel +from danswer.search.search_nlp_models import EmbedTextType from danswer.secondary_llm_flows.chunk_usefulness import llm_batch_eval_chunks from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion from danswer.utils.logger import setup_logger @@ -76,15 +80,6 @@ def query_processing( return query -@log_function_time(print_only=True) -def embed_query( - query: str, - prefix: str = ASYM_QUERY_PREFIX, -) -> list[float]: - prefixed_query = prefix + query - return EmbeddingModel().encode([prefixed_query])[0] - - def chunks_to_search_docs(chunks: list[InferenceChunk] | None) -> list[SearchDoc]: search_docs = ( [ @@ -140,6 +135,7 @@ def combine_retrieval_results( def doc_index_retrieval( query: SearchQuery, document_index: DocumentIndex, + db_session: Session, hybrid_alpha: float = HYBRID_ALPHA, ) -> list[InferenceChunk]: if query.search_type == SearchType.KEYWORD: @@ -149,27 +145,43 @@ def doc_index_retrieval( time_decay_multiplier=query.recency_bias_multiplier, num_to_retrieve=query.num_hits, ) - - elif query.search_type == SearchType.SEMANTIC: - top_chunks = document_index.semantic_retrieval( - query=query.query, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - ) - - elif query.search_type == SearchType.HYBRID: - top_chunks = document_index.hybrid_retrieval( - query=query.query, - filters=query.filters, - time_decay_multiplier=query.recency_bias_multiplier, - num_to_retrieve=query.num_hits, - offset=query.offset, - hybrid_alpha=hybrid_alpha, - ) - else: - raise RuntimeError("Invalid Search Flow") + db_embedding_model = get_current_db_embedding_model(db_session) + + model = EmbeddingModel( + model_name=db_embedding_model.model_name, + query_prefix=db_embedding_model.query_prefix, + passage_prefix=db_embedding_model.passage_prefix, + normalize=db_embedding_model.normalize, + # The below are globally set, this flow always uses the indexing one + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + + query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0] + + if query.search_type == SearchType.SEMANTIC: + top_chunks = document_index.semantic_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + ) + + elif query.search_type == SearchType.HYBRID: + top_chunks = document_index.hybrid_retrieval( + query=query.query, + query_embedding=query_embedding, + filters=query.filters, + time_decay_multiplier=query.recency_bias_multiplier, + num_to_retrieve=query.num_hits, + offset=query.offset, + hybrid_alpha=hybrid_alpha, + ) + + else: + raise RuntimeError("Invalid Search Flow") return top_chunks @@ -347,6 +359,7 @@ def _simplify_text(text: str) -> str: def retrieve_chunks( query: SearchQuery, document_index: DocumentIndex, + db_session: Session, hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] @@ -356,7 +369,10 @@ def retrieve_chunks( # Don't do query expansion on complex queries, rephrasings likely would not work well if not multilingual_expansion_str or "\n" in query.query or "\r" in query.query: top_chunks = doc_index_retrieval( - query=query, document_index=document_index, hybrid_alpha=hybrid_alpha + query=query, + document_index=document_index, + db_session=db_session, + hybrid_alpha=hybrid_alpha, ) else: simplified_queries = set() @@ -378,7 +394,10 @@ def retrieve_chunks( q_copy = query.copy(update={"query": rephrase}, deep=True) run_queries.append( - (doc_index_retrieval, (q_copy, document_index, hybrid_alpha)) + ( + doc_index_retrieval, + (q_copy, document_index, db_session, hybrid_alpha), + ) ) parallel_search_results = run_functions_tuples_in_parallel(run_queries) top_chunks = combine_retrieval_results(parallel_search_results) @@ -459,6 +478,7 @@ def filter_chunks( def full_chunk_search( query: SearchQuery, document_index: DocumentIndex, + db_session: Session, hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] @@ -471,6 +491,7 @@ def full_chunk_search( search_generator = full_chunk_search_generator( search_query=query, document_index=document_index, + db_session=db_session, hybrid_alpha=hybrid_alpha, multilingual_expansion_str=multilingual_expansion_str, retrieval_metrics_callback=retrieval_metrics_callback, @@ -489,6 +510,7 @@ def empty_search_generator() -> Iterator[list[InferenceChunk] | list[bool]]: def full_chunk_search_generator( search_query: SearchQuery, document_index: DocumentIndex, + db_session: Session, hybrid_alpha: float = HYBRID_ALPHA, # Only applicable to hybrid search multilingual_expansion_str: str | None = MULTILINGUAL_QUERY_EXPANSION, retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] @@ -503,6 +525,7 @@ def full_chunk_search_generator( retrieved_chunks = retrieve_chunks( query=search_query, document_index=document_index, + db_session=db_session, hybrid_alpha=hybrid_alpha, multilingual_expansion_str=multilingual_expansion_str, retrieval_metrics_callback=retrieval_metrics_callback, diff --git a/backend/danswer/server/danswer_api/ingestion.py b/backend/danswer/server/danswer_api/ingestion.py index 2a02816577a..8856e20d644 100644 --- a/backend/danswer/server/danswer_api/ingestion.py +++ b/backend/danswer/server/danswer_api/ingestion.py @@ -14,11 +14,14 @@ from danswer.db.connector import fetch_connector_by_id from danswer.db.connector import fetch_ingestion_connector_by_name from danswer.db.connector_credential_pair import get_connector_credential_pair from danswer.db.credentials import fetch_credential_by_id +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.engine import get_session from danswer.document_index.document_index_utils import get_both_index_names from danswer.document_index.factory import get_default_document_index from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.indexing.embedder import DefaultIndexingEmbedder from danswer.indexing.indexing_pipeline import build_indexing_pipeline from danswer.server.danswer_api.models import IngestionDocument from danswer.server.danswer_api.models import IngestionResult @@ -149,8 +152,19 @@ def document_ingestion( primary_index_name=curr_ind_name, secondary_index_name=None ) + db_embedding_model = get_current_db_embedding_model(db_session) + + index_embedding_model = DefaultIndexingEmbedder( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + query_prefix=db_embedding_model.query_prefix, + passage_prefix=db_embedding_model.passage_prefix, + ) + indexing_pipeline = build_indexing_pipeline( - ignore_time_skip=True, document_index=curr_doc_index + embedder=index_embedding_model, + document_index=curr_doc_index, + ignore_time_skip=True, ) new_doc, chunks = indexing_pipeline( @@ -167,8 +181,25 @@ def document_ingestion( primary_index_name=curr_ind_name, secondary_index_name=None ) + sec_db_embedding_model = get_secondary_db_embedding_model(db_session) + + if sec_db_embedding_model is None: + # Should not ever happen + raise RuntimeError( + "Secondary index exists but no embedding model configured" + ) + + new_index_embedding_model = DefaultIndexingEmbedder( + model_name=sec_db_embedding_model.model_name, + normalize=sec_db_embedding_model.normalize, + query_prefix=sec_db_embedding_model.query_prefix, + passage_prefix=sec_db_embedding_model.passage_prefix, + ) + sec_ind_pipeline = build_indexing_pipeline( - ignore_time_skip=True, document_index=sec_doc_index + embedder=new_index_embedding_model, + document_index=sec_doc_index, + ignore_time_skip=True, ) sec_ind_pipeline( diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 28bb8d5036a..af129a5cf37 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -54,7 +54,10 @@ from danswer.db.credentials import delete_google_drive_service_account_credentia from danswer.db.credentials import fetch_credential_by_id from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.document import get_document_cnts_for_cc_pairs +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.engine import get_session +from danswer.db.index_attempt import cancel_indexing_attempts_for_connector from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempts_for_cc_pair from danswer.db.index_attempt import get_latest_index_attempts @@ -347,6 +350,7 @@ def upload_files( @router.get("/admin/connector/indexing-status") def get_connector_indexing_status( + secondary_index: bool = False, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[ConnectorIndexingStatus]: @@ -362,8 +366,9 @@ def get_connector_indexing_status( ] latest_index_attempts = get_latest_index_attempts( - db_session=db_session, connector_credential_pair_identifiers=cc_pair_identifiers, + secondary_index=secondary_index, + db_session=db_session, ) cc_pair_to_latest_index_attempt = { (index_attempt.connector_id, index_attempt.credential_id): index_attempt @@ -449,6 +454,9 @@ def update_connector_from_model( status_code=404, detail=f"Connector {connector_id} does not exist" ) + if updated_connector.disabled: + cancel_indexing_attempts_for_connector(connector_id, db_session) + return ConnectorSnapshot( id=updated_connector.id, name=updated_connector.name, @@ -526,12 +534,31 @@ def connector_run_once( ) ] + embedding_model = get_current_db_embedding_model(db_session) + + secondary_embedding_model = get_secondary_db_embedding_model(db_session) + index_attempt_ids = [ - create_index_attempt(run_info.connector_id, credential_id, db_session) + create_index_attempt( + run_info.connector_id, credential_id, embedding_model.id, db_session + ) for credential_id in credential_ids if credential_id not in skipped_credentials ] + if secondary_embedding_model is not None: + # Secondary index doesn't have to be returned + [ + create_index_attempt( + run_info.connector_id, + credential_id, + secondary_embedding_model.id, + db_session, + ) + for credential_id in credential_ids + if credential_id not in skipped_credentials + ] + if not index_attempt_ids: raise HTTPException( status_code=400, diff --git a/backend/danswer/server/documents/document.py b/backend/danswer/server/documents/document.py index 57f515bff98..2778beaa042 100644 --- a/backend/danswer/server/documents/document.py +++ b/backend/danswer/server/documents/document.py @@ -5,9 +5,9 @@ from fastapi import Query from sqlalchemy.orm import Session from danswer.auth.users import current_user +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session from danswer.db.models import User -from danswer.document_index.document_index_utils import get_index_name from danswer.document_index.factory import get_default_document_index from danswer.llm.utils import get_default_llm_token_encode from danswer.search.access_filters import build_access_filters_for_user @@ -27,8 +27,10 @@ def get_document_info( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> DocumentInfo: + embedding_model = get_current_db_embedding_model(db_session) + document_index = get_default_document_index( - primary_index_name=get_index_name(db_session), secondary_index_name=None + primary_index_name=embedding_model.index_name, secondary_index_name=None ) user_acl_filters = build_access_filters_for_user(user, db_session) @@ -61,8 +63,10 @@ def get_chunk_info( user: User | None = Depends(current_user), db_session: Session = Depends(get_session), ) -> ChunkInfo: + embedding_model = get_current_db_embedding_model(db_session) + document_index = get_default_document_index( - primary_index_name=get_index_name(db_session), secondary_index_name=None + primary_index_name=embedding_model.index_name, secondary_index_name=None ) user_acl_filters = build_access_filters_for_user(user, db_session) diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index 772726b2fcb..9800032520e 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -6,8 +6,8 @@ from fastapi import Depends from pydantic import BaseModel from sqlalchemy.orm import Session +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session -from danswer.document_index.document_index_utils import get_index_name from danswer.document_index.factory import get_default_document_index from danswer.search.access_filters import build_access_filters_for_user from danswer.search.models import IndexFilters @@ -73,9 +73,7 @@ def gpt_search( query = search_request.query user_acl_filters = build_access_filters_for_user(None, db_session) - final_filters = IndexFilters( - access_control_list=user_acl_filters, - ) + final_filters = IndexFilters(access_control_list=user_acl_filters) search_query = SearchQuery( query=query, @@ -84,13 +82,14 @@ def gpt_search( skip_llm_chunk_filter=True, ) + embedding_model = get_current_db_embedding_model(db_session) + document_index = get_default_document_index( - primary_index_name=get_index_name(db_session), secondary_index_name=None + primary_index_name=embedding_model.index_name, secondary_index_name=None ) top_chunks, __ = full_chunk_search( - query=search_query, - document_index=document_index, + query=search_query, document_index=document_index, db_session=db_session ) return GptSearchResponse( diff --git a/backend/danswer/server/manage/secondary_index.py b/backend/danswer/server/manage/secondary_index.py index 198730e3de7..d1c5ffdb278 100644 --- a/backend/danswer/server/manage/secondary_index.py +++ b/backend/danswer/server/manage/secondary_index.py @@ -1,10 +1,22 @@ from fastapi import APIRouter from fastapi import Depends +from fastapi import HTTPException +from fastapi import status +from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user +from danswer.db.embedding_model import create_embedding_model +from danswer.db.embedding_model import get_current_db_embedding_model +from danswer.db.embedding_model import get_secondary_db_embedding_model +from danswer.db.embedding_model import update_embedding_model_status +from danswer.db.engine import get_session +from danswer.db.index_attempt import expire_index_attempts +from danswer.db.models import IndexModelStatus from danswer.db.models import User +from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import EmbeddingModelDetail from danswer.server.manage.models import ModelVersionResponse +from danswer.server.models import IdReturn from danswer.utils.logger import setup_logger router = APIRouter(prefix="/secondary-index") @@ -15,19 +27,93 @@ logger = setup_logger() def set_new_embedding_model( embed_model_details: EmbeddingModelDetail, _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> IdReturn: + """Creates a new EmbeddingModel row and cancels the previous secondary indexing if any + Gives an error if the same model name is used as the current or secondary index + """ + current_model = get_current_db_embedding_model(db_session) + + if embed_model_details.model_name == current_model.model_name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="New embedding model is the same as the currently active one.", + ) + + secondary_model = get_secondary_db_embedding_model(db_session) + + if secondary_model: + if embed_model_details.model_name == secondary_model.model_name: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Already reindexing with {secondary_model.model_name}", + ) + + # Cancel any background indexing jobs + expire_index_attempts( + embedding_model_id=secondary_model.id, db_session=db_session + ) + + # Mark previous model as a past model directly + update_embedding_model_status( + embedding_model=secondary_model, + new_status=IndexModelStatus.PAST, + db_session=db_session, + ) + + new_model = create_embedding_model( + model_details=embed_model_details, + db_session=db_session, + ) + + # Ensure Vespa has the new index immediately + document_index = get_default_document_index( + primary_index_name=current_model.index_name, + secondary_index_name=new_model.index_name, + ) + document_index.ensure_indices_exist( + index_embedding_dim=current_model.model_dim, + secondary_index_embedding_dim=new_model.model_dim, + ) + + return IdReturn(id=new_model.id) + + +@router.post("/cancel-new-embedding") +def cancel_new_embedding( + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), ) -> None: - raise NotImplementedError() + secondary_model = get_secondary_db_embedding_model(db_session) + + if secondary_model: + expire_index_attempts( + embedding_model_id=secondary_model.id, db_session=db_session + ) + + update_embedding_model_status( + embedding_model=secondary_model, + new_status=IndexModelStatus.PAST, + db_session=db_session, + ) @router.get("/get-current-embedding-model") def get_current_embedding_model( _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), ) -> ModelVersionResponse: - raise NotImplementedError() + current_model = get_current_db_embedding_model(db_session) + return ModelVersionResponse(model_name=current_model.model_name) @router.get("/get-secondary-embedding-model") def get_secondary_embedding_model( _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), ) -> ModelVersionResponse: - raise NotImplementedError() + next_model = get_secondary_db_embedding_model(db_session) + + return ModelVersionResponse( + model_name=next_model.model_name if next_model else None + ) diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index df3143629e5..d616edd4f86 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -17,3 +17,7 @@ class StatusResponse(GenericModel, Generic[DataT]): class ApiKey(BaseModel): api_key: str + + +class IdReturn(BaseModel): + id: int diff --git a/backend/danswer/server/query_and_chat/query_backend.py b/backend/danswer/server/query_and_chat/query_backend.py index 552bf266aad..0c27ecaf388 100644 --- a/backend/danswer/server/query_and_chat/query_backend.py +++ b/backend/danswer/server/query_and_chat/query_backend.py @@ -7,10 +7,10 @@ from sqlalchemy.orm import Session from danswer.auth.users import current_admin_user from danswer.auth.users import current_user from danswer.configs.constants import DocumentSource +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session from danswer.db.models import User from danswer.db.tag import get_tags_by_value_prefix_for_source_types -from danswer.document_index.document_index_utils import get_index_name from danswer.document_index.factory import get_default_document_index from danswer.document_index.vespa.index import VespaIndex from danswer.one_shot_answer.answer_question import stream_search_answer @@ -54,8 +54,10 @@ def admin_search( access_control_list=user_acl_filters, ) + embedding_model = get_current_db_embedding_model(db_session) + document_index = get_default_document_index( - primary_index_name=get_index_name(db_session), secondary_index_name=None + primary_index_name=embedding_model.index_name, secondary_index_name=None ) if not isinstance(document_index, VespaIndex): @@ -106,10 +108,15 @@ def get_tags( @basic_router.post("/search-intent") def get_search_type( - simple_query: SimpleQueryRequest, _: User = Depends(current_user) + simple_query: SimpleQueryRequest, + _: User = Depends(current_user), + db_session: Session = Depends(get_session), ) -> HelperResponse: logger.info(f"Calculating intent for {simple_query.query}") - return recommend_search_flow(simple_query.query) + embedding_model = get_current_db_embedding_model(db_session) + return recommend_search_flow( + simple_query.query, model_name=embedding_model.model_name + ) @basic_router.post("/query-validation") diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 1eb27ee96a6..1220736dea7 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -1,10 +1,10 @@ +from typing import TYPE_CHECKING + from fastapi import APIRouter from fastapi import HTTPException from danswer.configs.model_configs import CROSS_ENCODER_MODEL_ENSEMBLE -from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL -from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS -from danswer.search.search_nlp_models import get_local_embedding_model +from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE from danswer.search.search_nlp_models import get_local_reranking_model_ensemble from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -13,19 +13,46 @@ from shared_models.model_server_models import EmbedResponse from shared_models.model_server_models import RerankRequest from shared_models.model_server_models import RerankResponse +if TYPE_CHECKING: + from sentence_transformers import SentenceTransformer # type: ignore + + logger = setup_logger() WARM_UP_STRING = "Danswer is amazing" router = APIRouter(prefix="/encoder") +_GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {} + + +def get_embedding_model( + model_name: str, + max_context_length: int = DOC_EMBEDDING_CONTEXT_SIZE, +) -> "SentenceTransformer": + from sentence_transformers import SentenceTransformer # type: ignore + + global _GLOBAL_MODELS_DICT # A dictionary to store models + + if _GLOBAL_MODELS_DICT is None: + _GLOBAL_MODELS_DICT = {} + + if model_name not in _GLOBAL_MODELS_DICT: + logger.info(f"Loading {model_name}") + model = SentenceTransformer(model_name) + model.max_seq_length = max_context_length + _GLOBAL_MODELS_DICT[model_name] = model + elif max_context_length != _GLOBAL_MODELS_DICT[model_name].max_seq_length: + _GLOBAL_MODELS_DICT[model_name].max_seq_length = max_context_length + + return _GLOBAL_MODELS_DICT[model_name] + @log_function_time(print_only=True) def embed_text( - texts: list[str], - normalize_embeddings: bool = NORMALIZE_EMBEDDINGS, + texts: list[str], model_name: str, normalize_embeddings: bool ) -> list[list[float]]: - model = get_local_embedding_model() + model = get_embedding_model(model_name=model_name) embeddings = model.encode(texts, normalize_embeddings=normalize_embeddings) if not isinstance(embeddings, list): @@ -49,7 +76,11 @@ def process_embed_request( embed_request: EmbedRequest, ) -> EmbedResponse: try: - embeddings = embed_text(texts=embed_request.texts) + embeddings = embed_text( + texts=embed_request.texts, + model_name=embed_request.model_name, + normalize_embeddings=embed_request.normalize_embeddings, + ) return EmbedResponse(embeddings=embeddings) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @@ -66,11 +97,6 @@ def process_rerank_request(embed_request: RerankRequest) -> RerankResponse: raise HTTPException(status_code=500, detail=str(e)) -def warm_up_bi_encoder() -> None: - logger.info(f"Warming up Bi-Encoders: {DOCUMENT_ENCODER_MODEL}") - get_local_embedding_model().encode(WARM_UP_STRING) - - def warm_up_cross_encoders() -> None: logger.info(f"Warming up Cross-Encoders: {CROSS_ENCODER_MODEL_ENSEMBLE}") diff --git a/backend/model_server/main.py b/backend/model_server/main.py index 3b7ed5747c6..dead931dcdf 100644 --- a/backend/model_server/main.py +++ b/backend/model_server/main.py @@ -10,7 +10,6 @@ from danswer.utils.logger import setup_logger from model_server.custom_models import router as custom_models_router from model_server.custom_models import warm_up_intent_model from model_server.encoders import router as encoders_router -from model_server.encoders import warm_up_bi_encoder from model_server.encoders import warm_up_cross_encoders @@ -33,7 +32,6 @@ def get_model_app() -> FastAPI: torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads())) logger.info(f"Torch Threads: {torch.get_num_threads()}") - warm_up_bi_encoder() warm_up_cross_encoders() warm_up_intent_model() diff --git a/backend/shared_models/model_server_models.py b/backend/shared_models/model_server_models.py index 263b2b1f5ea..e3b04557d2a 100644 --- a/backend/shared_models/model_server_models.py +++ b/backend/shared_models/model_server_models.py @@ -3,6 +3,8 @@ from pydantic import BaseModel class EmbedRequest(BaseModel): texts: list[str] + model_name: str + normalize_embeddings: bool class EmbedResponse(BaseModel): diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index 9561ca9f6e8..7cd3e6068c6 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -8,8 +8,8 @@ from typing import TextIO from sqlalchemy.orm import Session from danswer.chat.chat_utils import get_chunks_for_qa +from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine -from danswer.document_index.document_index_utils import get_index_name from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk from danswer.search.models import IndexFilters @@ -97,15 +97,16 @@ def get_search_results( rerank_metrics = MetricsHander[RerankMetricsContainer]() with Session(get_sqlalchemy_engine()) as db_session: - ind_name = get_index_name(db_session) + embedding_model = get_current_db_embedding_model(db_session) document_index = get_default_document_index( - primary_index_name=ind_name, secondary_index_name=None + primary_index_name=embedding_model.index_name, secondary_index_name=None ) top_chunks, llm_chunk_selection = full_chunk_search( query=search_query, document_index=document_index, + db_session=db_session, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric, ) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 5f391f67eb0..c533b124639 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -228,7 +228,7 @@ services: while :; do sleep 6h & wait $${!}; nginx -s reload; done & nginx -g \"daemon off;\"" # Run with --profile model-server to bring up the danswer-model-server container # Be sure to change MODEL_SERVER_HOST (see above) as well - # ie. MODEL_SERVER_HOST="model_server" docker-compose -f docker-compose.dev.yml -p danswer-stack --profile model-server up -d --build + # ie. MODEL_SERVER_HOST="model_server" docker compose -f docker-compose.dev.yml -p danswer-stack --profile model-server up -d --build model_server: image: danswer/danswer-model-server:latest build: