danswer/backend/onyx/db/search_settings.py
Chris Weaver f25e1e80f6
Add option to not re-index (#4157)
* Add option to not re-index

* Add quantizaton / dimensionality override support

* Fix build / ut
2025-03-03 10:54:11 -08:00

338 lines
12 KiB
Python

from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.configs.model_configs import ASYM_PASSAGE_PREFIX
from onyx.configs.model_configs import ASYM_QUERY_PREFIX
from onyx.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.configs.model_configs import DOC_EMBEDDING_DIM
from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL
from onyx.configs.model_configs import NORMALIZE_EMBEDDINGS
from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_embedding_provider
from onyx.db.models import CloudEmbeddingProvider
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.indexing.models import IndexingSetting
from onyx.natural_language_processing.search_nlp_models import clean_model_name
from onyx.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from onyx.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from onyx.utils.logger import setup_logger
from shared_configs.configs import PRESERVED_SEARCH_FIELDS
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
class ActiveSearchSettings:
primary: SearchSettings
secondary: SearchSettings | None
def __init__(
self, primary: SearchSettings, secondary: SearchSettings | None
) -> None:
self.primary = primary
self.secondary = secondary
def create_search_settings(
search_settings: SavedSearchSettings,
db_session: Session,
status: IndexModelStatus = IndexModelStatus.FUTURE,
) -> SearchSettings:
embedding_model = SearchSettings(
model_name=search_settings.model_name,
model_dim=search_settings.model_dim,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
status=status,
index_name=search_settings.index_name,
provider_type=search_settings.provider_type,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
background_reindex_enabled=search_settings.background_reindex_enabled,
)
db_session.add(embedding_model)
db_session.commit()
return embedding_model
def get_embedding_provider_from_provider_type(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProvider | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.provider_type == provider_type
)
provider = db_session.execute(query).scalars().first()
return provider if provider else None
def get_current_db_embedding_provider(
db_session: Session,
) -> ServerCloudEmbeddingProvider | None:
search_settings = get_current_search_settings(db_session=db_session)
if search_settings.provider_type is None:
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session,
provider_type=search_settings.provider_type,
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
cloud_provider_model=embedding_provider
)
return current_embedding_provider
def delete_search_settings(db_session: Session, search_settings_id: int) -> None:
current_settings = get_current_search_settings(db_session)
if current_settings.id == search_settings_id:
raise ValueError("Cannot delete currently active search settings")
# First, delete associated index attempts
index_attempts_query = delete(IndexAttempt).where(
IndexAttempt.search_settings_id == search_settings_id
)
db_session.execute(index_attempts_query)
# Then, delete the search settings
search_settings_query = delete(SearchSettings).where(
and_(
SearchSettings.id == search_settings_id,
SearchSettings.status != IndexModelStatus.PRESENT,
)
)
db_session.execute(search_settings_query)
db_session.commit()
def get_current_search_settings(db_session: Session) -> SearchSettings:
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.PRESENT)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
latest_settings = result.scalars().first()
if not latest_settings:
raise RuntimeError("No search settings specified, DB is not in a valid state")
return latest_settings
def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.FUTURE)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
latest_settings = result.scalars().first()
return latest_settings
def get_active_search_settings(db_session: Session) -> ActiveSearchSettings:
"""Returns active search settings. Secondary search settings may be None."""
# Get the primary and secondary search settings
primary_search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
return ActiveSearchSettings(
primary=primary_search_settings, secondary=secondary_search_settings
)
def get_active_search_settings_list(db_session: Session) -> list[SearchSettings]:
"""Returns active search settings as a list. Primary settings are the first element,
and if secondary search settings exist, they will be the second element."""
search_settings_list: list[SearchSettings] = []
active_search_settings = get_active_search_settings(db_session)
search_settings_list.append(active_search_settings.primary)
if active_search_settings.secondary:
search_settings_list.append(active_search_settings.secondary)
return search_settings_list
def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
query = select(SearchSettings).order_by(SearchSettings.id.desc())
result = db_session.execute(query)
all_settings = result.scalars().all()
return list(all_settings)
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_current_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)
if not search_settings:
return []
return search_settings.multilingual_expansion
def update_search_settings(
current_settings: SearchSettings,
updated_settings: SavedSearchSettings,
preserved_fields: list[str],
) -> None:
for field, value in updated_settings.dict().items():
if field not in preserved_fields:
setattr(current_settings, field, value)
def update_current_search_settings(
db_session: Session,
search_settings: SavedSearchSettings,
preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS,
) -> None:
current_settings = get_current_search_settings(db_session)
if not current_settings:
logger.warning("No current search settings found to update")
return
# Whenever we update the current search settings, we should ensure that the local reranking model is warmed up.
if (
search_settings.rerank_provider_type is None
and search_settings.rerank_model_name is not None
and current_settings.rerank_model_name != search_settings.rerank_model_name
):
warm_up_cross_encoder(search_settings.rerank_model_name)
update_search_settings(current_settings, search_settings, preserved_fields)
db_session.commit()
logger.info("Current search settings updated successfully")
def update_secondary_search_settings(
db_session: Session,
search_settings: SavedSearchSettings,
preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS,
) -> None:
secondary_settings = get_secondary_search_settings(db_session)
if not secondary_settings:
logger.warning("No secondary search settings found to update")
return
preserved_fields = PRESERVED_SEARCH_FIELDS
update_search_settings(secondary_settings, search_settings, preserved_fields)
db_session.commit()
logger.info("Secondary search settings updated successfully")
def update_search_settings_status(
search_settings: SearchSettings, new_status: IndexModelStatus, db_session: Session
) -> None:
search_settings.status = new_status
db_session.commit()
def user_has_overridden_embedding_model() -> bool:
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
def get_old_default_search_settings() -> SearchSettings:
is_overridden = user_has_overridden_embedding_model()
return SearchSettings(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
status=IndexModelStatus.PRESENT,
index_name="danswer_chunk",
)
def get_new_default_search_settings(is_present: bool) -> SearchSettings:
return SearchSettings(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
)
def get_old_default_embedding_model() -> IndexingSetting:
is_overridden = user_has_overridden_embedding_model()
return IndexingSetting(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
api_url=None,
)
def get_new_default_embedding_model() -> IndexingSetting:
return IndexingSetting(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
api_url=None,
)