Kill Index Attempts for previous model (#1088)

This commit is contained in:
Yuhong Sun
2024-02-16 18:35:01 -08:00
committed by GitHub
parent 269431cc9d
commit 514e7f6e41
4 changed files with 86 additions and 55 deletions

View File

@ -28,6 +28,7 @@ from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_db_current_time from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_index_attempt
@ -381,6 +382,9 @@ def check_index_swap(db_session: Session) -> None:
db_session=db_session, db_session=db_session,
) )
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates # Recount aggregates
for cc_pair in all_cc_pairs: for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session) resync_cc_pair(cc_pair, db_session=db_session)

View File

@ -33,7 +33,7 @@ def get_index_attempt(
def create_index_attempt( def create_index_attempt(
connector_id: int, connector_id: int,
credential_id: int, credential_id: int,
embedding_model_id: int | None, embedding_model_id: int,
db_session: Session, db_session: Session,
from_beginning: bool = False, from_beginning: bool = False,
) -> int: ) -> int:
@ -248,24 +248,41 @@ def cancel_indexing_attempts_for_connector(
EmbeddingModel.status != IndexModelStatus.FUTURE EmbeddingModel.status != IndexModelStatus.FUTURE
) )
stmt = delete(IndexAttempt).where( stmt = (
update(IndexAttempt)
.where(
IndexAttempt.connector_id == connector_id, IndexAttempt.connector_id == connector_id,
IndexAttempt.status == IndexingStatus.NOT_STARTED, IndexAttempt.status == IndexingStatus.NOT_STARTED,
) )
.values(status=IndexingStatus.FAILED)
)
if not include_secondary_index: if not include_secondary_index:
stmt = stmt.where( stmt = stmt.where(IndexAttempt.embedding_model_id.in_(subquery))
or_(
IndexAttempt.embedding_model_id.is_(None),
IndexAttempt.embedding_model_id.in_(subquery),
)
)
db_session.execute(stmt) db_session.execute(stmt)
db_session.commit() db_session.commit()
def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
db_session.execute(
update(IndexAttempt)
.where(
IndexAttempt.status.in_(
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
),
IndexAttempt.embedding_model_id == EmbeddingModel.id,
EmbeddingModel.status == IndexModelStatus.PAST,
)
.values(status=IndexingStatus.FAILED)
)
db_session.commit()
def count_unique_cc_pairs_with_index_attempts( def count_unique_cc_pairs_with_index_attempts(
embedding_model_id: int | None, embedding_model_id: int | None,
db_session: Session, db_session: Session,

View File

@ -44,6 +44,7 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm from danswer.llm.factory import get_default_llm
from danswer.search.search_nlp_models import warm_up_models from danswer.search.search_nlp_models import warm_up_models
@ -209,6 +210,8 @@ def get_application() -> FastAPI:
@application.on_event("startup") @application.on_event("startup")
def startup_event() -> None: def startup_event() -> None:
engine = get_sqlalchemy_engine()
verify_auth = fetch_versioned_implementation( verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting" "danswer.auth.users", "verify_auth_setting"
) )
@ -242,20 +245,24 @@ def get_application() -> FastAPI:
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}" f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
) )
with Session(get_sqlalchemy_engine()) as db_session: with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session) db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session) secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
if ENABLE_RERANKING_REAL_TIME_FLOW: cancel_indexing_attempts_past_model(db_session)
logger.info("Reranking step of search flow is enabled.")
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"') logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix: if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"') logger.info(
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
)
logger.info( logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"' f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
) )
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
if MODEL_SERVER_HOST: if MODEL_SERVER_HOST:
logger.info( logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}" f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
@ -280,7 +287,6 @@ def get_application() -> FastAPI:
nltk.download("punkt", quiet=True) nltk.download("punkt", quiet=True)
logger.info("Verifying default connector/credential exist.") logger.info("Verifying default connector/credential exist.")
with Session(get_sqlalchemy_engine()) as db_session:
create_initial_public_credential(db_session) create_initial_public_credential(db_session)
create_initial_default_connector(db_session) create_initial_default_connector(db_session)
associate_default_cc_pair(db_session) associate_default_cc_pair(db_session)

View File

@ -57,6 +57,7 @@ from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.embedding_model import get_current_db_embedding_model from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempts from danswer.db.index_attempt import get_latest_index_attempts
@ -456,6 +457,9 @@ def update_connector_from_model(
if updated_connector.disabled: if updated_connector.disabled:
cancel_indexing_attempts_for_connector(connector_id, db_session) cancel_indexing_attempts_for_connector(connector_id, db_session)
# Just for good measure
cancel_indexing_attempts_past_model(db_session)
return ConnectorSnapshot( return ConnectorSnapshot(
id=updated_connector.id, id=updated_connector.id,
name=updated_connector.name, name=updated_connector.name,