Kill Index Attempts for previous model (#1088)

This commit is contained in:
Yuhong Sun 2024-02-16 18:35:01 -08:00 committed by GitHub
parent 269431cc9d
commit 514e7f6e41
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 86 additions and 55 deletions

View File

@ -28,6 +28,7 @@ from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import count_unique_cc_pairs_with_index_attempts
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
@ -381,6 +382,9 @@ def check_index_swap(db_session: Session) -> None:
db_session=db_session,
)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)

View File

@ -33,7 +33,7 @@ def get_index_attempt(
def create_index_attempt(
connector_id: int,
credential_id: int,
embedding_model_id: int | None,
embedding_model_id: int,
db_session: Session,
from_beginning: bool = False,
) -> int:
@ -248,24 +248,41 @@ def cancel_indexing_attempts_for_connector(
EmbeddingModel.status != IndexModelStatus.FUTURE
)
stmt = delete(IndexAttempt).where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.status == IndexingStatus.NOT_STARTED,
stmt = (
update(IndexAttempt)
.where(
IndexAttempt.connector_id == connector_id,
IndexAttempt.status == IndexingStatus.NOT_STARTED,
)
.values(status=IndexingStatus.FAILED)
)
if not include_secondary_index:
stmt = stmt.where(
or_(
IndexAttempt.embedding_model_id.is_(None),
IndexAttempt.embedding_model_id.in_(subquery),
)
)
stmt = stmt.where(IndexAttempt.embedding_model_id.in_(subquery))
db_session.execute(stmt)
db_session.commit()
def cancel_indexing_attempts_past_model(
db_session: Session,
) -> None:
db_session.execute(
update(IndexAttempt)
.where(
IndexAttempt.status.in_(
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
),
IndexAttempt.embedding_model_id == EmbeddingModel.id,
EmbeddingModel.status == IndexModelStatus.PAST,
)
.values(status=IndexingStatus.FAILED)
)
db_session.commit()
def count_unique_cc_pairs_with_index_attempts(
embedding_model_id: int | None,
db_session: Session,

View File

@ -44,6 +44,7 @@ from danswer.db.credentials import create_initial_public_credential
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.document_index.factory import get_default_document_index
from danswer.llm.factory import get_default_llm
from danswer.search.search_nlp_models import warm_up_models
@ -209,6 +210,8 @@ def get_application() -> FastAPI:
@application.on_event("startup")
def startup_event() -> None:
engine = get_sqlalchemy_engine()
verify_auth = fetch_versioned_implementation(
"danswer.auth.users", "verify_auth_setting"
)
@ -242,66 +245,69 @@ def get_application() -> FastAPI:
f"Using multilingual flow with languages: {MULTILINGUAL_QUERY_EXPANSION}"
)
with Session(get_sqlalchemy_engine()) as db_session:
with Session(engine) as db_session:
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
cancel_indexing_attempts_past_model(db_session)
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(f'Query embedding prefix: "{db_embedding_model.query_prefix}"')
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
logger.info(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.info(
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
)
logger.info(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
)
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
if ENABLE_RERANKING_REAL_TIME_FLOW:
logger.info("Reranking step of search flow is enabled.")
if torch.cuda.is_available():
logger.info("GPU is available")
if MODEL_SERVER_HOST:
logger.info(
f"Using Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
)
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Warming up local NLP models.")
warm_up_models(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
skip_cross_encoders=not ENABLE_RERANKING_REAL_TIME_FLOW,
)
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
if torch.cuda.is_available():
logger.info("GPU is available")
else:
logger.info("GPU is not available")
logger.info(f"Torch Threads: {torch.get_num_threads()}")
logger.info("Verifying default connector/credential exist.")
with Session(get_sqlalchemy_engine()) as db_session:
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords", quiet=True)
nltk.download("wordnet", quiet=True)
nltk.download("punkt", quiet=True)
logger.info("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.info("Loading default Prompts and Personas")
load_chat_yamls()
logger.info("Loading default Prompts and Personas")
load_chat_yamls()
logger.info("Verifying Document Index(s) is/are available.")
logger.info("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
else None,
)
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
else None,
)
optional_telemetry(
record_type=RecordType.VERSION, data={"version": __version__}

View File

@ -57,6 +57,7 @@ from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.index_attempt import cancel_indexing_attempts_for_connector
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
from danswer.db.index_attempt import get_latest_index_attempts
@ -456,6 +457,9 @@ def update_connector_from_model(
if updated_connector.disabled:
cancel_indexing_attempts_for_connector(connector_id, db_session)
# Just for good measure
cancel_indexing_attempts_past_model(db_session)
return ConnectorSnapshot(
id=updated_connector.id,
name=updated_connector.name,