From cb2169f2a39adf4caaadad984d7f3d9a7c584ff9 Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Thu, 12 Sep 2024 15:12:17 -0700 Subject: [PATCH] Warm up reranker on model switch (#2408) * warm up reranker on model switch * properly type * fix issue * Update search_settings.py --- backend/danswer/db/search_settings.py | 9 +++++++++ web/src/app/admin/embeddings/RerankingFormPage.tsx | 4 +++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/backend/danswer/db/search_settings.py b/backend/danswer/db/search_settings.py index 01f458493f7d..bb869c471dc6 100644 --- a/backend/danswer/db/search_settings.py +++ b/backend/danswer/db/search_settings.py @@ -20,6 +20,7 @@ from danswer.db.models import IndexModelStatus from danswer.db.models import SearchSettings from danswer.indexing.models import IndexingSetting from danswer.natural_language_processing.search_nlp_models import clean_model_name +from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder from danswer.search.models import SavedSearchSettings from danswer.server.manage.embedding.models import ( CloudEmbeddingProvider as ServerCloudEmbeddingProvider, @@ -180,6 +181,14 @@ def update_current_search_settings( logger.warning("No current search settings found to update") return + # Whenever we update the current search settings, we should ensure that the local reranking model is warmed up. + if ( + current_settings.provider_type is None + and search_settings.rerank_model_name is not None + and current_settings.rerank_model_name != search_settings.rerank_model_name + ): + warm_up_cross_encoder(search_settings.rerank_model_name) + update_search_settings(current_settings, search_settings, preserved_fields) db_session.commit() logger.info("Current search settings updated successfully") diff --git a/web/src/app/admin/embeddings/RerankingFormPage.tsx b/web/src/app/admin/embeddings/RerankingFormPage.tsx index d798498e8b33..eadcadc9654c 100644 --- a/web/src/app/admin/embeddings/RerankingFormPage.tsx +++ b/web/src/app/admin/embeddings/RerankingFormPage.tsx @@ -66,6 +66,7 @@ const RerankingDetailsForm = forwardRef< > {({ values, setFieldValue, resetForm }) => { const resetRerankingValues = () => { + setRerankingDetails(originalRerankingDetails); resetForm(); }; @@ -191,7 +192,8 @@ const RerankingDetailsForm = forwardRef< {card.rerank_provider_type === RerankerProvider.LITELLM ? ( - ) : RerankerProvider.COHERE ? ( + ) : card.rerank_provider_type === + RerankerProvider.COHERE ? ( ) : (