* k

* update enum imports

* add functional types + model swaps

* remove a log

* remove kv

* fully functional + robustified for kv swap

* validated with hosted + cloud

* ensure not updating current search settings when reindexing

* add instance check

* revert back to updating search settings (will need a slight refactor for endpoint)

* protect advanced config override1

* run pretty

* fix typing

* update typing

* remove unnecessary function

* update model name

* clearer interface names

* validated foreign key constaint

* proper migration

* squash

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
pablodanswer 2024-08-26 21:26:51 -07:00 committed by GitHub
parent 5f12b7ad58
commit 97ba71e1b3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
54 changed files with 1078 additions and 673 deletions

View File

@ -0,0 +1,135 @@
"""embedding model -> search settings
Revision ID: 1f60f60c3401
Revises: f17bf3b0d9f1
Create Date: 2024-08-25 12:39:51.731632
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
# revision identifiers, used by Alembic.
revision = "1f60f60c3401"
down_revision = "f17bf3b0d9f1"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.drop_constraint(
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
)
# Rename the table
op.rename_table("embedding_model", "search_settings")
# Add new columns
op.add_column(
"search_settings",
sa.Column(
"multipass_indexing", sa.Boolean(), nullable=False, server_default="true"
),
)
op.add_column(
"search_settings",
sa.Column(
"multilingual_expansion",
postgresql.ARRAY(sa.String()),
nullable=False,
server_default="{}",
),
)
op.add_column(
"search_settings",
sa.Column(
"disable_rerank_for_streaming",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
)
op.add_column(
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
)
op.add_column(
"search_settings",
sa.Column(
"num_rerank",
sa.Integer(),
nullable=False,
server_default=str(NUM_POSTPROCESSED_RESULTS),
),
)
# Add the new column as nullable initially
op.add_column(
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
)
# Populate the new column with data from the existing embedding_model_id
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
# Create the foreign key constraint
op.create_foreign_key(
"fk_index_attempt_search_settings",
"index_attempt",
"search_settings",
["search_settings_id"],
["id"],
)
# Make the new column non-nullable
op.alter_column("index_attempt", "search_settings_id", nullable=False)
# Drop the old embedding_model_id column
op.drop_column("index_attempt", "embedding_model_id")
def downgrade() -> None:
# Add back the embedding_model_id column
op.add_column(
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
)
# Populate the old column with data from search_settings_id
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
# Make the old column non-nullable
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
# Drop the foreign key constraint
op.drop_constraint(
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
)
# Drop the new search_settings_id column
op.drop_column("index_attempt", "search_settings_id")
# Rename the table back
op.rename_table("search_settings", "embedding_model")
# Remove added columns
op.drop_column("embedding_model", "num_rerank")
op.drop_column("embedding_model", "rerank_api_key")
op.drop_column("embedding_model", "rerank_provider_type")
op.drop_column("embedding_model", "rerank_model_name")
op.drop_column("embedding_model", "disable_rerank_for_streaming")
op.drop_column("embedding_model", "multilingual_expansion")
op.drop_column("embedding_model", "multipass_indexing")
op.create_foreign_key(
"index_attempt__embedding_model_fk",
"index_attempt",
"embedding_model",
["embedding_model_id"],
["id"],
)

View File

@ -11,8 +11,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "351faebd379d"
down_revision = "ee3f4b47fad5"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -10,7 +10,7 @@ import sqlalchemy as sa
from danswer.db.models import IndexModelStatus
from danswer.search.enums import RecencyBiasSetting
from danswer.search.models import SearchType
from danswer.search.enums import SearchType
# revision identifiers, used by Alembic.
revision = "776b3bbe9092"

View File

@ -9,7 +9,7 @@ from alembic import op
import sqlalchemy as sa
from sqlalchemy import table, column, String, Integer, Boolean
from danswer.db.embedding_model import (
from danswer.db.search_settings import (
get_new_default_embedding_model,
get_old_default_embedding_model,
user_has_overridden_embedding_model,
@ -71,14 +71,14 @@ def upgrade() -> None:
"query_prefix": old_embedding_model.query_prefix,
"passage_prefix": old_embedding_model.passage_prefix,
"index_name": old_embedding_model.index_name,
"status": old_embedding_model.status,
"status": IndexModelStatus.PRESENT,
}
],
)
# if the user has not overridden the default embedding model via env variables,
# insert the new default model into the database to auto-upgrade them
if not user_has_overridden_embedding_model():
new_embedding_model = get_new_default_embedding_model(is_present=False)
new_embedding_model = get_new_default_embedding_model()
op.bulk_insert(
EmbeddingModel,
[

View File

@ -13,8 +13,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ee3f4b47fad5"
down_revision = "2d2304e27d8c"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -13,8 +13,8 @@ import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "f17bf3b0d9f1"
down_revision = "351faebd379d"
branch_labels = None
depends_on = None
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:

View File

@ -90,20 +90,20 @@ def _run_indexing(
"""
start_time = time.time()
db_embedding_model = index_attempt.embedding_model
index_name = db_embedding_model.index_name
search_settings = index_attempt.search_settings
index_name = search_settings.index_name
# Only update cc-pair status for primary index jobs
# Secondary index syncs at the end when swapping
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
is_primary = search_settings.status == IndexModelStatus.PRESENT
# Indexing is only done into one index at a time
document_index = get_default_document_index(
primary_index_name=index_name, secondary_index_name=None
)
embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
db_embedding_model
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)
indexing_pipeline = build_indexing_pipeline(
@ -111,7 +111,7 @@ def _run_indexing(
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=index_attempt.from_beginning
or (db_embedding_model.status == IndexModelStatus.FUTURE),
or (search_settings.status == IndexModelStatus.FUTURE),
db_session=db_session,
)
@ -128,7 +128,7 @@ def _run_indexing(
else get_last_successful_attempt_time(
connector_id=db_connector.id,
credential_id=db_credential.id,
embedding_model=index_attempt.embedding_model,
search_settings=index_attempt.search_settings,
db_session=db_session,
)
)
@ -185,7 +185,7 @@ def _run_indexing(
if (
(
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
and db_embedding_model.status != IndexModelStatus.FUTURE
and search_settings.status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING

View File

@ -20,8 +20,6 @@ from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
@ -32,11 +30,14 @@ from danswer.db.index_attempt import get_last_attempt_for_cc_pair
from danswer.db.index_attempt import get_not_started_index_attempts
from danswer.db.index_attempt import mark_attempt_failed
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@ -60,7 +61,7 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
def _should_create_new_indexing(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
model: EmbeddingModel,
search_settings_instance: SearchSettings,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@ -69,11 +70,14 @@ def _should_create_new_indexing(
# User can still manually create single indexing attempts via the UI for the
# currently in use index
if DISABLE_INDEX_UPDATE_ON_SWAP:
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
if (
search_settings_instance.status == IndexModelStatus.PRESENT
and secondary_index_building
):
return False
# When switching over models, always index at least once
if model.status == IndexModelStatus.FUTURE:
if search_settings_instance.status == IndexModelStatus.FUTURE:
if last_index:
# No new index if the last index attempt succeeded
# Once is enough. The model will never be able to swap otherwise.
@ -160,35 +164,42 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
ongoing.add(
(
attempt.connector_credential_pair_id,
attempt.embedding_model_id,
attempt.search_settings_id,
)
)
embedding_models = [get_current_db_embedding_model(db_session)]
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
if secondary_embedding_model is not None:
embedding_models.append(secondary_embedding_model)
# Get the primary search settings
primary_search_settings = get_current_search_settings(db_session)
search_settings = [primary_search_settings]
# Check for secondary search settings
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings is not None:
# If secondary settings exist, add them to the list
search_settings.append(secondary_search_settings)
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair in all_connector_credential_pairs:
for model in embedding_models:
for search_settings_instance in search_settings:
# Check if there is an ongoing indexing attempt for this connector credential pair
if (cc_pair.id, model.id) in ongoing:
if (cc_pair.id, search_settings_instance.id) in ongoing:
continue
last_attempt = get_last_attempt_for_cc_pair(
cc_pair.id, model.id, db_session
cc_pair.id, search_settings_instance.id, db_session
)
if not _should_create_new_indexing(
cc_pair=cc_pair,
last_index=last_attempt,
model=model,
secondary_index_building=len(embedding_models) > 1,
search_settings_instance=search_settings_instance,
secondary_index_building=len(search_settings) > 1,
db_session=db_session,
):
continue
create_index_attempt(cc_pair.id, model.id, db_session)
create_index_attempt(
cc_pair.id, search_settings_instance.id, db_session
)
def cleanup_indexing_jobs(
@ -285,7 +296,7 @@ def kickoff_indexing_jobs(
# get_not_started_index_attempts orders its returned results from oldest to newest
# we must process attempts in a FIFO manner to prevent connector starvation
new_indexing_attempts = [
(attempt, attempt.embedding_model)
(attempt, attempt.search_settings)
for attempt in get_not_started_index_attempts(db_session)
if attempt.id not in existing_jobs
]
@ -297,10 +308,10 @@ def kickoff_indexing_jobs(
indexing_attempt_count = 0
for attempt, embedding_model in new_indexing_attempts:
for attempt, search_settings in new_indexing_attempts:
use_secondary_index = (
embedding_model.status == IndexModelStatus.FUTURE
if embedding_model is not None
search_settings.status == IndexModelStatus.FUTURE
if search_settings is not None
else False
)
if attempt.connector_credential_pair.connector is None:
@ -373,17 +384,21 @@ def update_loop(
engine = get_sqlalchemy_engine()
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
if db_embedding_model.provider_type is None:
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=db_embedding_model,
model_server_host=INDEXING_MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
embedding_model=embedding_model,
)
client_primary: Client | SimpleJobClient

View File

@ -32,13 +32,13 @@ from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.llm import fetch_existing_llm_providers
from danswer.db.models import SearchDoc as DbSearchDoc
from danswer.db.models import ToolCall
from danswer.db.models import User
from danswer.db.persona import get_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.file_store.models import ChatFileType
from danswer.file_store.models import FileDescriptor
@ -331,9 +331,9 @@ def stream_chat_message_objects(
Callable[[str], list[int]], llm_tokenizer.encode
)
embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
primary_index_name=search_settings.index_name, secondary_index_name=None
)
# Every chat Session begins with an empty root message

View File

@ -37,6 +37,7 @@ from danswer.db.models import Persona
from danswer.db.models import SlackBotConfig
from danswer.db.models import SlackBotResponseType
from danswer.db.persona import fetch_persona_by_id
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
@ -48,8 +49,8 @@ from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
from danswer.search.enums import OptionalSearchSetting
from danswer.search.models import BaseFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import RetrievalDetails
from danswer.search.search_settings import get_search_settings
from danswer.utils.logger import DanswerLoggingAdapter
@ -228,7 +229,8 @@ def handle_regular_answer(
)
# Always apply reranking settings if it exists, this is the non-streaming flow
saved_search_settings = get_search_settings()
with Session(get_sqlalchemy_engine()) as db_session:
saved_search_settings = get_current_search_settings(db_session)
# This includes throwing out answer via reflexion
answer = _get_answer(
@ -241,7 +243,7 @@ def handle_regular_answer(
persona_id=persona.id if persona is not None else 0,
retrieval_options=retrieval_details,
chain_of_thought=not disable_cot,
rerank_settings=saved_search_settings.to_reranking_detail()
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
if saved_search_settings
else None,
)

View File

@ -45,9 +45,10 @@ from danswer.danswerbot.slack.utils import read_slack_thread
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
from danswer.danswerbot.slack.utils import rephrase_slack_message
from danswer.danswerbot.slack.utils import respond_in_thread
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.search_settings import get_current_search_settings
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.one_shot_answer.models import ThreadMessage
from danswer.search.retrieval.search_runner import download_nltk_data
@ -468,13 +469,16 @@ if __name__ == "__main__":
# This happens on the very first time the listener process comes up
# or the tokens have updated (set up for the first time)
with Session(get_sqlalchemy_engine()) as db_session:
embedding_model = get_current_db_embedding_model(db_session)
if embedding_model.provider_type is None:
warm_up_bi_encoder(
embedding_model=embedding_model,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
)
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(
embedding_model=embedding_model,
)
slack_bot_tokens = latest_slack_bot_tokens
# potentially may cause a message to be dropped, but it is complicated

View File

@ -14,10 +14,10 @@ from danswer.db.connector import fetch_connector_by_id
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.enums import ConnectorCredentialPairStatus
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.models import User
from danswer.db.models import User__UserGroup
from danswer.db.models import UserGroup__ConnectorCredentialPair
@ -159,12 +159,12 @@ def get_connector_credential_pair_from_id(
def get_last_successful_attempt_time(
connector_id: int,
credential_id: int,
embedding_model: EmbeddingModel,
search_settings: SearchSettings,
db_session: Session,
) -> float:
"""Gets the timestamp of the last successful index run stored in
the CC Pair row in the database"""
if embedding_model.status == IndexModelStatus.PRESENT:
if search_settings.status == IndexModelStatus.PRESENT:
connector_credential_pair = get_connector_credential_pair(
connector_id, credential_id, db_session
)
@ -186,7 +186,7 @@ def get_last_successful_attempt_time(
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model.id,
IndexAttempt.search_settings_id == search_settings.id,
IndexAttempt.status == IndexingStatus.SUCCESS,
)
.order_by(IndexAttempt.time_started.desc())
@ -445,11 +445,11 @@ def resync_cc_pair(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
EmbeddingModel.status == IndexModelStatus.PRESENT,
SearchSettings.status == IndexModelStatus.PRESENT,
)
)

View File

@ -1,9 +1,9 @@
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.index_attempt import get_last_attempt
from danswer.db.models import ConnectorCredentialPair
from danswer.db.models import IndexingStatus
from danswer.db.search_settings import get_current_search_settings
def check_deletion_attempt_is_allowed(
@ -28,12 +28,12 @@ def check_deletion_attempt_is_allowed(
connector_id = connector_credential_pair.connector_id
credential_id = connector_credential_pair.credential_id
current_embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
last_indexing = get_last_attempt(
connector_id=connector_id,
credential_id=credential_id,
embedding_model_id=current_embedding_model.id,
search_settings_id=search_settings.id,
db_session=db_session,
)

View File

@ -1,157 +0,0 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.indexing.models import EmbeddingModelCreateRequest
from danswer.indexing.models import EmbeddingModelDetail
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from danswer.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
def create_embedding_model(
create_embed_model_details: EmbeddingModelCreateRequest,
db_session: Session,
status: IndexModelStatus = IndexModelStatus.FUTURE,
) -> EmbeddingModel:
embedding_model = EmbeddingModel(
model_name=create_embed_model_details.model_name,
model_dim=create_embed_model_details.model_dim,
normalize=create_embed_model_details.normalize,
query_prefix=create_embed_model_details.query_prefix,
passage_prefix=create_embed_model_details.passage_prefix,
status=status,
provider_type=create_embed_model_details.provider_type,
# Every single embedding model except the initial one from migrations has this name
# The initial one from migration is called "danswer_chunk"
index_name=create_embed_model_details.index_name,
)
db_session.add(embedding_model)
db_session.commit()
return embedding_model
def get_embedding_provider_from_provider_type(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProvider | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.provider_type == provider_type
)
provider = db_session.execute(query).scalars().first()
return provider if provider else None
def get_current_db_embedding_provider(
db_session: Session,
) -> ServerCloudEmbeddingProvider | None:
current_embedding_model = EmbeddingModelDetail.from_model(
get_current_db_embedding_model(db_session=db_session)
)
if current_embedding_model is None or current_embedding_model.provider_type is None:
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session,
provider_type=current_embedding_model.provider_type,
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
cloud_provider_model=embedding_provider
)
return current_embedding_provider
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
query = (
select(EmbeddingModel)
.where(EmbeddingModel.status == IndexModelStatus.PRESENT)
.order_by(EmbeddingModel.id.desc())
)
result = db_session.execute(query)
latest_model = result.scalars().first()
if not latest_model:
raise RuntimeError("No embedding model selected, DB is not in a valid state")
return latest_model
def get_secondary_db_embedding_model(db_session: Session) -> EmbeddingModel | None:
query = (
select(EmbeddingModel)
.where(EmbeddingModel.status == IndexModelStatus.FUTURE)
.order_by(EmbeddingModel.id.desc())
)
result = db_session.execute(query)
latest_model = result.scalars().first()
return latest_model
def update_embedding_model_status(
embedding_model: EmbeddingModel, new_status: IndexModelStatus, db_session: Session
) -> None:
embedding_model.status = new_status
db_session.commit()
def user_has_overridden_embedding_model() -> bool:
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
def get_old_default_embedding_model() -> EmbeddingModel:
is_overridden = user_has_overridden_embedding_model()
return EmbeddingModel(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
status=IndexModelStatus.PRESENT,
index_name="danswer_chunk",
)
def get_new_default_embedding_model(is_present: bool) -> EmbeddingModel:
return EmbeddingModel(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
)

View File

@ -11,11 +11,11 @@ from sqlalchemy.orm import Session
from danswer.connectors.models import Document
from danswer.connectors.models import DocumentErrorSummary
from danswer.db.models import EmbeddingModel
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexAttemptError
from danswer.db.models import IndexingStatus
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.server.documents.models import ConnectorCredentialPair
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.utils.logger import setup_logger
@ -27,14 +27,14 @@ logger = setup_logger()
def get_last_attempt_for_cc_pair(
cc_pair_id: int,
embedding_model_id: int,
search_settings_id: int,
db_session: Session,
) -> IndexAttempt | None:
return (
db_session.query(IndexAttempt)
.filter(
IndexAttempt.connector_credential_pair_id == cc_pair_id,
IndexAttempt.embedding_model_id == embedding_model_id,
IndexAttempt.search_settings_id == search_settings_id,
)
.order_by(IndexAttempt.time_updated.desc())
.first()
@ -50,13 +50,13 @@ def get_index_attempt(
def create_index_attempt(
connector_credential_pair_id: int,
embedding_model_id: int,
search_settings_id: int,
db_session: Session,
from_beginning: bool = False,
) -> int:
new_attempt = IndexAttempt(
connector_credential_pair_id=connector_credential_pair_id,
embedding_model_id=embedding_model_id,
search_settings_id=search_settings_id,
from_beginning=from_beginning,
status=IndexingStatus.NOT_STARTED,
)
@ -162,7 +162,7 @@ def update_docs_indexed(
def get_last_attempt(
connector_id: int,
credential_id: int,
embedding_model_id: int | None,
search_settings_id: int | None,
db_session: Session,
) -> IndexAttempt | None:
stmt = (
@ -171,7 +171,7 @@ def get_last_attempt(
.where(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
IndexAttempt.embedding_model_id == embedding_model_id,
IndexAttempt.search_settings_id == search_settings_id,
)
)
@ -188,12 +188,12 @@ def get_latest_index_attempts(
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
if secondary_index:
ids_stmt = ids_stmt.where(EmbeddingModel.status == IndexModelStatus.FUTURE)
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(EmbeddingModel.status == IndexModelStatus.PRESENT)
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@ -229,8 +229,8 @@ def get_index_attempts_for_connector(
)
)
if only_current:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(IndexAttempt.time_created.desc())
@ -250,12 +250,12 @@ def get_latest_finished_index_attempt_for_cc_pair(
),
)
if secondary_index:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.FUTURE
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
@ -286,8 +286,8 @@ def get_index_attempts_for_cc_pair(
)
)
if only_current:
stmt = stmt.join(EmbeddingModel).where(
EmbeddingModel.status == IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(IndexAttempt.time_created.desc())
@ -309,19 +309,19 @@ def delete_index_attempts(
def expire_index_attempts(
embedding_model_id: int,
search_settings_id: int,
db_session: Session,
) -> None:
delete_query = (
delete(IndexAttempt)
.where(IndexAttempt.embedding_model_id == embedding_model_id)
.where(IndexAttempt.search_settings_id == search_settings_id)
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
)
db_session.execute(delete_query)
update_query = (
update(IndexAttempt)
.where(IndexAttempt.embedding_model_id == embedding_model_id)
.where(IndexAttempt.search_settings_id == search_settings_id)
.where(IndexAttempt.status != IndexingStatus.SUCCESS)
.values(
status=IndexingStatus.FAILED,
@ -345,10 +345,10 @@ def cancel_indexing_attempts_for_ccpair(
)
if not include_secondary_index:
subquery = select(EmbeddingModel.id).where(
EmbeddingModel.status != IndexModelStatus.FUTURE
subquery = select(SearchSettings.id).where(
SearchSettings.status != IndexModelStatus.FUTURE
)
stmt = stmt.where(IndexAttempt.embedding_model_id.in_(subquery))
stmt = stmt.where(IndexAttempt.search_settings_id.in_(subquery))
db_session.execute(stmt)
@ -366,8 +366,8 @@ def cancel_indexing_attempts_past_model(
IndexAttempt.status.in_(
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
),
IndexAttempt.embedding_model_id == EmbeddingModel.id,
EmbeddingModel.status == IndexModelStatus.PAST,
IndexAttempt.search_settings_id == SearchSettings.id,
SearchSettings.status == IndexModelStatus.PAST,
)
.values(status=IndexingStatus.FAILED)
)
@ -376,7 +376,7 @@ def cancel_indexing_attempts_past_model(
def count_unique_cc_pairs_with_successful_index_attempts(
embedding_model_id: int | None,
search_settings_id: int | None,
db_session: Session,
) -> int:
"""Collect all of the Index Attempts that are successful and for the specified embedding model
@ -386,7 +386,7 @@ def count_unique_cc_pairs_with_successful_index_attempts(
db_session.query(IndexAttempt.connector_credential_pair_id)
.join(ConnectorCredentialPair)
.filter(
IndexAttempt.embedding_model_id == embedding_model_id,
IndexAttempt.search_settings_id == search_settings_id,
IndexAttempt.status == IndexingStatus.SUCCESS,
)
.distinct()

View File

@ -34,6 +34,7 @@ from sqlalchemy.types import LargeBinary
from sqlalchemy.types import TypeDecorator
from danswer.auth.schemas import UserRole
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import FileOrigin
@ -56,6 +57,7 @@ from danswer.search.enums import RecencyBiasSetting
from danswer.utils.encryption import decrypt_bytes_to_string
from danswer.utils.encryption import encrypt_string_to_bytes
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import RerankerProvider
class Base(DeclarativeBase):
@ -543,33 +545,47 @@ class Credential(Base):
user: Mapped[User | None] = relationship("User", back_populates="credentials")
class EmbeddingModel(Base):
__tablename__ = "embedding_model"
class SearchSettings(Base):
__tablename__ = "search_settings"
id: Mapped[int] = mapped_column(primary_key=True)
model_name: Mapped[str] = mapped_column(String)
model_dim: Mapped[int] = mapped_column(Integer)
normalize: Mapped[bool] = mapped_column(Boolean)
query_prefix: Mapped[str] = mapped_column(String)
passage_prefix: Mapped[str] = mapped_column(String)
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
status: Mapped[IndexModelStatus] = mapped_column(
Enum(IndexModelStatus, native_enum=False)
)
index_name: Mapped[str] = mapped_column(String)
# New field for cloud provider relationship
provider_type: Mapped[EmbeddingProvider | None] = mapped_column(
ForeignKey("embedding_provider.provider_type"), nullable=True
)
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
multilingual_expansion: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), default=[]
)
# Reranking settings
disable_rerank_for_streaming: Mapped[bool] = mapped_column(Boolean, default=False)
rerank_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
rerank_provider_type: Mapped[RerankerProvider | None] = mapped_column(
Enum(RerankerProvider, native_enum=False), nullable=True
)
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
"CloudEmbeddingProvider",
back_populates="embedding_models",
back_populates="search_settings",
foreign_keys=[provider_type],
)
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
"IndexAttempt", back_populates="embedding_model"
"IndexAttempt", back_populates="search_settings"
)
__table_args__ = (
@ -628,8 +644,8 @@ class IndexAttempt(Base):
# only filled if status = "failed" AND an unhandled exception caused the failure
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
# Nullable because in the past, we didn't allow swapping out embedding models live
embedding_model_id: Mapped[int] = mapped_column(
ForeignKey("embedding_model.id"),
search_settings_id: Mapped[int] = mapped_column(
ForeignKey("search_settings.id"),
nullable=False,
)
time_created: Mapped[datetime.datetime] = mapped_column(
@ -651,8 +667,8 @@ class IndexAttempt(Base):
"ConnectorCredentialPair", back_populates="index_attempts"
)
embedding_model: Mapped[EmbeddingModel] = relationship(
"EmbeddingModel", back_populates="index_attempts"
search_settings: Mapped[SearchSettings] = relationship(
"SearchSettings", back_populates="index_attempts"
)
error_rows = relationship("IndexAttemptError", back_populates="index_attempt")
@ -1070,10 +1086,10 @@ class CloudEmbeddingProvider(Base):
Enum(EmbeddingProvider), primary_key=True
)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
"EmbeddingModel",
search_settings: Mapped[list["SearchSettings"]] = relationship(
"SearchSettings",
back_populates="cloud_provider",
foreign_keys="EmbeddingModel.provider_type",
foreign_keys="SearchSettings.provider_type",
)
def __repr__(self) -> str:

View File

@ -0,0 +1,249 @@
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.llm import fetch_embedding_provider
from danswer.db.models import CloudEmbeddingProvider
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.indexing.models import IndexingSetting
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
from danswer.server.manage.embedding.models import (
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
)
from danswer.utils.logger import setup_logger
from shared_configs.configs import PRESERVED_SEARCH_FIELDS
from shared_configs.enums import EmbeddingProvider
logger = setup_logger()
def create_search_settings(
search_settings: SavedSearchSettings,
db_session: Session,
status: IndexModelStatus = IndexModelStatus.FUTURE,
) -> SearchSettings:
embedding_model = SearchSettings(
model_name=search_settings.model_name,
model_dim=search_settings.model_dim,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
status=status,
index_name=search_settings.index_name,
provider_type=search_settings.provider_type,
multipass_indexing=search_settings.multipass_indexing,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
)
db_session.add(embedding_model)
db_session.commit()
return embedding_model
def get_embedding_provider_from_provider_type(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProvider | None:
query = select(CloudEmbeddingProvider).where(
CloudEmbeddingProvider.provider_type == provider_type
)
provider = db_session.execute(query).scalars().first()
return provider if provider else None
def get_current_db_embedding_provider(
db_session: Session,
) -> ServerCloudEmbeddingProvider | None:
search_settings = get_current_search_settings(db_session=db_session)
if search_settings.provider_type is None:
return None
embedding_provider = fetch_embedding_provider(
db_session=db_session,
provider_type=search_settings.provider_type,
)
if embedding_provider is None:
raise RuntimeError("No embedding provider exists for this model.")
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
cloud_provider_model=embedding_provider
)
return current_embedding_provider
def get_current_search_settings(db_session: Session) -> SearchSettings:
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.PRESENT)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
latest_settings = result.scalars().first()
if not latest_settings:
raise RuntimeError("No search settings specified, DB is not in a valid state")
return latest_settings
def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.FUTURE)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
latest_settings = result.scalars().first()
return latest_settings
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with Session(get_sqlalchemy_engine()) as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)
if not search_settings:
return []
return search_settings.multilingual_expansion
def update_search_settings(
current_settings: SearchSettings,
updated_settings: SavedSearchSettings,
preserved_fields: list[str],
) -> None:
for field, value in updated_settings.dict().items():
if field not in preserved_fields:
setattr(current_settings, field, value)
def update_current_search_settings(
db_session: Session,
search_settings: SavedSearchSettings,
preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS,
) -> None:
current_settings = get_current_search_settings(db_session)
if not current_settings:
logger.warning("No current search settings found to update")
return
update_search_settings(current_settings, search_settings, preserved_fields)
db_session.commit()
logger.info("Current search settings updated successfully")
def update_secondary_search_settings(
db_session: Session,
search_settings: SavedSearchSettings,
preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS,
) -> None:
secondary_settings = get_secondary_search_settings(db_session)
if not secondary_settings:
logger.warning("No secondary search settings found to update")
return
preserved_fields = PRESERVED_SEARCH_FIELDS
update_search_settings(secondary_settings, search_settings, preserved_fields)
db_session.commit()
logger.info("Secondary search settings updated successfully")
def update_search_settings_status(
search_settings: SearchSettings, new_status: IndexModelStatus, db_session: Session
) -> None:
search_settings.status = new_status
db_session.commit()
def user_has_overridden_embedding_model() -> bool:
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
def get_old_default_search_settings() -> SearchSettings:
is_overridden = user_has_overridden_embedding_model()
return SearchSettings(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
status=IndexModelStatus.PRESENT,
index_name="danswer_chunk",
)
def get_new_default_search_settings(is_present: bool) -> SearchSettings:
return SearchSettings(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
)
def get_old_default_embedding_model() -> IndexingSetting:
is_overridden = user_has_overridden_embedding_model()
return IndexingSetting(
model_name=(
DOCUMENT_ENCODER_MODEL
if is_overridden
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
),
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
),
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
)
def get_new_default_embedding_model() -> IndexingSetting:
return IndexingSetting(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
)

View File

@ -3,14 +3,14 @@ from sqlalchemy.orm import Session
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.enums import IndexModelStatus
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.utils.logger import setup_logger
@ -24,13 +24,13 @@ def check_index_swap(db_session: Session) -> None:
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
embedding_model = get_secondary_db_embedding_model(db_session)
search_settings = get_secondary_search_settings(db_session)
if not embedding_model:
if not search_settings:
return
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
embedding_model_id=embedding_model.id, db_session=db_session
search_settings_id=search_settings.id, db_session=db_session
)
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
@ -40,15 +40,15 @@ def check_index_swap(db_session: Session) -> None:
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
now_old_embedding_model = get_current_db_embedding_model(db_session)
update_embedding_model_status(
embedding_model=now_old_embedding_model,
now_old_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=now_old_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_embedding_model_status(
embedding_model=embedding_model,
update_search_settings_status(
search_settings=search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)

View File

@ -3,8 +3,8 @@ import uuid
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.models import IndexChunk
from danswer.search.models import InferenceChunk
@ -14,13 +14,13 @@ DEFAULT_INDEX_NAME = "danswer_chunk"
def get_both_index_names(db_session: Session) -> tuple[str, str | None]:
model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
model_new = get_secondary_db_embedding_model(db_session)
if not model_new:
return model.index_name, None
search_settings_new = get_secondary_search_settings(db_session)
if not search_settings_new:
return search_settings.index_name, None
return model.index_name, model_new.index_name
return search_settings.index_name, search_settings_new.index_name
def translate_boost_count_to_multiplier(boost: int) -> float:

View File

@ -3,10 +3,10 @@ from abc import abstractmethod
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.models import EmbeddingModel as DbEmbeddingModel
from danswer.db.models import IndexModelStatus
from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.indexing.models import ChunkEmbedding
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import IndexChunk
@ -169,37 +169,37 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
return embedded_chunks
@classmethod
def from_db_embedding_model(
cls, embedding_model: DbEmbeddingModel
def from_db_search_settings(
cls, search_settings: SearchSettings
) -> "DefaultIndexingEmbedder":
return cls(
model_name=embedding_model.model_name,
normalize=embedding_model.normalize,
query_prefix=embedding_model.query_prefix,
passage_prefix=embedding_model.passage_prefix,
provider_type=embedding_model.provider_type,
api_key=embedding_model.api_key,
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
)
def get_embedding_model_from_db_embedding_model(
def get_embedding_model_from_search_settings(
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
) -> IndexingEmbedder:
db_embedding_model: DbEmbeddingModel | None
search_settings: SearchSettings | None
if index_model_status == IndexModelStatus.PRESENT:
db_embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
elif index_model_status == IndexModelStatus.FUTURE:
db_embedding_model = get_secondary_db_embedding_model(db_session)
if not db_embedding_model:
search_settings = get_secondary_search_settings(db_session)
if not search_settings:
raise RuntimeError("No secondary index configured")
else:
raise RuntimeError("Not supporting embedding model rollbacks")
return DefaultIndexingEmbedder(
model_name=db_embedding_model.model_name,
normalize=db_embedding_model.normalize,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
provider_type=db_embedding_model.provider_type,
api_key=db_embedding_model.api_key,
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
)

View File

@ -21,6 +21,7 @@ from danswer.db.document import upsert_documents_complete
from danswer.db.document_set import fetch_document_sets_for_documents
from danswer.db.index_attempt import create_index_attempt_error
from danswer.db.models import Document as DBDocument
from danswer.db.search_settings import get_current_search_settings
from danswer.db.tag import create_or_add_document_tag
from danswer.db.tag import create_or_add_document_tag_list
from danswer.document_index.interfaces import DocumentIndex
@ -29,7 +30,6 @@ from danswer.indexing.chunker import Chunker
from danswer.indexing.embedder import IndexingEmbedder
from danswer.indexing.models import DocAwareChunk
from danswer.indexing.models import DocMetadataAwareIndexChunk
from danswer.search.search_settings import get_search_settings
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from shared_configs.enums import EmbeddingProvider
@ -360,7 +360,7 @@ def build_indexing_pipeline(
attempt_id: int | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_search_settings()
search_settings = get_current_search_settings(db_session)
multipass = (
search_settings.multipass_indexing
if search_settings

View File

@ -9,7 +9,7 @@ from shared_configs.enums import EmbeddingProvider
from shared_configs.model_server_models import Embedding
if TYPE_CHECKING:
from danswer.db.models import EmbeddingModel
from danswer.db.models import SearchSettings
logger = setup_logger()
@ -96,26 +96,42 @@ class DocMetadataAwareIndexChunk(IndexChunk):
class EmbeddingModelDetail(BaseModel):
model_name: str
model_dim: int
normalize: bool
query_prefix: str | None
passage_prefix: str | None
provider_type: EmbeddingProvider | None = None
api_key: str | None = None
@classmethod
def from_model(
def from_db_model(
cls,
embedding_model: "EmbeddingModel",
search_settings: "SearchSettings",
) -> "EmbeddingModelDetail":
return cls(
model_name=embedding_model.model_name,
model_dim=embedding_model.model_dim,
normalize=embedding_model.normalize,
query_prefix=embedding_model.query_prefix,
passage_prefix=embedding_model.passage_prefix,
provider_type=embedding_model.provider_type,
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
api_key=search_settings.api_key,
)
class EmbeddingModelCreateRequest(EmbeddingModelDetail):
index_name: str
# Additional info needed for indexing time
class IndexingSetting(EmbeddingModelDetail):
model_dim: int
index_name: str | None
multipass_indexing: bool
@classmethod
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
return cls(
model_name=search_settings.model_name,
model_dim=search_settings.model_dim,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
)

View File

@ -5,6 +5,7 @@ from danswer.chat.models import LlmDoc
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.db.search_settings import get_multilingual_expansion
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llms_for_persona
@ -28,19 +29,17 @@ from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
from danswer.search.models import InferenceChunk
from danswer.search.search_settings import get_multilingual_expansion
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
# Note: currently custom prompts do not allow datetime aware, only default prompts
multilingual_expansion = get_multilingual_expansion()
return (
check_number_of_tokens(prompt_config.system_prompt)
+ check_number_of_tokens(prompt_config.task_prompt)
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
+ CITATION_STATEMENT_TOKEN_CNT
+ CITATION_REMINDER_TOKEN_CNT
+ (LANGUAGE_HINT_TOKEN_CNT if multilingual_expansion else 0)
+ (LANGUAGE_HINT_TOKEN_CNT if get_multilingual_expansion() else 0)
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
)

View File

@ -3,6 +3,7 @@ from langchain.schema.messages import HumanMessage
from danswer.chat.models import LlmDoc
from danswer.configs.chat_configs import LANGUAGE_HINT
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.answering.models import PromptConfig
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
@ -11,7 +12,6 @@ from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
from danswer.prompts.prompt_utils import add_date_time_to_prompt
from danswer.prompts.prompt_utils import build_complete_context_str
from danswer.search.models import InferenceChunk
from danswer.search.search_settings import get_search_settings
def _build_weak_llm_quotes_prompt(
@ -48,10 +48,7 @@ def _build_strong_llm_quotes_prompt(
history_str: str,
prompt: PromptConfig,
) -> HumanMessage:
search_settings = get_search_settings()
use_language_hint = (
bool(search_settings.multilingual_expansion) if search_settings else False
)
use_language_hint = bool(get_multilingual_expansion())
context_block = ""
if context_docs:

View File

@ -27,16 +27,14 @@ from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import USER_AUTH_SECRET
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.constants import AuthType
from danswer.configs.constants import KV_REINDEX_KEY
from danswer.configs.constants import KV_SEARCH_SETTINGS
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
from danswer.db.connector import check_connectors_exist
from danswer.db.connector import create_initial_default_connector
@ -45,28 +43,29 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.document import check_docs_exist
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import init_sqlalchemy_engine
from danswer.db.engine import warm_up_connections
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.models import EmbeddingModel
from danswer.db.persona import delete_old_default_personas
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_secondary_search_settings
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.swap_index import check_index_swap
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import IndexingSetting
from danswer.llm.llm_initialization import load_llm_providers
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.search.models import SavedSearchSettings
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.search.search_settings import get_search_settings
from danswer.search.search_settings import update_search_settings
from danswer.server.auth_check import check_router_auth
from danswer.server.danswer_api.ingestion import router as danswer_api_router
from danswer.server.documents.cc_pair import router as cc_pair_router
@ -116,13 +115,8 @@ from danswer.utils.telemetry import RecordType
from danswer.utils.variable_functionality import fetch_versioned_implementation
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import DEFAULT_CROSS_ENCODER_API_KEY
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
from shared_configs.configs import DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
from shared_configs.configs import DISABLE_RERANK_FOR_STREAMING
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import RerankerProvider
logger = setup_logger()
@ -198,6 +192,49 @@ def setup_postgres(db_session: Session) -> None:
auto_add_search_tool_to_personas(db_session)
def translate_saved_search_settings(db_session: Session) -> None:
kv_store = get_dynamic_config_store()
try:
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
if isinstance(search_settings_dict, dict):
# Update current search settings
current_settings = get_current_search_settings(db_session)
# Update non-preserved fields
if current_settings:
current_settings_dict = SavedSearchSettings.from_db_model(
current_settings
).dict()
new_current_settings = SavedSearchSettings(
**{**current_settings_dict, **search_settings_dict}
)
update_current_search_settings(db_session, new_current_settings)
# Update secondary search settings
secondary_settings = get_secondary_search_settings(db_session)
if secondary_settings:
secondary_settings_dict = SavedSearchSettings.from_db_model(
secondary_settings
).dict()
new_secondary_settings = SavedSearchSettings(
**{**secondary_settings_dict, **search_settings_dict}
)
update_secondary_search_settings(
db_session,
new_secondary_settings,
)
# Delete the KV store entry after successful update
kv_store.delete(KV_SEARCH_SETTINGS)
logger.notice("Search settings updated and KV store entry deleted.")
else:
logger.notice("KV store search settings is empty.")
except ConfigNotFoundError:
logger.notice("No search config found in KV store.")
def mark_reindex_flag(db_session: Session) -> None:
kv_store = get_dynamic_config_store()
try:
@ -221,17 +258,17 @@ def mark_reindex_flag(db_session: Session) -> None:
def setup_vespa(
document_index: DocumentIndex,
db_embedding_model: EmbeddingModel,
secondary_db_embedding_model: EmbeddingModel | None,
index_setting: IndexingSetting,
secondary_index_setting: IndexingSetting | None,
) -> None:
# Vespa startup is a bit slow, so give it a few seconds
wait_time = 5
for _ in range(5):
try:
document_index.ensure_indices_exist(
index_embedding_dim=db_embedding_model.model_dim,
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
if secondary_db_embedding_model
index_embedding_dim=index_setting.model_dim,
secondary_index_embedding_dim=secondary_index_setting.model_dim
if secondary_index_setting
else None,
)
break
@ -262,13 +299,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
embedding_model_id=db_embedding_model.id, db_session=db_session
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
@ -277,16 +314,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.notice(f'Using Embedding model: "{db_embedding_model.model_name}"')
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
)
logger.notice(
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
search_settings = get_search_settings()
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
@ -295,29 +329,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
else:
if DEFAULT_CROSS_ENCODER_MODEL_NAME:
logger.notice("Reranking is enabled.")
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
raise ValueError("No reranking model specified.")
search_settings = SavedSearchSettings(
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
else None,
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
num_rerank=NUM_POSTPROCESSED_RESULTS,
multilingual_expansion=[
s.strip()
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
if s.strip()
]
if MULTILINGUAL_QUERY_EXPANSION
else [],
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
)
update_search_settings(search_settings)
if search_settings.rerank_model_name and not search_settings.provider_type:
warm_up_cross_encoder(search_settings.rerank_model_name)
@ -328,6 +339,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# setup Postgres with default credential, llm providers, etc.
setup_postgres(db_session)
translate_saved_search_settings(db_session)
# Does the user need to trigger a reindexing to bring the document index
# into a good state, marked in the kv store
mark_reindex_flag(db_session)
@ -335,19 +348,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)
setup_vespa(
document_index,
IndexingSetting.from_db_model(search_settings),
IndexingSetting.from_db_model(secondary_search_settings)
if secondary_search_settings
else None,
)
setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model)
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
if db_embedding_model.provider_type is None:
if search_settings.provider_type is None:
warm_up_bi_encoder(
embedding_model=db_embedding_model,
model_server_host=MODEL_SERVER_HOST,
model_server_port=MODEL_SERVER_PORT,
embedding_model=EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
),
)
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})

View File

@ -15,7 +15,7 @@ from danswer.configs.model_configs import (
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
)
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from danswer.db.models import EmbeddingModel as DBEmbeddingModel
from danswer.db.models import SearchSettings
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.natural_language_processing.utils import tokenizer_trim_content
from danswer.utils.logger import setup_logger
@ -209,6 +209,26 @@ class EmbeddingModel:
max_seq_length=max_seq_length,
)
@classmethod
def from_db_model(
cls,
search_settings: SearchSettings,
server_host: str, # Changes depending on indexing or inference
server_port: int,
retrim_content: bool = False,
) -> "EmbeddingModel":
return cls(
server_host=server_host,
server_port=server_port,
model_name=search_settings.model_name,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
api_key=search_settings.api_key,
provider_type=search_settings.provider_type,
retrim_content=retrim_content,
)
class RerankingModel:
def __init__(
@ -302,47 +322,35 @@ def warm_up_retry(
def warm_up_bi_encoder(
embedding_model: DBEmbeddingModel,
model_server_host: str = MODEL_SERVER_HOST,
model_server_port: int = MODEL_SERVER_PORT,
embedding_model: EmbeddingModel,
non_blocking: bool = False,
) -> None:
model_name = embedding_model.model_name
normalize = embedding_model.normalize
provider_type = embedding_model.provider_type
warm_up_str = " ".join(WARM_UP_STRINGS)
logger.debug(f"Warming up encoder model: {model_name}")
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
warm_up_str
)
embed_model = EmbeddingModel(
model_name=model_name,
normalize=normalize,
provider_type=provider_type,
# Not a big deal if prefix is incorrect
query_prefix=None,
passage_prefix=None,
server_host=model_server_host,
server_port=model_server_port,
api_key=None,
)
logger.debug(f"Warming up encoder model: {embedding_model.model_name}")
get_tokenizer(
model_name=embedding_model.model_name,
provider_type=embedding_model.provider_type,
).encode(warm_up_str)
def _warm_up() -> None:
try:
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
logger.debug(f"Warm-up complete for encoder model: {model_name}")
embedding_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
logger.debug(
f"Warm-up complete for encoder model: {embedding_model.model_name}"
)
except Exception as e:
logger.warning(
f"Warm-up request failed for encoder model {model_name}: {e}"
f"Warm-up request failed for encoder model {embedding_model.model_name}: {e}"
)
if non_blocking:
threading.Thread(target=_warm_up, daemon=True).start()
logger.debug(f"Started non-blocking warm-up for encoder model: {model_name}")
logger.debug(
f"Started non-blocking warm-up for encoder model: {embedding_model.model_name}"
)
else:
retry_encode = warm_up_retry(embed_model.encode)
retry_encode = warm_up_retry(embedding_model.encode)
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)

View File

@ -7,7 +7,9 @@ from pydantic import validator
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.configs.constants import DocumentSource
from danswer.db.models import Persona
from danswer.db.models import SearchSettings
from danswer.indexing.models import BaseChunk
from danswer.indexing.models import IndexingSetting
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import OptionalSearchSetting
from danswer.search.enums import SearchType
@ -22,28 +24,61 @@ MAX_METRICS_CONTENT = (
class RerankingDetails(BaseModel):
# If model is None (or num_rerank is 0), then reranking is turned off
rerank_model_name: str | None
provider_type: RerankerProvider | None
api_key: str | None
rerank_provider_type: RerankerProvider | None
rerank_api_key: str | None
num_rerank: int
class SavedSearchSettings(RerankingDetails):
# Empty for no additional expansion
multilingual_expansion: list[str]
# Encompasses both mini and large chunks
multipass_indexing: bool
# For faster flows where the results should start immediately
# this more time intensive step can be skipped
disable_rerank_for_streaming: bool
disable_rerank_for_streaming: bool = False
def to_reranking_detail(self) -> RerankingDetails:
return RerankingDetails(
rerank_model_name=self.rerank_model_name,
provider_type=self.provider_type,
api_key=self.api_key,
num_rerank=self.num_rerank,
@classmethod
def from_db_model(cls, search_settings: SearchSettings) -> "RerankingDetails":
return cls(
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
)
class InferenceSettings(RerankingDetails):
# Empty for no additional expansion
multilingual_expansion: list[str]
class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
@classmethod
def from_db_model(
cls, search_settings: SearchSettings
) -> "SearchSettingsCreationRequest":
inference_settings = InferenceSettings.from_db_model(search_settings)
indexing_setting = IndexingSetting.from_db_model(search_settings)
return cls(**inference_settings.dict(), **indexing_setting.dict())
class SavedSearchSettings(InferenceSettings, IndexingSetting):
@classmethod
def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings":
return cls(
# Indexing Setting
model_name=search_settings.model_name,
model_dim=search_settings.model_dim,
normalize=search_settings.normalize,
query_prefix=search_settings.query_prefix,
passage_prefix=search_settings.passage_prefix,
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
# Multilingual Expansion
multilingual_expansion=search_settings.multilingual_expansion,
)

View File

@ -7,8 +7,8 @@ from sqlalchemy.orm import Session
from danswer.chat.models import SectionRelevancePiece
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.llm.answering.models import PromptConfig
@ -65,9 +65,9 @@ class SearchPipeline:
self.retrieval_metrics_callback = retrieval_metrics_callback
self.rerank_metrics_callback = rerank_metrics_callback
self.embedding_model = get_current_db_embedding_model(db_session)
self.search_settings = get_current_search_settings(db_session)
self.document_index = get_default_document_index(
primary_index_name=self.embedding_model.index_name,
primary_index_name=self.search_settings.index_name,
secondary_index_name=None,
)
self.prompt_config: PromptConfig | None = prompt_config

View File

@ -98,8 +98,8 @@ def semantic_reranking(
cross_encoder = RerankingModel(
model_name=rerank_settings.rerank_model_name,
provider_type=rerank_settings.provider_type,
api_key=rerank_settings.api_key,
provider_type=rerank_settings.rerank_provider_type,
api_key=rerank_settings.rerank_api_key,
)
passages = [

View File

@ -10,18 +10,19 @@ from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
from danswer.configs.chat_configs import NUM_RETURNED_HITS
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.llm.interfaces import LLM
from danswer.natural_language_processing.search_nlp_models import QueryAnalysisModel
from danswer.search.enums import LLMEvaluationType
from danswer.search.enums import RecencyBiasSetting
from danswer.search.enums import SearchType
from danswer.search.models import BaseFilters
from danswer.search.models import IndexFilters
from danswer.search.models import RerankingDetails
from danswer.search.models import SearchQuery
from danswer.search.models import SearchRequest
from danswer.search.models import SearchType
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
from danswer.search.search_settings import get_search_settings
from danswer.secondary_llm_flows.source_filter import extract_source_filter
from danswer.secondary_llm_flows.time_filter import extract_time_filter
from danswer.utils.logger import setup_logger
@ -179,12 +180,11 @@ def retrieval_preprocessing(
rerank_settings = search_request.rerank_settings
# If not explicitly specified by the query, use the current settings
if rerank_settings is None:
saved_search_settings = get_search_settings()
if not saved_search_settings:
rerank_settings = None
search_settings = get_current_search_settings(db_session)
# For non-streaming flows, the rerank settings are applied at the search_request level
elif not saved_search_settings.disable_rerank_for_streaming:
rerank_settings = saved_search_settings.to_reranking_detail()
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
# Decays at 1 / (1 + (multiplier * num years))
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:

View File

@ -7,7 +7,8 @@ from nltk.stem import WordNetLemmatizer # type:ignore
from nltk.tokenize import word_tokenize # type:ignore
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_multilingual_expansion
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.document_index.vespa.shared_utils.utils import (
@ -23,7 +24,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
from danswer.search.models import RetrievalMetricsContainer
from danswer.search.models import SearchQuery
from danswer.search.postprocessing.postprocessing import cleanup_chunks
from danswer.search.search_settings import get_multilingual_expansion
from danswer.search.utils import inference_section_from_chunks
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
from danswer.utils.logger import setup_logger
@ -121,15 +121,10 @@ def doc_index_retrieval(
from the large chunks to the referenced chunks,
dedupes the chunks, and cleans the chunks.
"""
db_embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel(
model_name=db_embedding_model.model_name,
query_prefix=db_embedding_model.query_prefix,
passage_prefix=db_embedding_model.passage_prefix,
normalize=db_embedding_model.normalize,
api_key=db_embedding_model.api_key,
provider_type=db_embedding_model.provider_type,
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
@ -230,7 +225,7 @@ def retrieve_chunks(
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
multilingual_expansion = get_multilingual_expansion()
multilingual_expansion = get_multilingual_expansion(db_session)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
top_chunks = doc_index_retrieval(

View File

@ -9,12 +9,7 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def get_multilingual_expansion() -> list[str]:
search_settings = get_search_settings()
return search_settings.multilingual_expansion if search_settings else []
def get_search_settings() -> SavedSearchSettings | None:
def get_kv_search_settings() -> SavedSearchSettings | None:
"""Get all user configured search settings which affect the search pipeline
Note: KV store is used in this case since there is no need to rollback the value or any need to audit past values
@ -33,8 +28,3 @@ def get_search_settings() -> SavedSearchSettings | None:
# or the user can set it via the API/UI
kv_store.delete(KV_SEARCH_SETTINGS)
return None
def update_search_settings(settings: SavedSearchSettings) -> None:
kv_store = get_dynamic_config_store()
kv_store.store(KV_SEARCH_SETTINGS, settings.dict())

View File

@ -2,11 +2,11 @@ from danswer.chat.chat_utils import combine_message_chain
from danswer.configs.chat_configs import LANGUAGE_CHAT_NAMING_HINT
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.db.search_settings import get_multilingual_expansion
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.chat_prompts import CHAT_NAMING
from danswer.search.search_settings import get_multilingual_expansion
from danswer.utils.logger import setup_logger
logger = setup_logger()

View File

@ -9,10 +9,10 @@ from danswer.connectors.models import IndexAttemptMetadata
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
from danswer.db.document import get_documents_by_cc_pair
from danswer.db.document import get_ingestion_documents
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.document_index.document_index_utils import get_both_index_names
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.embedder import DefaultIndexingEmbedder
@ -90,10 +90,10 @@ def upsert_ingestion_doc(
primary_index_name=curr_ind_name, secondary_index_name=None
)
db_embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
index_embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
db_embedding_model
index_embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)
indexing_pipeline = build_indexing_pipeline(
@ -117,16 +117,16 @@ def upsert_ingestion_doc(
primary_index_name=curr_ind_name, secondary_index_name=None
)
sec_db_embedding_model = get_secondary_db_embedding_model(db_session)
sec_search_settings = get_secondary_search_settings(db_session)
if sec_db_embedding_model is None:
if sec_search_settings is None:
# Should not ever happen
raise RuntimeError(
"Secondary index exists but no embedding model configured"
"Secondary index exists but no search settings configured"
)
new_index_embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
sec_db_embedding_model
new_index_embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=sec_search_settings
)
sec_ind_pipeline = build_indexing_pipeline(

View File

@ -63,7 +63,6 @@ from danswer.db.credentials import delete_google_drive_service_account_credentia
from danswer.db.credentials import fetch_credential_by_id
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.document import get_document_cnts_for_cc_pairs
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
@ -71,6 +70,7 @@ from danswer.db.index_attempt import get_latest_finished_index_attempt_for_cc_pa
from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.models import User
from danswer.db.models import UserRole
from danswer.db.search_settings import get_current_search_settings
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.server.documents.models import AuthStatus
@ -705,7 +705,7 @@ def connector_run_once(
)
]
embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
connector_credential_pairs = [
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
@ -716,7 +716,7 @@ def connector_run_once(
index_attempt_ids = [
create_index_attempt(
connector_credential_pair_id=connector_credential_pair.id,
embedding_model_id=embedding_model.id,
search_settings_id=search_settings.id,
from_beginning=run_info.from_beginning,
db_session=db_session,
)

View File

@ -5,9 +5,9 @@ from fastapi import Query
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import VespaChunkRequest
from danswer.natural_language_processing.utils import get_tokenizer
@ -29,10 +29,10 @@ def get_document_info(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> DocumentInfo:
embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
primary_index_name=search_settings.index_name, secondary_index_name=None
)
user_acl_filters = build_access_filters_for_user(user, db_session)
@ -51,8 +51,8 @@ def get_document_info(
# get actual document context used for LLM
first_chunk = inference_chunks[0]
tokenizer_encode = get_tokenizer(
provider_type=embedding_model.provider_type,
model_name=embedding_model.model_name,
provider_type=search_settings.provider_type,
model_name=search_settings.model_name,
).encode
full_context_str = build_doc_context_str(
semantic_identifier=first_chunk.semantic_identifier,
@ -76,10 +76,10 @@ def get_chunk_info(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ChunkInfo:
embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
primary_index_name=search_settings.index_name, secondary_index_name=None
)
user_acl_filters = build_access_filters_for_user(user, db_session)
@ -100,8 +100,8 @@ def get_chunk_info(
chunk_content = inference_chunks[0].content
tokenizer_encode = get_tokenizer(
provider_type=embedding_model.provider_type,
model_name=embedding_model.model_name,
provider_type=search_settings.provider_type,
model_name=search_settings.model_name,
).encode
return ChunkInfo(

View File

@ -4,12 +4,12 @@ from fastapi import HTTPException
from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.db.embedding_model import get_current_db_embedding_provider
from danswer.db.engine import get_session
from danswer.db.llm import fetch_existing_embedding_providers
from danswer.db.llm import remove_embedding_provider
from danswer.db.llm import upsert_cloud_embedding_provider
from danswer.db.models import User
from danswer.db.search_settings import get_current_db_embedding_provider
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest

View File

@ -17,7 +17,7 @@ from danswer.db.models import SlackBotResponseType
from danswer.db.models import StandardAnswer as StandardAnswerModel
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
from danswer.db.models import User
from danswer.indexing.models import EmbeddingModelDetail
from danswer.search.models import SavedSearchSettings
from danswer.server.features.persona.models import PersonaSnapshot
from danswer.server.models import FullUserSnapshot
from danswer.server.models import InvitedUserSnapshot
@ -246,8 +246,8 @@ class SlackBotConfig(BaseModel):
class FullModelVersionResponse(BaseModel):
current_model: EmbeddingModelDetail
secondary_model: EmbeddingModelDetail | None
current_settings: SavedSearchSettings
secondary_settings: SavedSearchSettings | None
class AllUsersResponse(BaseModel):

View File

@ -9,111 +9,108 @@ from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.embedding_model import create_embedding_model
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_embedding_provider_from_provider_type
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.embedding_model import update_embedding_model_status
from danswer.db.engine import get_session
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.models import IndexModelStatus
from danswer.db.models import User
from danswer.db.search_settings import create_search_settings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_embedding_provider_from_provider_type
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.search_settings import update_current_search_settings
from danswer.db.search_settings import update_search_settings_status
from danswer.document_index.factory import get_default_document_index
from danswer.indexing.models import EmbeddingModelCreateRequest
from danswer.indexing.models import EmbeddingModelDetail
from danswer.natural_language_processing.search_nlp_models import clean_model_name
from danswer.search.models import SavedSearchSettings
from danswer.search.search_settings import get_search_settings
from danswer.search.search_settings import update_search_settings
from danswer.search.models import SearchSettingsCreationRequest
from danswer.server.manage.models import FullModelVersionResponse
from danswer.server.models import IdReturn
from danswer.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
router = APIRouter(prefix="/search-settings")
logger = setup_logger()
@router.post("/set-new-embedding-model")
def set_new_embedding_model(
embed_model_details: EmbeddingModelDetail,
@router.post("/set-new-search-settings")
def set_new_search_settings(
search_settings_new: SearchSettingsCreationRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> IdReturn:
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
Gives an error if the same model name is used as the current or secondary index
"""
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Validate cloud provider exists
if embed_model_details.provider_type is not None:
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
db_session, provider_type=embed_model_details.provider_type
db_session, provider_type=search_settings_new.provider_type
)
if cloud_provider is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"No embedding provider exists for cloud embedding type {embed_model_details.provider_type}",
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
)
current_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
# We define index name here
index_name = f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}"
if (
embed_model_details.model_name == current_model.model_name
and not current_model.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
if search_settings_new.index_name is None:
# We define index name here
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
if (
search_settings_new.model_name == search_settings.model_name
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
):
index_name += ALT_INDEX_SUFFIX
search_values = search_settings_new.dict()
search_values["index_name"] = index_name
new_search_settings_request = SavedSearchSettings(**search_values)
else:
new_search_settings_request = SavedSearchSettings(**search_settings_new.dict())
create_embed_model_details = EmbeddingModelCreateRequest(
**embed_model_details.dict(), index_name=index_name
)
secondary_model = get_secondary_db_embedding_model(db_session)
if secondary_model:
if embed_model_details.model_name == secondary_model.model_name:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Already reindexing with {secondary_model.model_name}",
)
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_search_settings:
# Cancel any background indexing jobs
expire_index_attempts(
embedding_model_id=secondary_model.id, db_session=db_session
search_settings_id=secondary_search_settings.id, db_session=db_session
)
# Mark previous model as a past model directly
update_embedding_model_status(
embedding_model=secondary_model,
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
new_model = create_embedding_model(
create_embed_model_details=create_embed_model_details, db_session=db_session
new_search_settings = create_search_settings(
search_settings=new_search_settings_request, db_session=db_session
)
# Ensure Vespa has the new index immediately
document_index = get_default_document_index(
primary_index_name=current_model.index_name,
secondary_index_name=new_model.index_name,
primary_index_name=search_settings.index_name,
secondary_index_name=new_search_settings.index_name,
)
document_index.ensure_indices_exist(
index_embedding_dim=current_model.model_dim,
secondary_index_embedding_dim=new_model.model_dim,
index_embedding_dim=search_settings.model_dim,
secondary_index_embedding_dim=new_search_settings.model_dim,
)
# Pause index attempts for the currently in use index to preserve resources
if DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
embedding_model_id=current_model.id, db_session=db_session
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
return IdReturn(id=new_model.id)
return IdReturn(id=new_search_settings.id)
@router.post("/cancel-new-embedding")
@ -121,66 +118,63 @@ def cancel_new_embedding(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
secondary_model = get_secondary_db_embedding_model(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
if secondary_model:
if secondary_search_settings:
expire_index_attempts(
embedding_model_id=secondary_model.id, db_session=db_session
search_settings_id=secondary_search_settings.id, db_session=db_session
)
update_embedding_model_status(
embedding_model=secondary_model,
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
@router.get("/get-current-embedding-model")
def get_current_embedding_model(
@router.get("/get-current-search-settings")
def get_curr_search_settings(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> EmbeddingModelDetail:
current_model = get_current_db_embedding_model(db_session)
return EmbeddingModelDetail.from_model(current_model)
) -> SavedSearchSettings:
current_search_settings = get_current_search_settings(db_session)
return SavedSearchSettings.from_db_model(current_search_settings)
@router.get("/get-secondary-embedding-model")
def get_secondary_embedding_model(
@router.get("/get-secondary-search-settings")
def get_sec_search_settings(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> EmbeddingModelDetail | None:
next_model = get_secondary_db_embedding_model(db_session)
if not next_model:
) -> SavedSearchSettings | None:
secondary_search_settings = get_secondary_search_settings(db_session)
if not secondary_search_settings:
return None
return EmbeddingModelDetail.from_model(next_model)
return SavedSearchSettings.from_db_model(secondary_search_settings)
@router.get("/get-embedding-models")
def get_embedding_models(
@router.get("/get-all-search-settings")
def get_all_search_settings(
_: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> FullModelVersionResponse:
current_model = get_current_db_embedding_model(db_session)
next_model = get_secondary_db_embedding_model(db_session)
current_search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
return FullModelVersionResponse(
current_model=EmbeddingModelDetail.from_model(current_model),
secondary_model=EmbeddingModelDetail.from_model(next_model)
if next_model
current_settings=SavedSearchSettings.from_db_model(current_search_settings),
secondary_settings=SavedSearchSettings.from_db_model(secondary_search_settings)
if secondary_search_settings
else None,
)
@router.get("/get-search-settings")
def get_saved_search_settings(
_: User | None = Depends(current_admin_user),
) -> SavedSearchSettings | None:
return get_search_settings()
@router.post("/update-search-settings")
# Updates current non-reindex search settings
@router.post("/update-inference-settings")
def update_saved_search_settings(
search_settings: SavedSearchSettings,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_search_settings(search_settings)
update_current_search_settings(
search_settings=search_settings, db_session=db_session
)

View File

@ -15,9 +15,9 @@ from danswer.db.chat import get_first_messages_for_chat_sessions
from danswer.db.chat import get_search_docs_for_chat_message
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session
from danswer.db.models import User
from danswer.db.search_settings import get_current_search_settings
from danswer.db.tag import get_tags_by_value_prefix_for_source_types
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
@ -63,9 +63,9 @@ def admin_search(
tags=question.filters.tags,
access_control_list=user_acl_filters,
)
embedding_model = get_current_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=embedding_model.index_name, secondary_index_name=None
primary_index_name=search_settings.index_name, secondary_index_name=None
)
if not isinstance(document_index, VespaIndex):
raise HTTPException(

View File

@ -2,8 +2,8 @@ from sqlalchemy.orm import Session
from danswer.access.access import get_access_for_documents
from danswer.db.document import prepare_to_modify_documents
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.embedding_model import get_secondary_db_embedding_model
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.interfaces import DocumentIndex
from danswer.document_index.interfaces import UpdateRequest
@ -47,13 +47,13 @@ def _sync_user_group_batch(
def sync_user_groups(user_group_id: int, db_session: Session) -> None:
"""Sync the status of Postgres for the specified user group"""
db_embedding_model = get_current_db_embedding_model(db_session)
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
document_index = get_default_document_index(
primary_index_name=db_embedding_model.index_name,
secondary_index_name=secondary_db_embedding_model.index_name
if secondary_db_embedding_model
primary_index_name=search_settings.index_name,
secondary_index_name=secondary_search_settings.index_name
if secondary_search_settings
else None,
)

View File

@ -54,3 +54,17 @@ LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "danswer"
DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true"
# notset, debug, info, notice, warning, error, or critical
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
# Fields which should only be set on new search setting
PRESERVED_SEARCH_FIELDS = [
"provider_type",
"api_key",
"model_name",
"index_name",
"multipass_indexing",
"model_dim",
"normalize",
"passage_prefix",
"query_prefix",
]

View File

@ -10,13 +10,14 @@ from danswer.configs.app_configs import POSTGRES_HOST
from danswer.configs.app_configs import POSTGRES_PASSWORD
from danswer.configs.app_configs import POSTGRES_PORT
from danswer.configs.app_configs import POSTGRES_USER
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_session_context_manager
from danswer.db.engine import SYNC_DB_API
from danswer.db.search_settings import get_current_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from danswer.document_index.vespa.index import VespaIndex
from danswer.indexing.models import IndexingSetting
from danswer.main import setup_postgres
from danswer.main import setup_vespa
from tests.integration.common_utils.llm import seed_default_openai_provider
@ -127,13 +128,13 @@ def reset_vespa() -> None:
# swap to the correct default model
check_index_swap(db_session)
current_model = get_current_db_embedding_model(db_session)
index_name = current_model.index_name
search_settings = get_current_search_settings(db_session)
index_name = search_settings.index_name
setup_vespa(
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
db_embedding_model=current_model,
secondary_db_embedding_model=None,
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
for _ in range(5):

View File

@ -3,8 +3,8 @@ from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from danswer.db.embedding_model import get_current_db_embedding_model
from danswer.db.engine import get_session_context_manager
from danswer.db.search_settings import get_current_search_settings
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.vespa import TestVespaClient
@ -17,8 +17,8 @@ def db_session() -> Generator[Session, None, None]:
@pytest.fixture
def vespa_client(db_session: Session) -> TestVespaClient:
current_model = get_current_db_embedding_model(db_session)
return TestVespaClient(index_name=current_model.index_name)
search_settings = get_current_search_settings(db_session)
return TestVespaClient(index_name=search_settings.index_name)
@pytest.fixture

View File

@ -40,7 +40,7 @@ export default function UpgradingPage({
method: "POST",
});
if (response.ok) {
mutate("/api/search-settings/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-search-settings");
} else {
alert(
`Failed to cancel embedding model update - ${await response.text()}`

View File

@ -36,14 +36,14 @@ function Main() {
isLoading: isLoadingCurrentModel,
error: currentEmeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/search-settings/get-current-embedding-model",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
"/api/search-settings/get-search-settings",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@ -53,7 +53,7 @@ function Main() {
isLoading: isLoadingFutureModel,
error: futureEmeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/search-settings/get-secondary-embedding-model",
"/api/search-settings/get-secondary-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);

View File

@ -93,10 +93,10 @@ export function EmbeddingModelSelection({
const onConfirmSelection = async (model: EmbeddingModelDescriptor) => {
const response = await fetch(
"/api/search-settings/set-new-embedding-model",
"/api/search-settings/set-new-search-settings",
{
method: "POST",
body: JSON.stringify(model),
body: JSON.stringify({ ...model, index_name: null }),
headers: {
"Content-Type": "application/json",
},
@ -104,7 +104,7 @@ export function EmbeddingModelSelection({
);
if (response.ok) {
setShowTentativeModel(null);
mutate("/api/search-settings/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-search-settings");
if (!connectors || !connectors.length) {
setShowAddConnectorPopup(true);
}

View File

@ -114,32 +114,36 @@ const RerankingDetailsForm = forwardRef<
)
).map((card) => {
const isSelected =
values.provider_type === card.provider &&
values.rerank_provider_type === card.rerank_provider_type &&
values.rerank_model_name === card.modelName;
return (
<div
key={`${card.provider}-${card.modelName}`}
key={`${card.rerank_provider_type}-${card.modelName}`}
className={`p-4 border rounded-lg cursor-pointer transition-all duration-200 ${
isSelected
? "border-blue-500 bg-blue-50 shadow-md"
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
}`}
onClick={() => {
if (card.provider) {
if (card.rerank_provider_type) {
setIsApiKeyModalOpen(true);
}
setRerankingDetails({
...values,
provider_type: card.provider!,
rerank_provider_type: card.rerank_provider_type!,
rerank_model_name: card.modelName,
});
setFieldValue("provider_type", card.provider);
setFieldValue(
"rerank_provider_type",
card.rerank_provider_type
);
setFieldValue("rerank_model_name", card.modelName);
}}
>
<div className="flex items-center justify-between mb-3">
<div className="flex items-center">
{card.provider === RerankerProvider.COHERE ? (
{card.rerank_provider_type ===
RerankerProvider.COHERE ? (
<CohereIcon size={24} className="mr-2" />
) : (
<MixedBreadIcon size={24} className="mr-2" />

View File

@ -1,6 +1,9 @@
import { EmbeddingProvider } from "@/components/embedding/interfaces";
import { NonNullChain } from "typescript";
export interface RerankingDetails {
rerank_model_name: string | null;
provider_type: RerankerProvider | null;
rerank_provider_type: RerankerProvider | null;
api_key: string | null;
num_rerank: number;
}
@ -8,20 +11,33 @@ export interface RerankingDetails {
export enum RerankerProvider {
COHERE = "cohere",
}
export interface AdvancedDetails {
multilingual_expansion: string[];
export interface AdvancedSearchConfiguration {
model_name: string;
model_dim: number;
normalize: boolean;
query_prefix: string;
passage_prefix: string;
index_name: string | null;
multipass_indexing: boolean;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
}
export interface SavedSearchSettings extends RerankingDetails {
multilingual_expansion: string[];
model_name: string;
model_dim: number;
normalize: boolean;
query_prefix: string;
passage_prefix: string;
index_name: string | null;
multipass_indexing: boolean;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
provider_type: EmbeddingProvider | null;
}
export interface RerankingModel {
provider?: RerankerProvider;
rerank_provider_type: RerankerProvider | null;
modelName: string;
displayName: string;
description: string;
@ -31,6 +47,7 @@ export interface RerankingModel {
export const rerankingModels: RerankingModel[] = [
{
rerank_provider_type: null,
cloud: false,
modelName: "mixedbread-ai/mxbai-rerank-xsmall-v1",
displayName: "MixedBread XSmall",
@ -38,6 +55,7 @@ export const rerankingModels: RerankingModel[] = [
link: "https://huggingface.co/mixedbread-ai/mxbai-rerank-xsmall-v1",
},
{
rerank_provider_type: null,
cloud: false,
modelName: "mixedbread-ai/mxbai-rerank-base-v1",
displayName: "MixedBread Base",
@ -45,6 +63,7 @@ export const rerankingModels: RerankingModel[] = [
link: "https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1",
},
{
rerank_provider_type: null,
cloud: false,
modelName: "mixedbread-ai/mxbai-rerank-large-v1",
displayName: "MixedBread Large",
@ -53,7 +72,7 @@ export const rerankingModels: RerankingModel[] = [
},
{
cloud: true,
provider: RerankerProvider.COHERE,
rerank_provider_type: RerankerProvider.COHERE,
modelName: "rerank-english-v3.0",
displayName: "Cohere English",
description: "High-performance English-focused reranking model.",
@ -61,7 +80,7 @@ export const rerankingModels: RerankingModel[] = [
},
{
cloud: true,
provider: RerankerProvider.COHERE,
rerank_provider_type: RerankerProvider.COHERE,
modelName: "rerank-multilingual-v3.0",
displayName: "Cohere Multilingual",
description: "Powerful multilingual reranking model.",

View File

@ -5,14 +5,14 @@ import { EditingValue } from "@/components/credentials/EditingValue";
import CredentialSubText from "@/components/credentials/CredentialFields";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedDetails, RerankingDetails } from "../interfaces";
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
interface AdvancedEmbeddingFormPageProps {
updateAdvancedEmbeddingDetails: (
key: keyof AdvancedDetails,
key: keyof AdvancedSearchConfiguration,
value: any
) => void;
advancedEmbeddingDetails: AdvancedDetails;
advancedEmbeddingDetails: AdvancedSearchConfiguration;
setRerankingDetails: Dispatch<SetStateAction<RerankingDetails>>;
numRerank: number;
}

View File

@ -8,7 +8,7 @@ import { Button, Card, Text } from "@tremor/react";
import { ArrowLeft, ArrowRight, WarningCircle } from "@phosphor-icons/react";
import {
CloudEmbeddingModel,
EmbeddingModelDescriptor,
EmbeddingProvider,
HostedEmbeddingModel,
} from "../../../../components/embedding/interfaces";
import { errorHandlingFetcher } from "@/lib/fetcher";
@ -17,7 +17,8 @@ import useSWR, { mutate } from "swr";
import { ThreeDotsLoader } from "@/components/Loading";
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
import {
AdvancedDetails,
AdvancedSearchConfiguration,
RerankerProvider,
RerankingDetails,
SavedSearchSettings,
} from "../interfaces";
@ -30,21 +31,27 @@ export default function EmbeddingForm() {
const { popup, setPopup } = usePopup();
const [advancedEmbeddingDetails, setAdvancedEmbeddingDetails] =
useState<AdvancedDetails>({
disable_rerank_for_streaming: false,
multilingual_expansion: [],
useState<AdvancedSearchConfiguration>({
model_name: "",
model_dim: 0,
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
multipass_indexing: true,
multilingual_expansion: [],
disable_rerank_for_streaming: false,
});
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
api_key: "",
num_rerank: 0,
provider_type: null,
rerank_provider_type: null,
rerank_model_name: "",
});
const updateAdvancedEmbeddingDetails = (
key: keyof AdvancedDetails,
key: keyof AdvancedSearchConfiguration,
value: any
) => {
setAdvancedEmbeddingDetails((values) => ({ ...values, [key]: value }));
@ -52,7 +59,7 @@ export default function EmbeddingForm() {
async function updateSearchSettings(searchSettings: SavedSearchSettings) {
const response = await fetch(
"/api/search-settings/update-search-settings",
"/api/search-settings/update-inference-settings",
{
method: "POST",
headers: {
@ -80,7 +87,7 @@ export default function EmbeddingForm() {
isLoading: isLoadingCurrentModel,
error: currentEmbeddingModelError,
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
"/api/search-settings/get-current-embedding-model",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@ -91,7 +98,7 @@ export default function EmbeddingForm() {
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
useSWR<SavedSearchSettings | null>(
"/api/search-settings/get-search-settings",
"/api/search-settings/get-current-search-settings",
errorHandlingFetcher,
{ refreshInterval: 5000 } // 5 seconds
);
@ -99,31 +106,37 @@ export default function EmbeddingForm() {
useEffect(() => {
if (searchSettings) {
setAdvancedEmbeddingDetails({
model_name: searchSettings.model_name,
model_dim: searchSettings.model_dim,
normalize: searchSettings.normalize,
query_prefix: searchSettings.query_prefix,
passage_prefix: searchSettings.passage_prefix,
index_name: searchSettings.index_name,
multipass_indexing: searchSettings.multipass_indexing,
multilingual_expansion: searchSettings.multilingual_expansion,
disable_rerank_for_streaming:
searchSettings.disable_rerank_for_streaming,
multilingual_expansion: searchSettings.multilingual_expansion,
multipass_indexing: searchSettings.multipass_indexing,
});
setRerankingDetails({
api_key: searchSettings.api_key,
num_rerank: searchSettings.num_rerank,
provider_type: searchSettings.provider_type,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
});
}
}, [searchSettings]);
const originalRerankingDetails = searchSettings
const originalRerankingDetails: RerankingDetails = searchSettings
? {
api_key: searchSettings.api_key,
num_rerank: searchSettings.num_rerank,
provider_type: searchSettings.provider_type,
rerank_provider_type: searchSettings.rerank_provider_type,
rerank_model_name: searchSettings.rerank_model_name,
}
: {
api_key: "",
num_rerank: 0,
provider_type: null,
rerank_provider_type: null,
rerank_model_name: "",
};
@ -149,14 +162,17 @@ export default function EmbeddingForm() {
let values: SavedSearchSettings = {
...rerankingDetails,
...advancedEmbeddingDetails,
provider_type:
selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null,
};
const response = await updateSearchSettings(values);
if (response.ok) {
setPopup({
message: "Updated search settings succesffuly",
type: "success",
});
mutate("/api/search-settings/get-search-settings");
mutate("/api/search-settings/get-current-search-settings");
return true;
} else {
setPopup({ message: "Failed to update search settings", type: "error" });
@ -165,29 +181,37 @@ export default function EmbeddingForm() {
};
const onConfirm = async () => {
let newModel: EmbeddingModelDescriptor;
if (!selectedProvider) {
return;
}
let newModel: SavedSearchSettings;
if ("provider_type" in selectedProvider) {
// This is a CloudEmbeddingModel
if (selectedProvider.provider_type != null) {
// This is a cloud model
newModel = {
...advancedEmbeddingDetails,
...selectedProvider,
...rerankingDetails,
model_name: selectedProvider.model_name,
provider_type: selectedProvider.provider_type
?.toLowerCase()
.split(" ")[0],
provider_type:
(selectedProvider.provider_type
?.toLowerCase()
.split(" ")[0] as EmbeddingProvider) || null,
};
} else {
// This is an EmbeddingModelDescriptor
// This is a locally hosted model
newModel = {
...advancedEmbeddingDetails,
...selectedProvider,
...rerankingDetails,
model_name: selectedProvider.model_name!,
description: "",
provider_type: null,
};
}
newModel.index_name = null;
const response = await fetch(
"/api/search-settings/set-new-embedding-model",
"/api/search-settings/set-new-search-settings",
{
method: "POST",
body: JSON.stringify(newModel),
@ -201,7 +225,7 @@ export default function EmbeddingForm() {
message: "Changed provider suceessfully. Redirecing to embedding page",
type: "success",
});
mutate("/api/search-settings/get-secondary-embedding-model");
mutate("/api/search-settings/get-secondary-search-settings");
setTimeout(() => {
window.open("/admin/configuration/search", "_self");
}, 2000);
@ -217,14 +241,14 @@ export default function EmbeddingForm() {
searchSettings?.multipass_indexing !=
advancedEmbeddingDetails.multipass_indexing;
const ReIndxingButton = () => {
return (
const ReIndexingButton = ({ needsReIndex }: { needsReIndex: boolean }) => {
return needsReIndex ? (
<div className="flex mx-auto gap-x-1 ml-auto items-center">
<button
className="enabled:cursor-pointer disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
const updated = await updateSearch();
if (updated) {
const update = await updateSearch();
if (update) {
await onConfirm();
}
}}
@ -251,6 +275,15 @@ export default function EmbeddingForm() {
</div>
</div>
</div>
) : (
<button
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
updateSearch();
}}
>
Update Search
</button>
);
};
@ -361,18 +394,7 @@ export default function EmbeddingForm() {
Previous
</button>
{needsReIndex ? (
<ReIndxingButton />
) : (
<button
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
updateSearch();
}}
>
Update Search
</button>
)}
<ReIndexingButton needsReIndex={needsReIndex} />
<div className="flex w-full justify-end">
<button
@ -410,20 +432,7 @@ export default function EmbeddingForm() {
Previous
</button>
{needsReIndex ? (
<ReIndxingButton />
) : (
<button
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50
disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center
text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
onClick={async () => {
updateSearch();
}}
>
Update Search
</button>
)}
<ReIndexingButton needsReIndex={needsReIndex} />
</div>
</>
)}

View File

@ -49,7 +49,7 @@ export default async function Home() {
fetchSS("/manage/document-set"),
fetchAssistantsSS(),
fetchSS("/query/valid-tags"),
fetchSS("/search-settings/get-embedding-models"),
fetchSS("/search-settings/get-all-search-settings"),
fetchSS("/query/user-searches"),
];

View File

@ -35,7 +35,13 @@ export function CustomModelForm({
normalize: Yup.boolean().required(),
})}
onSubmit={async (values, formikHelpers) => {
onSubmit({ ...values, model_dim: parseInt(values.model_dim) });
onSubmit({
...values,
model_dim: parseInt(values.model_dim),
api_key: null,
provider_type: null,
index_name: null,
});
}}
>
{({ isSubmitting, setFieldValue }) => (

View File

@ -41,8 +41,10 @@ export interface EmbeddingModelDescriptor {
normalize: boolean;
query_prefix: string;
passage_prefix: string;
provider_type?: string | null;
provider_type: string | null;
description: string;
api_key: string | null;
index_name: string | null;
}
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
@ -82,6 +84,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/nomic-ai/nomic-embed-text-v1",
query_prefix: "search_query: ",
passage_prefix: "search_document: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/e5-base-v2",
@ -92,6 +97,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/e5-base-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/e5-small-v2",
@ -102,6 +110,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/e5-small-v2",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/multilingual-e5-base",
@ -112,6 +123,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
{
model_name: "intfloat/multilingual-e5-small",
@ -122,6 +136,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
link: "https://huggingface.co/intfloat/multilingual-e5-base",
query_prefix: "query: ",
passage_prefix: "passage: ",
index_name: "",
provider_type: null,
api_key: null,
},
];
@ -150,6 +167,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
{
model_name: "embed-english-light-v3.0",
@ -164,6 +183,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
],
},
@ -190,6 +211,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
mtebScore: 64.6,
maxContext: 8191,
enabled: false,
index_name: "",
api_key: null,
},
{
provider_type: EmbeddingProvider.OPENAI,
@ -204,6 +227,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
enabled: false,
mtebScore: 62.3,
maxContext: 8191,
index_name: "",
api_key: null,
},
],
},
@ -231,6 +256,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
{
provider_type: EmbeddingProvider.GOOGLE,
@ -244,6 +271,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
],
},
@ -270,6 +299,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
{
provider_type: EmbeddingProvider.VOYAGE,
@ -284,6 +315,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
normalize: false,
query_prefix: "",
passage_prefix: "",
index_name: "",
api_key: null,
},
],
},