mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-27 12:29:41 +02:00
Add tenant context (#2596)
* add proper tenant context to background tasks * update for new session logic * remove unnecessary functions * add additional tenant context * update ports * proper format / directory structure * update ports * ensure tenant context properly passed to ee bg tasks * add user provisioning * nit * validated for multi tenant * auth * nit * nit * nit * nit * validate pruning * evaluate integration tests * at long last, validated celery beat * nit: minor edge case patched * minor * validate update * nit
This commit is contained in:
@@ -1,12 +1,12 @@
|
||||
from datetime import timedelta
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.celery_app import celery_app
|
||||
from danswer.background.task_utils import build_celery_task_wrapper
|
||||
from danswer.background.update import get_all_tenant_ids
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.db.chat import delete_chat_sessions_older_than
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.server.settings.store import load_settings
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -32,6 +32,7 @@ from ee.danswer.external_permissions.permission_sync import (
|
||||
run_external_group_permission_sync,
|
||||
)
|
||||
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -41,22 +42,26 @@ global_version.set_ee()
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
def sync_external_group_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
|
||||
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -67,16 +72,16 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
|
||||
name="check_sync_external_doc_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_doc_permissions_task() -> None:
|
||||
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_doc_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_doc_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
@@ -84,16 +89,16 @@ def check_sync_external_doc_permissions_task() -> None:
|
||||
name="check_sync_external_group_permissions_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_sync_external_group_permissions_task() -> None:
|
||||
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
|
||||
"""Runs periodically to sync external group permissions"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
if should_perform_external_group_permissions_check(
|
||||
cc_pair=cc_pair, db_session=db_session
|
||||
):
|
||||
sync_external_group_permissions_task.apply_async(
|
||||
kwargs=dict(cc_pair_id=cc_pair.id),
|
||||
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
@@ -101,25 +106,33 @@ def check_sync_external_group_permissions_task() -> None:
|
||||
name="check_ttl_management_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task() -> None:
|
||||
def check_ttl_management_task(tenant_id: str | None) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(retention_limit_days=retention_limit_days),
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
name="autogenerate_usage_report_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task() -> None:
|
||||
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
@@ -130,22 +143,48 @@ def autogenerate_usage_report_task() -> None:
|
||||
#####
|
||||
# Celery Beat (Periodic Tasks) Settings
|
||||
#####
|
||||
celery_app.conf.beat_schedule = {
|
||||
"sync-external-doc-permissions": {
|
||||
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "sync-external-doc-permissions",
|
||||
"task": "check_sync_external_doc_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"sync-external-group-permissions": {
|
||||
{
|
||||
"name": "sync-external-group-permissions",
|
||||
"task": "check_sync_external_group_permissions_task",
|
||||
"schedule": timedelta(seconds=5), # TODO: optimize this
|
||||
},
|
||||
"autogenerate_usage_report": {
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": "autogenerate_usage_report_task",
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
"check-ttl-management": {
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": "check_ttl_management_task",
|
||||
"schedule": timedelta(hours=1),
|
||||
},
|
||||
**(celery_app.conf.beat_schedule or {}),
|
||||
}
|
||||
]
|
||||
|
||||
# Build the celery beat schedule dynamically
|
||||
beat_schedule = {}
|
||||
|
||||
for tenant_id in tenant_ids:
|
||||
for task in tasks_to_schedule:
|
||||
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
|
||||
beat_schedule[task_name] = {
|
||||
"task": task["task"],
|
||||
"schedule": task["schedule"],
|
||||
"args": (tenant_id,), # Must pass tenant_id as an argument
|
||||
}
|
||||
|
||||
# Include any existing beat schedules
|
||||
existing_beat_schedule = celery_app.conf.beat_schedule or {}
|
||||
beat_schedule.update(existing_beat_schedule)
|
||||
|
||||
# Update the Celery app configuration
|
||||
celery_app.conf.beat_schedule = beat_schedule
|
||||
|
@@ -2,9 +2,13 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
|
||||
return f"chat_ttl_{retention_limit_days}_days"
|
||||
|
||||
|
||||
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
|
||||
def name_sync_external_doc_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_doc_permissions_task__{cc_pair_id}"
|
||||
|
||||
|
||||
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
|
||||
def name_sync_external_group_permissions_task(
|
||||
cc_pair_id: int, tenant_id: str | None = None
|
||||
) -> str:
|
||||
return f"sync_external_group_permissions_task__{cc_pair_id}"
|
||||
|
@@ -4,6 +4,7 @@ from httpx_oauth.clients.openid import OpenID
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
@@ -24,6 +25,7 @@ from ee.danswer.server.enterprise_settings.api import (
|
||||
basic_router as enterprise_settings_router,
|
||||
)
|
||||
from ee.danswer.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.danswer.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -53,6 +55,9 @@ def get_application() -> FastAPI:
|
||||
|
||||
application = get_application_base()
|
||||
|
||||
if MULTI_TENANT:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
include_router_with_global_prefix_prepended(
|
||||
application,
|
||||
|
60
backend/ee/danswer/server/middleware/tenant_tracking.py
Normal file
60
backend/ee/danswer/server/middleware/tenant_tracking.py
Normal file
@@ -0,0 +1,60 @@
|
||||
import logging
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
|
||||
import jwt
|
||||
from fastapi import FastAPI
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from danswer.configs.app_configs import MULTI_TENANT
|
||||
from danswer.configs.app_configs import SECRET_JWT_KEY
|
||||
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
|
||||
from danswer.db.engine import is_valid_schema_name
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
|
||||
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
|
||||
@app.middleware("http")
|
||||
async def set_tenant_id(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
try:
|
||||
logger.info(f"Request route: {request.url.path}")
|
||||
|
||||
if not MULTI_TENANT:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
else:
|
||||
token = request.cookies.get("tenant_details")
|
||||
if token:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, SECRET_JWT_KEY, algorithms=["HS256"]
|
||||
)
|
||||
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if not is_valid_schema_name(tenant_id):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Invalid tenant ID format"
|
||||
)
|
||||
except jwt.InvalidTokenError:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Internal server error"
|
||||
)
|
||||
else:
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
current_tenant_id.set(tenant_id)
|
||||
logger.info(f"Middleware set current_tenant_id to: {tenant_id}")
|
||||
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||
raise
|
@@ -8,8 +8,11 @@ from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.server.tenants.models import CreateTenantRequest
|
||||
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
|
||||
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
|
||||
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
|
||||
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
|
||||
from shared_configs.configs import current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
@@ -19,9 +22,15 @@ router = APIRouter(prefix="/tenants")
|
||||
def create_tenant(
|
||||
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
tenant_id = create_tenant_request.tenant_id
|
||||
tenant_id = create_tenant_request.tenant_id
|
||||
email = create_tenant_request.initial_admin_email
|
||||
token = None
|
||||
if user_owns_a_tenant(email):
|
||||
raise HTTPException(
|
||||
status_code=409, detail="User already belongs to an organization"
|
||||
)
|
||||
|
||||
try:
|
||||
if not MULTI_TENANT:
|
||||
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
|
||||
|
||||
@@ -31,10 +40,14 @@ def create_tenant(
|
||||
logger.info(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
run_alembic_migrations(tenant_id)
|
||||
token = current_tenant_id.set(tenant_id)
|
||||
print("getting session", tenant_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
setup_danswer(db_session)
|
||||
|
||||
logger.info(f"Tenant {tenant_id} created successfully")
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Tenant {tenant_id} created successfully",
|
||||
@@ -44,3 +57,6 @@ def create_tenant(
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
current_tenant_id.reset(token)
|
||||
|
@@ -8,7 +8,9 @@ from sqlalchemy.schema import CreateSchema
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
from danswer.db.engine import build_connection_string
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.models import UserTenantMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -61,3 +63,48 @@ def ensure_schema_exists(tenant_id: str) -> bool:
|
||||
db_session.execute(stmt)
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
# For now, we're implementing a primitive mapping between users and tenants.
|
||||
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant("public") as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
.first()
|
||||
)
|
||||
return result is not None
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant("public") as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant("public") as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email.in_(emails),
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
for mapping in mappings_to_delete:
|
||||
db_session.delete(mapping)
|
||||
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
db_session.rollback()
|
||||
|
Reference in New Issue
Block a user