From 53387ab3eb12e8274b5e767f0775f879b98230ec Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Tue, 20 Aug 2024 17:31:00 -0700 Subject: [PATCH] Simplify index and model name swap logic (#2188) --- ...5951_remove__dim_suffix_from_model_name.py | 31 +++++++++++++++++++ backend/danswer/db/embedding_model.py | 2 +- backend/danswer/indexing/models.py | 5 +-- .../danswer/server/manage/search_settings.py | 12 +++++-- backend/model_server/encoders.py | 14 +++------ 5 files changed, 48 insertions(+), 16 deletions(-) create mode 100644 backend/alembic/versions/d9ec13955951_remove__dim_suffix_from_model_name.py diff --git a/backend/alembic/versions/d9ec13955951_remove__dim_suffix_from_model_name.py b/backend/alembic/versions/d9ec13955951_remove__dim_suffix_from_model_name.py new file mode 100644 index 0000000000..0e84d5fe85 --- /dev/null +++ b/backend/alembic/versions/d9ec13955951_remove__dim_suffix_from_model_name.py @@ -0,0 +1,31 @@ +"""Remove _alt suffix from model_name + +Revision ID: d9ec13955951 +Revises: da4c21c69164 +Create Date: 2024-08-20 16:31:32.955686 + +""" + +from alembic import op + + +# revision identifiers, used by Alembic. +revision = "d9ec13955951" +down_revision = "da4c21c69164" +branch_labels: None = None +depends_on: None = None + + +def upgrade() -> None: + op.execute( + """ + UPDATE embedding_model + SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '') + WHERE model_name LIKE '%__danswer_alt_index' + """ + ) + + +def downgrade() -> None: + # We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty + pass diff --git a/backend/danswer/db/embedding_model.py b/backend/danswer/db/embedding_model.py index 1709286108..1af6b7b7ef 100644 --- a/backend/danswer/db/embedding_model.py +++ b/backend/danswer/db/embedding_model.py @@ -39,7 +39,7 @@ def create_embedding_model( cloud_provider_id=model_details.cloud_provider_id, # Every single embedding model except the initial one from migrations has this name # The initial one from migration is called "danswer_chunk" - index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}", + index_name=model_details.index_name, ) db_session.add(embedding_model) diff --git a/backend/danswer/indexing/models.py b/backend/danswer/indexing/models.py index c70176b8f9..6357056eab 100644 --- a/backend/danswer/indexing/models.py +++ b/backend/danswer/indexing/models.py @@ -5,7 +5,6 @@ from pydantic import BaseModel from danswer.access.models import DocumentAccess from danswer.connectors.models import Document from danswer.utils.logger import setup_logger -from shared_configs.configs import ALT_INDEX_SUFFIX from shared_configs.model_server_models import Embedding if TYPE_CHECKING: @@ -102,6 +101,7 @@ class EmbeddingModelDetail(BaseModel): passage_prefix: str | None cloud_provider_id: int | None = None cloud_provider_name: str | None = None + index_name: str | None = None @classmethod def from_model( @@ -111,10 +111,11 @@ class EmbeddingModelDetail(BaseModel): return cls( # When constructing EmbeddingModel Detail for user-facing flows, strip the # unneeded additional data after the `_`s - model_name=embedding_model.model_name.removesuffix(ALT_INDEX_SUFFIX), + model_name=embedding_model.model_name, model_dim=embedding_model.model_dim, normalize=embedding_model.normalize, query_prefix=embedding_model.query_prefix, passage_prefix=embedding_model.passage_prefix, cloud_provider_id=embedding_model.cloud_provider_id, + index_name=embedding_model.index_name, ) diff --git a/backend/danswer/server/manage/search_settings.py b/backend/danswer/server/manage/search_settings.py index 39356521d2..14a612e516 100644 --- a/backend/danswer/server/manage/search_settings.py +++ b/backend/danswer/server/manage/search_settings.py @@ -20,6 +20,7 @@ from danswer.db.models import IndexModelStatus from danswer.db.models import User from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import EmbeddingModelDetail +from danswer.natural_language_processing.search_nlp_models import clean_model_name from danswer.search.models import SavedSearchSettings from danswer.search.search_settings import get_search_settings from danswer.search.search_settings import update_search_settings @@ -56,10 +57,15 @@ def set_new_embedding_model( embed_model_details.cloud_provider_id = cloud_id + embed_model_details.index_name = ( + f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}" + ) # account for same model name being indexed with two different configurations - if embed_model_details.model_name == current_model.model_name: - if not current_model.model_name.endswith(ALT_INDEX_SUFFIX): - embed_model_details.model_name += ALT_INDEX_SUFFIX + if ( + embed_model_details.model_name == current_model.model_name + and not current_model.index_name.endswith(ALT_INDEX_SUFFIX) + ): + embed_model_details.index_name += ALT_INDEX_SUFFIX secondary_model = get_secondary_db_embedding_model(db_session) diff --git a/backend/model_server/encoders.py b/backend/model_server/encoders.py index 13e79aece1..4e97bd00f2 100644 --- a/backend/model_server/encoders.py +++ b/backend/model_server/encoders.py @@ -23,7 +23,6 @@ from model_server.constants import DEFAULT_VOYAGE_MODEL from model_server.constants import EmbeddingModelTextType from model_server.constants import EmbeddingProvider from model_server.utils import simple_log_function_time -from shared_configs.configs import ALT_INDEX_SUFFIX from shared_configs.configs import INDEXING_ONLY from shared_configs.enums import EmbedTextType from shared_configs.enums import RerankerProvider @@ -250,11 +249,6 @@ def embed_text( if not all(texts): raise ValueError("Empty strings are not allowed for embedding.") - # strip additional metadata from model name right before constructing embedding requests - stripped_model_name = ( - model_name.removesuffix(ALT_INDEX_SUFFIX) if model_name else None - ) - # Third party API based embedding model if not texts: raise ValueError("No texts provided for embedding.") @@ -272,11 +266,11 @@ def embed_text( ) cloud_model = CloudEmbedding( - api_key=api_key, provider=provider_type, model=stripped_model_name + api_key=api_key, provider=provider_type, model=model_name ) embeddings = cloud_model.embed( texts=texts, - model_name=stripped_model_name, + model_name=model_name, text_type=text_type, ) @@ -287,11 +281,11 @@ def embed_text( error_message += "\n".join(texts) raise ValueError(error_message) - elif stripped_model_name is not None: + elif model_name is not None: prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts local_model = get_embedding_model( - model_name=stripped_model_name, max_context_length=max_context_length + model_name=model_name, max_context_length=max_context_length ) embeddings_vectors = local_model.encode( prefixed_texts, normalize_embeddings=normalize_embeddings