From 0da736bed945da61e6e0d4dd6f653be29e2c58bf Mon Sep 17 00:00:00 2001 From: pablodanswer Date: Sat, 5 Oct 2024 21:08:35 -0700 Subject: [PATCH] Tenant provisioning in the dataplane (#2694) * add tenant provisioning to data plane * minor typing update * ensure tenant router included * proper auth check * update disabling logic * validated basic provisioning * use new kv store --- backend/danswer/auth/users.py | 28 ++ backend/danswer/chat/load_yamls.py | 227 ++++++------- backend/danswer/configs/app_configs.py | 4 + backend/danswer/db/engine.py | 3 +- backend/danswer/key_value_store/store.py | 9 +- backend/danswer/main.py | 300 +---------------- backend/danswer/server/auth_check.py | 2 + backend/danswer/setup.py | 303 ++++++++++++++++++ backend/ee/danswer/main.py | 3 + backend/ee/danswer/server/tenants/__init__.py | 0 backend/ee/danswer/server/tenants/access.py | 0 backend/ee/danswer/server/tenants/api.py | 46 +++ backend/ee/danswer/server/tenants/models.py | 6 + .../ee/danswer/server/tenants/provisioning.py | 63 ++++ .../tests/integration/common_utils/reset.py | 4 +- web/src/components/search/SearchSection.tsx | 40 ++- 16 files changed, 615 insertions(+), 423 deletions(-) create mode 100644 backend/danswer/setup.py create mode 100644 backend/ee/danswer/server/tenants/__init__.py create mode 100644 backend/ee/danswer/server/tenants/access.py create mode 100644 backend/ee/danswer/server/tenants/api.py create mode 100644 backend/ee/danswer/server/tenants/models.py create mode 100644 backend/ee/danswer/server/tenants/provisioning.py diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index a583a9323..81607aab8 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -8,6 +8,7 @@ from email.mime.text import MIMEText from typing import Optional from typing import Tuple +import jwt from email_validator import EmailNotValidError from email_validator import validate_email from fastapi import APIRouter @@ -37,8 +38,10 @@ from danswer.auth.schemas import UserCreate from danswer.auth.schemas import UserRole from danswer.auth.schemas import UserUpdate from danswer.configs.app_configs import AUTH_TYPE +from danswer.configs.app_configs import DATA_PLANE_SECRET from danswer.configs.app_configs import DISABLE_AUTH from danswer.configs.app_configs import EMAIL_FROM +from danswer.configs.app_configs import EXPECTED_API_KEY from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from danswer.configs.app_configs import SMTP_PASS @@ -504,3 +507,28 @@ async def current_admin_user(user: User | None = Depends(current_user)) -> User def get_default_admin_user_emails_() -> list[str]: # No default seeding available for Danswer MIT return [] + + +async def control_plane_dep(request: Request) -> None: + api_key = request.headers.get("X-API-KEY") + if api_key != EXPECTED_API_KEY: + logger.warning("Invalid API key") + raise HTTPException(status_code=401, detail="Invalid API key") + + auth_header = request.headers.get("Authorization") + if not auth_header or not auth_header.startswith("Bearer "): + logger.warning("Invalid authorization header") + raise HTTPException(status_code=401, detail="Invalid authorization header") + + token = auth_header.split(" ")[1] + try: + payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"]) + if payload.get("scope") != "tenant:create": + logger.warning("Insufficient permissions") + raise HTTPException(status_code=403, detail="Insufficient permissions") + except jwt.ExpiredSignatureError: + logger.warning("Token has expired") + raise HTTPException(status_code=401, detail="Token has expired") + except jwt.InvalidTokenError: + logger.warning("Invalid token") + raise HTTPException(status_code=401, detail="Invalid token") diff --git a/backend/danswer/chat/load_yamls.py b/backend/danswer/chat/load_yamls.py index 8d0fd34d8..e8a19c158 100644 --- a/backend/danswer/chat/load_yamls.py +++ b/backend/danswer/chat/load_yamls.py @@ -6,7 +6,6 @@ from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT from danswer.configs.chat_configs import PERSONAS_YAML from danswer.configs.chat_configs import PROMPTS_YAML from danswer.db.document_set import get_or_create_document_set_by_name -from danswer.db.engine import get_sqlalchemy_engine from danswer.db.input_prompt import insert_input_prompt_if_not_exists from danswer.db.models import DocumentSet as DocumentSetDBModel from danswer.db.models import Persona @@ -18,30 +17,32 @@ from danswer.db.persona import upsert_prompt from danswer.search.enums import RecencyBiasSetting -def load_prompts_from_yaml(prompts_yaml: str = PROMPTS_YAML) -> None: +def load_prompts_from_yaml( + db_session: Session, prompts_yaml: str = PROMPTS_YAML +) -> None: with open(prompts_yaml, "r") as file: data = yaml.safe_load(file) all_prompts = data.get("prompts", []) - with Session(get_sqlalchemy_engine()) as db_session: - for prompt in all_prompts: - upsert_prompt( - user=None, - prompt_id=prompt.get("id"), - name=prompt["name"], - description=prompt["description"].strip(), - system_prompt=prompt["system"].strip(), - task_prompt=prompt["task"].strip(), - include_citations=prompt["include_citations"], - datetime_aware=prompt.get("datetime_aware", True), - default_prompt=True, - personas=None, - db_session=db_session, - commit=True, - ) + for prompt in all_prompts: + upsert_prompt( + user=None, + prompt_id=prompt.get("id"), + name=prompt["name"], + description=prompt["description"].strip(), + system_prompt=prompt["system"].strip(), + task_prompt=prompt["task"].strip(), + include_citations=prompt["include_citations"], + datetime_aware=prompt.get("datetime_aware", True), + default_prompt=True, + personas=None, + db_session=db_session, + commit=True, + ) def load_personas_from_yaml( + db_session: Session, personas_yaml: str = PERSONAS_YAML, default_chunks: float = MAX_CHUNKS_FED_TO_CHAT, ) -> None: @@ -49,117 +50,117 @@ def load_personas_from_yaml( data = yaml.safe_load(file) all_personas = data.get("personas", []) - with Session(get_sqlalchemy_engine()) as db_session: - for persona in all_personas: - doc_set_names = persona["document_sets"] - doc_sets: list[DocumentSetDBModel] = [ - get_or_create_document_set_by_name(db_session, name) - for name in doc_set_names + for persona in all_personas: + doc_set_names = persona["document_sets"] + doc_sets: list[DocumentSetDBModel] = [ + get_or_create_document_set_by_name(db_session, name) + for name in doc_set_names + ] + + # Assume if user hasn't set any document sets for the persona, the user may want + # to later attach document sets to the persona manually, therefore, don't overwrite/reset + # the document sets for the persona + doc_set_ids: list[int] | None = None + if doc_sets: + doc_set_ids = [doc_set.id for doc_set in doc_sets] + else: + doc_set_ids = None + + prompt_ids: list[int] | None = None + prompt_set_names = persona["prompts"] + if prompt_set_names: + prompts: list[PromptDBModel | None] = [ + get_prompt_by_name(prompt_name, user=None, db_session=db_session) + for prompt_name in prompt_set_names ] + if any([prompt is None for prompt in prompts]): + raise ValueError("Invalid Persona configs, not all prompts exist") - # Assume if user hasn't set any document sets for the persona, the user may want - # to later attach document sets to the persona manually, therefore, don't overwrite/reset - # the document sets for the persona - doc_set_ids: list[int] | None = None - if doc_sets: - doc_set_ids = [doc_set.id for doc_set in doc_sets] - else: - doc_set_ids = None + if prompts: + prompt_ids = [prompt.id for prompt in prompts if prompt is not None] - prompt_ids: list[int] | None = None - prompt_set_names = persona["prompts"] - if prompt_set_names: - prompts: list[PromptDBModel | None] = [ - get_prompt_by_name(prompt_name, user=None, db_session=db_session) - for prompt_name in prompt_set_names - ] - if any([prompt is None for prompt in prompts]): - raise ValueError("Invalid Persona configs, not all prompts exist") - - if prompts: - prompt_ids = [prompt.id for prompt in prompts if prompt is not None] - - p_id = persona.get("id") - tool_ids = [] - if persona.get("image_generation"): - image_gen_tool = ( - db_session.query(ToolDBModel) - .filter(ToolDBModel.name == "ImageGenerationTool") - .first() - ) - if image_gen_tool: - tool_ids.append(image_gen_tool.id) - - llm_model_provider_override = persona.get("llm_model_provider_override") - llm_model_version_override = persona.get("llm_model_version_override") - - # Set specific overrides for image generation persona - if persona.get("image_generation"): - llm_model_version_override = "gpt-4o" - - existing_persona = ( - db_session.query(Persona) - .filter(Persona.name == persona["name"]) + p_id = persona.get("id") + tool_ids = [] + if persona.get("image_generation"): + image_gen_tool = ( + db_session.query(ToolDBModel) + .filter(ToolDBModel.name == "ImageGenerationTool") .first() ) + if image_gen_tool: + tool_ids.append(image_gen_tool.id) - upsert_persona( - user=None, - persona_id=(-1 * p_id) if p_id is not None else None, - name=persona["name"], - description=persona["description"], - num_chunks=persona.get("num_chunks") - if persona.get("num_chunks") is not None - else default_chunks, - llm_relevance_filter=persona.get("llm_relevance_filter"), - starter_messages=persona.get("starter_messages"), - llm_filter_extraction=persona.get("llm_filter_extraction"), - icon_shape=persona.get("icon_shape"), - icon_color=persona.get("icon_color"), - llm_model_provider_override=llm_model_provider_override, - llm_model_version_override=llm_model_version_override, - recency_bias=RecencyBiasSetting(persona["recency_bias"]), - prompt_ids=prompt_ids, - document_set_ids=doc_set_ids, - tool_ids=tool_ids, - builtin_persona=True, - is_public=True, - display_priority=existing_persona.display_priority - if existing_persona is not None - else persona.get("display_priority"), - is_visible=existing_persona.is_visible - if existing_persona is not None - else persona.get("is_visible"), - db_session=db_session, - ) + llm_model_provider_override = persona.get("llm_model_provider_override") + llm_model_version_override = persona.get("llm_model_version_override") + + # Set specific overrides for image generation persona + if persona.get("image_generation"): + llm_model_version_override = "gpt-4o" + + existing_persona = ( + db_session.query(Persona).filter(Persona.name == persona["name"]).first() + ) + + upsert_persona( + user=None, + persona_id=(-1 * p_id) if p_id is not None else None, + name=persona["name"], + description=persona["description"], + num_chunks=persona.get("num_chunks") + if persona.get("num_chunks") is not None + else default_chunks, + llm_relevance_filter=persona.get("llm_relevance_filter"), + starter_messages=persona.get("starter_messages"), + llm_filter_extraction=persona.get("llm_filter_extraction"), + icon_shape=persona.get("icon_shape"), + icon_color=persona.get("icon_color"), + llm_model_provider_override=llm_model_provider_override, + llm_model_version_override=llm_model_version_override, + recency_bias=RecencyBiasSetting(persona["recency_bias"]), + prompt_ids=prompt_ids, + document_set_ids=doc_set_ids, + tool_ids=tool_ids, + builtin_persona=True, + is_public=True, + display_priority=existing_persona.display_priority + if existing_persona is not None + else persona.get("display_priority"), + is_visible=existing_persona.is_visible + if existing_persona is not None + else persona.get("is_visible"), + db_session=db_session, + ) -def load_input_prompts_from_yaml(input_prompts_yaml: str = INPUT_PROMPT_YAML) -> None: +def load_input_prompts_from_yaml( + db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML +) -> None: with open(input_prompts_yaml, "r") as file: data = yaml.safe_load(file) all_input_prompts = data.get("input_prompts", []) - with Session(get_sqlalchemy_engine()) as db_session: - for input_prompt in all_input_prompts: - # If these prompts are deleted (which is a hard delete in the DB), on server startup - # they will be recreated, but the user can always just deactivate them, just a light inconvenience - insert_input_prompt_if_not_exists( - user=None, - input_prompt_id=input_prompt.get("id"), - prompt=input_prompt["prompt"], - content=input_prompt["content"], - is_public=input_prompt["is_public"], - active=input_prompt.get("active", True), - db_session=db_session, - commit=True, - ) + for input_prompt in all_input_prompts: + # If these prompts are deleted (which is a hard delete in the DB), on server startup + # they will be recreated, but the user can always just deactivate them, just a light inconvenience + + insert_input_prompt_if_not_exists( + user=None, + input_prompt_id=input_prompt.get("id"), + prompt=input_prompt["prompt"], + content=input_prompt["content"], + is_public=input_prompt["is_public"], + active=input_prompt.get("active", True), + db_session=db_session, + commit=True, + ) def load_chat_yamls( + db_session: Session, prompt_yaml: str = PROMPTS_YAML, personas_yaml: str = PERSONAS_YAML, input_prompts_yaml: str = INPUT_PROMPT_YAML, ) -> None: - load_prompts_from_yaml(prompt_yaml) - load_personas_from_yaml(personas_yaml) - load_input_prompts_from_yaml(input_prompts_yaml) + load_prompts_from_yaml(db_session, prompt_yaml) + load_personas_from_yaml(db_session, personas_yaml) + load_input_prompts_from_yaml(db_session, input_prompts_yaml) diff --git a/backend/danswer/configs/app_configs.py b/backend/danswer/configs/app_configs.py index 0ac3d6e76..4559fed6b 100644 --- a/backend/danswer/configs/app_configs.py +++ b/backend/danswer/configs/app_configs.py @@ -416,3 +416,7 @@ ENTERPRISE_EDITION_ENABLED = ( MULTI_TENANT = os.environ.get("MULTI_TENANT", "").lower() == "true" SECRET_JWT_KEY = os.environ.get("SECRET_JWT_KEY", "") + + +DATA_PLANE_SECRET = os.environ.get("DATA_PLANE_SECRET", "") +EXPECTED_API_KEY = os.environ.get("EXPECTED_API_KEY", "") diff --git a/backend/danswer/db/engine.py b/backend/danswer/db/engine.py index dcae5ae1f..af7aad236 100644 --- a/backend/danswer/db/engine.py +++ b/backend/danswer/db/engine.py @@ -117,7 +117,7 @@ def get_db_current_time(db_session: Session) -> datetime: # Regular expression to validate schema names to prevent SQL injection -SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") +SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$") def is_valid_schema_name(name: str) -> bool: @@ -281,6 +281,7 @@ def get_session_with_tenant(tenant_id: str | None = None) -> Session: tenant_id = current_tenant_id.get() if not is_valid_schema_name(tenant_id): + logger.error(f"Invalid tenant ID: {tenant_id}") raise Exception("Invalid tenant ID") engine = SqlEngine.get_engine() diff --git a/backend/danswer/key_value_store/store.py b/backend/danswer/key_value_store/store.py index 450056c40..4306743f8 100644 --- a/backend/danswer/key_value_store/store.py +++ b/backend/danswer/key_value_store/store.py @@ -5,7 +5,7 @@ from typing import cast from sqlalchemy.orm import Session -from danswer.db.engine import get_session_factory +from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import KVStore from danswer.key_value_store.interface import JSON_ro from danswer.key_value_store.interface import KeyValueStore @@ -26,12 +26,9 @@ class PgRedisKVStore(KeyValueStore): @contextmanager def get_session(self) -> Iterator[Session]: - factory = get_session_factory() - session: Session = factory() - try: + engine = get_sqlalchemy_engine() + with Session(engine, expire_on_commit=False) as session: yield session - finally: - session.close() def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: # Not encrypted in Redis, but encrypted in Postgres diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 7727f25cc..b9231a9c5 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -1,4 +1,3 @@ -import time import traceback from collections.abc import AsyncGenerator from contextlib import asynccontextmanager @@ -23,13 +22,11 @@ from danswer.auth.schemas import UserRead from danswer.auth.schemas import UserUpdate from danswer.auth.users import auth_backend from danswer.auth.users import fastapi_users -from danswer.chat.load_yamls import load_chat_yamls from danswer.configs.app_configs import APP_API_PREFIX from danswer.configs.app_configs import APP_HOST from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import AUTH_TYPE from danswer.configs.app_configs import DISABLE_GENERATIVE_AI -from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import LOG_ENDPOINT_LATENCY from danswer.configs.app_configs import OAUTH_CLIENT_ID from danswer.configs.app_configs import OAUTH_CLIENT_SECRET @@ -38,42 +35,9 @@ from danswer.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE from danswer.configs.app_configs import USER_AUTH_SECRET from danswer.configs.app_configs import WEB_DOMAIN from danswer.configs.constants import AuthType -from danswer.configs.constants import KV_REINDEX_KEY -from danswer.configs.constants import KV_SEARCH_SETTINGS from danswer.configs.constants import POSTGRES_WEB_APP_NAME -from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION -from danswer.configs.model_configs import GEN_AI_API_KEY -from danswer.configs.model_configs import GEN_AI_MODEL_VERSION -from danswer.db.connector import check_connectors_exist -from danswer.db.connector import create_initial_default_connector -from danswer.db.connector_credential_pair import associate_default_cc_pair -from danswer.db.connector_credential_pair import get_connector_credential_pairs -from danswer.db.connector_credential_pair import resync_cc_pair -from danswer.db.credentials import create_initial_public_credential -from danswer.db.document import check_docs_exist from danswer.db.engine import SqlEngine from danswer.db.engine import warm_up_connections -from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import expire_index_attempts -from danswer.db.llm import fetch_default_provider -from danswer.db.llm import update_default_provider -from danswer.db.llm import upsert_llm_provider -from danswer.db.persona import delete_old_default_personas -from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings -from danswer.db.search_settings import update_current_search_settings -from danswer.db.search_settings import update_secondary_search_settings -from danswer.db.swap_index import check_index_swap -from danswer.document_index.factory import get_default_document_index -from danswer.document_index.interfaces import DocumentIndex -from danswer.indexing.models import IndexingSetting -from danswer.key_value_store.factory import get_kv_store -from danswer.key_value_store.interface import KvKeyNotFoundError -from danswer.natural_language_processing.search_nlp_models import EmbeddingModel -from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder -from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder -from danswer.search.models import SavedSearchSettings -from danswer.search.retrieval.search_runner import download_nltk_data from danswer.server.auth_check import check_router_auth from danswer.server.danswer_api.ingestion import router as danswer_api_router from danswer.server.documents.cc_pair import router as cc_pair_router @@ -99,7 +63,6 @@ from danswer.server.manage.embedding.api import basic_router as embedding_router from danswer.server.manage.get_state import router as state_router from danswer.server.manage.llm.api import admin_router as llm_admin_router from danswer.server.manage.llm.api import basic_router as llm_router -from danswer.server.manage.llm.models import LLMProviderUpsertRequest from danswer.server.manage.search_settings import router as search_settings_router from danswer.server.manage.slack_bot import router as slack_bot_management_router from danswer.server.manage.users import router as user_router @@ -111,15 +74,10 @@ from danswer.server.query_and_chat.query_backend import ( from danswer.server.query_and_chat.query_backend import basic_router as query_router from danswer.server.settings.api import admin_router as settings_admin_router from danswer.server.settings.api import basic_router as settings_router -from danswer.server.settings.store import load_settings -from danswer.server.settings.store import store_settings from danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) -from danswer.tools.built_in_tools import auto_add_search_tool_to_personas -from danswer.tools.built_in_tools import load_builtin_tools -from danswer.tools.built_in_tools import refresh_built_in_tools_cache -from danswer.utils.gpu_utils import gpu_status_request +from danswer.setup import setup_danswer from danswer.utils.logger import setup_logger from danswer.utils.telemetry import get_or_generate_uuid from danswer.utils.telemetry import optional_telemetry @@ -128,8 +86,6 @@ from danswer.utils.variable_functionality import fetch_versioned_implementation from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import CORS_ALLOWED_ORIGIN -from shared_configs.configs import MODEL_SERVER_HOST -from shared_configs.configs import MODEL_SERVER_PORT logger = setup_logger() @@ -182,176 +138,6 @@ def include_router_with_global_prefix_prepended( application.include_router(router, **final_kwargs) -def setup_postgres(db_session: Session) -> None: - logger.notice("Verifying default connector/credential exist.") - create_initial_public_credential(db_session) - create_initial_default_connector(db_session) - associate_default_cc_pair(db_session) - - logger.notice("Loading default Prompts and Personas") - delete_old_default_personas(db_session) - load_chat_yamls() - - logger.notice("Loading built-in tools") - load_builtin_tools(db_session) - refresh_built_in_tools_cache(db_session) - auto_add_search_tool_to_personas(db_session) - - if GEN_AI_API_KEY and fetch_default_provider(db_session) is None: - # Only for dev flows - logger.notice("Setting up default OpenAI LLM for dev.") - llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini" - fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini" - model_req = LLMProviderUpsertRequest( - name="DevEnvPresetOpenAI", - provider="openai", - api_key=GEN_AI_API_KEY, - api_base=None, - api_version=None, - custom_config=None, - default_model_name=llm_model, - fast_default_model_name=fast_model, - is_public=True, - groups=[], - display_model_names=[llm_model, fast_model], - model_names=[llm_model, fast_model], - ) - new_llm_provider = upsert_llm_provider( - llm_provider=model_req, db_session=db_session - ) - update_default_provider(provider_id=new_llm_provider.id, db_session=db_session) - - -def update_default_multipass_indexing(db_session: Session) -> None: - docs_exist = check_docs_exist(db_session) - connectors_exist = check_connectors_exist(db_session) - logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}") - - if not docs_exist and not connectors_exist: - logger.info( - "No existing docs or connectors found. Checking GPU availability for multipass indexing." - ) - gpu_available = gpu_status_request() - logger.info(f"GPU available: {gpu_available}") - - current_settings = get_current_search_settings(db_session) - - logger.notice(f"Updating multipass indexing setting to: {gpu_available}") - updated_settings = SavedSearchSettings.from_db_model(current_settings) - # Enable multipass indexing if GPU is available or if using a cloud provider - updated_settings.multipass_indexing = ( - gpu_available or current_settings.cloud_provider is not None - ) - update_current_search_settings(db_session, updated_settings) - - # Update settings with GPU availability - settings = load_settings() - settings.gpu_enabled = gpu_available - store_settings(settings) - logger.notice(f"Updated settings with GPU availability: {gpu_available}") - - else: - logger.debug( - "Existing docs or connectors found. Skipping multipass indexing update." - ) - - -def translate_saved_search_settings(db_session: Session) -> None: - kv_store = get_kv_store() - - try: - search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS) - if isinstance(search_settings_dict, dict): - # Update current search settings - current_settings = get_current_search_settings(db_session) - - # Update non-preserved fields - if current_settings: - current_settings_dict = SavedSearchSettings.from_db_model( - current_settings - ).dict() - - new_current_settings = SavedSearchSettings( - **{**current_settings_dict, **search_settings_dict} - ) - update_current_search_settings(db_session, new_current_settings) - - # Update secondary search settings - secondary_settings = get_secondary_search_settings(db_session) - if secondary_settings: - secondary_settings_dict = SavedSearchSettings.from_db_model( - secondary_settings - ).dict() - - new_secondary_settings = SavedSearchSettings( - **{**secondary_settings_dict, **search_settings_dict} - ) - update_secondary_search_settings( - db_session, - new_secondary_settings, - ) - # Delete the KV store entry after successful update - kv_store.delete(KV_SEARCH_SETTINGS) - logger.notice("Search settings updated and KV store entry deleted.") - else: - logger.notice("KV store search settings is empty.") - except KvKeyNotFoundError: - logger.notice("No search config found in KV store.") - - -def mark_reindex_flag(db_session: Session) -> None: - kv_store = get_kv_store() - try: - value = kv_store.load(KV_REINDEX_KEY) - logger.debug(f"Re-indexing flag has value {value}") - return - except KvKeyNotFoundError: - # Only need to update the flag if it hasn't been set - pass - - # If their first deployment is after the changes, it will - # enable this when the other changes go in, need to avoid - # this being set to False, then the user indexes things on the old version - docs_exist = check_docs_exist(db_session) - connectors_exist = check_connectors_exist(db_session) - if docs_exist or connectors_exist: - kv_store.store(KV_REINDEX_KEY, True) - else: - kv_store.store(KV_REINDEX_KEY, False) - - -def setup_vespa( - document_index: DocumentIndex, - index_setting: IndexingSetting, - secondary_index_setting: IndexingSetting | None, -) -> bool: - # Vespa startup is a bit slow, so give it a few seconds - WAIT_SECONDS = 5 - VESPA_ATTEMPTS = 5 - for x in range(VESPA_ATTEMPTS): - try: - logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...") - document_index.ensure_indices_exist( - index_embedding_dim=index_setting.model_dim, - secondary_index_embedding_dim=secondary_index_setting.model_dim - if secondary_index_setting - else None, - ) - - logger.notice("Vespa setup complete.") - return True - except Exception: - logger.notice( - f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds." - ) - time.sleep(WAIT_SECONDS) - - logger.error( - f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})" - ) - return False - - @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME) @@ -380,89 +166,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: get_or_generate_uuid() with Session(engine) as db_session: - check_index_swap(db_session=db_session) - search_settings = get_current_search_settings(db_session) - secondary_search_settings = get_secondary_search_settings(db_session) - - # Break bad state for thrashing indexes - if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP: - expire_index_attempts( - search_settings_id=search_settings.id, db_session=db_session - ) - - for cc_pair in get_connector_credential_pairs(db_session): - resync_cc_pair(cc_pair, db_session=db_session) - - # Expire all old embedding models indexing attempts, technically redundant - cancel_indexing_attempts_past_model(db_session) - - logger.notice(f'Using Embedding model: "{search_settings.model_name}"') - if search_settings.query_prefix or search_settings.passage_prefix: - logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"') - logger.notice( - f'Passage embedding prefix: "{search_settings.passage_prefix}"' - ) - - if search_settings: - if not search_settings.disable_rerank_for_streaming: - logger.notice("Reranking is enabled.") - - if search_settings.multilingual_expansion: - logger.notice( - f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." - ) - if ( - search_settings.rerank_model_name - and not search_settings.provider_type - and not search_settings.rerank_provider_type - ): - warm_up_cross_encoder(search_settings.rerank_model_name) - - logger.notice("Verifying query preprocessing (NLTK) data is downloaded") - download_nltk_data() - - # setup Postgres with default credential, llm providers, etc. - setup_postgres(db_session) - - translate_saved_search_settings(db_session) - - # Does the user need to trigger a reindexing to bring the document index - # into a good state, marked in the kv store - mark_reindex_flag(db_session) - - # ensure Vespa is setup correctly - logger.notice("Verifying Document Index(s) is/are available.") - document_index = get_default_document_index( - primary_index_name=search_settings.index_name, - secondary_index_name=secondary_search_settings.index_name - if secondary_search_settings - else None, - ) - - success = setup_vespa( - document_index, - IndexingSetting.from_db_model(search_settings), - IndexingSetting.from_db_model(secondary_search_settings) - if secondary_search_settings - else None, - ) - if not success: - raise RuntimeError( - "Could not connect to Vespa within the specified timeout." - ) - - logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") - if search_settings.provider_type is None: - warm_up_bi_encoder( - embedding_model=EmbeddingModel.from_db_model( - search_settings=search_settings, - server_host=MODEL_SERVER_HOST, - server_port=MODEL_SERVER_PORT, - ), - ) - - # update multipass indexing setting based on GPU availability - update_default_multipass_indexing(db_session) + setup_danswer(db_session) optional_telemetry(record_type=RecordType.VERSION, data={"version": __version__}) yield diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py index 8a35a560a..c79b9ad09 100644 --- a/backend/danswer/server/auth_check.py +++ b/backend/danswer/server/auth_check.py @@ -4,6 +4,7 @@ from fastapi import FastAPI from fastapi.dependencies.models import Dependant from starlette.routing import BaseRoute +from danswer.auth.users import control_plane_dep from danswer.auth.users import current_admin_user from danswer.auth.users import current_curator_or_admin_user from danswer.auth.users import current_user @@ -98,6 +99,7 @@ def check_router_auth( or depends_fn == current_curator_or_admin_user or depends_fn == api_key_dep or depends_fn == current_user_with_expired_token + or depends_fn == control_plane_dep ): found_auth = True break diff --git a/backend/danswer/setup.py b/backend/danswer/setup.py new file mode 100644 index 000000000..2baeda4a8 --- /dev/null +++ b/backend/danswer/setup.py @@ -0,0 +1,303 @@ +import time + +from sqlalchemy.orm import Session + +from danswer.chat.load_yamls import load_chat_yamls +from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP +from danswer.configs.constants import KV_REINDEX_KEY +from danswer.configs.constants import KV_SEARCH_SETTINGS +from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.db.connector import check_connectors_exist +from danswer.db.connector import create_initial_default_connector +from danswer.db.connector_credential_pair import associate_default_cc_pair +from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.connector_credential_pair import resync_cc_pair +from danswer.db.credentials import create_initial_public_credential +from danswer.db.document import check_docs_exist +from danswer.db.index_attempt import cancel_indexing_attempts_past_model +from danswer.db.index_attempt import expire_index_attempts +from danswer.db.llm import fetch_default_provider +from danswer.db.llm import update_default_provider +from danswer.db.llm import upsert_llm_provider +from danswer.db.persona import delete_old_default_personas +from danswer.db.search_settings import get_current_search_settings +from danswer.db.search_settings import get_secondary_search_settings +from danswer.db.search_settings import update_current_search_settings +from danswer.db.search_settings import update_secondary_search_settings +from danswer.db.swap_index import check_index_swap +from danswer.document_index.factory import get_default_document_index +from danswer.document_index.interfaces import DocumentIndex +from danswer.indexing.models import IndexingSetting +from danswer.key_value_store.factory import get_kv_store +from danswer.key_value_store.interface import KvKeyNotFoundError +from danswer.natural_language_processing.search_nlp_models import EmbeddingModel +from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder +from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder +from danswer.search.models import SavedSearchSettings +from danswer.search.retrieval.search_runner import download_nltk_data +from danswer.server.manage.llm.models import LLMProviderUpsertRequest +from danswer.server.settings.store import load_settings +from danswer.server.settings.store import store_settings +from danswer.tools.built_in_tools import auto_add_search_tool_to_personas +from danswer.tools.built_in_tools import load_builtin_tools +from danswer.tools.built_in_tools import refresh_built_in_tools_cache +from danswer.utils.gpu_utils import gpu_status_request +from danswer.utils.logger import setup_logger +from shared_configs.configs import MODEL_SERVER_HOST +from shared_configs.configs import MODEL_SERVER_PORT + +logger = setup_logger() + + +def setup_danswer(db_session: Session) -> None: + check_index_swap(db_session=db_session) + search_settings = get_current_search_settings(db_session) + secondary_search_settings = get_secondary_search_settings(db_session) + + # Break bad state for thrashing indexes + if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP: + expire_index_attempts( + search_settings_id=search_settings.id, db_session=db_session + ) + + for cc_pair in get_connector_credential_pairs(db_session): + resync_cc_pair(cc_pair, db_session=db_session) + + # Expire all old embedding models indexing attempts, technically redundant + cancel_indexing_attempts_past_model(db_session) + + logger.notice(f'Using Embedding model: "{search_settings.model_name}"') + if search_settings.query_prefix or search_settings.passage_prefix: + logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"') + logger.notice(f'Passage embedding prefix: "{search_settings.passage_prefix}"') + + if search_settings: + if not search_settings.disable_rerank_for_streaming: + logger.notice("Reranking is enabled.") + + if search_settings.multilingual_expansion: + logger.notice( + f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." + ) + if ( + search_settings.rerank_model_name + and not search_settings.provider_type + and not search_settings.rerank_provider_type + ): + warm_up_cross_encoder(search_settings.rerank_model_name) + + logger.notice("Verifying query preprocessing (NLTK) data is downloaded") + download_nltk_data() + + # setup Postgres with default credential, llm providers, etc. + setup_postgres(db_session) + + translate_saved_search_settings(db_session) + + # Does the user need to trigger a reindexing to bring the document index + # into a good state, marked in the kv store + mark_reindex_flag(db_session) + + # ensure Vespa is setup correctly + logger.notice("Verifying Document Index(s) is/are available.") + document_index = get_default_document_index( + primary_index_name=search_settings.index_name, + secondary_index_name=secondary_search_settings.index_name + if secondary_search_settings + else None, + ) + + success = setup_vespa( + document_index, + IndexingSetting.from_db_model(search_settings), + IndexingSetting.from_db_model(secondary_search_settings) + if secondary_search_settings + else None, + ) + if not success: + raise RuntimeError("Could not connect to Vespa within the specified timeout.") + + logger.notice(f"Model Server: http://{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}") + if search_settings.provider_type is None: + warm_up_bi_encoder( + embedding_model=EmbeddingModel.from_db_model( + search_settings=search_settings, + server_host=MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ), + ) + + # update multipass indexing setting based on GPU availability + update_default_multipass_indexing(db_session) + + +def translate_saved_search_settings(db_session: Session) -> None: + kv_store = get_kv_store() + + try: + search_settings_dict = kv_store.load(KV_SEARCH_SETTINGS) + if isinstance(search_settings_dict, dict): + # Update current search settings + current_settings = get_current_search_settings(db_session) + + # Update non-preserved fields + if current_settings: + current_settings_dict = SavedSearchSettings.from_db_model( + current_settings + ).dict() + + new_current_settings = SavedSearchSettings( + **{**current_settings_dict, **search_settings_dict} + ) + update_current_search_settings(db_session, new_current_settings) + + # Update secondary search settings + secondary_settings = get_secondary_search_settings(db_session) + if secondary_settings: + secondary_settings_dict = SavedSearchSettings.from_db_model( + secondary_settings + ).dict() + + new_secondary_settings = SavedSearchSettings( + **{**secondary_settings_dict, **search_settings_dict} + ) + update_secondary_search_settings( + db_session, + new_secondary_settings, + ) + # Delete the KV store entry after successful update + kv_store.delete(KV_SEARCH_SETTINGS) + logger.notice("Search settings updated and KV store entry deleted.") + else: + logger.notice("KV store search settings is empty.") + except KvKeyNotFoundError: + logger.notice("No search config found in KV store.") + + +def mark_reindex_flag(db_session: Session) -> None: + kv_store = get_kv_store() + try: + value = kv_store.load(KV_REINDEX_KEY) + logger.debug(f"Re-indexing flag has value {value}") + return + except KvKeyNotFoundError: + # Only need to update the flag if it hasn't been set + pass + + # If their first deployment is after the changes, it will + # enable this when the other changes go in, need to avoid + # this being set to False, then the user indexes things on the old version + docs_exist = check_docs_exist(db_session) + connectors_exist = check_connectors_exist(db_session) + if docs_exist or connectors_exist: + kv_store.store(KV_REINDEX_KEY, True) + else: + kv_store.store(KV_REINDEX_KEY, False) + + +def setup_vespa( + document_index: DocumentIndex, + index_setting: IndexingSetting, + secondary_index_setting: IndexingSetting | None, +) -> bool: + # Vespa startup is a bit slow, so give it a few seconds + WAIT_SECONDS = 5 + VESPA_ATTEMPTS = 5 + for x in range(VESPA_ATTEMPTS): + try: + logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...") + document_index.ensure_indices_exist( + index_embedding_dim=index_setting.model_dim, + secondary_index_embedding_dim=secondary_index_setting.model_dim + if secondary_index_setting + else None, + ) + + logger.notice("Vespa setup complete.") + return True + except Exception: + logger.notice( + f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds." + ) + time.sleep(WAIT_SECONDS) + + logger.error( + f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})" + ) + return False + + +def setup_postgres(db_session: Session) -> None: + logger.notice("Verifying default connector/credential exist.") + create_initial_public_credential(db_session) + create_initial_default_connector(db_session) + associate_default_cc_pair(db_session) + + logger.notice("Loading default Prompts and Personas") + delete_old_default_personas(db_session) + load_chat_yamls(db_session) + + logger.notice("Loading built-in tools") + load_builtin_tools(db_session) + refresh_built_in_tools_cache(db_session) + auto_add_search_tool_to_personas(db_session) + + if GEN_AI_API_KEY and fetch_default_provider(db_session) is None: + # Only for dev flows + logger.notice("Setting up default OpenAI LLM for dev.") + llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini" + fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini" + model_req = LLMProviderUpsertRequest( + name="DevEnvPresetOpenAI", + provider="openai", + api_key=GEN_AI_API_KEY, + api_base=None, + api_version=None, + custom_config=None, + default_model_name=llm_model, + fast_default_model_name=fast_model, + is_public=True, + groups=[], + display_model_names=[llm_model, fast_model], + model_names=[llm_model, fast_model], + ) + new_llm_provider = upsert_llm_provider( + llm_provider=model_req, db_session=db_session + ) + update_default_provider(provider_id=new_llm_provider.id, db_session=db_session) + + +def update_default_multipass_indexing(db_session: Session) -> None: + docs_exist = check_docs_exist(db_session) + connectors_exist = check_connectors_exist(db_session) + logger.debug(f"Docs exist: {docs_exist}, Connectors exist: {connectors_exist}") + + if not docs_exist and not connectors_exist: + logger.info( + "No existing docs or connectors found. Checking GPU availability for multipass indexing." + ) + gpu_available = gpu_status_request() + logger.info(f"GPU available: {gpu_available}") + + current_settings = get_current_search_settings(db_session) + + logger.notice(f"Updating multipass indexing setting to: {gpu_available}") + updated_settings = SavedSearchSettings.from_db_model(current_settings) + # Enable multipass indexing if GPU is available or if using a cloud provider + updated_settings.multipass_indexing = ( + gpu_available or current_settings.cloud_provider is not None + ) + update_current_search_settings(db_session, updated_settings) + + # Update settings with GPU availability + settings = load_settings() + settings.gpu_enabled = gpu_available + store_settings(settings) + logger.notice(f"Updated settings with GPU availability: {gpu_available}") + + else: + logger.debug( + "Existing docs or connectors found. Skipping multipass indexing update." + ) diff --git a/backend/ee/danswer/main.py b/backend/ee/danswer/main.py index 7d150107c..8422d5494 100644 --- a/backend/ee/danswer/main.py +++ b/backend/ee/danswer/main.py @@ -34,6 +34,7 @@ from ee.danswer.server.query_history.api import router as query_history_router from ee.danswer.server.reporting.usage_export_api import router as usage_export_router from ee.danswer.server.saml import router as saml_router from ee.danswer.server.seeding import seed_db +from ee.danswer.server.tenants.api import router as tenants_router from ee.danswer.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) @@ -79,6 +80,8 @@ def get_application() -> FastAPI: # RBAC / group access control include_router_with_global_prefix_prepended(application, user_group_router) + # Tenant management + include_router_with_global_prefix_prepended(application, tenants_router) # Analytics endpoints include_router_with_global_prefix_prepended(application, analytics_router) include_router_with_global_prefix_prepended(application, query_history_router) diff --git a/backend/ee/danswer/server/tenants/__init__.py b/backend/ee/danswer/server/tenants/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/ee/danswer/server/tenants/access.py b/backend/ee/danswer/server/tenants/access.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/ee/danswer/server/tenants/api.py b/backend/ee/danswer/server/tenants/api.py new file mode 100644 index 000000000..ec9635185 --- /dev/null +++ b/backend/ee/danswer/server/tenants/api.py @@ -0,0 +1,46 @@ +from fastapi import APIRouter +from fastapi import Depends +from fastapi import HTTPException + +from danswer.auth.users import control_plane_dep +from danswer.configs.app_configs import MULTI_TENANT +from danswer.db.engine import get_session_with_tenant +from danswer.setup import setup_danswer +from danswer.utils.logger import setup_logger +from ee.danswer.server.tenants.models import CreateTenantRequest +from ee.danswer.server.tenants.provisioning import ensure_schema_exists +from ee.danswer.server.tenants.provisioning import run_alembic_migrations + +logger = setup_logger() +router = APIRouter(prefix="/tenants") + + +@router.post("/create") +def create_tenant( + create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep) +) -> dict[str, str]: + try: + tenant_id = create_tenant_request.tenant_id + + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") + + if not ensure_schema_exists(tenant_id): + logger.info(f"Created schema for tenant {tenant_id}") + else: + logger.info(f"Schema already exists for tenant {tenant_id}") + + run_alembic_migrations(tenant_id) + with get_session_with_tenant(tenant_id) as db_session: + setup_danswer(db_session) + + logger.info(f"Tenant {tenant_id} created successfully") + return { + "status": "success", + "message": f"Tenant {tenant_id} created successfully", + } + except Exception as e: + logger.exception(f"Failed to create tenant {tenant_id}: {str(e)}") + raise HTTPException( + status_code=500, detail=f"Failed to create tenant: {str(e)}" + ) diff --git a/backend/ee/danswer/server/tenants/models.py b/backend/ee/danswer/server/tenants/models.py new file mode 100644 index 000000000..833650c42 --- /dev/null +++ b/backend/ee/danswer/server/tenants/models.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel + + +class CreateTenantRequest(BaseModel): + tenant_id: str + initial_admin_email: str diff --git a/backend/ee/danswer/server/tenants/provisioning.py b/backend/ee/danswer/server/tenants/provisioning.py new file mode 100644 index 000000000..62436c92e --- /dev/null +++ b/backend/ee/danswer/server/tenants/provisioning.py @@ -0,0 +1,63 @@ +import os +from types import SimpleNamespace + +from sqlalchemy import text +from sqlalchemy.orm import Session +from sqlalchemy.schema import CreateSchema + +from alembic import command +from alembic.config import Config +from danswer.db.engine import build_connection_string +from danswer.db.engine import get_sqlalchemy_engine +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def run_alembic_migrations(schema_name: str) -> None: + logger.info(f"Starting Alembic migrations for schema: {schema_name}") + + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) + alembic_ini_path = os.path.join(root_dir, "alembic.ini") + + # Configure Alembic + alembic_cfg = Config(alembic_ini_path) + alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) + alembic_cfg.set_main_option( + "script_location", os.path.join(root_dir, "alembic") + ) + + # Mimic command-line options by adding 'cmd_opts' to the config + alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore + alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore + + # Run migrations programmatically + command.upgrade(alembic_cfg, "head") + + # Run migrations programmatically + logger.info( + f"Alembic migrations completed successfully for schema: {schema_name}" + ) + + except Exception as e: + logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") + raise + + +def ensure_schema_exists(tenant_id: str) -> bool: + with Session(get_sqlalchemy_engine()) as db_session: + with db_session.begin(): + result = db_session.execute( + text( + "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" + ), + {"schema_name": tenant_id}, + ) + schema_exists = result.scalar() is not None + if not schema_exists: + stmt = CreateSchema(tenant_id) + db_session.execute(stmt) + return True + return False diff --git a/backend/tests/integration/common_utils/reset.py b/backend/tests/integration/common_utils/reset.py index 95b3f734e..a532406c4 100644 --- a/backend/tests/integration/common_utils/reset.py +++ b/backend/tests/integration/common_utils/reset.py @@ -18,8 +18,8 @@ from danswer.db.swap_index import check_index_swap from danswer.document_index.vespa.index import DOCUMENT_ID_ENDPOINT from danswer.document_index.vespa.index import VespaIndex from danswer.indexing.models import IndexingSetting -from danswer.main import setup_postgres -from danswer.main import setup_vespa +from danswer.setup import setup_postgres +from danswer.setup import setup_vespa from danswer.utils.logger import setup_logger logger = setup_logger() diff --git a/web/src/components/search/SearchSection.tsx b/web/src/components/search/SearchSection.tsx index b171b2bcb..4e8ea0abd 100644 --- a/web/src/components/search/SearchSection.tsx +++ b/web/src/components/search/SearchSection.tsx @@ -14,6 +14,7 @@ import { ValidQuestionResponse, Relevance, SearchDanswerDocument, + SourceMetadata, } from "@/lib/search/interfaces"; import { searchRequestStreamed } from "@/lib/search/streamingQa"; import { CancellationToken, cancellable } from "@/lib/search/cancellable"; @@ -40,6 +41,9 @@ import { ApiKeyModal } from "../llm/ApiKeyModal"; import { useSearchContext } from "../context/SearchContext"; import { useUser } from "../user/UserProvider"; import UnconfiguredProviderText from "../chat_search/UnconfiguredProviderText"; +import { DateRangePickerValue } from "@tremor/react"; +import { Tag } from "@/lib/types"; +import { isEqual } from "lodash"; export type searchState = | "input" @@ -370,8 +374,36 @@ export const SearchSection = ({ setSearchAnswerExpanded(false); }; - const [previousSearch, setPreviousSearch] = useState(""); + interface SearchDetails { + query: string; + sources: SourceMetadata[]; + agentic: boolean; + documentSets: string[]; + timeRange: DateRangePickerValue | null; + tags: Tag[]; + persona: Persona; + } + + const [previousSearch, setPreviousSearch] = useState( + null + ); const [agenticResults, setAgenticResults] = useState(null); + const currentSearch = (overrideMessage?: string): SearchDetails => { + return { + query: overrideMessage || query, + sources: filterManager.selectedSources, + agentic: agentic!, + documentSets: filterManager.selectedDocumentSets, + timeRange: filterManager.timeRange, + tags: filterManager.selectedTags, + persona: assistants.find( + (assistant) => assistant.id === selectedPersona + ) as Persona, + }; + }; + const isSearchChanged = () => { + return !isEqual(currentSearch(), previousSearch); + }; let lastSearchCancellationToken = useRef(null); const onSearch = async ({ @@ -394,7 +426,9 @@ export const SearchSection = ({ setIsFetching(true); setSearchResponse(initialSearchResponse); - setPreviousSearch(overrideMessage || query); + + setPreviousSearch(currentSearch(overrideMessage)); + const searchFnArgs = { query: overrideMessage || query, sources: filterManager.selectedSources, @@ -761,7 +795,7 @@ export const SearchSection = ({ />