mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-30 01:30:45 +02:00
Simplify index and model name swap logic (#2188)
This commit is contained in:
@ -0,0 +1,31 @@
|
|||||||
|
"""Remove _alt suffix from model_name
|
||||||
|
|
||||||
|
Revision ID: d9ec13955951
|
||||||
|
Revises: da4c21c69164
|
||||||
|
Create Date: 2024-08-20 16:31:32.955686
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "d9ec13955951"
|
||||||
|
down_revision = "da4c21c69164"
|
||||||
|
branch_labels: None = None
|
||||||
|
depends_on: None = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
UPDATE embedding_model
|
||||||
|
SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '')
|
||||||
|
WHERE model_name LIKE '%__danswer_alt_index'
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty
|
||||||
|
pass
|
@ -39,7 +39,7 @@ def create_embedding_model(
|
|||||||
cloud_provider_id=model_details.cloud_provider_id,
|
cloud_provider_id=model_details.cloud_provider_id,
|
||||||
# Every single embedding model except the initial one from migrations has this name
|
# Every single embedding model except the initial one from migrations has this name
|
||||||
# The initial one from migration is called "danswer_chunk"
|
# The initial one from migration is called "danswer_chunk"
|
||||||
index_name=f"danswer_chunk_{clean_model_name(model_details.model_name)}",
|
index_name=model_details.index_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
db_session.add(embedding_model)
|
db_session.add(embedding_model)
|
||||||
|
@ -5,7 +5,6 @@ from pydantic import BaseModel
|
|||||||
from danswer.access.models import DocumentAccess
|
from danswer.access.models import DocumentAccess
|
||||||
from danswer.connectors.models import Document
|
from danswer.connectors.models import Document
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
|
||||||
from shared_configs.model_server_models import Embedding
|
from shared_configs.model_server_models import Embedding
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -102,6 +101,7 @@ class EmbeddingModelDetail(BaseModel):
|
|||||||
passage_prefix: str | None
|
passage_prefix: str | None
|
||||||
cloud_provider_id: int | None = None
|
cloud_provider_id: int | None = None
|
||||||
cloud_provider_name: str | None = None
|
cloud_provider_name: str | None = None
|
||||||
|
index_name: str | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_model(
|
def from_model(
|
||||||
@ -111,10 +111,11 @@ class EmbeddingModelDetail(BaseModel):
|
|||||||
return cls(
|
return cls(
|
||||||
# When constructing EmbeddingModel Detail for user-facing flows, strip the
|
# When constructing EmbeddingModel Detail for user-facing flows, strip the
|
||||||
# unneeded additional data after the `_`s
|
# unneeded additional data after the `_`s
|
||||||
model_name=embedding_model.model_name.removesuffix(ALT_INDEX_SUFFIX),
|
model_name=embedding_model.model_name,
|
||||||
model_dim=embedding_model.model_dim,
|
model_dim=embedding_model.model_dim,
|
||||||
normalize=embedding_model.normalize,
|
normalize=embedding_model.normalize,
|
||||||
query_prefix=embedding_model.query_prefix,
|
query_prefix=embedding_model.query_prefix,
|
||||||
passage_prefix=embedding_model.passage_prefix,
|
passage_prefix=embedding_model.passage_prefix,
|
||||||
cloud_provider_id=embedding_model.cloud_provider_id,
|
cloud_provider_id=embedding_model.cloud_provider_id,
|
||||||
|
index_name=embedding_model.index_name,
|
||||||
)
|
)
|
||||||
|
@ -20,6 +20,7 @@ from danswer.db.models import IndexModelStatus
|
|||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.indexing.models import EmbeddingModelDetail
|
from danswer.indexing.models import EmbeddingModelDetail
|
||||||
|
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||||
from danswer.search.models import SavedSearchSettings
|
from danswer.search.models import SavedSearchSettings
|
||||||
from danswer.search.search_settings import get_search_settings
|
from danswer.search.search_settings import get_search_settings
|
||||||
from danswer.search.search_settings import update_search_settings
|
from danswer.search.search_settings import update_search_settings
|
||||||
@ -56,10 +57,15 @@ def set_new_embedding_model(
|
|||||||
|
|
||||||
embed_model_details.cloud_provider_id = cloud_id
|
embed_model_details.cloud_provider_id = cloud_id
|
||||||
|
|
||||||
|
embed_model_details.index_name = (
|
||||||
|
f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}"
|
||||||
|
)
|
||||||
# account for same model name being indexed with two different configurations
|
# account for same model name being indexed with two different configurations
|
||||||
if embed_model_details.model_name == current_model.model_name:
|
if (
|
||||||
if not current_model.model_name.endswith(ALT_INDEX_SUFFIX):
|
embed_model_details.model_name == current_model.model_name
|
||||||
embed_model_details.model_name += ALT_INDEX_SUFFIX
|
and not current_model.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||||
|
):
|
||||||
|
embed_model_details.index_name += ALT_INDEX_SUFFIX
|
||||||
|
|
||||||
secondary_model = get_secondary_db_embedding_model(db_session)
|
secondary_model = get_secondary_db_embedding_model(db_session)
|
||||||
|
|
||||||
|
@ -23,7 +23,6 @@ from model_server.constants import DEFAULT_VOYAGE_MODEL
|
|||||||
from model_server.constants import EmbeddingModelTextType
|
from model_server.constants import EmbeddingModelTextType
|
||||||
from model_server.constants import EmbeddingProvider
|
from model_server.constants import EmbeddingProvider
|
||||||
from model_server.utils import simple_log_function_time
|
from model_server.utils import simple_log_function_time
|
||||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
|
||||||
from shared_configs.configs import INDEXING_ONLY
|
from shared_configs.configs import INDEXING_ONLY
|
||||||
from shared_configs.enums import EmbedTextType
|
from shared_configs.enums import EmbedTextType
|
||||||
from shared_configs.enums import RerankerProvider
|
from shared_configs.enums import RerankerProvider
|
||||||
@ -250,11 +249,6 @@ def embed_text(
|
|||||||
if not all(texts):
|
if not all(texts):
|
||||||
raise ValueError("Empty strings are not allowed for embedding.")
|
raise ValueError("Empty strings are not allowed for embedding.")
|
||||||
|
|
||||||
# strip additional metadata from model name right before constructing embedding requests
|
|
||||||
stripped_model_name = (
|
|
||||||
model_name.removesuffix(ALT_INDEX_SUFFIX) if model_name else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# Third party API based embedding model
|
# Third party API based embedding model
|
||||||
if not texts:
|
if not texts:
|
||||||
raise ValueError("No texts provided for embedding.")
|
raise ValueError("No texts provided for embedding.")
|
||||||
@ -272,11 +266,11 @@ def embed_text(
|
|||||||
)
|
)
|
||||||
|
|
||||||
cloud_model = CloudEmbedding(
|
cloud_model = CloudEmbedding(
|
||||||
api_key=api_key, provider=provider_type, model=stripped_model_name
|
api_key=api_key, provider=provider_type, model=model_name
|
||||||
)
|
)
|
||||||
embeddings = cloud_model.embed(
|
embeddings = cloud_model.embed(
|
||||||
texts=texts,
|
texts=texts,
|
||||||
model_name=stripped_model_name,
|
model_name=model_name,
|
||||||
text_type=text_type,
|
text_type=text_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -287,11 +281,11 @@ def embed_text(
|
|||||||
error_message += "\n".join(texts)
|
error_message += "\n".join(texts)
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
elif stripped_model_name is not None:
|
elif model_name is not None:
|
||||||
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts
|
||||||
|
|
||||||
local_model = get_embedding_model(
|
local_model = get_embedding_model(
|
||||||
model_name=stripped_model_name, max_context_length=max_context_length
|
model_name=model_name, max_context_length=max_context_length
|
||||||
)
|
)
|
||||||
embeddings_vectors = local_model.encode(
|
embeddings_vectors = local_model.encode(
|
||||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||||
|
Reference in New Issue
Block a user