Add tenant context (#2596)

* add proper tenant context to background tasks

* update for new session logic

* remove unnecessary functions

* add additional tenant context

* update ports

* proper format / directory structure

* update ports

* ensure tenant context properly passed to ee bg tasks

* add user provisioning

* nit

* validated for multi tenant

* auth

* nit

* nit

* nit

* nit

* validate pruning

* evaluate integration tests

* at long last, validated celery beat

* nit: minor edge case patched

* minor

* validate update

* nit
This commit is contained in:
pablodanswer
2024-10-10 09:34:32 -07:00
committed by GitHub
parent 9be54a2b4c
commit f40c5ca9bd
52 changed files with 1319 additions and 389 deletions

View File

@@ -1,12 +1,12 @@
from datetime import timedelta
from sqlalchemy.orm import Session
from danswer.background.celery.celery_app import celery_app
from danswer.background.task_utils import build_celery_task_wrapper
from danswer.background.update import get_all_tenant_ids
from danswer.configs.app_configs import JOB_TIMEOUT
from danswer.configs.app_configs import MULTI_TENANT
from danswer.db.chat import delete_chat_sessions_older_than
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.engine import get_session_with_tenant
from danswer.server.settings.store import load_settings
from danswer.utils.logger import setup_logger
from danswer.utils.variable_functionality import global_version
@@ -32,6 +32,7 @@ from ee.danswer.external_permissions.permission_sync import (
run_external_group_permission_sync,
)
from ee.danswer.server.reporting.usage_export_generation import create_new_usage_report
from shared_configs.configs import current_tenant_id
logger = setup_logger()
@@ -41,22 +42,26 @@ global_version.set_ee()
@build_celery_task_wrapper(name_sync_external_doc_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_doc_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def sync_external_doc_permissions_task(cc_pair_id: int, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_doc_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_sync_external_group_permissions_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def sync_external_group_permissions_task(cc_pair_id: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
run_external_group_permission_sync(db_session=db_session, cc_pair_id=cc_pair_id)
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int) -> None:
with Session(get_sqlalchemy_engine()) as db_session:
def perform_ttl_management_task(
retention_limit_days: int, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -67,16 +72,16 @@ def perform_ttl_management_task(retention_limit_days: int) -> None:
name="check_sync_external_doc_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_doc_permissions_task() -> None:
def check_sync_external_doc_permissions_task(tenant_id: str | None) -> None:
"""Runs periodically to sync external permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_doc_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_doc_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)
@@ -84,16 +89,16 @@ def check_sync_external_doc_permissions_task() -> None:
name="check_sync_external_group_permissions_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_sync_external_group_permissions_task() -> None:
def check_sync_external_group_permissions_task(tenant_id: str | None) -> None:
"""Runs periodically to sync external group permissions"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
if should_perform_external_group_permissions_check(
cc_pair=cc_pair, db_session=db_session
):
sync_external_group_permissions_task.apply_async(
kwargs=dict(cc_pair_id=cc_pair.id),
kwargs=dict(cc_pair_id=cc_pair.id, tenant_id=tenant_id),
)
@@ -101,25 +106,33 @@ def check_sync_external_group_permissions_task() -> None:
name="check_ttl_management_task",
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task() -> None:
def check_ttl_management_task(tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = current_tenant_id.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(retention_limit_days=retention_limit_days),
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
current_tenant_id.reset(token)
@celery_app.task(
name="autogenerate_usage_report_task",
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task() -> None:
def autogenerate_usage_report_task(tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with Session(get_sqlalchemy_engine()) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,
@@ -130,22 +143,48 @@ def autogenerate_usage_report_task() -> None:
#####
# Celery Beat (Periodic Tasks) Settings
#####
celery_app.conf.beat_schedule = {
"sync-external-doc-permissions": {
tenant_ids = get_all_tenant_ids()
tasks_to_schedule = [
{
"name": "sync-external-doc-permissions",
"task": "check_sync_external_doc_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"sync-external-group-permissions": {
{
"name": "sync-external-group-permissions",
"task": "check_sync_external_group_permissions_task",
"schedule": timedelta(seconds=5), # TODO: optimize this
},
"autogenerate_usage_report": {
{
"name": "autogenerate_usage_report",
"task": "autogenerate_usage_report_task",
"schedule": timedelta(days=30), # TODO: change this to config flag
},
"check-ttl-management": {
{
"name": "check-ttl-management",
"task": "check_ttl_management_task",
"schedule": timedelta(hours=1),
},
**(celery_app.conf.beat_schedule or {}),
}
]
# Build the celery beat schedule dynamically
beat_schedule = {}
for tenant_id in tenant_ids:
for task in tasks_to_schedule:
task_name = f"{task['name']}-{tenant_id}" # Unique name for each scheduled task
beat_schedule[task_name] = {
"task": task["task"],
"schedule": task["schedule"],
"args": (tenant_id,), # Must pass tenant_id as an argument
}
# Include any existing beat schedules
existing_beat_schedule = celery_app.conf.beat_schedule or {}
beat_schedule.update(existing_beat_schedule)
# Update the Celery app configuration
celery_app.conf.beat_schedule = beat_schedule

View File

@@ -2,9 +2,13 @@ def name_chat_ttl_task(retention_limit_days: int) -> str:
return f"chat_ttl_{retention_limit_days}_days"
def name_sync_external_doc_permissions_task(cc_pair_id: int) -> str:
def name_sync_external_doc_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_doc_permissions_task__{cc_pair_id}"
def name_sync_external_group_permissions_task(cc_pair_id: int) -> str:
def name_sync_external_group_permissions_task(
cc_pair_id: int, tenant_id: str | None = None
) -> str:
return f"sync_external_group_permissions_task__{cc_pair_id}"

View File

@@ -4,6 +4,7 @@ from httpx_oauth.clients.openid import OpenID
from danswer.auth.users import auth_backend
from danswer.auth.users import fastapi_users
from danswer.configs.app_configs import AUTH_TYPE
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import OAUTH_CLIENT_ID
from danswer.configs.app_configs import OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import USER_AUTH_SECRET
@@ -24,6 +25,7 @@ from ee.danswer.server.enterprise_settings.api import (
basic_router as enterprise_settings_router,
)
from ee.danswer.server.manage.standard_answer import router as standard_answer_router
from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.danswer.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -53,6 +55,9 @@ def get_application() -> FastAPI:
application = get_application_base()
if MULTI_TENANT:
add_tenant_id_middleware(application, logger)
if AUTH_TYPE == AuthType.OIDC:
include_router_with_global_prefix_prepended(
application,

View File

@@ -0,0 +1,60 @@
import logging
from collections.abc import Awaitable
from collections.abc import Callable
import jwt
from fastapi import FastAPI
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from danswer.configs.app_configs import MULTI_TENANT
from danswer.configs.app_configs import SECRET_JWT_KEY
from danswer.configs.constants import POSTGRES_DEFAULT_SCHEMA
from danswer.db.engine import is_valid_schema_name
from shared_configs.configs import current_tenant_id
def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> None:
@app.middleware("http")
async def set_tenant_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
try:
logger.info(f"Request route: {request.url.path}")
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
else:
token = request.cookies.get("tenant_details")
if token:
try:
payload = jwt.decode(
token, SECRET_JWT_KEY, algorithms=["HS256"]
)
tenant_id = payload.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
except jwt.InvalidTokenError:
tenant_id = POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(
f"Unexpected error in set_tenant_id_middleware: {str(e)}"
)
raise HTTPException(
status_code=500, detail="Internal server error"
)
else:
tenant_id = POSTGRES_DEFAULT_SCHEMA
current_tenant_id.set(tenant_id)
logger.info(f"Middleware set current_tenant_id to: {tenant_id}")
response = await call_next(request)
return response
except Exception as e:
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise

View File

@@ -8,8 +8,11 @@ from danswer.db.engine import get_session_with_tenant
from danswer.setup import setup_danswer
from danswer.utils.logger import setup_logger
from ee.danswer.server.tenants.models import CreateTenantRequest
from ee.danswer.server.tenants.provisioning import add_users_to_tenant
from ee.danswer.server.tenants.provisioning import ensure_schema_exists
from ee.danswer.server.tenants.provisioning import run_alembic_migrations
from ee.danswer.server.tenants.provisioning import user_owns_a_tenant
from shared_configs.configs import current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@@ -19,9 +22,15 @@ router = APIRouter(prefix="/tenants")
def create_tenant(
create_tenant_request: CreateTenantRequest, _: None = Depends(control_plane_dep)
) -> dict[str, str]:
try:
tenant_id = create_tenant_request.tenant_id
tenant_id = create_tenant_request.tenant_id
email = create_tenant_request.initial_admin_email
token = None
if user_owns_a_tenant(email):
raise HTTPException(
status_code=409, detail="User already belongs to an organization"
)
try:
if not MULTI_TENANT:
raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled")
@@ -31,10 +40,14 @@ def create_tenant(
logger.info(f"Schema already exists for tenant {tenant_id}")
run_alembic_migrations(tenant_id)
token = current_tenant_id.set(tenant_id)
print("getting session", tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
setup_danswer(db_session)
logger.info(f"Tenant {tenant_id} created successfully")
add_users_to_tenant([email], tenant_id)
return {
"status": "success",
"message": f"Tenant {tenant_id} created successfully",
@@ -44,3 +57,6 @@ def create_tenant(
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
current_tenant_id.reset(token)

View File

@@ -8,7 +8,9 @@ from sqlalchemy.schema import CreateSchema
from alembic import command
from alembic.config import Config
from danswer.db.engine import build_connection_string
from danswer.db.engine import get_session_with_tenant
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.models import UserTenantMapping
from danswer.utils.logger import setup_logger
logger = setup_logger()
@@ -61,3 +63,48 @@ def ensure_schema_exists(tenant_id: str) -> bool:
db_session.execute(stmt)
return True
return False
# For now, we're implementing a primitive mapping between users and tenants.
# This function is only used to determine a user's relationship to a tenant upon creation (implying ownership).
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant("public") as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
.first()
)
return result is not None
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant("public") as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception as e:
logger.exception(f"Failed to add users to tenant {tenant_id}: {str(e)}")
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant("public") as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email.in_(emails),
UserTenantMapping.tenant_id == tenant_id,
)
.all()
)
for mapping in mappings_to_delete:
db_session.delete(mapping)
db_session.commit()
except Exception as e:
logger.exception(
f"Failed to remove users from tenant {tenant_id}: {str(e)}"
)
db_session.rollback()