mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-28 12:39:54 +02:00
Fix embedding model migration with existing index_attempts
This commit is contained in:
parent
4eaf2b1200
commit
e246ea9d3b
@ -7,7 +7,13 @@ Create Date: 2024-01-25 17:12:31.813160
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.embedding_model import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
)
|
||||
from danswer.db.models import IndexModelStatus
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@ -34,6 +40,60 @@ def upgrade() -> None:
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# since all index attempts must be associated with an embedding model,
|
||||
# need to put something in here to avoid nulls. On server startup,
|
||||
# this value will be overriden
|
||||
EmbeddingModel = table(
|
||||
"embedding_model",
|
||||
column("id", Integer),
|
||||
column("model_name", String),
|
||||
column("model_dim", Integer),
|
||||
column("normalize", Boolean),
|
||||
column("query_prefix", String),
|
||||
column("passage_prefix", String),
|
||||
column("index_name", String),
|
||||
column(
|
||||
"status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False)
|
||||
),
|
||||
)
|
||||
# insert an embedding model row that corresponds to the embedding model
|
||||
# the user selected via env variables before this change. This is needed since
|
||||
# all index_attempts must be associated with an embedding model, so without this
|
||||
# we will run into violations of non-null contraints
|
||||
old_embedding_model = get_old_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
{
|
||||
"model_name": old_embedding_model.model_name,
|
||||
"model_dim": old_embedding_model.model_dim,
|
||||
"normalize": old_embedding_model.normalize,
|
||||
"query_prefix": old_embedding_model.query_prefix,
|
||||
"passage_prefix": old_embedding_model.passage_prefix,
|
||||
"index_name": old_embedding_model.index_name,
|
||||
"status": old_embedding_model.status,
|
||||
}
|
||||
],
|
||||
)
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model(is_present=False)
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
{
|
||||
"model_name": new_embedding_model.model_name,
|
||||
"model_dim": new_embedding_model.model_dim,
|
||||
"normalize": new_embedding_model.normalize,
|
||||
"query_prefix": new_embedding_model.query_prefix,
|
||||
"passage_prefix": new_embedding_model.passage_prefix,
|
||||
"index_name": new_embedding_model.index_name,
|
||||
"status": IndexModelStatus.FUTURE,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
op.add_column(
|
||||
"index_attempt",
|
||||
sa.Column("embedding_model_id", sa.Integer(), nullable=True),
|
||||
|
@ -10,7 +10,6 @@ from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
@ -77,53 +76,40 @@ def update_embedding_model_status(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def insert_initial_embedding_models(db_session: Session) -> None:
|
||||
"""Should be called on startup to ensure that the initial
|
||||
embedding model is present in the DB."""
|
||||
existing_embedding_models = db_session.scalars(select(EmbeddingModel)).all()
|
||||
if existing_embedding_models:
|
||||
logger.error(
|
||||
"Called `insert_initial_embedding_models` but models already exist in the DB. Skipping."
|
||||
)
|
||||
return
|
||||
def user_has_overridden_embedding_model() -> bool:
|
||||
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
|
||||
existing_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
|
||||
# if the user is overriding the `DOCUMENT_ENCODER_MODEL`, then
|
||||
# allow them to continue to use that model and do nothing fancy
|
||||
# in the background OR if the user has no connectors, then we can
|
||||
# also just use the new model immediately
|
||||
can_skip_upgrade = (
|
||||
DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
or not existing_cc_pairs
|
||||
def get_old_default_embedding_model() -> EmbeddingModel:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return EmbeddingModel(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
status=IndexModelStatus.PRESENT,
|
||||
index_name="danswer_chunk",
|
||||
)
|
||||
|
||||
# if we need to automatically upgrade the user, then create
|
||||
# an entry which will automatically be replaced by the
|
||||
# below desired model
|
||||
if not can_skip_upgrade:
|
||||
embedding_model_to_upgrade = EmbeddingModel(
|
||||
model_name=OLD_DEFAULT_DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM,
|
||||
normalize=OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS,
|
||||
query_prefix="",
|
||||
passage_prefix="",
|
||||
status=IndexModelStatus.PRESENT,
|
||||
index_name="danswer_chunk",
|
||||
)
|
||||
db_session.add(embedding_model_to_upgrade)
|
||||
|
||||
desired_embedding_model = EmbeddingModel(
|
||||
def get_new_default_embedding_model(is_present: bool) -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
status=IndexModelStatus.PRESENT
|
||||
if can_skip_upgrade
|
||||
else IndexModelStatus.FUTURE,
|
||||
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
)
|
||||
db_session.add(desired_embedding_model)
|
||||
|
||||
db_session.commit()
|
||||
|
@ -47,7 +47,6 @@ from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import insert_initial_embedding_models
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
@ -252,13 +251,7 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
with Session(engine) as db_session:
|
||||
try:
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
except RuntimeError:
|
||||
logger.info("No embedding model's found in DB, creating initial model.")
|
||||
insert_initial_embedding_models(db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
|
Loading…
x
Reference in New Issue
Block a user