mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 20:39:29 +02:00
Db search (#2235)
* k * update enum imports * add functional types + model swaps * remove a log * remove kv * fully functional + robustified for kv swap * validated with hosted + cloud * ensure not updating current search settings when reindexing * add instance check * revert back to updating search settings (will need a slight refactor for endpoint) * protect advanced config override1 * run pretty * fix typing * update typing * remove unnecessary function * update model name * clearer interface names * validated foreign key constaint * proper migration * squash --------- Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
This commit is contained in:
parent
5f12b7ad58
commit
97ba71e1b3
@ -0,0 +1,135 @@
|
||||
"""embedding model -> search settings
|
||||
|
||||
Revision ID: 1f60f60c3401
|
||||
Revises: f17bf3b0d9f1
|
||||
Create Date: 2024-08-25 12:39:51.731632
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1f60f60c3401"
|
||||
down_revision = "f17bf3b0d9f1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
# Rename the table
|
||||
op.rename_table("embedding_model", "search_settings")
|
||||
|
||||
# Add new columns
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multipass_indexing", sa.Boolean(), nullable=False, server_default="true"
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"multilingual_expansion",
|
||||
postgresql.ARRAY(sa.String()),
|
||||
nullable=False,
|
||||
server_default="{}",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"disable_rerank_for_streaming",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="false",
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"search_settings",
|
||||
sa.Column(
|
||||
"num_rerank",
|
||||
sa.Integer(),
|
||||
nullable=False,
|
||||
server_default=str(NUM_POSTPROCESSED_RESULTS),
|
||||
),
|
||||
)
|
||||
|
||||
# Add the new column as nullable initially
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the new column with data from the existing embedding_model_id
|
||||
op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id")
|
||||
|
||||
# Create the foreign key constraint
|
||||
op.create_foreign_key(
|
||||
"fk_index_attempt_search_settings",
|
||||
"index_attempt",
|
||||
"search_settings",
|
||||
["search_settings_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Make the new column non-nullable
|
||||
op.alter_column("index_attempt", "search_settings_id", nullable=False)
|
||||
|
||||
# Drop the old embedding_model_id column
|
||||
op.drop_column("index_attempt", "embedding_model_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the embedding_model_id column
|
||||
op.add_column(
|
||||
"index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
|
||||
# Populate the old column with data from search_settings_id
|
||||
op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id")
|
||||
|
||||
# Make the old column non-nullable
|
||||
op.alter_column("index_attempt", "embedding_model_id", nullable=False)
|
||||
|
||||
# Drop the foreign key constraint
|
||||
op.drop_constraint(
|
||||
"fk_index_attempt_search_settings", "index_attempt", type_="foreignkey"
|
||||
)
|
||||
|
||||
# Drop the new search_settings_id column
|
||||
op.drop_column("index_attempt", "search_settings_id")
|
||||
|
||||
# Rename the table back
|
||||
op.rename_table("search_settings", "embedding_model")
|
||||
|
||||
# Remove added columns
|
||||
op.drop_column("embedding_model", "num_rerank")
|
||||
op.drop_column("embedding_model", "rerank_api_key")
|
||||
op.drop_column("embedding_model", "rerank_provider_type")
|
||||
op.drop_column("embedding_model", "rerank_model_name")
|
||||
op.drop_column("embedding_model", "disable_rerank_for_streaming")
|
||||
op.drop_column("embedding_model", "multilingual_expansion")
|
||||
op.drop_column("embedding_model", "multipass_indexing")
|
||||
|
||||
op.create_foreign_key(
|
||||
"index_attempt__embedding_model_fk",
|
||||
"index_attempt",
|
||||
"embedding_model",
|
||||
["embedding_model_id"],
|
||||
["id"],
|
||||
)
|
@ -11,8 +11,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "351faebd379d"
|
||||
down_revision = "ee3f4b47fad5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.enums import SearchType
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "776b3bbe9092"
|
||||
|
@ -9,7 +9,7 @@ from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import table, column, String, Integer, Boolean
|
||||
|
||||
from danswer.db.embedding_model import (
|
||||
from danswer.db.search_settings import (
|
||||
get_new_default_embedding_model,
|
||||
get_old_default_embedding_model,
|
||||
user_has_overridden_embedding_model,
|
||||
@ -71,14 +71,14 @@ def upgrade() -> None:
|
||||
"query_prefix": old_embedding_model.query_prefix,
|
||||
"passage_prefix": old_embedding_model.passage_prefix,
|
||||
"index_name": old_embedding_model.index_name,
|
||||
"status": old_embedding_model.status,
|
||||
"status": IndexModelStatus.PRESENT,
|
||||
}
|
||||
],
|
||||
)
|
||||
# if the user has not overridden the default embedding model via env variables,
|
||||
# insert the new default model into the database to auto-upgrade them
|
||||
if not user_has_overridden_embedding_model():
|
||||
new_embedding_model = get_new_default_embedding_model(is_present=False)
|
||||
new_embedding_model = get_new_default_embedding_model()
|
||||
op.bulk_insert(
|
||||
EmbeddingModel,
|
||||
[
|
||||
|
@ -13,8 +13,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ee3f4b47fad5"
|
||||
down_revision = "2d2304e27d8c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -13,8 +13,8 @@ import sqlalchemy as sa
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f17bf3b0d9f1"
|
||||
down_revision = "351faebd379d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
|
@ -90,20 +90,20 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
db_embedding_model = index_attempt.embedding_model
|
||||
index_name = db_embedding_model.index_name
|
||||
search_settings = index_attempt.search_settings
|
||||
index_name = search_settings.index_name
|
||||
|
||||
# Only update cc-pair status for primary index jobs
|
||||
# Secondary index syncs at the end when swapping
|
||||
is_primary = index_attempt.embedding_model.status == IndexModelStatus.PRESENT
|
||||
is_primary = search_settings.status == IndexModelStatus.PRESENT
|
||||
|
||||
# Indexing is only done into one index at a time
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
|
||||
db_embedding_model
|
||||
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@ -111,7 +111,7 @@ def _run_indexing(
|
||||
embedder=embedding_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=index_attempt.from_beginning
|
||||
or (db_embedding_model.status == IndexModelStatus.FUTURE),
|
||||
or (search_settings.status == IndexModelStatus.FUTURE),
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@ -128,7 +128,7 @@ def _run_indexing(
|
||||
else get_last_successful_attempt_time(
|
||||
connector_id=db_connector.id,
|
||||
credential_id=db_credential.id,
|
||||
embedding_model=index_attempt.embedding_model,
|
||||
search_settings=index_attempt.search_settings,
|
||||
db_session=db_session,
|
||||
)
|
||||
)
|
||||
@ -185,7 +185,7 @@ def _run_indexing(
|
||||
if (
|
||||
(
|
||||
db_cc_pair.status == ConnectorCredentialPairStatus.PAUSED
|
||||
and db_embedding_model.status != IndexModelStatus.FUTURE
|
||||
and search_settings.status != IndexModelStatus.FUTURE
|
||||
)
|
||||
# if it's deleting, we don't care if this is a secondary index
|
||||
or db_cc_pair.status == ConnectorCredentialPairStatus.DELETING
|
||||
|
@ -20,8 +20,6 @@ from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
|
||||
from danswer.configs.constants import POSTGRES_INDEXER_APP_NAME
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
@ -32,11 +30,14 @@ from danswer.db.index_attempt import get_last_attempt_for_cc_pair
|
||||
from danswer.db.index_attempt import get_not_started_index_attempts
|
||||
from danswer.db.index_attempt import mark_attempt_failed
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@ -60,7 +61,7 @@ _UNEXPECTED_STATE_FAILURE_REASON = (
|
||||
def _should_create_new_indexing(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
last_index: IndexAttempt | None,
|
||||
model: EmbeddingModel,
|
||||
search_settings_instance: SearchSettings,
|
||||
secondary_index_building: bool,
|
||||
db_session: Session,
|
||||
) -> bool:
|
||||
@ -69,11 +70,14 @@ def _should_create_new_indexing(
|
||||
# User can still manually create single indexing attempts via the UI for the
|
||||
# currently in use index
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if model.status == IndexModelStatus.PRESENT and secondary_index_building:
|
||||
if (
|
||||
search_settings_instance.status == IndexModelStatus.PRESENT
|
||||
and secondary_index_building
|
||||
):
|
||||
return False
|
||||
|
||||
# When switching over models, always index at least once
|
||||
if model.status == IndexModelStatus.FUTURE:
|
||||
if search_settings_instance.status == IndexModelStatus.FUTURE:
|
||||
if last_index:
|
||||
# No new index if the last index attempt succeeded
|
||||
# Once is enough. The model will never be able to swap otherwise.
|
||||
@ -160,35 +164,42 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob]) -> None:
|
||||
ongoing.add(
|
||||
(
|
||||
attempt.connector_credential_pair_id,
|
||||
attempt.embedding_model_id,
|
||||
attempt.search_settings_id,
|
||||
)
|
||||
)
|
||||
|
||||
embedding_models = [get_current_db_embedding_model(db_session)]
|
||||
secondary_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
if secondary_embedding_model is not None:
|
||||
embedding_models.append(secondary_embedding_model)
|
||||
# Get the primary search settings
|
||||
primary_search_settings = get_current_search_settings(db_session)
|
||||
search_settings = [primary_search_settings]
|
||||
|
||||
# Check for secondary search settings
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_search_settings is not None:
|
||||
# If secondary settings exist, add them to the list
|
||||
search_settings.append(secondary_search_settings)
|
||||
|
||||
all_connector_credential_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair in all_connector_credential_pairs:
|
||||
for model in embedding_models:
|
||||
for search_settings_instance in search_settings:
|
||||
# Check if there is an ongoing indexing attempt for this connector credential pair
|
||||
if (cc_pair.id, model.id) in ongoing:
|
||||
if (cc_pair.id, search_settings_instance.id) in ongoing:
|
||||
continue
|
||||
|
||||
last_attempt = get_last_attempt_for_cc_pair(
|
||||
cc_pair.id, model.id, db_session
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
if not _should_create_new_indexing(
|
||||
cc_pair=cc_pair,
|
||||
last_index=last_attempt,
|
||||
model=model,
|
||||
secondary_index_building=len(embedding_models) > 1,
|
||||
search_settings_instance=search_settings_instance,
|
||||
secondary_index_building=len(search_settings) > 1,
|
||||
db_session=db_session,
|
||||
):
|
||||
continue
|
||||
|
||||
create_index_attempt(cc_pair.id, model.id, db_session)
|
||||
create_index_attempt(
|
||||
cc_pair.id, search_settings_instance.id, db_session
|
||||
)
|
||||
|
||||
|
||||
def cleanup_indexing_jobs(
|
||||
@ -285,7 +296,7 @@ def kickoff_indexing_jobs(
|
||||
# get_not_started_index_attempts orders its returned results from oldest to newest
|
||||
# we must process attempts in a FIFO manner to prevent connector starvation
|
||||
new_indexing_attempts = [
|
||||
(attempt, attempt.embedding_model)
|
||||
(attempt, attempt.search_settings)
|
||||
for attempt in get_not_started_index_attempts(db_session)
|
||||
if attempt.id not in existing_jobs
|
||||
]
|
||||
@ -297,10 +308,10 @@ def kickoff_indexing_jobs(
|
||||
|
||||
indexing_attempt_count = 0
|
||||
|
||||
for attempt, embedding_model in new_indexing_attempts:
|
||||
for attempt, search_settings in new_indexing_attempts:
|
||||
use_secondary_index = (
|
||||
embedding_model.status == IndexModelStatus.FUTURE
|
||||
if embedding_model is not None
|
||||
search_settings.status == IndexModelStatus.FUTURE
|
||||
if search_settings is not None
|
||||
else False
|
||||
)
|
||||
if attempt.connector_credential_pair.connector is None:
|
||||
@ -373,17 +384,21 @@ def update_loop(
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
# batch of documents indexed
|
||||
|
||||
if db_embedding_model.provider_type is None:
|
||||
if search_settings.provider_type is None:
|
||||
logger.notice("Running a first inference to warm up embedding model")
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
client_primary: Client | SimpleJobClient
|
||||
|
@ -32,13 +32,13 @@ from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import SearchDoc as DbSearchDoc
|
||||
from danswer.db.models import ToolCall
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
@ -331,9 +331,9 @@ def stream_chat_message_objects(
|
||||
Callable[[str], list[int]], llm_tokenizer.encode
|
||||
)
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
# Every chat Session begins with an empty root message
|
||||
|
@ -37,6 +37,7 @@ from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotConfig
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
@ -48,8 +49,8 @@ from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
@ -228,7 +229,8 @@ def handle_regular_answer(
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
saved_search_settings = get_search_settings()
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
@ -241,7 +243,7 @@ def handle_regular_answer(
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=saved_search_settings.to_reranking_detail()
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
|
@ -45,9 +45,10 @@ from danswer.danswerbot.slack.utils import read_slack_thread
|
||||
from danswer.danswerbot.slack.utils import remove_danswer_bot_tag
|
||||
from danswer.danswerbot.slack.utils import rephrase_slack_message
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
@ -468,13 +469,16 @@ if __name__ == "__main__":
|
||||
# This happens on the very first time the listener process comes up
|
||||
# or the tokens have updated (set up for the first time)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
if embedding_model.provider_type is None:
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
slack_bot_tokens = latest_slack_bot_tokens
|
||||
# potentially may cause a message to be dropped, but it is complicated
|
||||
|
@ -14,10 +14,10 @@ from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.db.models import UserGroup__ConnectorCredentialPair
|
||||
@ -159,12 +159,12 @@ def get_connector_credential_pair_from_id(
|
||||
def get_last_successful_attempt_time(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
embedding_model: EmbeddingModel,
|
||||
search_settings: SearchSettings,
|
||||
db_session: Session,
|
||||
) -> float:
|
||||
"""Gets the timestamp of the last successful index run stored in
|
||||
the CC Pair row in the database"""
|
||||
if embedding_model.status == IndexModelStatus.PRESENT:
|
||||
if search_settings.status == IndexModelStatus.PRESENT:
|
||||
connector_credential_pair = get_connector_credential_pair(
|
||||
connector_id, credential_id, db_session
|
||||
)
|
||||
@ -186,7 +186,7 @@ def get_last_successful_attempt_time(
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model.id,
|
||||
IndexAttempt.search_settings_id == search_settings.id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
)
|
||||
.order_by(IndexAttempt.time_started.desc())
|
||||
@ -445,11 +445,11 @@ def resync_cc_pair(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT,
|
||||
SearchSettings.status == IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -1,9 +1,9 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.index_attempt import get_last_attempt
|
||||
from danswer.db.models import ConnectorCredentialPair
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
|
||||
|
||||
def check_deletion_attempt_is_allowed(
|
||||
@ -28,12 +28,12 @@ def check_deletion_attempt_is_allowed(
|
||||
|
||||
connector_id = connector_credential_pair.connector_id
|
||||
credential_id = connector_credential_pair.credential_id
|
||||
current_embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
last_indexing = get_last_attempt(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
embedding_model_id=current_embedding_model.id,
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
@ -1,157 +0,0 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.indexing.models import EmbeddingModelCreateRequest
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_embedding_model(
|
||||
create_embed_model_details: EmbeddingModelCreateRequest,
|
||||
db_session: Session,
|
||||
status: IndexModelStatus = IndexModelStatus.FUTURE,
|
||||
) -> EmbeddingModel:
|
||||
embedding_model = EmbeddingModel(
|
||||
model_name=create_embed_model_details.model_name,
|
||||
model_dim=create_embed_model_details.model_dim,
|
||||
normalize=create_embed_model_details.normalize,
|
||||
query_prefix=create_embed_model_details.query_prefix,
|
||||
passage_prefix=create_embed_model_details.passage_prefix,
|
||||
status=status,
|
||||
provider_type=create_embed_model_details.provider_type,
|
||||
# Every single embedding model except the initial one from migrations has this name
|
||||
# The initial one from migration is called "danswer_chunk"
|
||||
index_name=create_embed_model_details.index_name,
|
||||
)
|
||||
|
||||
db_session.add(embedding_model)
|
||||
db_session.commit()
|
||||
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_embedding_provider_from_provider_type(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProvider | None:
|
||||
query = select(CloudEmbeddingProvider).where(
|
||||
CloudEmbeddingProvider.provider_type == provider_type
|
||||
)
|
||||
provider = db_session.execute(query).scalars().first()
|
||||
return provider if provider else None
|
||||
|
||||
|
||||
def get_current_db_embedding_provider(
|
||||
db_session: Session,
|
||||
) -> ServerCloudEmbeddingProvider | None:
|
||||
current_embedding_model = EmbeddingModelDetail.from_model(
|
||||
get_current_db_embedding_model(db_session=db_session)
|
||||
)
|
||||
|
||||
if current_embedding_model is None or current_embedding_model.provider_type is None:
|
||||
return None
|
||||
|
||||
embedding_provider = fetch_embedding_provider(
|
||||
db_session=db_session,
|
||||
provider_type=current_embedding_model.provider_type,
|
||||
)
|
||||
if embedding_provider is None:
|
||||
raise RuntimeError("No embedding provider exists for this model.")
|
||||
|
||||
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
|
||||
cloud_provider_model=embedding_provider
|
||||
)
|
||||
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def get_current_db_embedding_model(db_session: Session) -> EmbeddingModel:
|
||||
query = (
|
||||
select(EmbeddingModel)
|
||||
.where(EmbeddingModel.status == IndexModelStatus.PRESENT)
|
||||
.order_by(EmbeddingModel.id.desc())
|
||||
)
|
||||
result = db_session.execute(query)
|
||||
latest_model = result.scalars().first()
|
||||
|
||||
if not latest_model:
|
||||
raise RuntimeError("No embedding model selected, DB is not in a valid state")
|
||||
|
||||
return latest_model
|
||||
|
||||
|
||||
def get_secondary_db_embedding_model(db_session: Session) -> EmbeddingModel | None:
|
||||
query = (
|
||||
select(EmbeddingModel)
|
||||
.where(EmbeddingModel.status == IndexModelStatus.FUTURE)
|
||||
.order_by(EmbeddingModel.id.desc())
|
||||
)
|
||||
result = db_session.execute(query)
|
||||
latest_model = result.scalars().first()
|
||||
|
||||
return latest_model
|
||||
|
||||
|
||||
def update_embedding_model_status(
|
||||
embedding_model: EmbeddingModel, new_status: IndexModelStatus, db_session: Session
|
||||
) -> None:
|
||||
embedding_model.status = new_status
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def user_has_overridden_embedding_model() -> bool:
|
||||
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
|
||||
|
||||
def get_old_default_embedding_model() -> EmbeddingModel:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return EmbeddingModel(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
status=IndexModelStatus.PRESENT,
|
||||
index_name="danswer_chunk",
|
||||
)
|
||||
|
||||
|
||||
def get_new_default_embedding_model(is_present: bool) -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
)
|
@ -11,11 +11,11 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import DocumentErrorSummary
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexAttemptError
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.server.documents.models import ConnectorCredentialPair
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -27,14 +27,14 @@ logger = setup_logger()
|
||||
|
||||
def get_last_attempt_for_cc_pair(
|
||||
cc_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
return (
|
||||
db_session.query(IndexAttempt)
|
||||
.filter(
|
||||
IndexAttempt.connector_credential_pair_id == cc_pair_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
)
|
||||
.order_by(IndexAttempt.time_updated.desc())
|
||||
.first()
|
||||
@ -50,13 +50,13 @@ def get_index_attempt(
|
||||
|
||||
def create_index_attempt(
|
||||
connector_credential_pair_id: int,
|
||||
embedding_model_id: int,
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
from_beginning: bool = False,
|
||||
) -> int:
|
||||
new_attempt = IndexAttempt(
|
||||
connector_credential_pair_id=connector_credential_pair_id,
|
||||
embedding_model_id=embedding_model_id,
|
||||
search_settings_id=search_settings_id,
|
||||
from_beginning=from_beginning,
|
||||
status=IndexingStatus.NOT_STARTED,
|
||||
)
|
||||
@ -162,7 +162,7 @@ def update_docs_indexed(
|
||||
def get_last_attempt(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
embedding_model_id: int | None,
|
||||
search_settings_id: int | None,
|
||||
db_session: Session,
|
||||
) -> IndexAttempt | None:
|
||||
stmt = (
|
||||
@ -171,7 +171,7 @@ def get_last_attempt(
|
||||
.where(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
)
|
||||
)
|
||||
|
||||
@ -188,12 +188,12 @@ def get_latest_index_attempts(
|
||||
ids_stmt = select(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.max(IndexAttempt.id).label("max_id"),
|
||||
).join(EmbeddingModel, IndexAttempt.embedding_model_id == EmbeddingModel.id)
|
||||
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
|
||||
if secondary_index:
|
||||
ids_stmt = ids_stmt.where(EmbeddingModel.status == IndexModelStatus.FUTURE)
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
else:
|
||||
ids_stmt = ids_stmt.where(EmbeddingModel.status == IndexModelStatus.PRESENT)
|
||||
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
|
||||
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
ids_subquery = ids_stmt.subquery()
|
||||
@ -229,8 +229,8 @@ def get_index_attempts_for_connector(
|
||||
)
|
||||
)
|
||||
if only_current:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(IndexAttempt.time_created.desc())
|
||||
@ -250,12 +250,12 @@ def get_latest_finished_index_attempt_for_cc_pair(
|
||||
),
|
||||
)
|
||||
if secondary_index:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.FUTURE
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.FUTURE
|
||||
)
|
||||
else:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
stmt = stmt.order_by(desc(IndexAttempt.time_created))
|
||||
stmt = stmt.limit(1)
|
||||
@ -286,8 +286,8 @@ def get_index_attempts_for_cc_pair(
|
||||
)
|
||||
)
|
||||
if only_current:
|
||||
stmt = stmt.join(EmbeddingModel).where(
|
||||
EmbeddingModel.status == IndexModelStatus.PRESENT
|
||||
stmt = stmt.join(SearchSettings).where(
|
||||
SearchSettings.status == IndexModelStatus.PRESENT
|
||||
)
|
||||
|
||||
stmt = stmt.order_by(IndexAttempt.time_created.desc())
|
||||
@ -309,19 +309,19 @@ def delete_index_attempts(
|
||||
|
||||
|
||||
def expire_index_attempts(
|
||||
embedding_model_id: int,
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
delete_query = (
|
||||
delete(IndexAttempt)
|
||||
.where(IndexAttempt.embedding_model_id == embedding_model_id)
|
||||
.where(IndexAttempt.search_settings_id == search_settings_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
)
|
||||
db_session.execute(delete_query)
|
||||
|
||||
update_query = (
|
||||
update(IndexAttempt)
|
||||
.where(IndexAttempt.embedding_model_id == embedding_model_id)
|
||||
.where(IndexAttempt.search_settings_id == search_settings_id)
|
||||
.where(IndexAttempt.status != IndexingStatus.SUCCESS)
|
||||
.values(
|
||||
status=IndexingStatus.FAILED,
|
||||
@ -345,10 +345,10 @@ def cancel_indexing_attempts_for_ccpair(
|
||||
)
|
||||
|
||||
if not include_secondary_index:
|
||||
subquery = select(EmbeddingModel.id).where(
|
||||
EmbeddingModel.status != IndexModelStatus.FUTURE
|
||||
subquery = select(SearchSettings.id).where(
|
||||
SearchSettings.status != IndexModelStatus.FUTURE
|
||||
)
|
||||
stmt = stmt.where(IndexAttempt.embedding_model_id.in_(subquery))
|
||||
stmt = stmt.where(IndexAttempt.search_settings_id.in_(subquery))
|
||||
|
||||
db_session.execute(stmt)
|
||||
|
||||
@ -366,8 +366,8 @@ def cancel_indexing_attempts_past_model(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
|
||||
),
|
||||
IndexAttempt.embedding_model_id == EmbeddingModel.id,
|
||||
EmbeddingModel.status == IndexModelStatus.PAST,
|
||||
IndexAttempt.search_settings_id == SearchSettings.id,
|
||||
SearchSettings.status == IndexModelStatus.PAST,
|
||||
)
|
||||
.values(status=IndexingStatus.FAILED)
|
||||
)
|
||||
@ -376,7 +376,7 @@ def cancel_indexing_attempts_past_model(
|
||||
|
||||
|
||||
def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
embedding_model_id: int | None,
|
||||
search_settings_id: int | None,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
"""Collect all of the Index Attempts that are successful and for the specified embedding model
|
||||
@ -386,7 +386,7 @@ def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
db_session.query(IndexAttempt.connector_credential_pair_id)
|
||||
.join(ConnectorCredentialPair)
|
||||
.filter(
|
||||
IndexAttempt.embedding_model_id == embedding_model_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
IndexAttempt.status == IndexingStatus.SUCCESS,
|
||||
)
|
||||
.distinct()
|
||||
|
@ -34,6 +34,7 @@ from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
@ -56,6 +57,7 @@ from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.utils.encryption import decrypt_bytes_to_string
|
||||
from danswer.utils.encryption import encrypt_string_to_bytes
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
@ -543,33 +545,47 @@ class Credential(Base):
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
|
||||
class EmbeddingModel(Base):
|
||||
__tablename__ = "embedding_model"
|
||||
class SearchSettings(Base):
|
||||
__tablename__ = "search_settings"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
model_name: Mapped[str] = mapped_column(String)
|
||||
model_dim: Mapped[int] = mapped_column(Integer)
|
||||
normalize: Mapped[bool] = mapped_column(Boolean)
|
||||
query_prefix: Mapped[str] = mapped_column(String)
|
||||
passage_prefix: Mapped[str] = mapped_column(String)
|
||||
query_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
status: Mapped[IndexModelStatus] = mapped_column(
|
||||
Enum(IndexModelStatus, native_enum=False)
|
||||
)
|
||||
index_name: Mapped[str] = mapped_column(String)
|
||||
|
||||
# New field for cloud provider relationship
|
||||
provider_type: Mapped[EmbeddingProvider | None] = mapped_column(
|
||||
ForeignKey("embedding_provider.provider_type"), nullable=True
|
||||
)
|
||||
|
||||
# Mini and Large Chunks (large chunk also checks for model max context)
|
||||
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
multilingual_expansion: Mapped[list[str]] = mapped_column(
|
||||
postgresql.ARRAY(String), default=[]
|
||||
)
|
||||
|
||||
# Reranking settings
|
||||
disable_rerank_for_streaming: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
rerank_model_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
rerank_provider_type: Mapped[RerankerProvider | None] = mapped_column(
|
||||
Enum(RerankerProvider, native_enum=False), nullable=True
|
||||
)
|
||||
rerank_api_key: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
num_rerank: Mapped[int] = mapped_column(Integer, default=NUM_POSTPROCESSED_RESULTS)
|
||||
|
||||
cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship(
|
||||
"CloudEmbeddingProvider",
|
||||
back_populates="embedding_models",
|
||||
back_populates="search_settings",
|
||||
foreign_keys=[provider_type],
|
||||
)
|
||||
|
||||
index_attempts: Mapped[list["IndexAttempt"]] = relationship(
|
||||
"IndexAttempt", back_populates="embedding_model"
|
||||
"IndexAttempt", back_populates="search_settings"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
@ -628,8 +644,8 @@ class IndexAttempt(Base):
|
||||
# only filled if status = "failed" AND an unhandled exception caused the failure
|
||||
full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None)
|
||||
# Nullable because in the past, we didn't allow swapping out embedding models live
|
||||
embedding_model_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("embedding_model.id"),
|
||||
search_settings_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("search_settings.id"),
|
||||
nullable=False,
|
||||
)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
@ -651,8 +667,8 @@ class IndexAttempt(Base):
|
||||
"ConnectorCredentialPair", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
embedding_model: Mapped[EmbeddingModel] = relationship(
|
||||
"EmbeddingModel", back_populates="index_attempts"
|
||||
search_settings: Mapped[SearchSettings] = relationship(
|
||||
"SearchSettings", back_populates="index_attempts"
|
||||
)
|
||||
|
||||
error_rows = relationship("IndexAttemptError", back_populates="index_attempt")
|
||||
@ -1070,10 +1086,10 @@ class CloudEmbeddingProvider(Base):
|
||||
Enum(EmbeddingProvider), primary_key=True
|
||||
)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
embedding_models: Mapped[list["EmbeddingModel"]] = relationship(
|
||||
"EmbeddingModel",
|
||||
search_settings: Mapped[list["SearchSettings"]] = relationship(
|
||||
"SearchSettings",
|
||||
back_populates="cloud_provider",
|
||||
foreign_keys="EmbeddingModel.provider_type",
|
||||
foreign_keys="SearchSettings.provider_type",
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
249
backend/danswer/db/search_settings.py
Normal file
249
backend/danswer/db/search_settings.py
Normal file
@ -0,0 +1,249 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
from danswer.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.llm import fetch_embedding_provider
|
||||
from danswer.db.models import CloudEmbeddingProvider
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.manage.embedding.models import (
|
||||
CloudEmbeddingProvider as ServerCloudEmbeddingProvider,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import PRESERVED_SEARCH_FIELDS
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def create_search_settings(
|
||||
search_settings: SavedSearchSettings,
|
||||
db_session: Session,
|
||||
status: IndexModelStatus = IndexModelStatus.FUTURE,
|
||||
) -> SearchSettings:
|
||||
embedding_model = SearchSettings(
|
||||
model_name=search_settings.model_name,
|
||||
model_dim=search_settings.model_dim,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
status=status,
|
||||
index_name=search_settings.index_name,
|
||||
provider_type=search_settings.provider_type,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
rerank_api_key=search_settings.rerank_api_key,
|
||||
num_rerank=search_settings.num_rerank,
|
||||
)
|
||||
|
||||
db_session.add(embedding_model)
|
||||
db_session.commit()
|
||||
|
||||
return embedding_model
|
||||
|
||||
|
||||
def get_embedding_provider_from_provider_type(
|
||||
db_session: Session, provider_type: EmbeddingProvider
|
||||
) -> CloudEmbeddingProvider | None:
|
||||
query = select(CloudEmbeddingProvider).where(
|
||||
CloudEmbeddingProvider.provider_type == provider_type
|
||||
)
|
||||
provider = db_session.execute(query).scalars().first()
|
||||
return provider if provider else None
|
||||
|
||||
|
||||
def get_current_db_embedding_provider(
|
||||
db_session: Session,
|
||||
) -> ServerCloudEmbeddingProvider | None:
|
||||
search_settings = get_current_search_settings(db_session=db_session)
|
||||
|
||||
if search_settings.provider_type is None:
|
||||
return None
|
||||
|
||||
embedding_provider = fetch_embedding_provider(
|
||||
db_session=db_session,
|
||||
provider_type=search_settings.provider_type,
|
||||
)
|
||||
if embedding_provider is None:
|
||||
raise RuntimeError("No embedding provider exists for this model.")
|
||||
|
||||
current_embedding_provider = ServerCloudEmbeddingProvider.from_request(
|
||||
cloud_provider_model=embedding_provider
|
||||
)
|
||||
|
||||
return current_embedding_provider
|
||||
|
||||
|
||||
def get_current_search_settings(db_session: Session) -> SearchSettings:
|
||||
query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.PRESENT)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
result = db_session.execute(query)
|
||||
latest_settings = result.scalars().first()
|
||||
|
||||
if not latest_settings:
|
||||
raise RuntimeError("No search settings specified, DB is not in a valid state")
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_secondary_search_settings(db_session: Session) -> SearchSettings | None:
|
||||
query = (
|
||||
select(SearchSettings)
|
||||
.where(SearchSettings.status == IndexModelStatus.FUTURE)
|
||||
.order_by(SearchSettings.id.desc())
|
||||
)
|
||||
result = db_session.execute(query)
|
||||
latest_settings = result.scalars().first()
|
||||
|
||||
return latest_settings
|
||||
|
||||
|
||||
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
if db_session is None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
else:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if not search_settings:
|
||||
return []
|
||||
return search_settings.multilingual_expansion
|
||||
|
||||
|
||||
def update_search_settings(
|
||||
current_settings: SearchSettings,
|
||||
updated_settings: SavedSearchSettings,
|
||||
preserved_fields: list[str],
|
||||
) -> None:
|
||||
for field, value in updated_settings.dict().items():
|
||||
if field not in preserved_fields:
|
||||
setattr(current_settings, field, value)
|
||||
|
||||
|
||||
def update_current_search_settings(
|
||||
db_session: Session,
|
||||
search_settings: SavedSearchSettings,
|
||||
preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS,
|
||||
) -> None:
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
if not current_settings:
|
||||
logger.warning("No current search settings found to update")
|
||||
return
|
||||
|
||||
update_search_settings(current_settings, search_settings, preserved_fields)
|
||||
db_session.commit()
|
||||
logger.info("Current search settings updated successfully")
|
||||
|
||||
|
||||
def update_secondary_search_settings(
|
||||
db_session: Session,
|
||||
search_settings: SavedSearchSettings,
|
||||
preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS,
|
||||
) -> None:
|
||||
secondary_settings = get_secondary_search_settings(db_session)
|
||||
if not secondary_settings:
|
||||
logger.warning("No secondary search settings found to update")
|
||||
return
|
||||
|
||||
preserved_fields = PRESERVED_SEARCH_FIELDS
|
||||
update_search_settings(secondary_settings, search_settings, preserved_fields)
|
||||
|
||||
db_session.commit()
|
||||
logger.info("Secondary search settings updated successfully")
|
||||
|
||||
|
||||
def update_search_settings_status(
|
||||
search_settings: SearchSettings, new_status: IndexModelStatus, db_session: Session
|
||||
) -> None:
|
||||
search_settings.status = new_status
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def user_has_overridden_embedding_model() -> bool:
|
||||
return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
|
||||
|
||||
def get_old_default_search_settings() -> SearchSettings:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return SearchSettings(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
status=IndexModelStatus.PRESENT,
|
||||
index_name="danswer_chunk",
|
||||
)
|
||||
|
||||
|
||||
def get_new_default_search_settings(is_present: bool) -> SearchSettings:
|
||||
return SearchSettings(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
status=IndexModelStatus.PRESENT if is_present else IndexModelStatus.FUTURE,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
)
|
||||
|
||||
|
||||
def get_old_default_embedding_model() -> IndexingSetting:
|
||||
is_overridden = user_has_overridden_embedding_model()
|
||||
return IndexingSetting(
|
||||
model_name=(
|
||||
DOCUMENT_ENCODER_MODEL
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
),
|
||||
model_dim=(
|
||||
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
|
||||
),
|
||||
normalize=(
|
||||
NORMALIZE_EMBEDDINGS
|
||||
if is_overridden
|
||||
else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
|
||||
),
|
||||
query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""),
|
||||
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
|
||||
index_name="danswer_chunk",
|
||||
multipass_indexing=False,
|
||||
)
|
||||
|
||||
|
||||
def get_new_default_embedding_model() -> IndexingSetting:
|
||||
return IndexingSetting(
|
||||
model_name=DOCUMENT_ENCODER_MODEL,
|
||||
model_dim=DOC_EMBEDDING_DIM,
|
||||
normalize=NORMALIZE_EMBEDDINGS,
|
||||
query_prefix=ASYM_QUERY_PREFIX,
|
||||
passage_prefix=ASYM_PASSAGE_PREFIX,
|
||||
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
|
||||
multipass_indexing=False,
|
||||
)
|
@ -3,14 +3,14 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.enums import IndexModelStatus
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import (
|
||||
count_unique_cc_pairs_with_successful_index_attempts,
|
||||
)
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_search_settings_status
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@ -24,13 +24,13 @@ def check_index_swap(db_session: Session) -> None:
|
||||
# Default CC-pair created for Ingestion API unused here
|
||||
all_cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
|
||||
embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
if not embedding_model:
|
||||
if not search_settings:
|
||||
return
|
||||
|
||||
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
|
||||
embedding_model_id=embedding_model.id, db_session=db_session
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
|
||||
@ -40,15 +40,15 @@ def check_index_swap(db_session: Session) -> None:
|
||||
|
||||
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
|
||||
# Swap indices
|
||||
now_old_embedding_model = get_current_db_embedding_model(db_session)
|
||||
update_embedding_model_status(
|
||||
embedding_model=now_old_embedding_model,
|
||||
now_old_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=now_old_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_embedding_model_status(
|
||||
embedding_model=embedding_model,
|
||||
update_search_settings_status(
|
||||
search_settings=search_settings,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
@ -3,8 +3,8 @@ import uuid
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.indexing.models import IndexChunk
|
||||
from danswer.search.models import InferenceChunk
|
||||
|
||||
@ -14,13 +14,13 @@ DEFAULT_INDEX_NAME = "danswer_chunk"
|
||||
|
||||
|
||||
def get_both_index_names(db_session: Session) -> tuple[str, str | None]:
|
||||
model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model_new = get_secondary_db_embedding_model(db_session)
|
||||
if not model_new:
|
||||
return model.index_name, None
|
||||
search_settings_new = get_secondary_search_settings(db_session)
|
||||
if not search_settings_new:
|
||||
return search_settings.index_name, None
|
||||
|
||||
return model.index_name, model_new.index_name
|
||||
return search_settings.index_name, search_settings_new.index_name
|
||||
|
||||
|
||||
def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
|
@ -3,10 +3,10 @@ from abc import abstractmethod
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.models import EmbeddingModel as DbEmbeddingModel
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.indexing.models import ChunkEmbedding
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import IndexChunk
|
||||
@ -169,37 +169,37 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
|
||||
return embedded_chunks
|
||||
|
||||
@classmethod
|
||||
def from_db_embedding_model(
|
||||
cls, embedding_model: DbEmbeddingModel
|
||||
def from_db_search_settings(
|
||||
cls, search_settings: SearchSettings
|
||||
) -> "DefaultIndexingEmbedder":
|
||||
return cls(
|
||||
model_name=embedding_model.model_name,
|
||||
normalize=embedding_model.normalize,
|
||||
query_prefix=embedding_model.query_prefix,
|
||||
passage_prefix=embedding_model.passage_prefix,
|
||||
provider_type=embedding_model.provider_type,
|
||||
api_key=embedding_model.api_key,
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model_from_db_embedding_model(
|
||||
def get_embedding_model_from_search_settings(
|
||||
db_session: Session, index_model_status: IndexModelStatus = IndexModelStatus.PRESENT
|
||||
) -> IndexingEmbedder:
|
||||
db_embedding_model: DbEmbeddingModel | None
|
||||
search_settings: SearchSettings | None
|
||||
if index_model_status == IndexModelStatus.PRESENT:
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
elif index_model_status == IndexModelStatus.FUTURE:
|
||||
db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
if not db_embedding_model:
|
||||
search_settings = get_secondary_search_settings(db_session)
|
||||
if not search_settings:
|
||||
raise RuntimeError("No secondary index configured")
|
||||
else:
|
||||
raise RuntimeError("Not supporting embedding model rollbacks")
|
||||
|
||||
return DefaultIndexingEmbedder(
|
||||
model_name=db_embedding_model.model_name,
|
||||
normalize=db_embedding_model.normalize,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
api_key=db_embedding_model.api_key,
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
)
|
||||
|
@ -21,6 +21,7 @@ from danswer.db.document import upsert_documents_complete
|
||||
from danswer.db.document_set import fetch_document_sets_for_documents
|
||||
from danswer.db.index_attempt import create_index_attempt_error
|
||||
from danswer.db.models import Document as DBDocument
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tag import create_or_add_document_tag
|
||||
from danswer.db.tag import create_or_add_document_tag_list
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
@ -29,7 +30,6 @@ from danswer.indexing.chunker import Chunker
|
||||
from danswer.indexing.embedder import IndexingEmbedder
|
||||
from danswer.indexing.models import DocAwareChunk
|
||||
from danswer.indexing.models import DocMetadataAwareIndexChunk
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
@ -360,7 +360,7 @@ def build_indexing_pipeline(
|
||||
attempt_id: int | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
search_settings = get_search_settings()
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
multipass = (
|
||||
search_settings.multipass_indexing
|
||||
if search_settings
|
||||
|
@ -9,7 +9,7 @@ from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.models import SearchSettings
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -96,26 +96,42 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
|
||||
class EmbeddingModelDetail(BaseModel):
|
||||
model_name: str
|
||||
model_dim: int
|
||||
normalize: bool
|
||||
query_prefix: str | None
|
||||
passage_prefix: str | None
|
||||
provider_type: EmbeddingProvider | None = None
|
||||
api_key: str | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
def from_db_model(
|
||||
cls,
|
||||
embedding_model: "EmbeddingModel",
|
||||
search_settings: "SearchSettings",
|
||||
) -> "EmbeddingModelDetail":
|
||||
return cls(
|
||||
model_name=embedding_model.model_name,
|
||||
model_dim=embedding_model.model_dim,
|
||||
normalize=embedding_model.normalize,
|
||||
query_prefix=embedding_model.query_prefix,
|
||||
passage_prefix=embedding_model.passage_prefix,
|
||||
provider_type=embedding_model.provider_type,
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
api_key=search_settings.api_key,
|
||||
)
|
||||
|
||||
|
||||
class EmbeddingModelCreateRequest(EmbeddingModelDetail):
|
||||
index_name: str
|
||||
# Additional info needed for indexing time
|
||||
class IndexingSetting(EmbeddingModelDetail):
|
||||
model_dim: int
|
||||
index_name: str | None
|
||||
multipass_indexing: bool
|
||||
|
||||
@classmethod
|
||||
def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting":
|
||||
return cls(
|
||||
model_name=search_settings.model_name,
|
||||
model_dim=search_settings.model_dim,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
index_name=search_settings.index_name,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
)
|
||||
|
@ -5,6 +5,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
@ -28,19 +29,17 @@ from danswer.prompts.token_counts import CITATION_REMINDER_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import CITATION_STATEMENT_TOKEN_CNT
|
||||
from danswer.prompts.token_counts import LANGUAGE_HINT_TOKEN_CNT
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.search_settings import get_multilingual_expansion
|
||||
|
||||
|
||||
def get_prompt_tokens(prompt_config: PromptConfig) -> int:
|
||||
# Note: currently custom prompts do not allow datetime aware, only default prompts
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
return (
|
||||
check_number_of_tokens(prompt_config.system_prompt)
|
||||
+ check_number_of_tokens(prompt_config.task_prompt)
|
||||
+ CHAT_USER_PROMPT_WITH_CONTEXT_OVERHEAD_TOKEN_CNT
|
||||
+ CITATION_STATEMENT_TOKEN_CNT
|
||||
+ CITATION_REMINDER_TOKEN_CNT
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if multilingual_expansion else 0)
|
||||
+ (LANGUAGE_HINT_TOKEN_CNT if get_multilingual_expansion() else 0)
|
||||
+ (ADDITIONAL_INFO_TOKEN_CNT if prompt_config.datetime_aware else 0)
|
||||
)
|
||||
|
||||
|
@ -3,6 +3,7 @@ from langchain.schema.messages import HumanMessage
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
@ -11,7 +12,6 @@ from danswer.prompts.direct_qa_prompts import WEAK_LLM_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import build_complete_context_str
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
|
||||
|
||||
def _build_weak_llm_quotes_prompt(
|
||||
@ -48,10 +48,7 @@ def _build_strong_llm_quotes_prompt(
|
||||
history_str: str,
|
||||
prompt: PromptConfig,
|
||||
) -> HumanMessage:
|
||||
search_settings = get_search_settings()
|
||||
use_language_hint = (
|
||||
bool(search_settings.multilingual_expansion) if search_settings else False
|
||||
)
|
||||
use_language_hint = bool(get_multilingual_expansion())
|
||||
|
||||
context_block = ""
|
||||
if context_docs:
|
||||
|
@ -27,16 +27,14 @@ from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.chat_configs import MULTILINGUAL_QUERY_EXPANSION
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from danswer.configs.constants import AuthType
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from danswer.db.connector import check_connectors_exist
|
||||
from danswer.db.connector import create_initial_default_connector
|
||||
@ -45,28 +43,29 @@ from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.db.document import check_docs_exist
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import init_sqlalchemy_engine
|
||||
from danswer.db.engine import warm_up_connections
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.models import EmbeddingModel
|
||||
from danswer.db.persona import delete_old_default_personas
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_secondary_search_settings
|
||||
from danswer.db.standard_answer import create_initial_default_standard_answer_category
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.llm.llm_initialization import load_llm_providers
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.retrieval.search_runner import download_nltk_data
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
from danswer.search.search_settings import update_search_settings
|
||||
from danswer.server.auth_check import check_router_auth
|
||||
from danswer.server.danswer_api.ingestion import router as danswer_api_router
|
||||
from danswer.server.documents.cc_pair import router as cc_pair_router
|
||||
@ -116,13 +115,8 @@ from danswer.utils.telemetry import RecordType
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_API_KEY
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_MODEL_NAME
|
||||
from shared_configs.configs import DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
|
||||
from shared_configs.configs import DISABLE_RERANK_FOR_STREAMING
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@ -198,6 +192,49 @@ def setup_postgres(db_session: Session) -> None:
|
||||
auto_add_search_tool_to_personas(db_session)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
|
||||
try:
|
||||
search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS)
|
||||
if isinstance(search_settings_dict, dict):
|
||||
# Update current search settings
|
||||
current_settings = get_current_search_settings(db_session)
|
||||
|
||||
# Update non-preserved fields
|
||||
if current_settings:
|
||||
current_settings_dict = SavedSearchSettings.from_db_model(
|
||||
current_settings
|
||||
).dict()
|
||||
|
||||
new_current_settings = SavedSearchSettings(
|
||||
**{**current_settings_dict, **search_settings_dict}
|
||||
)
|
||||
update_current_search_settings(db_session, new_current_settings)
|
||||
|
||||
# Update secondary search settings
|
||||
secondary_settings = get_secondary_search_settings(db_session)
|
||||
if secondary_settings:
|
||||
secondary_settings_dict = SavedSearchSettings.from_db_model(
|
||||
secondary_settings
|
||||
).dict()
|
||||
|
||||
new_secondary_settings = SavedSearchSettings(
|
||||
**{**secondary_settings_dict, **search_settings_dict}
|
||||
)
|
||||
update_secondary_search_settings(
|
||||
db_session,
|
||||
new_secondary_settings,
|
||||
)
|
||||
# Delete the KV store entry after successful update
|
||||
kv_store.delete(KV_SEARCH_SETTINGS)
|
||||
logger.notice("Search settings updated and KV store entry deleted.")
|
||||
else:
|
||||
logger.notice("KV store search settings is empty.")
|
||||
except ConfigNotFoundError:
|
||||
logger.notice("No search config found in KV store.")
|
||||
|
||||
|
||||
def mark_reindex_flag(db_session: Session) -> None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
try:
|
||||
@ -221,17 +258,17 @@ def mark_reindex_flag(db_session: Session) -> None:
|
||||
|
||||
def setup_vespa(
|
||||
document_index: DocumentIndex,
|
||||
db_embedding_model: EmbeddingModel,
|
||||
secondary_db_embedding_model: EmbeddingModel | None,
|
||||
index_setting: IndexingSetting,
|
||||
secondary_index_setting: IndexingSetting | None,
|
||||
) -> None:
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
wait_time = 5
|
||||
for _ in range(5):
|
||||
try:
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=db_embedding_model.model_dim,
|
||||
secondary_index_embedding_dim=secondary_db_embedding_model.model_dim
|
||||
if secondary_db_embedding_model
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
if secondary_index_setting
|
||||
else None,
|
||||
)
|
||||
break
|
||||
@ -262,13 +299,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
|
||||
with Session(engine) as db_session:
|
||||
check_index_swap(db_session=db_session)
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
# Break bad state for thrashing indexes
|
||||
if secondary_db_embedding_model and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
embedding_model_id=db_embedding_model.id, db_session=db_session
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
@ -277,16 +314,13 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
# Expire all old embedding models indexing attempts, technically redundant
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
logger.notice(f'Using Embedding model: "{db_embedding_model.model_name}"')
|
||||
if db_embedding_model.query_prefix or db_embedding_model.passage_prefix:
|
||||
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
|
||||
if search_settings.query_prefix or search_settings.passage_prefix:
|
||||
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
|
||||
logger.notice(
|
||||
f'Query embedding prefix: "{db_embedding_model.query_prefix}"'
|
||||
)
|
||||
logger.notice(
|
||||
f'Passage embedding prefix: "{db_embedding_model.passage_prefix}"'
|
||||
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
|
||||
)
|
||||
|
||||
search_settings = get_search_settings()
|
||||
if search_settings:
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
logger.notice("Reranking is enabled.")
|
||||
@ -295,29 +329,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice(
|
||||
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
|
||||
)
|
||||
else:
|
||||
if DEFAULT_CROSS_ENCODER_MODEL_NAME:
|
||||
logger.notice("Reranking is enabled.")
|
||||
if not DEFAULT_CROSS_ENCODER_MODEL_NAME:
|
||||
raise ValueError("No reranking model specified.")
|
||||
search_settings = SavedSearchSettings(
|
||||
rerank_model_name=DEFAULT_CROSS_ENCODER_MODEL_NAME,
|
||||
provider_type=RerankerProvider(DEFAULT_CROSS_ENCODER_PROVIDER_TYPE)
|
||||
if DEFAULT_CROSS_ENCODER_PROVIDER_TYPE
|
||||
else None,
|
||||
api_key=DEFAULT_CROSS_ENCODER_API_KEY,
|
||||
disable_rerank_for_streaming=DISABLE_RERANK_FOR_STREAMING,
|
||||
num_rerank=NUM_POSTPROCESSED_RESULTS,
|
||||
multilingual_expansion=[
|
||||
s.strip()
|
||||
for s in MULTILINGUAL_QUERY_EXPANSION.split(",")
|
||||
if s.strip()
|
||||
]
|
||||
if MULTILINGUAL_QUERY_EXPANSION
|
||||
else [],
|
||||
multipass_indexing=ENABLE_MULTIPASS_INDEXING,
|
||||
)
|
||||
update_search_settings(search_settings)
|
||||
|
||||
if search_settings.rerank_model_name and not search_settings.provider_type:
|
||||
warm_up_cross_encoder(search_settings.rerank_model_name)
|
||||
@ -328,6 +339,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
# setup Postgres with default credential, llm providers, etc.
|
||||
setup_postgres(db_session)
|
||||
|
||||
translate_saved_search_settings(db_session)
|
||||
|
||||
# Does the user need to trigger a reindexing to bring the document index
|
||||
# into a good state, marked in the kv store
|
||||
mark_reindex_flag(db_session)
|
||||
@ -335,19 +348,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
# ensure Vespa is setup correctly
|
||||
logger.notice("Verifying Document Index(s) is/are available.")
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=db_embedding_model.index_name,
|
||||
secondary_index_name=secondary_db_embedding_model.index_name
|
||||
if secondary_db_embedding_model
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
setup_vespa(
|
||||
document_index,
|
||||
IndexingSetting.from_db_model(search_settings),
|
||||
IndexingSetting.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
setup_vespa(document_index, db_embedding_model, secondary_db_embedding_model)
|
||||
|
||||
logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}")
|
||||
if db_embedding_model.provider_type is None:
|
||||
if search_settings.provider_type is None:
|
||||
warm_up_bi_encoder(
|
||||
embedding_model=db_embedding_model,
|
||||
model_server_host=MODEL_SERVER_HOST,
|
||||
model_server_port=MODEL_SERVER_PORT,
|
||||
embedding_model=EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
),
|
||||
)
|
||||
|
||||
optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__})
|
||||
|
@ -15,7 +15,7 @@ from danswer.configs.model_configs import (
|
||||
BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES,
|
||||
)
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.db.models import EmbeddingModel as DBEmbeddingModel
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -209,6 +209,26 @@ class EmbeddingModel:
|
||||
max_seq_length=max_seq_length,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_db_model(
|
||||
cls,
|
||||
search_settings: SearchSettings,
|
||||
server_host: str, # Changes depending on indexing or inference
|
||||
server_port: int,
|
||||
retrim_content: bool = False,
|
||||
) -> "EmbeddingModel":
|
||||
return cls(
|
||||
server_host=server_host,
|
||||
server_port=server_port,
|
||||
model_name=search_settings.model_name,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
api_key=search_settings.api_key,
|
||||
provider_type=search_settings.provider_type,
|
||||
retrim_content=retrim_content,
|
||||
)
|
||||
|
||||
|
||||
class RerankingModel:
|
||||
def __init__(
|
||||
@ -302,47 +322,35 @@ def warm_up_retry(
|
||||
|
||||
|
||||
def warm_up_bi_encoder(
|
||||
embedding_model: DBEmbeddingModel,
|
||||
model_server_host: str = MODEL_SERVER_HOST,
|
||||
model_server_port: int = MODEL_SERVER_PORT,
|
||||
embedding_model: EmbeddingModel,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
model_name = embedding_model.model_name
|
||||
normalize = embedding_model.normalize
|
||||
provider_type = embedding_model.provider_type
|
||||
warm_up_str = " ".join(WARM_UP_STRINGS)
|
||||
|
||||
logger.debug(f"Warming up encoder model: {model_name}")
|
||||
get_tokenizer(model_name=model_name, provider_type=provider_type).encode(
|
||||
warm_up_str
|
||||
)
|
||||
|
||||
embed_model = EmbeddingModel(
|
||||
model_name=model_name,
|
||||
normalize=normalize,
|
||||
provider_type=provider_type,
|
||||
# Not a big deal if prefix is incorrect
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
server_host=model_server_host,
|
||||
server_port=model_server_port,
|
||||
api_key=None,
|
||||
)
|
||||
logger.debug(f"Warming up encoder model: {embedding_model.model_name}")
|
||||
get_tokenizer(
|
||||
model_name=embedding_model.model_name,
|
||||
provider_type=embedding_model.provider_type,
|
||||
).encode(warm_up_str)
|
||||
|
||||
def _warm_up() -> None:
|
||||
try:
|
||||
embed_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
logger.debug(f"Warm-up complete for encoder model: {model_name}")
|
||||
embedding_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
logger.debug(
|
||||
f"Warm-up complete for encoder model: {embedding_model.model_name}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Warm-up request failed for encoder model {model_name}: {e}"
|
||||
f"Warm-up request failed for encoder model {embedding_model.model_name}: {e}"
|
||||
)
|
||||
|
||||
if non_blocking:
|
||||
threading.Thread(target=_warm_up, daemon=True).start()
|
||||
logger.debug(f"Started non-blocking warm-up for encoder model: {model_name}")
|
||||
logger.debug(
|
||||
f"Started non-blocking warm-up for encoder model: {embedding_model.model_name}"
|
||||
)
|
||||
else:
|
||||
retry_encode = warm_up_retry(embed_model.encode)
|
||||
retry_encode = warm_up_retry(embedding_model.encode)
|
||||
retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY)
|
||||
|
||||
|
||||
|
@ -7,7 +7,9 @@ from pydantic import validator
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SearchSettings
|
||||
from danswer.indexing.models import BaseChunk
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.enums import SearchType
|
||||
@ -22,28 +24,61 @@ MAX_METRICS_CONTENT = (
|
||||
class RerankingDetails(BaseModel):
|
||||
# If model is None (or num_rerank is 0), then reranking is turned off
|
||||
rerank_model_name: str | None
|
||||
provider_type: RerankerProvider | None
|
||||
api_key: str | None
|
||||
rerank_provider_type: RerankerProvider | None
|
||||
rerank_api_key: str | None
|
||||
|
||||
num_rerank: int
|
||||
|
||||
|
||||
class SavedSearchSettings(RerankingDetails):
|
||||
# Empty for no additional expansion
|
||||
multilingual_expansion: list[str]
|
||||
# Encompasses both mini and large chunks
|
||||
multipass_indexing: bool
|
||||
|
||||
# For faster flows where the results should start immediately
|
||||
# this more time intensive step can be skipped
|
||||
disable_rerank_for_streaming: bool
|
||||
disable_rerank_for_streaming: bool = False
|
||||
|
||||
def to_reranking_detail(self) -> RerankingDetails:
|
||||
return RerankingDetails(
|
||||
rerank_model_name=self.rerank_model_name,
|
||||
provider_type=self.provider_type,
|
||||
api_key=self.api_key,
|
||||
num_rerank=self.num_rerank,
|
||||
@classmethod
|
||||
def from_db_model(cls, search_settings: SearchSettings) -> "RerankingDetails":
|
||||
return cls(
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
rerank_api_key=search_settings.rerank_api_key,
|
||||
num_rerank=search_settings.num_rerank,
|
||||
)
|
||||
|
||||
|
||||
class InferenceSettings(RerankingDetails):
|
||||
# Empty for no additional expansion
|
||||
multilingual_expansion: list[str]
|
||||
|
||||
|
||||
class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
|
||||
@classmethod
|
||||
def from_db_model(
|
||||
cls, search_settings: SearchSettings
|
||||
) -> "SearchSettingsCreationRequest":
|
||||
inference_settings = InferenceSettings.from_db_model(search_settings)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
return cls(**inference_settings.dict(), **indexing_setting.dict())
|
||||
|
||||
|
||||
class SavedSearchSettings(InferenceSettings, IndexingSetting):
|
||||
@classmethod
|
||||
def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings":
|
||||
return cls(
|
||||
# Indexing Setting
|
||||
model_name=search_settings.model_name,
|
||||
model_dim=search_settings.model_dim,
|
||||
normalize=search_settings.normalize,
|
||||
query_prefix=search_settings.query_prefix,
|
||||
passage_prefix=search_settings.passage_prefix,
|
||||
provider_type=search_settings.provider_type,
|
||||
index_name=search_settings.index_name,
|
||||
multipass_indexing=search_settings.multipass_indexing,
|
||||
# Reranking Details
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
rerank_api_key=search_settings.rerank_api_key,
|
||||
num_rerank=search_settings.num_rerank,
|
||||
# Multilingual Expansion
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
)
|
||||
|
||||
|
||||
|
@ -7,8 +7,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
@ -65,9 +65,9 @@ class SearchPipeline:
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
self.rerank_metrics_callback = rerank_metrics_callback
|
||||
|
||||
self.embedding_model = get_current_db_embedding_model(db_session)
|
||||
self.search_settings = get_current_search_settings(db_session)
|
||||
self.document_index = get_default_document_index(
|
||||
primary_index_name=self.embedding_model.index_name,
|
||||
primary_index_name=self.search_settings.index_name,
|
||||
secondary_index_name=None,
|
||||
)
|
||||
self.prompt_config: PromptConfig | None = prompt_config
|
||||
|
@ -98,8 +98,8 @@ def semantic_reranking(
|
||||
|
||||
cross_encoder = RerankingModel(
|
||||
model_name=rerank_settings.rerank_model_name,
|
||||
provider_type=rerank_settings.provider_type,
|
||||
api_key=rerank_settings.api_key,
|
||||
provider_type=rerank_settings.rerank_provider_type,
|
||||
api_key=rerank_settings.rerank_api_key,
|
||||
)
|
||||
|
||||
passages = [
|
||||
|
@ -10,18 +10,19 @@ from danswer.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from danswer.configs.chat_configs import NUM_POSTPROCESSED_RESULTS
|
||||
from danswer.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.natural_language_processing.search_nlp_models import QueryAnalysisModel
|
||||
from danswer.search.enums import LLMEvaluationType
|
||||
from danswer.search.enums import RecencyBiasSetting
|
||||
from danswer.search.enums import SearchType
|
||||
from danswer.search.models import BaseFilters
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import RerankingDetails
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.search.preprocessing.access_filters import build_access_filters_for_user
|
||||
from danswer.search.retrieval.search_runner import remove_stop_words_and_punctuation
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
from danswer.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from danswer.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -179,12 +180,11 @@ def retrieval_preprocessing(
|
||||
rerank_settings = search_request.rerank_settings
|
||||
# If not explicitly specified by the query, use the current settings
|
||||
if rerank_settings is None:
|
||||
saved_search_settings = get_search_settings()
|
||||
if not saved_search_settings:
|
||||
rerank_settings = None
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# For non-streaming flows, the rerank settings are applied at the search_request level
|
||||
elif not saved_search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = saved_search_settings.to_reranking_detail()
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Decays at 1 / (1 + (multiplier * num years))
|
||||
if persona and persona.recency_bias == RecencyBiasSetting.NO_DECAY:
|
||||
|
@ -7,7 +7,8 @@ from nltk.stem import WordNetLemmatizer # type:ignore
|
||||
from nltk.tokenize import word_tokenize # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.document_index.vespa.shared_utils.utils import (
|
||||
@ -23,7 +24,6 @@ from danswer.search.models import MAX_METRICS_CONTENT
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.postprocessing.postprocessing import cleanup_chunks
|
||||
from danswer.search.search_settings import get_multilingual_expansion
|
||||
from danswer.search.utils import inference_section_from_chunks
|
||||
from danswer.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -121,15 +121,10 @@ def doc_index_retrieval(
|
||||
from the large chunks to the referenced chunks,
|
||||
dedupes the chunks, and cleans the chunks.
|
||||
"""
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel(
|
||||
model_name=db_embedding_model.model_name,
|
||||
query_prefix=db_embedding_model.query_prefix,
|
||||
passage_prefix=db_embedding_model.passage_prefix,
|
||||
normalize=db_embedding_model.normalize,
|
||||
api_key=db_embedding_model.api_key,
|
||||
provider_type=db_embedding_model.provider_type,
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
@ -230,7 +225,7 @@ def retrieve_chunks(
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
|
||||
multilingual_expansion = get_multilingual_expansion()
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
|
||||
top_chunks = doc_index_retrieval(
|
||||
|
@ -9,12 +9,7 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_multilingual_expansion() -> list[str]:
|
||||
search_settings = get_search_settings()
|
||||
return search_settings.multilingual_expansion if search_settings else []
|
||||
|
||||
|
||||
def get_search_settings() -> SavedSearchSettings | None:
|
||||
def get_kv_search_settings() -> SavedSearchSettings | None:
|
||||
"""Get all user configured search settings which affect the search pipeline
|
||||
Note: KV store is used in this case since there is no need to rollback the value or any need to audit past values
|
||||
|
||||
@ -33,8 +28,3 @@ def get_search_settings() -> SavedSearchSettings | None:
|
||||
# or the user can set it via the API/UI
|
||||
kv_store.delete(KV_SEARCH_SETTINGS)
|
||||
return None
|
||||
|
||||
|
||||
def update_search_settings(settings: SavedSearchSettings) -> None:
|
||||
kv_store = get_dynamic_config_store()
|
||||
kv_store.store(KV_SEARCH_SETTINGS, settings.dict())
|
||||
|
@ -2,11 +2,11 @@ from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.chat_configs import LANGUAGE_CHAT_NAMING_HINT
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import CHAT_NAMING
|
||||
from danswer.search.search_settings import get_multilingual_expansion
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
@ -9,10 +9,10 @@ from danswer.connectors.models import IndexAttemptMetadata
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import get_documents_by_cc_pair
|
||||
from danswer.db.document import get_ingestion_documents
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.embedder import DefaultIndexingEmbedder
|
||||
@ -90,10 +90,10 @@ def upsert_ingestion_doc(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
index_embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
|
||||
db_embedding_model
|
||||
index_embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
@ -117,16 +117,16 @@ def upsert_ingestion_doc(
|
||||
primary_index_name=curr_ind_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
sec_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
sec_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
if sec_db_embedding_model is None:
|
||||
if sec_search_settings is None:
|
||||
# Should not ever happen
|
||||
raise RuntimeError(
|
||||
"Secondary index exists but no embedding model configured"
|
||||
"Secondary index exists but no search settings configured"
|
||||
)
|
||||
|
||||
new_index_embedding_model = DefaultIndexingEmbedder.from_db_embedding_model(
|
||||
sec_db_embedding_model
|
||||
new_index_embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=sec_search_settings
|
||||
)
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
|
@ -63,7 +63,6 @@ from danswer.db.credentials import delete_google_drive_service_account_credentia
|
||||
from danswer.db.credentials import fetch_credential_by_id
|
||||
from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.document import get_document_cnts_for_cc_pairs
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
@ -71,6 +70,7 @@ from danswer.db.index_attempt import get_latest_finished_index_attempt_for_cc_pa
|
||||
from danswer.db.index_attempt import get_latest_index_attempts
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import UserRole
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.server.documents.models import AuthStatus
|
||||
@ -705,7 +705,7 @@ def connector_run_once(
|
||||
)
|
||||
]
|
||||
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
connector_credential_pairs = [
|
||||
get_connector_credential_pair(run_info.connector_id, credential_id, db_session)
|
||||
@ -716,7 +716,7 @@ def connector_run_once(
|
||||
index_attempt_ids = [
|
||||
create_index_attempt(
|
||||
connector_credential_pair_id=connector_credential_pair.id,
|
||||
embedding_model_id=embedding_model.id,
|
||||
search_settings_id=search_settings.id,
|
||||
from_beginning=run_info.from_beginning,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
@ -5,9 +5,9 @@ from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
@ -29,10 +29,10 @@ def get_document_info(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> DocumentInfo:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
@ -51,8 +51,8 @@ def get_document_info(
|
||||
# get actual document context used for LLM
|
||||
first_chunk = inference_chunks[0]
|
||||
tokenizer_encode = get_tokenizer(
|
||||
provider_type=embedding_model.provider_type,
|
||||
model_name=embedding_model.model_name,
|
||||
provider_type=search_settings.provider_type,
|
||||
model_name=search_settings.model_name,
|
||||
).encode
|
||||
full_context_str = build_doc_context_str(
|
||||
semantic_identifier=first_chunk.semantic_identifier,
|
||||
@ -76,10 +76,10 @@ def get_chunk_info(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChunkInfo:
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
@ -100,8 +100,8 @@ def get_chunk_info(
|
||||
chunk_content = inference_chunks[0].content
|
||||
|
||||
tokenizer_encode = get_tokenizer(
|
||||
provider_type=embedding_model.provider_type,
|
||||
model_name=embedding_model.model_name,
|
||||
provider_type=search_settings.provider_type,
|
||||
model_name=search_settings.model_name,
|
||||
).encode
|
||||
|
||||
return ChunkInfo(
|
||||
|
@ -4,12 +4,12 @@ from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.db.embedding_model import get_current_db_embedding_provider
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.llm import fetch_existing_embedding_providers
|
||||
from danswer.db.llm import remove_embedding_provider
|
||||
from danswer.db.llm import upsert_cloud_embedding_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_db_embedding_provider
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from danswer.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
|
@ -17,7 +17,7 @@ from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import StandardAnswer as StandardAnswerModel
|
||||
from danswer.db.models import StandardAnswerCategory as StandardAnswerCategoryModel
|
||||
from danswer.db.models import User
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
@ -246,8 +246,8 @@ class SlackBotConfig(BaseModel):
|
||||
|
||||
|
||||
class FullModelVersionResponse(BaseModel):
|
||||
current_model: EmbeddingModelDetail
|
||||
secondary_model: EmbeddingModelDetail | None
|
||||
current_settings: SavedSearchSettings
|
||||
secondary_settings: SavedSearchSettings | None
|
||||
|
||||
|
||||
class AllUsersResponse(BaseModel):
|
||||
|
@ -9,111 +9,108 @@ from danswer.auth.users import current_user
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from danswer.db.connector_credential_pair import resync_cc_pair
|
||||
from danswer.db.embedding_model import create_embedding_model
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_embedding_provider_from_provider_type
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.embedding_model import update_embedding_model_status
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.index_attempt import expire_index_attempts
|
||||
from danswer.db.models import IndexModelStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import create_search_settings
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_embedding_provider_from_provider_type
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.db.search_settings import update_current_search_settings
|
||||
from danswer.db.search_settings import update_search_settings_status
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.indexing.models import EmbeddingModelCreateRequest
|
||||
from danswer.indexing.models import EmbeddingModelDetail
|
||||
from danswer.natural_language_processing.search_nlp_models import clean_model_name
|
||||
from danswer.search.models import SavedSearchSettings
|
||||
from danswer.search.search_settings import get_search_settings
|
||||
from danswer.search.search_settings import update_search_settings
|
||||
from danswer.search.models import SearchSettingsCreationRequest
|
||||
from danswer.server.manage.models import FullModelVersionResponse
|
||||
from danswer.server.models import IdReturn
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
||||
|
||||
|
||||
router = APIRouter(prefix="/search-settings")
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@router.post("/set-new-embedding-model")
|
||||
def set_new_embedding_model(
|
||||
embed_model_details: EmbeddingModelDetail,
|
||||
@router.post("/set-new-search-settings")
|
||||
def set_new_search_settings(
|
||||
search_settings_new: SearchSettingsCreationRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> IdReturn:
|
||||
"""Creates a new EmbeddingModel row and cancels the previous secondary indexing if any
|
||||
Gives an error if the same model name is used as the current or secondary index
|
||||
"""
|
||||
if search_settings_new.index_name:
|
||||
logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# Validate cloud provider exists
|
||||
if embed_model_details.provider_type is not None:
|
||||
if search_settings_new.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
db_session, provider_type=embed_model_details.provider_type
|
||||
db_session, provider_type=search_settings_new.provider_type
|
||||
)
|
||||
|
||||
if cloud_provider is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"No embedding provider exists for cloud embedding type {embed_model_details.provider_type}",
|
||||
detail=f"No embedding provider exists for cloud embedding type {search_settings_new.provider_type}",
|
||||
)
|
||||
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# We define index name here
|
||||
index_name = f"danswer_chunk_{clean_model_name(embed_model_details.model_name)}"
|
||||
if (
|
||||
embed_model_details.model_name == current_model.model_name
|
||||
and not current_model.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
if search_settings_new.index_name is None:
|
||||
# We define index name here
|
||||
index_name = f"danswer_chunk_{clean_model_name(search_settings_new.model_name)}"
|
||||
if (
|
||||
search_settings_new.model_name == search_settings.model_name
|
||||
and not search_settings.index_name.endswith(ALT_INDEX_SUFFIX)
|
||||
):
|
||||
index_name += ALT_INDEX_SUFFIX
|
||||
search_values = search_settings_new.dict()
|
||||
search_values["index_name"] = index_name
|
||||
new_search_settings_request = SavedSearchSettings(**search_values)
|
||||
else:
|
||||
new_search_settings_request = SavedSearchSettings(**search_settings_new.dict())
|
||||
|
||||
create_embed_model_details = EmbeddingModelCreateRequest(
|
||||
**embed_model_details.dict(), index_name=index_name
|
||||
)
|
||||
|
||||
secondary_model = get_secondary_db_embedding_model(db_session)
|
||||
|
||||
if secondary_model:
|
||||
if embed_model_details.model_name == secondary_model.model_name:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Already reindexing with {secondary_model.model_name}",
|
||||
)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
if secondary_search_settings:
|
||||
# Cancel any background indexing jobs
|
||||
expire_index_attempts(
|
||||
embedding_model_id=secondary_model.id, db_session=db_session
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
# Mark previous model as a past model directly
|
||||
update_embedding_model_status(
|
||||
embedding_model=secondary_model,
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
new_model = create_embedding_model(
|
||||
create_embed_model_details=create_embed_model_details, db_session=db_session
|
||||
new_search_settings = create_search_settings(
|
||||
search_settings=new_search_settings_request, db_session=db_session
|
||||
)
|
||||
|
||||
# Ensure Vespa has the new index immediately
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=current_model.index_name,
|
||||
secondary_index_name=new_model.index_name,
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=new_search_settings.index_name,
|
||||
)
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=current_model.model_dim,
|
||||
secondary_index_embedding_dim=new_model.model_dim,
|
||||
index_embedding_dim=search_settings.model_dim,
|
||||
secondary_index_embedding_dim=new_search_settings.model_dim,
|
||||
)
|
||||
|
||||
# Pause index attempts for the currently in use index to preserve resources
|
||||
if DISABLE_INDEX_UPDATE_ON_SWAP:
|
||||
expire_index_attempts(
|
||||
embedding_model_id=current_model.id, db_session=db_session
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
return IdReturn(id=new_model.id)
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
|
||||
|
||||
@router.post("/cancel-new-embedding")
|
||||
@ -121,66 +118,63 @@ def cancel_new_embedding(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
secondary_model = get_secondary_db_embedding_model(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
if secondary_model:
|
||||
if secondary_search_settings:
|
||||
expire_index_attempts(
|
||||
embedding_model_id=secondary_model.id, db_session=db_session
|
||||
search_settings_id=secondary_search_settings.id, db_session=db_session
|
||||
)
|
||||
|
||||
update_embedding_model_status(
|
||||
embedding_model=secondary_model,
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/get-current-embedding-model")
|
||||
def get_current_embedding_model(
|
||||
@router.get("/get-current-search-settings")
|
||||
def get_curr_search_settings(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> EmbeddingModelDetail:
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
return EmbeddingModelDetail.from_model(current_model)
|
||||
) -> SavedSearchSettings:
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
return SavedSearchSettings.from_db_model(current_search_settings)
|
||||
|
||||
|
||||
@router.get("/get-secondary-embedding-model")
|
||||
def get_secondary_embedding_model(
|
||||
@router.get("/get-secondary-search-settings")
|
||||
def get_sec_search_settings(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> EmbeddingModelDetail | None:
|
||||
next_model = get_secondary_db_embedding_model(db_session)
|
||||
if not next_model:
|
||||
) -> SavedSearchSettings | None:
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
if not secondary_search_settings:
|
||||
return None
|
||||
|
||||
return EmbeddingModelDetail.from_model(next_model)
|
||||
return SavedSearchSettings.from_db_model(secondary_search_settings)
|
||||
|
||||
|
||||
@router.get("/get-embedding-models")
|
||||
def get_embedding_models(
|
||||
@router.get("/get-all-search-settings")
|
||||
def get_all_search_settings(
|
||||
_: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullModelVersionResponse:
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
next_model = get_secondary_db_embedding_model(db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
return FullModelVersionResponse(
|
||||
current_model=EmbeddingModelDetail.from_model(current_model),
|
||||
secondary_model=EmbeddingModelDetail.from_model(next_model)
|
||||
if next_model
|
||||
current_settings=SavedSearchSettings.from_db_model(current_search_settings),
|
||||
secondary_settings=SavedSearchSettings.from_db_model(secondary_search_settings)
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/get-search-settings")
|
||||
def get_saved_search_settings(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> SavedSearchSettings | None:
|
||||
return get_search_settings()
|
||||
|
||||
|
||||
@router.post("/update-search-settings")
|
||||
# Updates current non-reindex search settings
|
||||
@router.post("/update-inference-settings")
|
||||
def update_saved_search_settings(
|
||||
search_settings: SavedSearchSettings,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_search_settings(search_settings)
|
||||
update_current_search_settings(
|
||||
search_settings=search_settings, db_session=db_session
|
||||
)
|
||||
|
@ -15,9 +15,9 @@ from danswer.db.chat import get_first_messages_for_chat_sessions
|
||||
from danswer.db.chat import get_search_docs_for_chat_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tag import get_tags_by_value_prefix_for_source_types
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
@ -63,9 +63,9 @@ def admin_search(
|
||||
tags=question.filters.tags,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
embedding_model = get_current_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=embedding_model.index_name, secondary_index_name=None
|
||||
primary_index_name=search_settings.index_name, secondary_index_name=None
|
||||
)
|
||||
if not isinstance(document_index, VespaIndex):
|
||||
raise HTTPException(
|
||||
|
@ -2,8 +2,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.access.access import get_access_for_documents
|
||||
from danswer.db.document import prepare_to_modify_documents
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.embedding_model import get_secondary_db_embedding_model
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import DocumentIndex
|
||||
from danswer.document_index.interfaces import UpdateRequest
|
||||
@ -47,13 +47,13 @@ def _sync_user_group_batch(
|
||||
|
||||
def sync_user_groups(user_group_id: int, db_session: Session) -> None:
|
||||
"""Sync the status of Postgres for the specified user group"""
|
||||
db_embedding_model = get_current_db_embedding_model(db_session)
|
||||
secondary_db_embedding_model = get_secondary_db_embedding_model(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
secondary_search_settings = get_secondary_search_settings(db_session)
|
||||
|
||||
document_index = get_default_document_index(
|
||||
primary_index_name=db_embedding_model.index_name,
|
||||
secondary_index_name=secondary_db_embedding_model.index_name
|
||||
if secondary_db_embedding_model
|
||||
primary_index_name=search_settings.index_name,
|
||||
secondary_index_name=secondary_search_settings.index_name
|
||||
if secondary_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
|
@ -54,3 +54,17 @@ LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "danswer"
|
||||
DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true"
|
||||
# notset, debug, info, notice, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
|
||||
|
||||
|
||||
# Fields which should only be set on new search setting
|
||||
PRESERVED_SEARCH_FIELDS = [
|
||||
"provider_type",
|
||||
"api_key",
|
||||
"model_name",
|
||||
"index_name",
|
||||
"multipass_indexing",
|
||||
"model_dim",
|
||||
"normalize",
|
||||
"passage_prefix",
|
||||
"query_prefix",
|
||||
]
|
||||
|
@ -10,13 +10,14 @@ from danswer.configs.app_configs import POSTGRES_HOST
|
||||
from danswer.configs.app_configs import POSTGRES_PASSWORD
|
||||
from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.engine import SYNC_DB_API
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.swap_index import check_index_swap
|
||||
from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.indexing.models import IndexingSetting
|
||||
from danswer.main import setup_postgres
|
||||
from danswer.main import setup_vespa
|
||||
from tests.integration.common_utils.llm import seed_default_openai_provider
|
||||
@ -127,13 +128,13 @@ def reset_vespa() -> None:
|
||||
# swap to the correct default model
|
||||
check_index_swap(db_session)
|
||||
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
index_name = current_model.index_name
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
setup_vespa(
|
||||
document_index=VespaIndex(index_name=index_name, secondary_index_name=None),
|
||||
db_embedding_model=current_model,
|
||||
secondary_db_embedding_model=None,
|
||||
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||
secondary_index_setting=None,
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
|
@ -3,8 +3,8 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.vespa import TestVespaClient
|
||||
|
||||
@ -17,8 +17,8 @@ def db_session() -> Generator[Session, None, None]:
|
||||
|
||||
@pytest.fixture
|
||||
def vespa_client(db_session: Session) -> TestVespaClient:
|
||||
current_model = get_current_db_embedding_model(db_session)
|
||||
return TestVespaClient(index_name=current_model.index_name)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
return TestVespaClient(index_name=search_settings.index_name)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -40,7 +40,7 @@ export default function UpgradingPage({
|
||||
method: "POST",
|
||||
});
|
||||
if (response.ok) {
|
||||
mutate("/api/search-settings/get-secondary-embedding-model");
|
||||
mutate("/api/search-settings/get-secondary-search-settings");
|
||||
} else {
|
||||
alert(
|
||||
`Failed to cancel embedding model update - ${await response.text()}`
|
||||
|
@ -36,14 +36,14 @@ function Main() {
|
||||
isLoading: isLoadingCurrentModel,
|
||||
error: currentEmeddingModelError,
|
||||
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
|
||||
"/api/search-settings/get-current-embedding-model",
|
||||
"/api/search-settings/get-current-search-settings",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
||||
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
|
||||
useSWR<SavedSearchSettings | null>(
|
||||
"/api/search-settings/get-search-settings",
|
||||
"/api/search-settings/get-current-search-settings",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
@ -53,7 +53,7 @@ function Main() {
|
||||
isLoading: isLoadingFutureModel,
|
||||
error: futureEmeddingModelError,
|
||||
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
|
||||
"/api/search-settings/get-secondary-embedding-model",
|
||||
"/api/search-settings/get-secondary-search-settings",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
|
@ -93,10 +93,10 @@ export function EmbeddingModelSelection({
|
||||
|
||||
const onConfirmSelection = async (model: EmbeddingModelDescriptor) => {
|
||||
const response = await fetch(
|
||||
"/api/search-settings/set-new-embedding-model",
|
||||
"/api/search-settings/set-new-search-settings",
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify(model),
|
||||
body: JSON.stringify({ ...model, index_name: null }),
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
@ -104,7 +104,7 @@ export function EmbeddingModelSelection({
|
||||
);
|
||||
if (response.ok) {
|
||||
setShowTentativeModel(null);
|
||||
mutate("/api/search-settings/get-secondary-embedding-model");
|
||||
mutate("/api/search-settings/get-secondary-search-settings");
|
||||
if (!connectors || !connectors.length) {
|
||||
setShowAddConnectorPopup(true);
|
||||
}
|
||||
|
@ -114,32 +114,36 @@ const RerankingDetailsForm = forwardRef<
|
||||
)
|
||||
).map((card) => {
|
||||
const isSelected =
|
||||
values.provider_type === card.provider &&
|
||||
values.rerank_provider_type === card.rerank_provider_type &&
|
||||
values.rerank_model_name === card.modelName;
|
||||
return (
|
||||
<div
|
||||
key={`${card.provider}-${card.modelName}`}
|
||||
key={`${card.rerank_provider_type}-${card.modelName}`}
|
||||
className={`p-4 border rounded-lg cursor-pointer transition-all duration-200 ${
|
||||
isSelected
|
||||
? "border-blue-500 bg-blue-50 shadow-md"
|
||||
: "border-gray-200 hover:border-blue-300 hover:shadow-sm"
|
||||
}`}
|
||||
onClick={() => {
|
||||
if (card.provider) {
|
||||
if (card.rerank_provider_type) {
|
||||
setIsApiKeyModalOpen(true);
|
||||
}
|
||||
setRerankingDetails({
|
||||
...values,
|
||||
provider_type: card.provider!,
|
||||
rerank_provider_type: card.rerank_provider_type!,
|
||||
rerank_model_name: card.modelName,
|
||||
});
|
||||
setFieldValue("provider_type", card.provider);
|
||||
setFieldValue(
|
||||
"rerank_provider_type",
|
||||
card.rerank_provider_type
|
||||
);
|
||||
setFieldValue("rerank_model_name", card.modelName);
|
||||
}}
|
||||
>
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<div className="flex items-center">
|
||||
{card.provider === RerankerProvider.COHERE ? (
|
||||
{card.rerank_provider_type ===
|
||||
RerankerProvider.COHERE ? (
|
||||
<CohereIcon size={24} className="mr-2" />
|
||||
) : (
|
||||
<MixedBreadIcon size={24} className="mr-2" />
|
||||
|
@ -1,6 +1,9 @@
|
||||
import { EmbeddingProvider } from "@/components/embedding/interfaces";
|
||||
import { NonNullChain } from "typescript";
|
||||
|
||||
export interface RerankingDetails {
|
||||
rerank_model_name: string | null;
|
||||
provider_type: RerankerProvider | null;
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
api_key: string | null;
|
||||
num_rerank: number;
|
||||
}
|
||||
@ -8,20 +11,33 @@ export interface RerankingDetails {
|
||||
export enum RerankerProvider {
|
||||
COHERE = "cohere",
|
||||
}
|
||||
export interface AdvancedDetails {
|
||||
multilingual_expansion: string[];
|
||||
export interface AdvancedSearchConfiguration {
|
||||
model_name: string;
|
||||
model_dim: number;
|
||||
normalize: boolean;
|
||||
query_prefix: string;
|
||||
passage_prefix: string;
|
||||
index_name: string | null;
|
||||
multipass_indexing: boolean;
|
||||
multilingual_expansion: string[];
|
||||
disable_rerank_for_streaming: boolean;
|
||||
}
|
||||
|
||||
export interface SavedSearchSettings extends RerankingDetails {
|
||||
multilingual_expansion: string[];
|
||||
model_name: string;
|
||||
model_dim: number;
|
||||
normalize: boolean;
|
||||
query_prefix: string;
|
||||
passage_prefix: string;
|
||||
index_name: string | null;
|
||||
multipass_indexing: boolean;
|
||||
multilingual_expansion: string[];
|
||||
disable_rerank_for_streaming: boolean;
|
||||
provider_type: EmbeddingProvider | null;
|
||||
}
|
||||
|
||||
export interface RerankingModel {
|
||||
provider?: RerankerProvider;
|
||||
rerank_provider_type: RerankerProvider | null;
|
||||
modelName: string;
|
||||
displayName: string;
|
||||
description: string;
|
||||
@ -31,6 +47,7 @@ export interface RerankingModel {
|
||||
|
||||
export const rerankingModels: RerankingModel[] = [
|
||||
{
|
||||
rerank_provider_type: null,
|
||||
cloud: false,
|
||||
modelName: "mixedbread-ai/mxbai-rerank-xsmall-v1",
|
||||
displayName: "MixedBread XSmall",
|
||||
@ -38,6 +55,7 @@ export const rerankingModels: RerankingModel[] = [
|
||||
link: "https://huggingface.co/mixedbread-ai/mxbai-rerank-xsmall-v1",
|
||||
},
|
||||
{
|
||||
rerank_provider_type: null,
|
||||
cloud: false,
|
||||
modelName: "mixedbread-ai/mxbai-rerank-base-v1",
|
||||
displayName: "MixedBread Base",
|
||||
@ -45,6 +63,7 @@ export const rerankingModels: RerankingModel[] = [
|
||||
link: "https://huggingface.co/mixedbread-ai/mxbai-rerank-base-v1",
|
||||
},
|
||||
{
|
||||
rerank_provider_type: null,
|
||||
cloud: false,
|
||||
modelName: "mixedbread-ai/mxbai-rerank-large-v1",
|
||||
displayName: "MixedBread Large",
|
||||
@ -53,7 +72,7 @@ export const rerankingModels: RerankingModel[] = [
|
||||
},
|
||||
{
|
||||
cloud: true,
|
||||
provider: RerankerProvider.COHERE,
|
||||
rerank_provider_type: RerankerProvider.COHERE,
|
||||
modelName: "rerank-english-v3.0",
|
||||
displayName: "Cohere English",
|
||||
description: "High-performance English-focused reranking model.",
|
||||
@ -61,7 +80,7 @@ export const rerankingModels: RerankingModel[] = [
|
||||
},
|
||||
{
|
||||
cloud: true,
|
||||
provider: RerankerProvider.COHERE,
|
||||
rerank_provider_type: RerankerProvider.COHERE,
|
||||
modelName: "rerank-multilingual-v3.0",
|
||||
displayName: "Cohere Multilingual",
|
||||
description: "Powerful multilingual reranking model.",
|
||||
|
@ -5,14 +5,14 @@ import { EditingValue } from "@/components/credentials/EditingValue";
|
||||
import CredentialSubText from "@/components/credentials/CredentialFields";
|
||||
import { TrashIcon } from "@/components/icons/icons";
|
||||
import { FaPlus } from "react-icons/fa";
|
||||
import { AdvancedDetails, RerankingDetails } from "../interfaces";
|
||||
import { AdvancedSearchConfiguration, RerankingDetails } from "../interfaces";
|
||||
|
||||
interface AdvancedEmbeddingFormPageProps {
|
||||
updateAdvancedEmbeddingDetails: (
|
||||
key: keyof AdvancedDetails,
|
||||
key: keyof AdvancedSearchConfiguration,
|
||||
value: any
|
||||
) => void;
|
||||
advancedEmbeddingDetails: AdvancedDetails;
|
||||
advancedEmbeddingDetails: AdvancedSearchConfiguration;
|
||||
setRerankingDetails: Dispatch<SetStateAction<RerankingDetails>>;
|
||||
numRerank: number;
|
||||
}
|
||||
|
@ -8,7 +8,7 @@ import { Button, Card, Text } from "@tremor/react";
|
||||
import { ArrowLeft, ArrowRight, WarningCircle } from "@phosphor-icons/react";
|
||||
import {
|
||||
CloudEmbeddingModel,
|
||||
EmbeddingModelDescriptor,
|
||||
EmbeddingProvider,
|
||||
HostedEmbeddingModel,
|
||||
} from "../../../../components/embedding/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
@ -17,7 +17,8 @@ import useSWR, { mutate } from "swr";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import AdvancedEmbeddingFormPage from "./AdvancedEmbeddingFormPage";
|
||||
import {
|
||||
AdvancedDetails,
|
||||
AdvancedSearchConfiguration,
|
||||
RerankerProvider,
|
||||
RerankingDetails,
|
||||
SavedSearchSettings,
|
||||
} from "../interfaces";
|
||||
@ -30,21 +31,27 @@ export default function EmbeddingForm() {
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const [advancedEmbeddingDetails, setAdvancedEmbeddingDetails] =
|
||||
useState<AdvancedDetails>({
|
||||
disable_rerank_for_streaming: false,
|
||||
multilingual_expansion: [],
|
||||
useState<AdvancedSearchConfiguration>({
|
||||
model_name: "",
|
||||
model_dim: 0,
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
multipass_indexing: true,
|
||||
multilingual_expansion: [],
|
||||
disable_rerank_for_streaming: false,
|
||||
});
|
||||
|
||||
const [rerankingDetails, setRerankingDetails] = useState<RerankingDetails>({
|
||||
api_key: "",
|
||||
num_rerank: 0,
|
||||
provider_type: null,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
});
|
||||
|
||||
const updateAdvancedEmbeddingDetails = (
|
||||
key: keyof AdvancedDetails,
|
||||
key: keyof AdvancedSearchConfiguration,
|
||||
value: any
|
||||
) => {
|
||||
setAdvancedEmbeddingDetails((values) => ({ ...values, [key]: value }));
|
||||
@ -52,7 +59,7 @@ export default function EmbeddingForm() {
|
||||
|
||||
async function updateSearchSettings(searchSettings: SavedSearchSettings) {
|
||||
const response = await fetch(
|
||||
"/api/search-settings/update-search-settings",
|
||||
"/api/search-settings/update-inference-settings",
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
@ -80,7 +87,7 @@ export default function EmbeddingForm() {
|
||||
isLoading: isLoadingCurrentModel,
|
||||
error: currentEmbeddingModelError,
|
||||
} = useSWR<CloudEmbeddingModel | HostedEmbeddingModel | null>(
|
||||
"/api/search-settings/get-current-embedding-model",
|
||||
"/api/search-settings/get-current-search-settings",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
@ -91,7 +98,7 @@ export default function EmbeddingForm() {
|
||||
|
||||
const { data: searchSettings, isLoading: isLoadingSearchSettings } =
|
||||
useSWR<SavedSearchSettings | null>(
|
||||
"/api/search-settings/get-search-settings",
|
||||
"/api/search-settings/get-current-search-settings",
|
||||
errorHandlingFetcher,
|
||||
{ refreshInterval: 5000 } // 5 seconds
|
||||
);
|
||||
@ -99,31 +106,37 @@ export default function EmbeddingForm() {
|
||||
useEffect(() => {
|
||||
if (searchSettings) {
|
||||
setAdvancedEmbeddingDetails({
|
||||
model_name: searchSettings.model_name,
|
||||
model_dim: searchSettings.model_dim,
|
||||
normalize: searchSettings.normalize,
|
||||
query_prefix: searchSettings.query_prefix,
|
||||
passage_prefix: searchSettings.passage_prefix,
|
||||
index_name: searchSettings.index_name,
|
||||
multipass_indexing: searchSettings.multipass_indexing,
|
||||
multilingual_expansion: searchSettings.multilingual_expansion,
|
||||
disable_rerank_for_streaming:
|
||||
searchSettings.disable_rerank_for_streaming,
|
||||
multilingual_expansion: searchSettings.multilingual_expansion,
|
||||
multipass_indexing: searchSettings.multipass_indexing,
|
||||
});
|
||||
setRerankingDetails({
|
||||
api_key: searchSettings.api_key,
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
provider_type: searchSettings.provider_type,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
});
|
||||
}
|
||||
}, [searchSettings]);
|
||||
|
||||
const originalRerankingDetails = searchSettings
|
||||
const originalRerankingDetails: RerankingDetails = searchSettings
|
||||
? {
|
||||
api_key: searchSettings.api_key,
|
||||
num_rerank: searchSettings.num_rerank,
|
||||
provider_type: searchSettings.provider_type,
|
||||
rerank_provider_type: searchSettings.rerank_provider_type,
|
||||
rerank_model_name: searchSettings.rerank_model_name,
|
||||
}
|
||||
: {
|
||||
api_key: "",
|
||||
num_rerank: 0,
|
||||
provider_type: null,
|
||||
rerank_provider_type: null,
|
||||
rerank_model_name: "",
|
||||
};
|
||||
|
||||
@ -149,14 +162,17 @@ export default function EmbeddingForm() {
|
||||
let values: SavedSearchSettings = {
|
||||
...rerankingDetails,
|
||||
...advancedEmbeddingDetails,
|
||||
provider_type:
|
||||
selectedProvider.provider_type?.toLowerCase() as EmbeddingProvider | null,
|
||||
};
|
||||
|
||||
const response = await updateSearchSettings(values);
|
||||
if (response.ok) {
|
||||
setPopup({
|
||||
message: "Updated search settings succesffuly",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/search-settings/get-search-settings");
|
||||
mutate("/api/search-settings/get-current-search-settings");
|
||||
return true;
|
||||
} else {
|
||||
setPopup({ message: "Failed to update search settings", type: "error" });
|
||||
@ -165,29 +181,37 @@ export default function EmbeddingForm() {
|
||||
};
|
||||
|
||||
const onConfirm = async () => {
|
||||
let newModel: EmbeddingModelDescriptor;
|
||||
if (!selectedProvider) {
|
||||
return;
|
||||
}
|
||||
let newModel: SavedSearchSettings;
|
||||
|
||||
if ("provider_type" in selectedProvider) {
|
||||
// This is a CloudEmbeddingModel
|
||||
if (selectedProvider.provider_type != null) {
|
||||
// This is a cloud model
|
||||
newModel = {
|
||||
...advancedEmbeddingDetails,
|
||||
...selectedProvider,
|
||||
...rerankingDetails,
|
||||
model_name: selectedProvider.model_name,
|
||||
provider_type: selectedProvider.provider_type
|
||||
?.toLowerCase()
|
||||
.split(" ")[0],
|
||||
provider_type:
|
||||
(selectedProvider.provider_type
|
||||
?.toLowerCase()
|
||||
.split(" ")[0] as EmbeddingProvider) || null,
|
||||
};
|
||||
} else {
|
||||
// This is an EmbeddingModelDescriptor
|
||||
// This is a locally hosted model
|
||||
newModel = {
|
||||
...advancedEmbeddingDetails,
|
||||
...selectedProvider,
|
||||
...rerankingDetails,
|
||||
model_name: selectedProvider.model_name!,
|
||||
description: "",
|
||||
provider_type: null,
|
||||
};
|
||||
}
|
||||
newModel.index_name = null;
|
||||
|
||||
const response = await fetch(
|
||||
"/api/search-settings/set-new-embedding-model",
|
||||
"/api/search-settings/set-new-search-settings",
|
||||
{
|
||||
method: "POST",
|
||||
body: JSON.stringify(newModel),
|
||||
@ -201,7 +225,7 @@ export default function EmbeddingForm() {
|
||||
message: "Changed provider suceessfully. Redirecing to embedding page",
|
||||
type: "success",
|
||||
});
|
||||
mutate("/api/search-settings/get-secondary-embedding-model");
|
||||
mutate("/api/search-settings/get-secondary-search-settings");
|
||||
setTimeout(() => {
|
||||
window.open("/admin/configuration/search", "_self");
|
||||
}, 2000);
|
||||
@ -217,14 +241,14 @@ export default function EmbeddingForm() {
|
||||
searchSettings?.multipass_indexing !=
|
||||
advancedEmbeddingDetails.multipass_indexing;
|
||||
|
||||
const ReIndxingButton = () => {
|
||||
return (
|
||||
const ReIndexingButton = ({ needsReIndex }: { needsReIndex: boolean }) => {
|
||||
return needsReIndex ? (
|
||||
<div className="flex mx-auto gap-x-1 ml-auto items-center">
|
||||
<button
|
||||
className="enabled:cursor-pointer disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
|
||||
onClick={async () => {
|
||||
const updated = await updateSearch();
|
||||
if (updated) {
|
||||
const update = await updateSearch();
|
||||
if (update) {
|
||||
await onConfirm();
|
||||
}
|
||||
}}
|
||||
@ -251,6 +275,15 @@ export default function EmbeddingForm() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<button
|
||||
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
|
||||
onClick={async () => {
|
||||
updateSearch();
|
||||
}}
|
||||
>
|
||||
Update Search
|
||||
</button>
|
||||
);
|
||||
};
|
||||
|
||||
@ -361,18 +394,7 @@ export default function EmbeddingForm() {
|
||||
Previous
|
||||
</button>
|
||||
|
||||
{needsReIndex ? (
|
||||
<ReIndxingButton />
|
||||
) : (
|
||||
<button
|
||||
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50 disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
|
||||
onClick={async () => {
|
||||
updateSearch();
|
||||
}}
|
||||
>
|
||||
Update Search
|
||||
</button>
|
||||
)}
|
||||
<ReIndexingButton needsReIndex={needsReIndex} />
|
||||
|
||||
<div className="flex w-full justify-end">
|
||||
<button
|
||||
@ -410,20 +432,7 @@ export default function EmbeddingForm() {
|
||||
Previous
|
||||
</button>
|
||||
|
||||
{needsReIndex ? (
|
||||
<ReIndxingButton />
|
||||
) : (
|
||||
<button
|
||||
className="enabled:cursor-pointer ml-auto disabled:bg-accent/50
|
||||
disabled:cursor-not-allowed bg-accent flex mx-auto gap-x-1 items-center
|
||||
text-white py-2.5 px-3.5 text-sm font-regular rounded-sm"
|
||||
onClick={async () => {
|
||||
updateSearch();
|
||||
}}
|
||||
>
|
||||
Update Search
|
||||
</button>
|
||||
)}
|
||||
<ReIndexingButton needsReIndex={needsReIndex} />
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
@ -49,7 +49,7 @@ export default async function Home() {
|
||||
fetchSS("/manage/document-set"),
|
||||
fetchAssistantsSS(),
|
||||
fetchSS("/query/valid-tags"),
|
||||
fetchSS("/search-settings/get-embedding-models"),
|
||||
fetchSS("/search-settings/get-all-search-settings"),
|
||||
fetchSS("/query/user-searches"),
|
||||
];
|
||||
|
||||
|
@ -35,7 +35,13 @@ export function CustomModelForm({
|
||||
normalize: Yup.boolean().required(),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
onSubmit({ ...values, model_dim: parseInt(values.model_dim) });
|
||||
onSubmit({
|
||||
...values,
|
||||
model_dim: parseInt(values.model_dim),
|
||||
api_key: null,
|
||||
provider_type: null,
|
||||
index_name: null,
|
||||
});
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, setFieldValue }) => (
|
||||
|
@ -41,8 +41,10 @@ export interface EmbeddingModelDescriptor {
|
||||
normalize: boolean;
|
||||
query_prefix: string;
|
||||
passage_prefix: string;
|
||||
provider_type?: string | null;
|
||||
provider_type: string | null;
|
||||
description: string;
|
||||
api_key: string | null;
|
||||
index_name: string | null;
|
||||
}
|
||||
|
||||
export interface CloudEmbeddingModel extends EmbeddingModelDescriptor {
|
||||
@ -82,6 +84,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
link: "https://huggingface.co/nomic-ai/nomic-embed-text-v1",
|
||||
query_prefix: "search_query: ",
|
||||
passage_prefix: "search_document: ",
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/e5-base-v2",
|
||||
@ -92,6 +97,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
link: "https://huggingface.co/intfloat/e5-base-v2",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/e5-small-v2",
|
||||
@ -102,6 +110,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
link: "https://huggingface.co/intfloat/e5-small-v2",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-base",
|
||||
@ -112,6 +123,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
model_name: "intfloat/multilingual-e5-small",
|
||||
@ -122,6 +136,9 @@ export const AVAILABLE_MODELS: HostedEmbeddingModel[] = [
|
||||
link: "https://huggingface.co/intfloat/multilingual-e5-base",
|
||||
query_prefix: "query: ",
|
||||
passage_prefix: "passage: ",
|
||||
index_name: "",
|
||||
provider_type: null,
|
||||
api_key: null,
|
||||
},
|
||||
];
|
||||
|
||||
@ -150,6 +167,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
model_name: "embed-english-light-v3.0",
|
||||
@ -164,6 +183,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -190,6 +211,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
mtebScore: 64.6,
|
||||
maxContext: 8191,
|
||||
enabled: false,
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.OPENAI,
|
||||
@ -204,6 +227,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
enabled: false,
|
||||
mtebScore: 62.3,
|
||||
maxContext: 8191,
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -231,6 +256,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.GOOGLE,
|
||||
@ -244,6 +271,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
@ -270,6 +299,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
{
|
||||
provider_type: EmbeddingProvider.VOYAGE,
|
||||
@ -284,6 +315,8 @@ export const AVAILABLE_CLOUD_PROVIDERS: CloudEmbeddingProvider[] = [
|
||||
normalize: false,
|
||||
query_prefix: "",
|
||||
passage_prefix: "",
|
||||
index_name: "",
|
||||
api_key: null,
|
||||
},
|
||||
],
|
||||
},
|
||||
|
Loading…
x
Reference in New Issue
Block a user