diff --git a/backend/alembic/env.py b/backend/alembic/env.py index f0b932f7e..556f31496 100644 --- a/backend/alembic/env.py +++ b/backend/alembic/env.py @@ -31,18 +31,15 @@ def get_schema_options() -> Tuple[str, bool]: if '=' in pair: key, value = pair.split('=', 1) x_args[key] = value - print(f"x_args: {x_args}") # For debugging - schema_name = x_args.get('schema', 'public') # Default schema - create_schema = x_args.get('create_schema', 'false').lower() == 'true' - return schema_name, create_schema + schema_name = x_args.get('schema', 'public') + return schema_name + def run_migrations_offline() -> None: """Run migrations in 'offline' mode.""" url = build_connection_string() - schema, create_schema = get_schema_options() + schema = get_schema_options() - if create_schema: - raise RuntimeError("Cannot create schema in offline mode. Please run migrations online to create the schema.") context.configure( url=url, @@ -57,14 +54,11 @@ def run_migrations_offline() -> None: context.run_migrations() def do_run_migrations(connection: Connection) -> None: - schema, create_schema = get_schema_options() + schema = get_schema_options() - if create_schema: - # Use text() to create a proper SQL expression - connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"')) - connection.execute(text('COMMIT')) + connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"')) + connection.execute(text('COMMIT')) - # Set the search_path to the target schema connection.execute(text(f'SET search_path TO "{schema}"')) context.configure( diff --git a/backend/danswer/auth/users.py b/backend/danswer/auth/users.py index 085afdc96..f0af5db4a 100644 --- a/backend/danswer/auth/users.py +++ b/backend/danswer/auth/users.py @@ -241,7 +241,7 @@ async def get_or_create_user(email: str, user_id: str) -> User: async def create_user_session(user: User, tenant_id: str) -> str: - # Create a payload with user information and tenant_id + # Create a payload user information and tenant_id payload = { "sub": str(user.id), "email": user.email, @@ -261,7 +261,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): async def sso_authenticate( self, email: str, - user_id: str, tenant_id: str, ) -> User: try: diff --git a/backend/danswer/background/celery/celery_app.py b/backend/danswer/background/celery/celery_app.py index 9d787fe12..a013fb9da 100644 --- a/backend/danswer/background/celery/celery_app.py +++ b/backend/danswer/background/celery/celery_app.py @@ -70,7 +70,7 @@ _SYNC_BATCH_SIZE = 100 def cleanup_connector_credential_pair_task( connector_id: int, credential_id: int, - tenant_id: str + tenant_id: str | None ) -> int: """Connector deletion task. This is run as an async task because it is a somewhat slow job. Needs to potentially update a large number of Postgres and Vespa docs, including deleting them @@ -114,7 +114,7 @@ def cleanup_connector_credential_pair_task( @build_celery_task_wrapper(name_cc_prune_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str) -> None: +def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" @@ -183,7 +183,7 @@ def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str) @build_celery_task_wrapper(name_document_set_sync_task) @celery_app.task(soft_time_limit=JOB_TIMEOUT) -def sync_document_set_task(document_set_id: int, tenant_id: str) -> None: +def sync_document_set_task(document_set_id: int, tenant_id: str | None) -> None: """For document sets marked as not up to date, sync the state from postgres into the datastore. Also handles deletions.""" @@ -267,7 +267,7 @@ def sync_document_set_task(document_set_id: int, tenant_id: str) -> None: name="check_for_document_sets_sync_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_document_sets_sync_task(tenant_id: str) -> None: +def check_for_document_sets_sync_task(tenant_id: str | None) -> None: """Runs periodically to check if any sync tasks should be run and adds them to the queue""" with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: @@ -287,7 +287,7 @@ def check_for_document_sets_sync_task(tenant_id: str) -> None: name="check_for_cc_pair_deletion_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_cc_pair_deletion_task(tenant_id: str) -> None: +def check_for_cc_pair_deletion_task(tenant_id: str | None) -> None: """Runs periodically to check if any deletion tasks should be run""" with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: # check if any document sets are not synced @@ -310,7 +310,7 @@ def check_for_cc_pair_deletion_task(tenant_id: str) -> None: bind=True, base=AbortableTask, ) -def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: +def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int: """Runs periodically to clean up the kombu_message table""" # we will select messages older than this amount to clean up @@ -423,7 +423,7 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: name="check_for_prune_task", soft_time_limit=JOB_TIMEOUT, ) -def check_for_prune_task(tenant_id: str) -> None: +def check_for_prune_task(tenant_id: str | None) -> None: """Runs periodically to check if any prune tasks should be run and adds them to the queue""" @@ -455,7 +455,7 @@ def schedule_tenant_tasks(): if MULTI_TENANT: tenants = get_all_tenant_ids() else: - tenants = ['public'] # Default tenant in single-tenancy mode + tenants = [None] # Filter out any invalid tenants if necessary valid_tenants = [tenant for tenant in tenants if not tenant.startswith('pg_')] diff --git a/backend/danswer/background/indexing/run_indexing.py b/backend/danswer/background/indexing/run_indexing.py index da38cd646..67536e0ed 100644 --- a/backend/danswer/background/indexing/run_indexing.py +++ b/backend/danswer/background/indexing/run_indexing.py @@ -5,9 +5,7 @@ from datetime import datetime from datetime import timedelta from datetime import timezone -from danswer.db.engine import current_tenant_id from sqlalchemy.orm import Session - from danswer.db.engine import get_sqlalchemy_engine from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt from danswer.background.indexing.tracer import DanswerTracer @@ -46,7 +44,7 @@ def _get_connector_runner( attempt: IndexAttempt, start_time: datetime, end_time: datetime, - tenant_id: str | None = None + tenant_id: str | None ) -> ConnectorRunner: """ NOTE: `start_time` and `end_time` are only used for poll connectors @@ -86,7 +84,7 @@ def _get_connector_runner( def _run_indexing( db_session: Session, index_attempt: IndexAttempt, - tenant_id: str + tenant_id: str | None ) -> None: """ 1. Get documents which are either new or updated from specified application @@ -176,6 +174,7 @@ def _run_indexing( attempt=index_attempt, start_time=window_start, end_time=window_end, + tenant_id=tenant_id ) all_connector_doc_ids: set[str] = set() @@ -390,12 +389,11 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA return attempt -def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str, is_ee: bool = False) -> None: +def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str | None, is_ee: bool = False) -> None: try: if is_ee: global_version.set_ee() - current_tenant_id.set(tenant_id) IndexAttemptSingleton.set_index_attempt_id(index_attempt_id) with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: diff --git a/backend/danswer/background/task_utils.py b/backend/danswer/background/task_utils.py index 10f7f95d0..3a1369c94 100644 --- a/backend/danswer/background/task_utils.py +++ b/backend/danswer/background/task_utils.py @@ -14,8 +14,11 @@ from danswer.db.tasks import mark_task_start from danswer.db.tasks import register_task -def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str = "") -> str: - return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}_{tenant_id}" +def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str | None) -> str: + task_name = f"cleanup_connector_credential_pair_{connector_id}_{credential_id}" + if tenant_id is not None: + task_name += f"_{tenant_id}" + return task_name def name_document_set_sync_task(document_set_id: int) -> str: diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index 8504c3cb0..695de8481 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -1,15 +1,20 @@ -from sqlalchemy import text +import logging import time -from sqlalchemy.exc import ProgrammingError +from datetime import datetime + import dask from dask.distributed import Client from dask.distributed import Future +from distributed import LocalCluster from sqlalchemy.orm import Session +from sqlalchemy import text +from danswer.background.indexing.dask_utils import ResourceLogger from danswer.background.indexing.job_client import SimpleJob from danswer.background.indexing.job_client import SimpleJobClient from danswer.background.indexing.run_indexing import run_indexing_entrypoint from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT +from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from danswer.configs.app_configs import NUM_INDEXING_WORKERS from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS @@ -33,10 +38,16 @@ from danswer.db.models import SearchSettings from danswer.db.search_settings import get_current_search_settings from danswer.db.search_settings import get_secondary_search_settings from danswer.db.swap_index import check_index_swap +from danswer.natural_language_processing.search_nlp_models import EmbeddingModel +from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder from danswer.utils.logger import setup_logger from danswer.utils.variable_functionality import global_version from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable -from danswer.db.engine import current_tenant_id +from shared_configs.configs import INDEXING_MODEL_SERVER_HOST +from shared_configs.configs import LOG_LEVEL +from shared_configs.configs import MODEL_SERVER_PORT +from danswer.configs.app_configs import MULTI_TENANT +from sqlalchemy.exc import ProgrammingError logger = setup_logger() @@ -134,9 +145,7 @@ def _mark_run_failed( """Main funcs""" -def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str) -> None: - current_tenant_id.set(tenant_id) - +def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None) -> None: """Creates new indexing jobs for each connector / credential pair which is: 1. Enabled @@ -200,10 +209,9 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id def cleanup_indexing_jobs( existing_jobs: dict[int, Future | SimpleJob], - tenant_id: str, + tenant_id: str | None, timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT, ) -> dict[int, Future | SimpleJob]: - current_tenant_id.set(tenant_id) existing_jobs_copy = existing_jobs.copy() # clean up completed jobs @@ -247,7 +255,6 @@ def cleanup_indexing_jobs( # clean up in-progress jobs that were never completed try: connectors = fetch_connectors(db_session) - logger.info(f"len(connectors): {len(connectors)}") for connector in connectors: in_progress_indexing_attempts = get_inprogress_index_attempts( connector.id, db_session @@ -264,7 +271,6 @@ def cleanup_indexing_jobs( # batch of documents indexed current_db_time = get_db_current_time(db_session=db_session) time_since_update = current_db_time - index_attempt.time_updated - logger.info("ERRORS 1") if time_since_update.total_seconds() > 60 * 60 * timeout_hours: existing_jobs[index_attempt.id].cancel() _mark_run_failed( @@ -274,8 +280,6 @@ def cleanup_indexing_jobs( "The run will be re-attempted at next scheduled indexing time.", ) else: - logger.info(f"ERRORS 2 {tenant_id} {len(existing_jobs)}") - continue # If job isn't known, simply mark it as failed _mark_run_failed( db_session=db_session, @@ -292,9 +296,8 @@ def kickoff_indexing_jobs( existing_jobs: dict[int, Future | SimpleJob], client: Client | SimpleJobClient, secondary_client: Client | SimpleJobClient, - tenant_id: str, + tenant_id: str | None, ) -> dict[int, Future | SimpleJob]: - current_tenant_id.set(tenant_id) existing_jobs_copy = existing_jobs.copy() engine = get_sqlalchemy_engine(schema=tenant_id) @@ -387,6 +390,7 @@ def kickoff_indexing_jobs( return existing_jobs_copy + def get_all_tenant_ids() -> list[str]: with Session(get_sqlalchemy_engine(schema='public')) as session: result = session.execute(text(""" @@ -397,63 +401,100 @@ def get_all_tenant_ids() -> list[str]: tenant_ids = [row[0] for row in result] return tenant_ids + def update_loop( delay: int = 10, num_workers: int = NUM_INDEXING_WORKERS, num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS, ) -> None: - client_primary = Client(n_workers=num_workers) - client_secondary = Client(n_workers=num_secondary_workers) - try: - while True: - tenants = get_all_tenant_ids() + client_primary: Client | SimpleJobClient + client_secondary: Client | SimpleJobClient + if DASK_JOB_CLIENT_ENABLED: + cluster_primary = LocalCluster( + n_workers=num_workers, + threads_per_worker=1, + silence_logs=logging.ERROR, + ) + cluster_secondary = LocalCluster( + n_workers=num_secondary_workers, + threads_per_worker=1, + silence_logs=logging.ERROR, + ) + client_primary = Client(cluster_primary) + client_secondary = Client(cluster_secondary) + if LOG_LEVEL.lower() == "debug": + client_primary.register_worker_plugin(ResourceLogger()) + else: + client_primary = SimpleJobClient(n_workers=num_workers) + client_secondary = SimpleJobClient(n_workers=num_secondary_workers) - valid_tenants = [tenant for tenant in tenants if not tenant.startswith('pg_')] - logger.info(f"Found valid tenants: {valid_tenants}") - tenants = valid_tenants + existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {} + + logger.notice("Startup complete. Waiting for indexing jobs...") + while True: + start = time.time() + start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S") + logger.debug(f"Running update, current UTC time: {start_time_utc}") + + if existing_jobs: + logger.debug( + "Found existing indexing jobs: " + f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}" + ) + + try: + tenants = get_all_tenant_ids() if MULTI_TENANT else [None] + tenants = [tenant for tenant in tenants if not tenant.startswith('pg_')] if MULTI_TENANT else tenants + if MULTI_TENANT: + logger.info(f"Found valid tenants: {tenants}") for tenant_id in tenants: try: - logger.debug(f"Processing tenant: {tenant_id}") - with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: - try: - check_index_swap(db_session) - except ProgrammingError: - pass + logger.debug(f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}") + engine = get_sqlalchemy_engine(schema=tenant_id) + with Session(engine) as db_session: + check_index_swap(db_session=db_session) + if not MULTI_TENANT: + search_settings = get_current_search_settings(db_session) + if search_settings.provider_type is None: + logger.notice("Running a first inference to warm up embedding model") + embedding_model = EmbeddingModel.from_db_model( + search_settings=search_settings, + server_host=INDEXING_MODEL_SERVER_HOST, + server_port=MODEL_SERVER_PORT, + ) + warm_up_bi_encoder(embedding_model=embedding_model) + logger.notice("First inference complete.") - # Initialize or retrieve existing jobs per tenant - existing_jobs: dict[int, Future | SimpleJob] = {} + tenant_jobs = existing_jobs.get(tenant_id, {}) - # Perform cleanup, job creation, and kickoff for this tenant - existing_jobs = cleanup_indexing_jobs( - existing_jobs=existing_jobs, + tenant_jobs = cleanup_indexing_jobs( + existing_jobs=tenant_jobs, tenant_id=tenant_id ) create_indexing_jobs( - existing_jobs=existing_jobs, + existing_jobs=tenant_jobs, tenant_id=tenant_id ) - - logger.debug(f"Indexing Jobs are {len(existing_jobs)} many") - - existing_jobs = kickoff_indexing_jobs( - existing_jobs=existing_jobs, + tenant_jobs = kickoff_indexing_jobs( + existing_jobs=tenant_jobs, client=client_primary, secondary_client=client_secondary, tenant_id=tenant_id, ) + existing_jobs[tenant_id] = tenant_jobs + except Exception as e: - logger.exception(f"Failed to process tenant {tenant_id}: {e}") + logger.exception(f"Failed to process tenant {tenant_id or 'default'}: {e}") - sleep_time = delay + except Exception as e: + logger.exception(f"Failed to run update due to {e}") + + sleep_time = delay - (time.time() - start) + if sleep_time > 0: time.sleep(sleep_time) - finally: - client_primary.close() - client_secondary.close() - - def update__main() -> None: set_is_ee_based_on_env_variable() init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME) diff --git a/backend/danswer/db_setup.py b/backend/danswer/db_setup.py index aa47b2caa..b084a8956 100644 --- a/backend/danswer/db_setup.py +++ b/backend/danswer/db_setup.py @@ -1,16 +1,4 @@ -from danswer.search.retrieval.search_runner import download_nltk_data -from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder -from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP -from danswer.db.connector_credential_pair import get_connector_credential_pairs -from danswer.db.connector_credential_pair import resync_cc_pair -from danswer.db.index_attempt import cancel_indexing_attempts_past_model -from danswer.db.index_attempt import expire_index_attempts -from danswer.db.search_settings import get_current_search_settings -from danswer.db.search_settings import get_secondary_search_settings -from danswer.db.swap_index import check_index_swap - -from sqlalchemy.orm import Session from danswer.llm.llm_initialization import load_llm_providers from danswer.db.connector import create_initial_default_connector from danswer.db.connector_credential_pair import associate_default_cc_pair @@ -22,81 +10,10 @@ from danswer.tools.built_in_tools import auto_add_search_tool_to_personas from danswer.tools.built_in_tools import load_builtin_tools from danswer.tools.built_in_tools import refresh_built_in_tools_cache from danswer.utils.logger import setup_logger +from sqlalchemy.orm import Session logger = setup_logger() -def setup_postgres_and_initial_settings(db_session: Session) -> None: - - - check_index_swap(db_session=db_session) - search_settings = get_current_search_settings(db_session) - secondary_search_settings = get_secondary_search_settings(db_session) - - # Break bad state for thrashing indexes - if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP: - expire_index_attempts( - search_settings_id=search_settings.id, db_session=db_session - ) - - for cc_pair in get_connector_credential_pairs(db_session): - resync_cc_pair(cc_pair, db_session=db_session) - - # Expire all old embedding models indexing attempts, technically redundant - cancel_indexing_attempts_past_model(db_session) - - logger.notice(f'Using Embedding model: "{search_settings.model_name}"') - if search_settings.query_prefix or search_settings.passage_prefix: - logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"') - logger.notice( - f'Passage embedding prefix: "{search_settings.passage_prefix}"' - ) - - if search_settings: - if not search_settings.disable_rerank_for_streaming: - logger.notice("Reranking is enabled.") - - if search_settings.multilingual_expansion: - logger.notice( - f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." - ) - - if search_settings.rerank_model_name and not search_settings.provider_type: - warm_up_cross_encoder(search_settings.rerank_model_name) - - logger.notice("Verifying query preprocessing (NLTK) data is downloaded") - download_nltk_data() - - # setup Postgres with default credential, llm providers, etc. - setup_postgres(db_session) - - # Does the user need to trigger a reindexing to bring the document index - # into a good state, marked in the kv store - - # ensure Vespa is setup correctly - logger.notice("Verifying Document Index(s) is/are available.") - - - logger.notice("Verifying default connector/credential exist.") - create_initial_public_credential(db_session) - create_initial_default_connector(db_session) - associate_default_cc_pair(db_session) - - logger.notice("Verifying default standard answer category exists.") - create_initial_default_standard_answer_category(db_session) - - logger.notice("Loading LLM providers from env variables") - load_llm_providers(db_session) - - logger.notice("Loading default Prompts and Personas") - delete_old_default_personas(db_session) - load_chat_yamls(db_session) - - logger.notice("Loading built-in tools") - load_builtin_tools(db_session) - refresh_built_in_tools_cache(db_session) - auto_add_search_tool_to_personas(db_session) - - def setup_postgres(db_session: Session) -> None: logger.notice("Verifying default connector/credential exist.") create_initial_public_credential(db_session) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 74ee23576..8a1a2f9e7 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -273,6 +273,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: # Break bad state for thrashing indexes if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP: + expire_index_attempts( search_settings_id=search_settings.id, db_session=db_session ) @@ -285,12 +286,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator: logger.notice(f'Using Embedding model: "{search_settings.model_name}"') if search_settings.query_prefix or search_settings.passage_prefix: + logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"') logger.notice( f'Passage embedding prefix: "{search_settings.passage_prefix}"' ) if search_settings: + if not search_settings.disable_rerank_for_streaming: logger.notice("Reranking is enabled.") diff --git a/backend/danswer/search/preprocessing/preprocessing.py b/backend/danswer/search/preprocessing/preprocessing.py index d592bafdd..506b0ebd5 100644 --- a/backend/danswer/search/preprocessing/preprocessing.py +++ b/backend/danswer/search/preprocessing/preprocessing.py @@ -29,7 +29,8 @@ from danswer.utils.logger import setup_logger from danswer.utils.threadpool_concurrency import FunctionCall from danswer.utils.threadpool_concurrency import run_functions_in_parallel from danswer.utils.timing import log_function_time - +from danswer.configs.app_configs import MULTI_TENANT +from danswer.db.engine import current_tenant_id logger = setup_logger() @@ -152,14 +153,14 @@ def retrieval_preprocessing( None if bypass_acl else build_access_filters_for_user(user, db_session) ) + final_filters = IndexFilters( source_type=preset_filters.source_type or predicted_source_filters, document_set=preset_filters.document_set, time_cutoff=preset_filters.time_cutoff or predicted_time_cutoff, tags=preset_filters.tags, # Tags are never auto-extracted access_control_list=user_acl_filters, - tenant_id="0a4bae55-27d4-406a-8c12-d3827af22e42" - # tenant_id="84d01c94-7032-4dfa-9a1e-0d80805c7826" # TODO FIX + tenant_id=current_tenant_id.get() if MULTI_TENANT else None, ) llm_evaluation_type = LLMEvaluationType.BASIC diff --git a/backend/danswer/server/auth_check.py b/backend/danswer/server/auth_check.py index 06dffdee9..6f0952f19 100644 --- a/backend/danswer/server/auth_check.py +++ b/backend/danswer/server/auth_check.py @@ -42,7 +42,7 @@ PUBLIC_ENDPOINT_SPECS = [ ("/users/{id}", {"DELETE"}), # oauth ("/auth/oauth/authorize", {"GET"}), - ("/auth/oaute a tsth/callback", {"GET"}), + ("/auth/oauth/callback", {"GET"}), # tenant service related (must use API key) ("/tenants/create", {"POST"}), ] diff --git a/backend/danswer/server/tenants/api.py b/backend/danswer/server/tenants/api.py index 1b2d3912c..e3813c6ed 100644 --- a/backend/danswer/server/tenants/api.py +++ b/backend/danswer/server/tenants/api.py @@ -1,14 +1,8 @@ from fastapi import APIRouter from fastapi import Depends from sqlalchemy.orm import Session - -import os -from danswer.db_setup import setup_postgres_and_initial_settings from fastapi import Body from danswer.db.engine import get_sqlalchemy_engine - -from typing import Any -from typing import Callable from danswer.auth.users import create_user_session from danswer.auth.users import get_user_manager from danswer.auth.users import UserManager @@ -16,138 +10,45 @@ from danswer.auth.users import verify_sso_token from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from danswer.utils.logger import setup_logger from fastapi.responses import JSONResponse -from danswer.db.engine import get_async_session -import contextlib -from fastapi import HTTPException, Request -from sqlalchemy import text -from alembic import command +from fastapi import HTTPException +from ee.danswer.auth.users import control_plane_dep +from danswer.server.tenants.provisioning import setup_postgres_and_initial_settings +from danswer.server.tenants.provisioning import check_schema_exists +from danswer.server.tenants.provisioning import run_alembic_migrations +from danswer.server.tenants.provisioning import create_tenant_schema +from danswer.configs.app_configs import MULTI_TENANT -from danswer.db.engine import build_connection_string - -from alembic.config import Config -from functools import wraps -import jwt -DATA_PLANE_SECRET = "your_shared_secret_key" -EXPECTED_API_KEY = "your_control_plane_api_key" logger = setup_logger() - basic_router = APIRouter(prefix="/tenants") -def run_alembic_migrations(schema_name: str, create_schema: bool = True) -> None: - logger.info(f"Starting Alembic migrations for schema: {schema_name}") - try: - current_dir = os.path.dirname(os.path.abspath(__file__)) - root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) - alembic_ini_path = os.path.join(root_dir, 'alembic.ini') - - # Configure Alembic - alembic_cfg = Config(alembic_ini_path) - alembic_cfg.set_main_option('sqlalchemy.url', build_connection_string()) - - # Prepare the x arguments - x_arg_schema = f"schema={schema_name}" - x_arg_create_schema = f"create_schema={'true' if create_schema else 'false'}" - x_arguments = [x_arg_schema, x_arg_create_schema] - - # Set the x arguments into the Alembic context - alembic_cfg.cmd_opts = lambda: None # Create a dummy object - alembic_cfg.cmd_opts.x = x_arguments - - # Run migrations programmatically - command.upgrade(alembic_cfg, 'head') - - logger.info(f"Alembic migrations completed successfully for schema: {schema_name}") - - except Exception as e: - logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") - raise - - -def authenticate_request(func: Callable) -> Callable: - @wraps(func) - def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any: - auth_header = request.headers.get("Authorization") - api_key = request.headers.get("X-API-KEY") - - if api_key != EXPECTED_API_KEY: - raise HTTPException(status_code=401, detail="Invalid API key") - - if not auth_header or not auth_header.startswith("Bearer "): - raise HTTPException(status_code=401, detail="Invalid authorization header") - - token = auth_header.split(" ")[1] - try: - payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"]) - if payload.get("scope") != "tenant:create": - raise HTTPException(status_code=403, detail="Insufficient permissions") - except jwt.ExpiredSignatureError: - raise HTTPException(status_code=401, detail="Token has expired") - except jwt.InvalidTokenError: - raise HTTPException(status_code=401, detail="Invalid token") - - return func(request, *args, **kwargs) - return wrapper - - @basic_router.post("/create") -@authenticate_request -def create_tenant(request: Request, tenant_id: str) -> dict[str, str]: +def create_tenant(tenant_id: str, _ = Depends(control_plane_dep)) -> dict[str, str]: + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenant is not enabled") + if not tenant_id: raise HTTPException(status_code=400, detail="tenant_id is required") - logger.info(f"Creating tenant schema: {tenant_id}") - + create_tenant_schema(tenant_id) + run_alembic_migrations(tenant_id) with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: - with db_session.begin(): - result = db_session.execute( - text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"), - {"schema_name": tenant_id} - ) - schema_exists = result.scalar() is not None - - if not schema_exists: - db_session.execute(text(f'CREATE SCHEMA "{tenant_id}"')) - logger.info(f"Schema {tenant_id} created") - else: - logger.info(f"Schema {tenant_id} already exists") - - try: - run_alembic_migrations(tenant_id) - logger.info(f"Migrations completed for tenant: {tenant_id}") - except Exception as e: - logger.exception(f"Error running migrations for tenant {tenant_id}: {str(e)}") - raise HTTPException(status_code=500, detail=str(e)) - - try: - with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session: - setup_postgres_and_initial_settings(db_session) - except Exception as e: - logger.exception(f"Error setting up postgres for tenant {tenant_id}: {str(e)}") - raise + setup_postgres_and_initial_settings(db_session) logger.info(f"Tenant {tenant_id} created successfully") return {"status": "success", "message": f"Tenant {tenant_id} created successfully"} -async def check_schema_exists(tenant_id: str) -> bool: - get_async_session_context = contextlib.asynccontextmanager( - get_async_session - ) - async with get_async_session_context() as session: - result = await session.execute( - text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"), - {"schema_name": tenant_id} - ) - return result.scalar() is not None @basic_router.post("/auth/sso-callback") async def sso_callback( - request: Request, sso_token: str = Body(..., embed=True), user_manager: UserManager = Depends(get_user_manager), ) -> JSONResponse: + if not MULTI_TENANT: + raise HTTPException(status_code=403, detail="Multi-tenant is not enabled") + payload = verify_sso_token(sso_token) user = await user_manager.sso_authenticate( - payload["email"], payload["user_id"], payload["tenant_id"] + payload["email"], payload["tenant_id"] ) tenant_id = payload["tenant_id"] @@ -164,7 +65,7 @@ async def sso_callback( max_age=SESSION_EXPIRE_TIME_SECONDS, expires=SESSION_EXPIRE_TIME_SECONDS, path="/", - secure=False, # Set to True in production with HTTPS + secure=False, httponly=True, samesite="lax", ) diff --git a/backend/danswer/server/tenants/provisioning.py b/backend/danswer/server/tenants/provisioning.py new file mode 100644 index 000000000..49b1805e8 --- /dev/null +++ b/backend/danswer/server/tenants/provisioning.py @@ -0,0 +1,160 @@ +import contextlib +from danswer.search.retrieval.search_runner import download_nltk_data +from danswer.db.engine import get_async_session +from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder +from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP +from danswer.db.connector_credential_pair import get_connector_credential_pairs +from danswer.db.connector_credential_pair import resync_cc_pair +from danswer.db.index_attempt import cancel_indexing_attempts_past_model +from danswer.db.index_attempt import expire_index_attempts +from danswer.db.search_settings import get_current_search_settings +from danswer.db.search_settings import get_secondary_search_settings +from danswer.db.swap_index import check_index_swap + +from sqlalchemy.orm import Session +from danswer.llm.llm_initialization import load_llm_providers +from danswer.db.connector import create_initial_default_connector +from danswer.db.connector_credential_pair import associate_default_cc_pair +from danswer.db.credentials import create_initial_public_credential +from danswer.db.standard_answer import create_initial_default_standard_answer_category +from danswer.db.persona import delete_old_default_personas +from danswer.chat.load_yamls import load_chat_yamls +from danswer.tools.built_in_tools import auto_add_search_tool_to_personas +from danswer.tools.built_in_tools import load_builtin_tools +from danswer.tools.built_in_tools import refresh_built_in_tools_cache +from danswer.utils.logger import setup_logger +from danswer.db.engine import get_sqlalchemy_engine +from sqlalchemy.schema import CreateSchema +from sqlalchemy import text +from alembic.config import Config +from alembic import command +from danswer.db.engine import build_connection_string +import os +from danswer.db_setup import setup_postgres + +DATA_PLANE_SECRET = "your_shared_secret_key" +EXPECTED_API_KEY = "your_control_plane_api_key" + +logger = setup_logger() + +def run_alembic_migrations(schema_name: str) -> None: + logger.info(f"Starting Alembic migrations for schema: {schema_name}") + try: + current_dir = os.path.dirname(os.path.abspath(__file__)) + root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..')) + alembic_ini_path = os.path.join(root_dir, 'alembic.ini') + + # Configure Alembic + alembic_cfg = Config(alembic_ini_path) + alembic_cfg.set_main_option('sqlalchemy.url', build_connection_string()) + + # Prepare the x arguments + x_arguments = [f"schema={schema_name}"] + alembic_cfg.cmd_opts.x = x_arguments + + # Run migrations programmatically + command.upgrade(alembic_cfg, 'head') + + logger.info(f"Alembic migrations completed successfully for schema: {schema_name}") + + except Exception as e: + logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") + raise + +def create_tenant_schema(tenant_id: str) -> None: + with Session(get_sqlalchemy_engine()) as db_session: + with db_session.begin(): + result = db_session.execute( + text(""" + SELECT schema_name + FROM information_schema.schemata + WHERE schema_name = :schema_name + """), + {"schema_name": tenant_id} + ) + schema_exists = result.scalar() is not None + + if not schema_exists: + db_session.execute(CreateSchema(tenant_id)) + logger.info(f"Schema {tenant_id} created") + else: + logger.info(f"Schema {tenant_id} already exists") + + +def setup_postgres_and_initial_settings(db_session: Session) -> None: + check_index_swap(db_session=db_session) + search_settings = get_current_search_settings(db_session) + secondary_search_settings = get_secondary_search_settings(db_session) + + # Break bad state for thrashing indexes + if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP: + expire_index_attempts( + search_settings_id=search_settings.id, db_session=db_session + ) + + for cc_pair in get_connector_credential_pairs(db_session): + resync_cc_pair(cc_pair, db_session=db_session) + + # Expire all old embedding models indexing attempts, technically redundant + cancel_indexing_attempts_past_model(db_session) + + logger.notice(f'Using Embedding model: "{search_settings.model_name}"') + if search_settings.query_prefix or search_settings.passage_prefix: + logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"') + logger.notice( + f'Passage embedding prefix: "{search_settings.passage_prefix}"' + ) + + if search_settings: + if not search_settings.disable_rerank_for_streaming: + logger.notice("Reranking is enabled.") + + if search_settings.multilingual_expansion: + logger.notice( + f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}." + ) + + if search_settings.rerank_model_name and not search_settings.provider_type: + warm_up_cross_encoder(search_settings.rerank_model_name) + + logger.notice("Verifying query preprocessing (NLTK) data is downloaded") + download_nltk_data() + + # setup Postgres with default credentials, llm providers, etc. + setup_postgres(db_session) + + # ensure Vespa is setup correctly + logger.notice("Verifying Document Index(s) is/are available.") + + + logger.notice("Verifying default connector/credential exist.") + create_initial_public_credential(db_session) + create_initial_default_connector(db_session) + associate_default_cc_pair(db_session) + + logger.notice("Verifying default standard answer category exists.") + create_initial_default_standard_answer_category(db_session) + + logger.notice("Loading LLM providers from env variables") + load_llm_providers(db_session) + + logger.notice("Loading default Prompts and Personas") + delete_old_default_personas(db_session) + load_chat_yamls(db_session) + + logger.notice("Loading built-in tools") + load_builtin_tools(db_session) + refresh_built_in_tools_cache(db_session) + auto_add_search_tool_to_personas(db_session) + + +async def check_schema_exists(tenant_id: str) -> bool: + get_async_session_context = contextlib.asynccontextmanager( + get_async_session + ) + async with get_async_session_context() as session: + result = await session.execute( + text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"), + {"schema_name": tenant_id} + ) + return result.scalar() is not None diff --git a/backend/ee/danswer/auth/users.py b/backend/ee/danswer/auth/users.py index 18dff6ab0..9edb0515b 100644 --- a/backend/ee/danswer/auth/users.py +++ b/backend/ee/danswer/auth/users.py @@ -13,15 +13,44 @@ from ee.danswer.db.api_key import fetch_user_for_api_key from ee.danswer.db.saml import get_saml_account from ee.danswer.server.seeding import get_seed_config from ee.danswer.utils.secrets import extract_hashed_cookie +import jwt + +DATA_PLANE_SECRET = "your_shared_secret_key" +EXPECTED_API_KEY = "your_control_plane_api_key" logger = setup_logger() - def verify_auth_setting() -> None: # All the Auth flows are valid for EE version logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") +async def control_plane_dep(request: Request): + auth_header = request.headers.get("Authorization") + api_key = request.headers.get("X-API-KEY") + + if api_key != EXPECTED_API_KEY: + logger.warning("Invalid API key") + raise HTTPException(status_code=401, detail="Invalid API key") + + if not auth_header or not auth_header.startswith("Bearer "): + logger.warning("Invalid authorization header") + raise HTTPException(status_code=401, detail="Invalid authorization header") + + token = auth_header.split(" ")[1] + try: + payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"]) + if payload.get("scope") != "tenant:create": + logger.warning("Insufficient permissions") + raise HTTPException(status_code=403, detail="Insufficient permissions") + except jwt.ExpiredSignatureError: + logger.warning("Token has expired") + raise HTTPException(status_code=401, detail="Token has expired") + except jwt.InvalidTokenError: + logger.warning("Invalid token") + raise HTTPException(status_code=401, detail="Invalid token") + + async def optional_user_( request: Request, user: User | None, @@ -44,6 +73,7 @@ async def optional_user_( return user + def api_key_dep( request: Request, db_session: Session = Depends(get_session) ) -> User | None: @@ -63,6 +93,7 @@ def api_key_dep( return user + def get_default_admin_user_emails_() -> list[str]: seed_config = get_seed_config() if seed_config and seed_config.admin_user_emails: