validated functionality for single-tenant

This commit is contained in:
pablodanswer 2024-09-27 10:21:08 -07:00
parent dc5a91fd85
commit e5f3f2d73a
13 changed files with 332 additions and 284 deletions

View File

@ -31,18 +31,15 @@ def get_schema_options() -> Tuple[str, bool]:
if '=' in pair:
key, value = pair.split('=', 1)
x_args[key] = value
print(f"x_args: {x_args}") # For debugging
schema_name = x_args.get('schema', 'public') # Default schema
create_schema = x_args.get('create_schema', 'false').lower() == 'true'
return schema_name, create_schema
schema_name = x_args.get('schema', 'public')
return schema_name
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode."""
url = build_connection_string()
schema, create_schema = get_schema_options()
schema = get_schema_options()
if create_schema:
raise RuntimeError("Cannot create schema in offline mode. Please run migrations online to create the schema.")
context.configure(
url=url,
@ -57,14 +54,11 @@ def run_migrations_offline() -> None:
context.run_migrations()
def do_run_migrations(connection: Connection) -> None:
schema, create_schema = get_schema_options()
schema = get_schema_options()
if create_schema:
# Use text() to create a proper SQL expression
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text('COMMIT'))
connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema}"'))
connection.execute(text('COMMIT'))
# Set the search_path to the target schema
connection.execute(text(f'SET search_path TO "{schema}"'))
context.configure(

View File

@ -241,7 +241,7 @@ async def get_or_create_user(email: str, user_id: str) -> User:
async def create_user_session(user: User, tenant_id: str) -> str:
# Create a payload with user information and tenant_id
# Create a payload user information and tenant_id
payload = {
"sub": str(user.id),
"email": user.email,
@ -261,7 +261,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async def sso_authenticate(
self,
email: str,
user_id: str,
tenant_id: str,
) -> User:
try:

View File

@ -70,7 +70,7 @@ _SYNC_BATCH_SIZE = 100
def cleanup_connector_credential_pair_task(
connector_id: int,
credential_id: int,
tenant_id: str
tenant_id: str | None
) -> int:
"""Connector deletion task. This is run as an async task because it is a somewhat slow job.
Needs to potentially update a large number of Postgres and Vespa docs, including deleting them
@ -114,7 +114,7 @@ def cleanup_connector_credential_pair_task(
@build_celery_task_wrapper(name_cc_prune_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str) -> None:
def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str | None) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
@ -183,7 +183,7 @@ def prune_documents_task(connector_id: int, credential_id: int, tenant_id: str)
@build_celery_task_wrapper(name_document_set_sync_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_document_set_task(document_set_id: int, tenant_id: str) -> None:
def sync_document_set_task(document_set_id: int, tenant_id: str | None) -> None:
"""For document sets marked as not up to date, sync the state from postgres
into the datastore. Also handles deletions."""
@ -267,7 +267,7 @@ def sync_document_set_task(document_set_id: int, tenant_id: str) -> None:
name="check_for_document_sets_sync_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_document_sets_sync_task(tenant_id: str) -> None:
def check_for_document_sets_sync_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any sync tasks should be run and adds them
to the queue"""
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
@ -287,7 +287,7 @@ def check_for_document_sets_sync_task(tenant_id: str) -> None:
name="check_for_cc_pair_deletion_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_cc_pair_deletion_task(tenant_id: str) -> None:
def check_for_cc_pair_deletion_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any deletion tasks should be run"""
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
# check if any document sets are not synced
@ -310,7 +310,7 @@ def check_for_cc_pair_deletion_task(tenant_id: str) -> None:
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
@ -423,7 +423,7 @@ def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
name="check_for_prune_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_for_prune_task(tenant_id: str) -> None:
def check_for_prune_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any prune tasks should be run and adds them
to the queue"""
@ -455,7 +455,7 @@ def schedule_tenant_tasks():
if MULTI_TENANT:
tenants = get_all_tenant_ids()
else:
tenants = ['public'] # Default tenant in single-tenancy mode
tenants = [None]
# Filter out any invalid tenants if necessary
valid_tenants = [tenant for tenant in tenants if not tenant.startswith('pg_')]

View File

@ -5,9 +5,7 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from danswer.db.engine import current_tenant_id
from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.background.indexing.checkpointing import get_time_windows_for_index_attempt
from danswer.background.indexing.tracer import DanswerTracer
@ -46,7 +44,7 @@ def _get_connector_runner(
attempt: IndexAttempt,
start_time: datetime,
end_time: datetime,
tenant_id: str | None = None
tenant_id: str | None
) -> ConnectorRunner:
"""
NOTE: `start_time` and `end_time` are only used for poll connectors
@ -86,7 +84,7 @@ def _get_connector_runner(
def _run_indexing(
db_session: Session,
index_attempt: IndexAttempt,
tenant_id: str
tenant_id: str | None
) -> None:
"""
1. Get documents which are either new or updated from specified application
@ -176,6 +174,7 @@ def _run_indexing(
attempt=index_attempt,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id
)
all_connector_doc_ids: set[str] = set()
@ -390,12 +389,11 @@ def _prepare_index_attempt(db_session: Session, index_attempt_id: int) -> IndexA
return attempt
def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str, is_ee: bool = False) -> None:
def run_indexing_entrypoint(index_attempt_id: int, tenant_id: str | None, is_ee: bool = False) -> None:
try:
if is_ee:
global_version.set_ee()
current_tenant_id.set(tenant_id)
IndexAttemptSingleton.set_index_attempt_id(index_attempt_id)
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:

View File

@ -14,8 +14,11 @@ from danswer.db.tasks import mark_task_start
from danswer.db.tasks import register_task
def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str = "") -> str:
return f"cleanup_connector_credential_pair_{connector_id}_{credential_id}_{tenant_id}"
def name_cc_cleanup_task(connector_id: int, credential_id: int, tenant_id: str | None) -> str:
task_name = f"cleanup_connector_credential_pair_{connector_id}_{credential_id}"
if tenant_id is not None:
task_name += f"_{tenant_id}"
return task_name
def name_document_set_sync_task(document_set_id: int) -> str:

View File

@ -1,15 +1,20 @@
from sqlalchemy import text
import logging
import time
from sqlalchemy.exc import ProgrammingError
from datetime import datetime
import dask
from dask.distributed import Client
from dask.distributed import Future
from distributed import LocalCluster
from sqlalchemy.orm import Session
from sqlalchemy import text
from danswer.background.indexing.dask_utils import ResourceLogger
from danswer.background.indexing.job_client import SimpleJob
from danswer.background.indexing.job_client import SimpleJobClient
from danswer.background.indexing.run_indexing import run_indexing_entrypoint
from danswer.configs.app_configs import CLEANUP_INDEXING_JOBS_TIMEOUT
from danswer.configs.app_configs import DASK_JOB_CLIENT_ENABLED
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.configs.app_configs import NUM_INDEXING_WORKERS
from danswer.configs.app_configs import NUM_SECONDARY_INDEXING_WORKERS
@ -33,10 +38,16 @@ from danswer.db.models import SearchSettings
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
from danswer.utils.variable_functionality import set_is_ee_based_on_env_variable
from danswer.db.engine import current_tenant_id
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import LOG_LEVEL
from shared_configs.configs import MODEL_SERVER_PORT
from danswer.configs.app_configs import MULTI_TENANT
from sqlalchemy.exc import ProgrammingError
logger = setup_logger()
@ -134,9 +145,7 @@ def _mark_run_failed(
"""Main funcs"""
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str) -> None:
current_tenant_id.set(tenant_id)
def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id: str | None) -> None:
"""Creates new indexing jobs for each connector / credential pair which is:
1. Enabled
@ -200,10 +209,9 @@ def create_indexing_jobs(existing_jobs: dict[int, Future | SimpleJob], tenant_id
def cleanup_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
tenant_id: str,
tenant_id: str | None,
timeout_hours: int = CLEANUP_INDEXING_JOBS_TIMEOUT,
) -> dict[int, Future | SimpleJob]:
current_tenant_id.set(tenant_id)
existing_jobs_copy = existing_jobs.copy()
# clean up completed jobs
@ -247,7 +255,6 @@ def cleanup_indexing_jobs(
# clean up in-progress jobs that were never completed
try:
connectors = fetch_connectors(db_session)
logger.info(f"len(connectors): {len(connectors)}")
for connector in connectors:
in_progress_indexing_attempts = get_inprogress_index_attempts(
connector.id, db_session
@ -264,7 +271,6 @@ def cleanup_indexing_jobs(
# batch of documents indexed
current_db_time = get_db_current_time(db_session=db_session)
time_since_update = current_db_time - index_attempt.time_updated
logger.info("ERRORS 1")
if time_since_update.total_seconds() > 60 * 60 * timeout_hours:
existing_jobs[index_attempt.id].cancel()
_mark_run_failed(
@ -274,8 +280,6 @@ def cleanup_indexing_jobs(
"The run will be re-attempted at next scheduled indexing time.",
)
else:
logger.info(f"ERRORS 2 {tenant_id} {len(existing_jobs)}")
continue
# If job isn't known, simply mark it as failed
_mark_run_failed(
db_session=db_session,
@ -292,9 +296,8 @@ def kickoff_indexing_jobs(
existing_jobs: dict[int, Future | SimpleJob],
client: Client | SimpleJobClient,
secondary_client: Client | SimpleJobClient,
tenant_id: str,
tenant_id: str | None,
) -> dict[int, Future | SimpleJob]:
current_tenant_id.set(tenant_id)
existing_jobs_copy = existing_jobs.copy()
engine = get_sqlalchemy_engine(schema=tenant_id)
@ -387,6 +390,7 @@ def kickoff_indexing_jobs(
return existing_jobs_copy
def get_all_tenant_ids() -> list[str]:
with Session(get_sqlalchemy_engine(schema='public')) as session:
result = session.execute(text("""
@ -397,63 +401,100 @@ def get_all_tenant_ids() -> list[str]:
tenant_ids = [row[0] for row in result]
return tenant_ids
def update_loop(
delay: int = 10,
num_workers: int = NUM_INDEXING_WORKERS,
num_secondary_workers: int = NUM_SECONDARY_INDEXING_WORKERS,
) -> None:
client_primary = Client(n_workers=num_workers)
client_secondary = Client(n_workers=num_secondary_workers)
try:
while True:
tenants = get_all_tenant_ids()
client_primary: Client | SimpleJobClient
client_secondary: Client | SimpleJobClient
if DASK_JOB_CLIENT_ENABLED:
cluster_primary = LocalCluster(
n_workers=num_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
cluster_secondary = LocalCluster(
n_workers=num_secondary_workers,
threads_per_worker=1,
silence_logs=logging.ERROR,
)
client_primary = Client(cluster_primary)
client_secondary = Client(cluster_secondary)
if LOG_LEVEL.lower() == "debug":
client_primary.register_worker_plugin(ResourceLogger())
else:
client_primary = SimpleJobClient(n_workers=num_workers)
client_secondary = SimpleJobClient(n_workers=num_secondary_workers)
valid_tenants = [tenant for tenant in tenants if not tenant.startswith('pg_')]
logger.info(f"Found valid tenants: {valid_tenants}")
tenants = valid_tenants
existing_jobs: dict[str | None, dict[int, Future | SimpleJob]] = {}
logger.notice("Startup complete. Waiting for indexing jobs...")
while True:
start = time.time()
start_time_utc = datetime.utcfromtimestamp(start).strftime("%Y-%m-%d %H:%M:%S")
logger.debug(f"Running update, current UTC time: {start_time_utc}")
if existing_jobs:
logger.debug(
"Found existing indexing jobs: "
f"{[(tenant_id, list(jobs.keys())) for tenant_id, jobs in existing_jobs.items()]}"
)
try:
tenants = get_all_tenant_ids() if MULTI_TENANT else [None]
tenants = [tenant for tenant in tenants if not tenant.startswith('pg_')] if MULTI_TENANT else tenants
if MULTI_TENANT:
logger.info(f"Found valid tenants: {tenants}")
for tenant_id in tenants:
try:
logger.debug(f"Processing tenant: {tenant_id}")
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
try:
check_index_swap(db_session)
except ProgrammingError:
pass
logger.debug(f"Processing {'index attempts' if tenant_id is None else f'tenant {tenant_id}'}")
engine = get_sqlalchemy_engine(schema=tenant_id)
with Session(engine) as db_session:
check_index_swap(db_session=db_session)
if not MULTI_TENANT:
search_settings = get_current_search_settings(db_session)
if search_settings.provider_type is None:
logger.notice("Running a first inference to warm up embedding model")
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=INDEXING_MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
warm_up_bi_encoder(embedding_model=embedding_model)
logger.notice("First inference complete.")
# Initialize or retrieve existing jobs per tenant
existing_jobs: dict[int, Future | SimpleJob] = {}
tenant_jobs = existing_jobs.get(tenant_id, {})
# Perform cleanup, job creation, and kickoff for this tenant
existing_jobs = cleanup_indexing_jobs(
existing_jobs=existing_jobs,
tenant_jobs = cleanup_indexing_jobs(
existing_jobs=tenant_jobs,
tenant_id=tenant_id
)
create_indexing_jobs(
existing_jobs=existing_jobs,
existing_jobs=tenant_jobs,
tenant_id=tenant_id
)
logger.debug(f"Indexing Jobs are {len(existing_jobs)} many")
existing_jobs = kickoff_indexing_jobs(
existing_jobs=existing_jobs,
tenant_jobs = kickoff_indexing_jobs(
existing_jobs=tenant_jobs,
client=client_primary,
secondary_client=client_secondary,
tenant_id=tenant_id,
)
existing_jobs[tenant_id] = tenant_jobs
except Exception as e:
logger.exception(f"Failed to process tenant {tenant_id}: {e}")
logger.exception(f"Failed to process tenant {tenant_id or 'default'}: {e}")
sleep_time = delay
except Exception as e:
logger.exception(f"Failed to run update due to {e}")
sleep_time = delay - (time.time() - start)
if sleep_time > 0:
time.sleep(sleep_time)
finally:
client_primary.close()
client_secondary.close()
def update__main() -> None:
set_is_ee_based_on_env_variable()
init_sqlalchemy_engine(POSTGRES_INDEXER_APP_NAME)

View File

@ -1,16 +1,4 @@
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from sqlalchemy.orm import Session
from danswer.llm.llm_initialization import load_llm_providers
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
@ -22,81 +10,10 @@ from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.utils.logger import setup_logger
from sqlalchemy.orm import Session
logger = setup_logger()
def setup_postgres_and_initial_settings(db_session: Session) -> None:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
if search_settings.multilingual_expansion:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if search_settings.rerank_model_name and not search_settings.provider_type:
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
# setup Postgres with default credential, llm providers, etc.
setup_postgres(db_session)
# Does the user need to trigger a reindexing to bring the document index
# into a good state, marked in the kv store
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
def setup_postgres(db_session: Session) -> None:
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)

View File

@ -273,6 +273,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
@ -285,12 +286,14 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")

View File

@ -29,7 +29,8 @@ from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import FunctionCall
from danswer.utils.threadpool_concurrency import run_functions_in_parallel
from danswer.utils.timing import log_function_time
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import current_tenant_id
logger = setup_logger()
@ -152,14 +153,14 @@ def retrieval_preprocessing(
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
source_type=preset_filters.source_type or predicted_source_filters,
document_set=preset_filters.document_set,
time_cutoff=preset_filters.time_cutoff or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id="0a4bae55-27d4-406a-8c12-d3827af22e42"
# tenant_id="84d01c94-7032-4dfa-9a1e-0d80805c7826" # TODO FIX
tenant_id=current_tenant_id.get() if MULTI_TENANT else None,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@ -42,7 +42,7 @@ PUBLIC_ENDPOINT_SPECS = [
("/users/{id}", {"DELETE"}),
# oauth
("/auth/oauth/authorize", {"GET"}),
("/auth/oaute a tsth/callback", {"GET"}),
("/auth/oauth/callback", {"GET"}),
# tenant service related (must use API key)
("/tenants/create", {"POST"}),
]

View File

@ -1,14 +1,8 @@
from fastapi import APIRouter
from fastapi import Depends
from sqlalchemy.orm import Session
import os
from danswer.db_setup import setup_postgres_and_initial_settings
from fastapi import Body
from danswer.db.engine import get_sqlalchemy_engine
from typing import Any
from typing import Callable
from danswer.auth.users import create_user_session
from danswer.auth.users import get_user_manager
from danswer.auth.users import UserManager
@ -16,138 +10,45 @@ from danswer.auth.users import verify_sso_token
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from danswer.utils.logger import setup_logger
from fastapi.responses import JSONResponse
from danswer.db.engine import get_async_session
import contextlib
from fastapi import HTTPException, Request
from sqlalchemy import text
from alembic import command
from fastapi import HTTPException
from ee.danswer.auth.users import control_plane_dep
from danswer.server.tenants.provisioning import setup_postgres_and_initial_settings
from danswer.server.tenants.provisioning import check_schema_exists
from danswer.server.tenants.provisioning import run_alembic_migrations
from danswer.server.tenants.provisioning import create_tenant_schema
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.engine import build_connection_string
from alembic.config import Config
from functools import wraps
import jwt
DATA_PLANE_SECRET = "your_shared_secret_key"
EXPECTED_API_KEY = "your_control_plane_api_key"
logger = setup_logger()
basic_router = APIRouter(prefix="/tenants")
def run_alembic_migrations(schema_name: str, create_schema: bool = True) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
alembic_ini_path = os.path.join(root_dir, 'alembic.ini')
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option('sqlalchemy.url', build_connection_string())
# Prepare the x arguments
x_arg_schema = f"schema={schema_name}"
x_arg_create_schema = f"create_schema={'true' if create_schema else 'false'}"
x_arguments = [x_arg_schema, x_arg_create_schema]
# Set the x arguments into the Alembic context
alembic_cfg.cmd_opts = lambda: None # Create a dummy object
alembic_cfg.cmd_opts.x = x_arguments
# Run migrations programmatically
command.upgrade(alembic_cfg, 'head')
logger.info(f"Alembic migrations completed successfully for schema: {schema_name}")
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def authenticate_request(func: Callable) -> Callable:
@wraps(func)
def wrapper(request: Request, *args: Any, **kwargs: Any) -> Any:
auth_header = request.headers.get("Authorization")
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
raise HTTPException(status_code=401, detail="Invalid API key")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Invalid token")
return func(request, *args, **kwargs)
return wrapper
@basic_router.post("/create")
@authenticate_request
def create_tenant(request: Request, tenant_id: str) -> dict[str, str]:
def create_tenant(tenant_id: str, _ = Depends(control_plane_dep)) -> dict[str, str]:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenant is not enabled")
if not tenant_id:
raise HTTPException(status_code=400, detail="tenant_id is required")
logger.info(f"Creating tenant schema: {tenant_id}")
create_tenant_schema(tenant_id)
run_alembic_migrations(tenant_id)
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
with db_session.begin():
result = db_session.execute(
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
{"schema_name": tenant_id}
)
schema_exists = result.scalar() is not None
if not schema_exists:
db_session.execute(text(f'CREATE SCHEMA "{tenant_id}"'))
logger.info(f"Schema {tenant_id} created")
else:
logger.info(f"Schema {tenant_id} already exists")
try:
run_alembic_migrations(tenant_id)
logger.info(f"Migrations completed for tenant: {tenant_id}")
except Exception as e:
logger.exception(f"Error running migrations for tenant {tenant_id}: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
try:
with Session(get_sqlalchemy_engine(schema=tenant_id)) as db_session:
setup_postgres_and_initial_settings(db_session)
except Exception as e:
logger.exception(f"Error setting up postgres for tenant {tenant_id}: {str(e)}")
raise
setup_postgres_and_initial_settings(db_session)
logger.info(f"Tenant {tenant_id} created successfully")
return {"status": "success", "message": f"Tenant {tenant_id} created successfully"}
async def check_schema_exists(tenant_id: str) -> bool:
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
)
async with get_async_session_context() as session:
result = await session.execute(
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
{"schema_name": tenant_id}
)
return result.scalar() is not None
@basic_router.post("/auth/sso-callback")
async def sso_callback(
request: Request,
sso_token: str = Body(..., embed=True),
user_manager: UserManager = Depends(get_user_manager),
) -> JSONResponse:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenant is not enabled")
payload = verify_sso_token(sso_token)
user = await user_manager.sso_authenticate(
payload["email"], payload["user_id"], payload["tenant_id"]
payload["email"], payload["tenant_id"]
)
tenant_id = payload["tenant_id"]
@ -164,7 +65,7 @@ async def sso_callback(
max_age=SESSION_EXPIRE_TIME_SECONDS,
expires=SESSION_EXPIRE_TIME_SECONDS,
path="/",
secure=False, # Set to True in production with HTTPS
secure=False,
httponly=True,
samesite="lax",
)

View File

@ -0,0 +1,160 @@
import contextlib
from danswer.search.retrieval.search_runner import download_nltk_data
from danswer.db.engine import get_async_session
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
from danswer.db.connector_credential_pair import get_connector_credential_pairs
from danswer.db.connector_credential_pair import resync_cc_pair
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
from danswer.db.index_attempt import expire_index_attempts
from danswer.db.search_settings import get_current_search_settings
from danswer.db.search_settings import get_secondary_search_settings
from danswer.db.swap_index import check_index_swap
from sqlalchemy.orm import Session
from danswer.llm.llm_initialization import load_llm_providers
from danswer.db.connector import create_initial_default_connector
from danswer.db.connector_credential_pair import associate_default_cc_pair
from danswer.db.credentials import create_initial_public_credential
from danswer.db.standard_answer import create_initial_default_standard_answer_category
from danswer.db.persona import delete_old_default_personas
from danswer.chat.load_yamls import load_chat_yamls
from danswer.tools.built_in_tools import auto_add_search_tool_to_personas
from danswer.tools.built_in_tools import load_builtin_tools
from danswer.tools.built_in_tools import refresh_built_in_tools_cache
from danswer.utils.logger import setup_logger
from danswer.db.engine import get_sqlalchemy_engine
from sqlalchemy.schema import CreateSchema
from sqlalchemy import text
from alembic.config import Config
from alembic import command
from danswer.db.engine import build_connection_string
import os
from danswer.db_setup import setup_postgres
DATA_PLANE_SECRET = "your_shared_secret_key"
EXPECTED_API_KEY = "your_control_plane_api_key"
logger = setup_logger()
def run_alembic_migrations(schema_name: str) -> None:
logger.info(f"Starting Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, '..', '..', '..'))
alembic_ini_path = os.path.join(root_dir, 'alembic.ini')
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option('sqlalchemy.url', build_connection_string())
# Prepare the x arguments
x_arguments = [f"schema={schema_name}"]
alembic_cfg.cmd_opts.x = x_arguments
# Run migrations programmatically
command.upgrade(alembic_cfg, 'head')
logger.info(f"Alembic migrations completed successfully for schema: {schema_name}")
except Exception as e:
logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}")
raise
def create_tenant_schema(tenant_id: str) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():
result = db_session.execute(
text("""
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name = :schema_name
"""),
{"schema_name": tenant_id}
)
schema_exists = result.scalar() is not None
if not schema_exists:
db_session.execute(CreateSchema(tenant_id))
logger.info(f"Schema {tenant_id} created")
else:
logger.info(f"Schema {tenant_id} already exists")
def setup_postgres_and_initial_settings(db_session: Session) -> None:
check_index_swap(db_session=db_session)
search_settings = get_current_search_settings(db_session)
secondary_search_settings = get_secondary_search_settings(db_session)
# Break bad state for thrashing indexes
if secondary_search_settings and DISABLE_INDEX_UPDATE_ON_SWAP:
expire_index_attempts(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)
logger.notice(f'Using Embedding model: "{search_settings.model_name}"')
if search_settings.query_prefix or search_settings.passage_prefix:
logger.notice(f'Query embedding prefix: "{search_settings.query_prefix}"')
logger.notice(
f'Passage embedding prefix: "{search_settings.passage_prefix}"'
)
if search_settings:
if not search_settings.disable_rerank_for_streaming:
logger.notice("Reranking is enabled.")
if search_settings.multilingual_expansion:
logger.notice(
f"Multilingual query expansion is enabled with {search_settings.multilingual_expansion}."
)
if search_settings.rerank_model_name and not search_settings.provider_type:
warm_up_cross_encoder(search_settings.rerank_model_name)
logger.notice("Verifying query preprocessing (NLTK) data is downloaded")
download_nltk_data()
# setup Postgres with default credentials, llm providers, etc.
setup_postgres(db_session)
# ensure Vespa is setup correctly
logger.notice("Verifying Document Index(s) is/are available.")
logger.notice("Verifying default connector/credential exist.")
create_initial_public_credential(db_session)
create_initial_default_connector(db_session)
associate_default_cc_pair(db_session)
logger.notice("Verifying default standard answer category exists.")
create_initial_default_standard_answer_category(db_session)
logger.notice("Loading LLM providers from env variables")
load_llm_providers(db_session)
logger.notice("Loading default Prompts and Personas")
delete_old_default_personas(db_session)
load_chat_yamls(db_session)
logger.notice("Loading built-in tools")
load_builtin_tools(db_session)
refresh_built_in_tools_cache(db_session)
auto_add_search_tool_to_personas(db_session)
async def check_schema_exists(tenant_id: str) -> bool:
get_async_session_context = contextlib.asynccontextmanager(
get_async_session
)
async with get_async_session_context() as session:
result = await session.execute(
text("SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name"),
{"schema_name": tenant_id}
)
return result.scalar() is not None

View File

@ -13,15 +13,44 @@ from ee.danswer.db.api_key import fetch_user_for_api_key
from ee.danswer.db.saml import get_saml_account
from ee.danswer.server.seeding import get_seed_config
from ee.danswer.utils.secrets import extract_hashed_cookie
import jwt
DATA_PLANE_SECRET = "your_shared_secret_key"
EXPECTED_API_KEY = "your_control_plane_api_key"
logger = setup_logger()
def verify_auth_setting() -> None:
# All the Auth flows are valid for EE version
logger.notice(f"Using Auth Type: {AUTH_TYPE.value}")
async def control_plane_dep(request: Request):
auth_header = request.headers.get("Authorization")
api_key = request.headers.get("X-API-KEY")
if api_key != EXPECTED_API_KEY:
logger.warning("Invalid API key")
raise HTTPException(status_code=401, detail="Invalid API key")
if not auth_header or not auth_header.startswith("Bearer "):
logger.warning("Invalid authorization header")
raise HTTPException(status_code=401, detail="Invalid authorization header")
token = auth_header.split(" ")[1]
try:
payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=["HS256"])
if payload.get("scope") != "tenant:create":
logger.warning("Insufficient permissions")
raise HTTPException(status_code=403, detail="Insufficient permissions")
except jwt.ExpiredSignatureError:
logger.warning("Token has expired")
raise HTTPException(status_code=401, detail="Token has expired")
except jwt.InvalidTokenError:
logger.warning("Invalid token")
raise HTTPException(status_code=401, detail="Invalid token")
async def optional_user_(
request: Request,
user: User | None,
@ -44,6 +73,7 @@ async def optional_user_(
return user
def api_key_dep(
request: Request, db_session: Session = Depends(get_session)
) -> User | None:
@ -63,6 +93,7 @@ def api_key_dep(
return user
def get_default_admin_user_emails_() -> list[str]:
seed_config = get_seed_config()
if seed_config and seed_config.admin_user_emails: