From 514e7f6e41cd316e3dda0f971c54ff4f65fc5c1c Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Fri, 16 Feb 2024 18:35:01 -0800 Subject: [PATCH] Kill Index Attempts for previous model (#1088) --- backend/danswer/background/update.py | 4 + backend/danswer/db/index_attempt.py | 37 +++++-- backend/danswer/main.py | 96 ++++++++++--------- backend/danswer/server/documents/connector.py | 4 + 4 files changed, 86 insertions(+), 55 deletions(-) diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 4818ec8ab..851ada5d0 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -28,6 +28,7 @@ from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.embedding_model import update_embedding_model_status from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt @@ -381,6 +382,9 @@ def check_index_swap(db_session: Session) -> None: db_session=db_session, ) + # Expire jobs for the now past index/embedding model + cancel_indexing_attempts_past_model(db_session) + # Recount aggregates for cc_pair in all_cc_pairs: resync_cc_pair(cc_pair, db_session=db_session) diff --git a/backend/danswer/db/index_attempt.py b/backend/danswer/db/index_attempt.py index 0d674b1aa..7e08c167b 100644 --- a/backend/danswer/db/index_attempt.py +++ b/backend/danswer/db/index_attempt.py @@ -33,7 +33,7 @@ def get_index_attempt( def create_index_attempt( connector_id: int, credential_id: int, - embedding_model_id: int | None, + embedding_model_id: int, db_session: Session, from_beginning: bool = False, ) -> int: @@ -248,24 +248,41 @@ def cancel_indexing_attempts_for_connector( EmbeddingModel.status != IndexModelStatus.FUTURE ) - stmt = delete(IndexAttempt).where( - IndexAttempt.connector_id == connector_id, - IndexAttempt.status == IndexingStatus.NOT_STARTED, + stmt = ( + update(IndexAttempt) + .where( + IndexAttempt.connector_id == connector_id, + IndexAttempt.status == IndexingStatus.NOT_STARTED, + ) + .values(status=IndexingStatus.FAILED) ) if not include_secondary_index: - stmt = stmt.where( - or_( - IndexAttempt.embedding_model_id.is_(None), - IndexAttempt.embedding_model_id.in_(subquery), - ) - ) + stmt = stmt.where(IndexAttempt.embedding_model_id.in_(subquery)) db_session.execute(stmt) db_session.commit() +def cancel_indexing_attempts_past_model( + db_session: Session, +) -> None: + db_session.execute( + update(IndexAttempt) + .where( + IndexAttempt.status.in_( + [IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED] + ), + IndexAttempt.embedding_model_id == EmbeddingModel.id, + EmbeddingModel.status == IndexModelStatus.PAST, + ) + .values(status=IndexingStatus.FAILED) + ) + + db_session.commit() + + def count_unique_cc_pairs_with_index_attempts( embedding_model_id: int | None, db_session: Session, diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 87d137aeb..6268263ae 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -44,6 +44,7 @@ from danswer.db.credentials import create_initial_public_credential from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.document_index.factory import get_default_document_index from danswer.llm.factory import get_default_llm from danswer.search.search_nlp_models import warm_up_models @@ -209,6 +210,8 @@ def get_application() -> FastAPI: @application.on_event("startup") def startup_event() -> None: + engine = get_sqlalchemy_engine() + verify_auth = fetch_versioned_implementation( "danswer.auth.users", "verify_auth_setting" ) @@ -242,66 +245,69 @@ def get_application() -> FastAPI: f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" ) - with Session(get_sqlalchemy_engine()) as db_session: + with Session(engine) as db_session: db_embedding_model = get_current_db_embedding_model(db_session) secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) - if ENABLE_RERANKING_REAL_TIME_FLOW: - logger.info("Reranking step of search flow is enabled.") + cancel_indexing_attempts_past_model(db_session) - logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"') - if db_embedding_model.query_prefix or db_embedding_model.passage_prefix: - logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"') - logger.info( - f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"' - ) + logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"') + if db_embedding_model.query_prefix or db_embedding_model.passage_prefix: + logger.info( + f'Query embedding prefix: "{db_embedding_model.query_prefix}"' + ) + logger.info( + f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"' + ) - if MODEL_SERVER_HOST: - logger.info( - f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" - ) - else: - logger.info("Warming up local NLP models.") - warm_up_models( - model_name=db_embedding_model.model_name, - normalize=db_embedding_model.normalize, - skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW, - ) + if ENABLE_RERANKING_REAL_TIME_FLOW: + logger.info("Reranking step of search flow is enabled.") - if torch.cuda.is_available(): - logger.info("GPU is available") + if MODEL_SERVER_HOST: + logger.info( + f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" + ) else: - logger.info("GPU is not available") - logger.info(f"Torch Threads: {torch.get_num_threads()}") + logger.info("Warming up local NLP models.") + warm_up_models( + model_name=db_embedding_model.model_name, + normalize=db_embedding_model.normalize, + skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW, + ) - logger.info("Verifying query preprocessing (NLTK) data is downloaded") - nltk.download("stopwords", quiet=True) - nltk.download("wordnet", quiet=True) - nltk.download("punkt", quiet=True) + if torch.cuda.is_available(): + logger.info("GPU is available") + else: + logger.info("GPU is not available") + logger.info(f"Torch Threads: {torch.get_num_threads()}") - logger.info("Verifying default connector/credential exist.") - with Session(get_sqlalchemy_engine()) as db_session: + logger.info("Verifying query preprocessing (NLTK) data is downloaded") + nltk.download("stopwords", quiet=True) + nltk.download("wordnet", quiet=True) + nltk.download("punkt", quiet=True) + + logger.info("Verifying default connector/credential exist.") create_initial_public_credential(db_session) create_initial_default_connector(db_session) associate_default_cc_pair(db_session) - logger.info("Loading default Prompts and Personas") - load_chat_yamls() + logger.info("Loading default Prompts and Personas") + load_chat_yamls() - logger.info("Verifying Document Index(s) is/are available.") + logger.info("Verifying Document Index(s) is/are available.") - document_index = get_default_document_index( - primary_index_name=db_embedding_model.index_name, - secondary_index_name=secondary_db_embedding_model.index_name - if secondary_db_embedding_model - else None, - ) - document_index.ensure_indices_exist( - index_embedding_dim=db_embedding_model.model_dim, - secondary_index_embedding_dim=secondary_db_embedding_model.model_dim - if secondary_db_embedding_model - else None, - ) + document_index = get_default_document_index( + primary_index_name=db_embedding_model.index_name, + secondary_index_name=secondary_db_embedding_model.index_name + if secondary_db_embedding_model + else None, + ) + document_index.ensure_indices_exist( + index_embedding_dim=db_embedding_model.model_dim, + secondary_index_embedding_dim=secondary_db_embedding_model.model_dim + if secondary_db_embedding_model + else None, + ) optional_telemetry( record_type=RecordType.VERSION, data={"version": __version__} diff --git a/backend/danswer/server/documents/connector.py b/backend/danswer/server/documents/connector.py index 4c2b55bc7..a5fa93d40 100644 --- a/backend/danswer/server/documents/connector.py +++ b/backend/danswer/server/documents/connector.py @@ -57,6 +57,7 @@ from danswer.db.document import get_document_cnts_for_cc_pairs from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.engine import get_session from danswer.db.index_attempt import cancel_indexing_attempts_for_connector +from danswer.db.index_attempt import cancel_indexing_attempts_past_model from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempts_for_cc_pair from danswer.db.index_attempt import get_latest_index_attempts @@ -456,6 +457,9 @@ def update_connector_from_model( if updated_connector.disabled: cancel_indexing_attempts_for_connector(connector_id, db_session) + # Just for good measure + cancel_indexing_attempts_past_model(db_session) + return ConnectorSnapshot( id=updated_connector.id, name=updated_connector.name,